diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -38,32 +38,6 @@ } }; -class IndexToSizeOpConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(IndexToSizeOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - IndexToSizeOp::Adaptor transformed(operands); - rewriter.replaceOp(op.getOperation(), transformed.arg()); - return success(); - } -}; - -class SizeToIndexOpConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(SizeToIndexOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - SizeToIndexOp::Adaptor transformed(operands); - rewriter.replaceOp(op.getOperation(), transformed.arg()); - return success(); - } -}; - class ConstSizeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -132,9 +106,7 @@ patterns.insert< BinaryOpConversion, BinaryOpConversion, - ConstSizeOpConverter, - IndexToSizeOpConversion, - SizeToIndexOpConversion>(ctx); + ConstSizeOpConverter>(ctx); // clang-format on } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -10,3 +10,12 @@ (Shape_ToExtentTensorOp $input), (replaceWithValue $input)>; +// Convert `index_to_size` and `size_to_index` to no-ops as sizes will be +// represented as indices. +def IndexToSizeOpConversion : Pat< + (Shape_IndexToSizeOp $arg), + (replaceWithValue $arg)>; +def SizeToIndexOpConversion : Pat< + (Shape_SizeToIndexOp $arg), + (replaceWithValue $arg)>; +