diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1504,6 +1504,13 @@ let arguments = (ins Variadic:$dynamicExtents); let results = (outs AnyRankedTensor:$result); let regions = (region SizedRegion<1>:$body); + + let builders = [ + // Build op and populate its body per callback function. + OpBuilder<"OpBuilder &b, OperationState &result, Type resultTy, " + "ValueRange dynamicExtents, " + "function_ref">, + ]; } //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -422,6 +422,7 @@ return failure(); // For ranked tensor arguments, lower to `tensor_from_elements`. + auto loc = op.getLoc(); ShapeOfOp::Adaptor transformed(operands); Value tensor = transformed.arg(); Type tensorTy = tensor.getType(); @@ -431,7 +432,6 @@ SmallVector extentValues; RankedTensorType rankedTensorTy = tensorTy.cast(); int64_t rank = rankedTensorTy.getRank(); - auto loc = op.getLoc(); for (int64_t i = 0; i < rank; i++) { if (rankedTensorTy.isDynamicDim(i)) { Value extent = rewriter.create(loc, tensor, i); @@ -451,26 +451,17 @@ return success(); } - // Allocate stack memory. - auto loc = op.getLoc(); + // Lower to `dynamic_tensor_from_elements` otherwise. + auto *ctx = rewriter.getContext(); Value rank = rewriter.create(loc, tensor); - Type indexTy = rewriter.getIndexType(); - Type memTy = MemRefType::get({ShapedType::kDynamicSize}, indexTy); - Value mem = rewriter.create(loc, memTy, ValueRange{rank}); - - // Copy shape extents to stack-allocated memory. - Value zero = rewriter.create(loc, 0); - Value one = rewriter.create(loc, 1); - rewriter.create( - loc, zero, rank, one, llvm::None, - [&](OpBuilder &b, Location loc, Value iv, ValueRange args) { - Value dim = rewriter.create(loc, tensor, iv); - rewriter.create(loc, dim, mem, ValueRange{iv}); - rewriter.create(loc); + rewriter.replaceOpWithNewOp( + op, getExtentTensorType(ctx), ValueRange{rank}, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value dim = args.front(); + Value extent = b.create(loc, tensor, dim); + b.create(loc, extent); }); - // Load extents to tensor value. - rewriter.replaceOpWithNewOp(op.getOperation(), mem); return success(); } diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1694,6 +1694,22 @@ return success(); } +void DynamicTensorFromElementsOp::build( + OpBuilder &b, OperationState &result, Type resultTy, + ValueRange dynamicExtents, + function_ref bodyBuilder) { + build(b, result, resultTy, dynamicExtents); + + // Build and populate body. + OpBuilder::InsertionGuard guard(b); + Region *bodyRegion = result.regions.front().get(); + auto rank = resultTy.cast().getRank(); + SmallVector argumentTypes(rank, b.getIndexType()); + Block *bodyBlock = + b.createBlock(bodyRegion, bodyRegion->end(), argumentTypes); + bodyBuilder(b, result.location, bodyBlock->getArguments()); +} + //===----------------------------------------------------------------------===// // ExtractElementOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -191,14 +191,11 @@ // CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) func @shape_of_unranked(%arg : tensor<*xf32>) { // CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> - // CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref - // CHECK: %[[C0:.*]] = constant 0 : index - // CHECK: %[[C1:.*]] = constant 1 : index - // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] { - // CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32> - // CHECK: store %[[DIM]], %[[SHAPE_MEM]][%[[I]]] : memref - // CHECK: } - // CHECK: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref + // CHECK: %[[SHAPE:.*]] = dynamic_tensor_from_elements %[[RANK]] { + // CHECK: ^bb0(%[[I:.*]]: index): + // CHECK: %[[EXTENT:.*]] = dim %[[ARG]], %[[I]] : tensor<*xf32> + // CHECK: yield %[[EXTENT]] : index + // CHECK: } : tensor %shape = shape.shape_of %arg : tensor<*xf32> -> tensor return }