diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -38,6 +38,20 @@ } }; +class GetSizeOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(GetSizeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + GetSizeOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(op.getOperation(), transformed.shape(), + 0); + return success(); + } +}; + class IndexToSizeOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -133,6 +147,7 @@ BinaryOpConversion, BinaryOpConversion, ConstSizeOpConverter, + GetSizeOpConverter, IndexToSizeOpConversion, SizeToIndexOpConversion>(ctx); // clang-format on diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -9,4 +9,3 @@ def ToExtentTensorOpConversion : Pat< (Shape_ToExtentTensorOp $input), (replaceWithValue $input)>; - diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -86,3 +86,17 @@ } // CHECK: %[[C1:.*]] = constant 1 : index // CHECK: return %[[C1]] : index + +// ----- + +// Convert `get_size` to `dim` of the first dimension. +// CHECK-LABEL: @get_size +// CHECK-SAME: (%[[SHAPE:.*]]: tensor) -> index +func @get_size(%shape : !shape.shape) -> !shape.size { + // CHECK-DAG: %[[C0:.*]] = constant 0 : index + // CHECK-DAG: %[[RESULT:.*]] = dim %[[SHAPE]], %[[C0]] + // CHECK-DAG: return %[[RESULT]] : index + %size = shape.get_size %shape + return %size : !shape.size +} +