diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -37,32 +37,6 @@ } }; -class IndexToSizeOpConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(IndexToSizeOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - IndexToSizeOp::Adaptor transformed(operands); - rewriter.replaceOp(op.getOperation(), transformed.arg()); - return success(); - } -}; - -class SizeToIndexOpConversion : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(SizeToIndexOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - SizeToIndexOp::Adaptor transformed(operands); - rewriter.replaceOp(op.getOperation(), transformed.arg()); - return success(); - } -}; - class ConstSizeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -131,9 +105,7 @@ patterns.insert< BinaryOpConversion, BinaryOpConversion, - ConstSizeOpConverter, - IndexToSizeOpConversion, - SizeToIndexOpConversion>(ctx); + ConstSizeOpConverter>(ctx); // clang-format on } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -5,6 +5,14 @@ (Shape_FromExtentTensorOp $input), (replaceWithValue $input)>; +def IndexToSizeOpConversion : Pat< + (Shape_IndexToSizeOp $arg), + (replaceWithValue $arg)>; + +def SizeToIndexOpConversion : Pat< + (Shape_SizeToIndexOp $arg), + (replaceWithValue $arg)>; + def ToExtentTensorOpConversion : Pat< (Shape_ToExtentTensorOp $input), (replaceWithValue $input)>;