diff --git a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td --- a/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td +++ b/mlir/include/mlir/Dialect/StandardOps/IR/Ops.td @@ -1408,7 +1408,9 @@ let builders = [ OpBuilder<"OpBuilder &builder, OperationState &result, " - "Value memrefOrTensor, int64_t index"> + "Value memrefOrTensor, int64_t index">, + OpBuilder<"OpBuilder &builder, OperationState &result, " + "Value memrefOrTensor, Value index"> ]; let extraClassDeclaration = [{ diff --git a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp --- a/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp +++ b/mlir/lib/Conversion/ShapeToSCF/ShapeToSCF.cpp @@ -69,6 +69,54 @@ return success(); } +namespace { +/// Converts `shape_of` to for loop for unranked tensors. +class ShapeOfOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult +ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + ShapeOfOp::Adaptor transformed(operands); + auto tensorVal = transformed.arg(); + auto tensorTy = tensorVal.getType(); + + // For ranked tensors `shape_of` lowers to `std` and the pattern can be + // found in the corresponding pass. + if (tensorTy.isa()) + return failure(); + + // Allocate stack memory. + auto loc = op.getLoc(); + auto rankVal = rewriter.create(loc, tensorVal); + auto i64Ty = rewriter.getI64Type(); + auto memTy = MemRefType::get({ShapedType::kDynamicSize}, i64Ty); + auto memVal = rewriter.create(loc, memTy, ValueRange({rankVal})); + + // Copy shape extents to stack-allocated memory. + auto zeroVal = rewriter.create(loc, 0); + auto oneVal = rewriter.create(loc, 1); + rewriter.create( + loc, zeroVal, rankVal, oneVal, ValueRange(), + [&](OpBuilder &b, Location loc, Value iVal, ValueRange args) { + auto dimVal = b.create(loc, tensorVal, iVal); + auto dimIntVal = b.create(loc, dimVal, i64Ty); + b.create(loc, dimIntVal, memVal, ValueRange({iVal})); + b.create(loc); + }); + + // Load extents to tensor value. + rewriter.replaceOpWithNewOp(op.getOperation(), memVal); + return success(); +} + namespace { struct ConvertShapeToSCFPass : public ConvertShapeToSCFBase { @@ -79,19 +127,23 @@ void ConvertShapeToSCFPass::runOnFunction() { MLIRContext &ctx = getContext(); + // Populate conversion patterns. OwningRewritePatternList patterns; populateShapeToSCFConversionPatterns(patterns, &ctx); + // Setup target legality. ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); - if (failed(mlir::applyPartialConversion(getFunction(), target, patterns))) + target.addIllegalOp(); + + // Apply conversion. + if (failed(applyPartialConversion(getFunction(), target, patterns))) signalPassFailure(); } void mlir::populateShapeToSCFConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); + patterns.insert(ctx); } std::unique_ptr mlir::createConvertShapeToSCFPass() { diff --git a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp --- a/mlir/lib/Dialect/StandardOps/IR/Ops.cpp +++ b/mlir/lib/Dialect/StandardOps/IR/Ops.cpp @@ -1273,8 +1273,13 @@ Value memrefOrTensor, int64_t index) { auto loc = result.location; Value indexValue = builder.create(loc, index); + build(builder, result, memrefOrTensor, indexValue); +} + +void DimOp::build(OpBuilder &builder, OperationState &result, + Value memrefOrTensor, Value index) { auto indexTy = builder.getIndexType(); - build(builder, result, indexTy, memrefOrTensor, indexValue); + build(builder, result, indexTy, memrefOrTensor, index); } Optional DimOp::getConstantIndex() { diff --git a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir --- a/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir +++ b/mlir/test/Conversion/ShapeToSCF/shape-to-scf.mlir @@ -26,3 +26,24 @@ // CHECK-NEXT: scf.yield [[NEW_ACC]] : !shape.size // CHECK-NEXT: } // CHECK-NEXT: return [[RESULT]] : !shape.size + +// ----- + +// Lower `shape_of` for unranked tensors. +// CHECK-LABEL: @shape_of_unranked +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +func @shape_of_unranked(%arg : tensor<*xf32>) { + // CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32> + // CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[C1:.*]] = constant 1 : index + // CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] { + // CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32> + // CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64 + // CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref + // CHECK: } + // CHECK-DAG: %[[SHAPE:.*]] = tensor_load %[[SHAPE_MEM]] : memref + %shape = shape.shape_of %arg : tensor<*xf32> + return +} +