diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -49,6 +49,23 @@ } }; +class GetExtentOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::GetExtentOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + shape::GetExtentOpOperandAdaptor transformed(operands); + auto shapeOfOp = + cast(transformed.shape().getDefiningOp()); + rewriter.replaceOpWithNewOp(op.getOperation(), + rewriter.getIndexType(), shapeOfOp.arg(), + transformed.dim()); + return success(); + } +}; + class IndexToSizeOpConversion : public OpConversionPattern { public: @@ -63,6 +80,21 @@ } }; +class ShapeOfOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::ShapeOfOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + // Simply delete `shape_of`. + // The lowering of all operations that depends on a `!shape.shape` ensure + // correct lowering. + rewriter.eraseOp(op.getOperation()); + return success(); + } +}; + class SizeToIndexOpConversion : public OpConversionPattern { public: @@ -147,7 +179,9 @@ BinaryOpConversion, BinaryOpConversion, FromExtentTensorOpConversion, + GetExtentOpConversion, IndexToSizeOpConversion, + ShapeOfOpConversion, SizeToIndexOpConversion, ToExtentTensorOpConversion>(ctx); // clang-format on diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -75,3 +75,27 @@ // CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index return } + +// ----- + +// Lower `get_extent` such that is does not rely on the result of `shape_of`. +// CHECK-LABEL: @get_extent +// CHECK-SAME: (%[[ARG:.*]]: tensor<2x?xf32>, %[[IDX:.*]]: index) -> index +func @get_extent(%arg : tensor<2x?xf32>, %idx : !shape.size) -> !shape.size { + // CHECK-NEXT: %[[RESULT:.*]] = dim %[[ARG]], %[[IDX]] : tensor<2x?xf32> + // CHECK-NEXT: return %[[RESULT]] + %shape = shape.shape_of %arg : tensor<2x?xf32> + %result = shape.get_extent %shape, %idx + return %result : !shape.size +} + +// ----- + +// Erase `shape_of` operation. +// CHECK-LABEL: @shape_of +func @shape_of(%arg : tensor<2x?xf32>) { + // CHECK-NOT shape_of + %shape = shape.shape_of %arg : tensor<2x?xf32> + return +} +