diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -312,22 +312,29 @@ NoSideEffect, TypesMatchWith<"operand types match result element type", "result", "elements", "SmallVector(" - "$_self.cast().getDimSize(0), " + "$_self.cast().getNumElements(), " "$_self.cast().getElementType())"> ]> { string summary = "tensor from elements operation."; string description = [{ - Create a 1D tensor from a range of same-type arguments. + Create a N-D tensor from a range of same-type arguments. The number of + provided `elements` should equal to the number of the elements in the + result type. The `elements` correspond to a flattened tensor. Example: ```mlir - tensor.from_elements i_1, ..., i_N : tensor + tensor.from_elements %a, %b, %c, %d, %e, %f : tensor<2x3xindex> ``` + + will result in a tensor + + [[%a, %b, %c] + [%d, %e, %f]] }]; let arguments = (ins Variadic:$elements); - let results = (outs 1DTensorOf<[AnyType]>:$result); + let results = (outs AnyStaticShapeTensor:$result); let assemblyFormat = "$elements attr-dict `:` type($result)"; @@ -336,7 +343,7 @@ let skipDefaultBuilders = 1; let builders = [ - OpBuilder<(ins "Type":$elementType, "ValueRange":$elements)>, + OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>, // Special case builder for when `elements` has size >=1. OpBuilder<(ins "ValueRange":$elements)> ]; diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -193,10 +193,10 @@ extentOperands.push_back( rewriter.create(loc, extent.getLimitedValue())); } - Type indexTy = rewriter.getIndexType(); + Type resultTy = + RankedTensorType::get({op.getShape().size()}, rewriter.getIndexType()); Value tensor = - rewriter.create(loc, indexTy, extentOperands); - Type resultTy = RankedTensorType::get({op.getShape().size()}, indexTy); + rewriter.create(loc, resultTy, extentOperands); rewriter.replaceOpWithNewOp(op, resultTy, tensor); return success(); } @@ -569,7 +569,8 @@ // Materialize extent tensor. Value staticExtentTensor = rewriter.create( - loc, rewriter.getIndexType(), extentValues); + loc, RankedTensorType::get({rank}, rewriter.getIndexType()), + extentValues); rewriter.replaceOpWithNewOp(op, op.getType(), staticExtentTensor); return success(); diff --git a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Detensorize.cpp @@ -28,8 +28,8 @@ // A detensored value is converted back by creating a new tensor from its // element(s). - auto createNewTensorOp = builder.create( - loc, inputs[0].getType(), inputs[0]); + auto createNewTensorOp = + builder.create(loc, inputs[0]); // FromElementsOp results in a tensor<1xdtype>, we need to reshape that to // a tensor instead. diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -364,17 +364,17 @@ //===----------------------------------------------------------------------===// void FromElementsOp::build(OpBuilder &builder, OperationState &result, - Type elementType, ValueRange elements) { - Type resultTy = RankedTensorType::get({static_cast(elements.size())}, - elementType); + Type resultType, ValueRange elements) { result.addOperands(elements); - result.addTypes(resultTy); + result.addTypes(resultType); } void FromElementsOp::build(OpBuilder &builder, OperationState &result, ValueRange elements) { assert(!elements.empty() && "expected at least one element"); - build(builder, result, elements.front().getType(), elements); + Type resultType = RankedTensorType::get( + {static_cast(elements.size())}, elements.front().getType()); + build(builder, result, resultType, elements); } OpFoldResult FromElementsOp::fold(ArrayRef operands) { @@ -397,23 +397,27 @@ LogicalResult matchAndRewrite(tensor::ExtractOp extract, PatternRewriter &rewriter) const final { - if (extract.indices().size() != 1) - return failure(); - auto tensorFromElements = extract.tensor().getDefiningOp(); - if (tensorFromElements == nullptr) - return failure(); - - APInt index; - if (!matchPattern(*extract.indices().begin(), m_ConstantInt(&index))) + if (!tensorFromElements) return failure(); + auto tensorType = tensorFromElements.getType().cast(); + auto rank = tensorType.getRank(); + SmallVector indices(rank); + int64_t flatIndex = 0; + int64_t stride = 1; + for (int i = rank - 1; i >= 0; --i) { + APInt index; + if (!matchPattern(extract.indices()[i], m_ConstantInt(&index))) + return failure(); + if (i < rank - 1) + stride *= tensorType.getDimSize(i); + flatIndex += index.getSExtValue() * stride; + } // Prevent out of bounds accesses. This can happen in invalid code that will // never execute. - if (tensorFromElements->getNumOperands() <= index.getZExtValue() || - index.getSExtValue() < 0) + if (tensorFromElements->getNumOperands() <= flatIndex || flatIndex < 0) return failure(); - rewriter.replaceOp(extract, - tensorFromElements.getOperand(index.getZExtValue())); + rewriter.replaceOp(extract, tensorFromElements.getOperand(flatIndex)); return success(); } }; diff --git a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/Bufferize.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/StandardOps/IR/Ops.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Passes.h" +#include "mlir/IR/ImplicitLocOpBuilder.h" #include "mlir/Transforms/DialectConversion.h" using namespace mlir; @@ -65,19 +66,65 @@ LogicalResult matchAndRewrite(tensor::FromElementsOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - int numberOfElements = op.elements().size(); - auto resultType = MemRefType::get( - {numberOfElements}, op.getType().cast().getElementType()); - Value result = rewriter.create(op.getLoc(), resultType); - for (auto element : llvm::enumerate(op.elements())) { - Value index = - rewriter.create(op.getLoc(), element.index()); - rewriter.create(op.getLoc(), element.value(), result, - index); + Location loc = op.getLoc(); + auto tensorType = op.getType().cast(); + auto shape = tensorType.getShape(); + + // Allocate a buffer for the result. + auto resultType = + MemRefType::get(tensorType.getShape(), tensorType.getElementType()); + Value buffer = rewriter.create(loc, resultType); + + // Case: tensor<0xelem_type>. + if (op.elements().empty()) { + rewriter.replaceOp(op, {buffer}); + return success(); } - rewriter.replaceOp(op, {result}); + + // Case: tensor. + if (shape.empty()) { + rewriter.create(loc, op.elements().front(), buffer); + rewriter.replaceOp(op, {buffer}); + return success(); + } + + // Create constants for the range of possible indices [0, max{shape_i}). + auto maxDim = *std::max_element(shape.begin(), shape.end()); + SmallVector constants; + constants.reserve(maxDim); + for (int i = 0; i < maxDim; ++i) + constants.push_back(rewriter.create(loc, i)); + + // Traverse all `elements` and create `memref.store` ops. + ImplicitLocOpBuilder b(loc, rewriter); + auto element_it = adaptor.elements().begin(); + SmallVector indices(tensorType.getRank(), constants[0]); + CreateStores(/*dim=*/0, buffer, shape, constants, element_it, indices, b); + + rewriter.replaceOp(op, {buffer}); return success(); } + +private: + // Implements backtracking to traverse indices of the output buffer while + // iterating over op.elements(). + void CreateStores(int dim, Value buffer, ArrayRef shape, + ArrayRef constants, ValueRange::iterator &element_it, + SmallVectorImpl &indices, + ImplicitLocOpBuilder b) const { + if (dim == shape.size() - 1) { + for (int i = 0; i < shape.back(); ++i) { + indices.back() = constants[i]; + b.create(*element_it, buffer, indices); + ++element_it; + } + return; + } + for (int i = 0; i < shape[dim]; ++i) { + indices[dim] = constants[i]; + CreateStores(dim + 1, buffer, shape, constants, element_it, indices, b); + } + } }; struct BufferizeGenerateOp : public OpConversionPattern { diff --git a/mlir/test/Dialect/Tensor/bufferize.mlir b/mlir/test/Dialect/Tensor/bufferize.mlir --- a/mlir/test/Dialect/Tensor/bufferize.mlir +++ b/mlir/test/Dialect/Tensor/bufferize.mlir @@ -65,21 +65,116 @@ return %0 : f32 } -// CHECK-LABEL: func @tensor.from_elements( +// CHECK-LABEL: func @tensor.from_elements_no_elements() -> tensor<0xindex> { +// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<0xindex> +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] +// CHECK: return %[[RET]] : tensor<0xindex> +func @tensor.from_elements_no_elements() -> tensor<0xindex> { + %0 = tensor.from_elements : tensor<0xindex> + return %0 : tensor<0xindex> +} + +// CHECK-LABEL: func @tensor.from_elements_0d( +// CHECK-SAME: %[[ELEM0:.*]]: index) -> tensor { +// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref +// CHECK: store %[[ELEM0]], %[[MEMREF]] +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] +// CHECK: return %[[RET]] : tensor +func @tensor.from_elements_0d(%arg0: index) -> tensor { + %0 = tensor.from_elements %arg0 : tensor + return %0 : tensor +} + +// CHECK-LABEL: func @tensor.from_elements_1d( // CHECK-SAME: %[[ELEM0:.*]]: index, // CHECK-SAME: %[[ELEM1:.*]]: index) -> tensor<2xindex> { -// CHECK: %[[MEMREF:.*]] = memref.alloc() +// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<2xindex> // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] // CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]]] // CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]]] // CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] // CHECK: return %[[RET]] : tensor<2xindex> -func @tensor.from_elements(%arg0: index, %arg1: index) -> tensor<2xindex> { +func @tensor.from_elements_1d(%arg0: index, %arg1: index) -> tensor<2xindex> { %0 = tensor.from_elements %arg0, %arg1 : tensor<2xindex> return %0 : tensor<2xindex> } +// CHECK-LABEL: func @tensor.from_elements_2d( +// CHECK-SAME: %[[ELEM0:.*]]: index, %[[ELEM1:.*]]: index) +// CHECK-SAME: -> tensor<3x2xindex> { +// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2xindex> +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index +// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C0]], %[[C0]]] +// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C0]], %[[C1]]] +// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C1]], %[[C0]]] +// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C1]], %[[C1]]] +// CHECK: store %[[ELEM0]], %[[MEMREF]][%[[C2]], %[[C0]]] +// CHECK: store %[[ELEM1]], %[[MEMREF]][%[[C2]], %[[C1]]] +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] +// CHECK: return %[[RET]] : tensor<3x2xindex> +func @tensor.from_elements_2d(%arg0: index, %arg1: index) -> tensor<3x2xindex> { + %0 = tensor.from_elements %arg0, %arg1, %arg0, %arg1, %arg0, %arg1 + : tensor<3x2xindex> + return %0 : tensor<3x2xindex> +} + +// CHECK-LABEL: func @tensor.from_elements_3d() + +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0 +// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 +// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 +// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 +// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0 +// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0 +// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0 +// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0 +// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0 +// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0 +// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 +// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 + +// CHECK: %[[MEMREF:.*]] = memref.alloc() : memref<3x2x2xf32> + +// CHECK: %[[C0:.*]] = arith.constant 0 : index +// CHECK: %[[C1:.*]] = arith.constant 1 : index +// CHECK: %[[C2:.*]] = arith.constant 2 : index + +// CHECK: store %[[F0]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C0]]] +// CHECK: store %[[F1]], %[[MEMREF]][%[[C0]], %[[C0]], %[[C1]]] +// CHECK: store %[[F2]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C0]]] +// CHECK: store %[[F3]], %[[MEMREF]][%[[C0]], %[[C1]], %[[C1]]] +// CHECK: store %[[F4]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C0]]] +// CHECK: store %[[F5]], %[[MEMREF]][%[[C1]], %[[C0]], %[[C1]]] +// CHECK: store %[[F6]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C0]]] +// CHECK: store %[[F7]], %[[MEMREF]][%[[C1]], %[[C1]], %[[C1]]] +// CHECK: store %[[F8]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C0]]] +// CHECK: store %[[F9]], %[[MEMREF]][%[[C2]], %[[C0]], %[[C1]]] +// CHECK: store %[[F10]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C0]]] +// CHECK: store %[[F11]], %[[MEMREF]][%[[C2]], %[[C1]], %[[C1]]] + +// CHECK: %[[RET:.*]] = bufferization.to_tensor %[[MEMREF]] +// CHECK: return %[[RET]] : tensor<3x2x2xf32> +func @tensor.from_elements_3d() -> tensor<3x2x2xf32> { + %f0 = arith.constant 0.0 : f32 + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + %f4 = arith.constant 4.0 : f32 + %f5 = arith.constant 5.0 : f32 + %f6 = arith.constant 6.0 : f32 + %f7 = arith.constant 7.0 : f32 + %f8 = arith.constant 8.0 : f32 + %f9 = arith.constant 9.0 : f32 + %f10 = arith.constant 10.0 : f32 + %f11 = arith.constant 11.0 : f32 + %0 = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 + : tensor<3x2x2xf32> + return %0 : tensor<3x2x2xf32> +} + // CHECK-LABEL: func @tensor.generate( // CHECK-SAME: %[[ARG:.*]]: tensor<*xf32>, // CHECK-SAME: %[[DYNAMIC_EXTENT:.*]]: index) -> tensor { diff --git a/mlir/test/Dialect/Tensor/canonicalize.mlir b/mlir/test/Dialect/Tensor/canonicalize.mlir --- a/mlir/test/Dialect/Tensor/canonicalize.mlir +++ b/mlir/test/Dialect/Tensor/canonicalize.mlir @@ -135,6 +135,61 @@ // ----- +// CHECK-LABEL: func @extract_from_tensor.from_elements_3d +func @extract_from_tensor.from_elements_3d() + -> (f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32, f32) { + %f0 = arith.constant 0.0 : f32 + %f1 = arith.constant 1.0 : f32 + %f2 = arith.constant 2.0 : f32 + %f3 = arith.constant 3.0 : f32 + %f4 = arith.constant 4.0 : f32 + %f5 = arith.constant 5.0 : f32 + %f6 = arith.constant 6.0 : f32 + %f7 = arith.constant 7.0 : f32 + %f8 = arith.constant 8.0 : f32 + %f9 = arith.constant 9.0 : f32 + %f10 = arith.constant 10.0 : f32 + %f11 = arith.constant 11.0 : f32 + + %tensor = tensor.from_elements %f0,%f1,%f2,%f3,%f4,%f5,%f6,%f7,%f8,%f9,%f10,%f11 + : tensor<3x2x2xf32> + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + + %r0 = tensor.extract %tensor[%c0, %c0, %c0] : tensor<3x2x2xf32> + %r1 = tensor.extract %tensor[%c0, %c0, %c1] : tensor<3x2x2xf32> + %r2 = tensor.extract %tensor[%c0, %c1, %c0] : tensor<3x2x2xf32> + %r3 = tensor.extract %tensor[%c0, %c1, %c1] : tensor<3x2x2xf32> + %r4 = tensor.extract %tensor[%c1, %c0, %c0] : tensor<3x2x2xf32> + %r5 = tensor.extract %tensor[%c1, %c0, %c1] : tensor<3x2x2xf32> + %r6 = tensor.extract %tensor[%c1, %c1, %c0] : tensor<3x2x2xf32> + %r7 = tensor.extract %tensor[%c1, %c1, %c1] : tensor<3x2x2xf32> + %r8 = tensor.extract %tensor[%c2, %c0, %c0] : tensor<3x2x2xf32> + %r9 = tensor.extract %tensor[%c2, %c0, %c1] : tensor<3x2x2xf32> + %r10 = tensor.extract %tensor[%c2, %c1, %c0] : tensor<3x2x2xf32> + %r11 = tensor.extract %tensor[%c2, %c1, %c1] : tensor<3x2x2xf32> + return %r0,%r1,%r2,%r3,%r4,%r5,%r6,%r7,%r8,%r9,%r10,%r11 + : f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32,f32 +} +// CHECK-DAG: %[[F0:.*]] = arith.constant 0.0 +// CHECK-DAG: %[[F1:.*]] = arith.constant 1.0{{0+}}e+00 +// CHECK-DAG: %[[F2:.*]] = arith.constant 2.0 +// CHECK-DAG: %[[F3:.*]] = arith.constant 3.0 +// CHECK-DAG: %[[F4:.*]] = arith.constant 4.0 +// CHECK-DAG: %[[F5:.*]] = arith.constant 5.0 +// CHECK-DAG: %[[F6:.*]] = arith.constant 6.0 +// CHECK-DAG: %[[F7:.*]] = arith.constant 7.0 +// CHECK-DAG: %[[F8:.*]] = arith.constant 8.0 +// CHECK-DAG: %[[F9:.*]] = arith.constant 9.0 +// CHECK-DAG: %[[F10:.*]] = arith.constant 1.0{{0+}}e+01 +// CHECK-DAG: %[[F11:.*]] = arith.constant 1.1{{0+}}e+01 + +// CHECK: return %[[F0]], %[[F1]], %[[F2]], %[[F3]], %[[F4]], %[[F5]], +// CHECK-SAME: %[[F6]], %[[F7]], %[[F8]], %[[F9]], %[[F10]], %[[F11]] + +// ----- + // Ensure the optimization doesn't segfault from bad constants // CHECK-LABEL: func @extract_negative_from_tensor.from_elements func @extract_negative_from_tensor.from_elements(%element : index) -> index { diff --git a/mlir/test/Dialect/Tensor/invalid.mlir b/mlir/test/Dialect/Tensor/invalid.mlir --- a/mlir/test/Dialect/Tensor/invalid.mlir +++ b/mlir/test/Dialect/Tensor/invalid.mlir @@ -33,7 +33,7 @@ // ----- func @tensor.from_elements_wrong_result_type() { - // expected-error@+2 {{'result' must be 1D tensor of any type values, but got 'tensor<*xi32>'}} + // expected-error@+2 {{'result' must be statically shaped tensor of any type values, but got 'tensor<*xi32>'}} %c0 = arith.constant 0 : i32 %0 = tensor.from_elements %c0 : tensor<*xi32> return diff --git a/mlir/test/Dialect/Tensor/ops.mlir b/mlir/test/Dialect/Tensor/ops.mlir --- a/mlir/test/Dialect/Tensor/ops.mlir +++ b/mlir/test/Dialect/Tensor/ops.mlir @@ -38,21 +38,26 @@ // CHECK-LABEL: func @tensor.from_elements() { func @tensor.from_elements() { %c0 = "arith.constant"() {value = 0: index} : () -> index - // CHECK: %0 = tensor.from_elements %c0 : tensor<1xindex> + // CHECK: tensor.from_elements %c0 : tensor<1xindex> %0 = tensor.from_elements %c0 : tensor<1xindex> %c1 = "arith.constant"() {value = 1: index} : () -> index - // CHECK: %1 = tensor.from_elements %c0, %c1 : tensor<2xindex> + // CHECK: tensor.from_elements %c0, %c1 : tensor<2xindex> %1 = tensor.from_elements %c0, %c1 : tensor<2xindex> %c0_f32 = "arith.constant"() {value = 0.0: f32} : () -> f32 // CHECK: [[C0_F32:%.*]] = arith.constant - // CHECK: %2 = tensor.from_elements [[C0_F32]] : tensor<1xf32> + // CHECK: tensor.from_elements [[C0_F32]] : tensor<1xf32> %2 = tensor.from_elements %c0_f32 : tensor<1xf32> // CHECK: tensor.from_elements : tensor<0xindex> %3 = tensor.from_elements : tensor<0xindex> + // CHECK: tensor.from_elements %c0, %c1, %c0, %c1, %c0, %c1 : tensor<2x3xindex> + %4 = tensor.from_elements %c0, %c1, %c0, %c1, %c0, %c1 : tensor<2x3xindex> + + // CHECK: tensor.from_elements %c0 : tensor + %5 = tensor.from_elements %c0 : tensor return } diff --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp --- a/mlir/test/lib/Dialect/Test/TestDialect.cpp +++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp @@ -945,14 +945,14 @@ Location loc = getLoc(); shapes.reserve(operands.size()); for (Value operand : llvm::reverse(operands)) { - auto currShape = llvm::to_vector<4>(llvm::map_range( - llvm::seq( - 0, operand.getType().cast().getRank()), - [&](int64_t dim) -> Value { + auto rank = operand.getType().cast().getRank(); + auto currShape = llvm::to_vector<4>( + llvm::map_range(llvm::seq(0, rank), [&](int64_t dim) -> Value { return builder.createOrFold(loc, operand, dim); })); shapes.push_back(builder.create( - getLoc(), builder.getIndexType(), currShape)); + getLoc(), RankedTensorType::get({rank}, builder.getIndexType()), + currShape)); } return success(); }