diff --git a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt --- a/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt +++ b/mlir/lib/Conversion/ShapeToStandard/CMakeLists.txt @@ -1,3 +1,7 @@ +set(LLVM_TARGET_DEFINITIONS ShapeToStandardPatterns.td) +mlir_tablegen(ShapeToStandardPatterns.inc -gen-rewriters) +add_public_tablegen_target(ShapeToStandardPatternsIncGen) + add_mlir_conversion_library(MLIRShapeToStandard ShapeToStandard.cpp @@ -6,6 +10,7 @@ DEPENDS MLIRConversionPassIncGen + ShapeToStandardPatternsIncGen LINK_COMPONENTS Core diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -19,6 +19,9 @@ namespace { +/// Generated conversion patterns. +#include "ShapeToStandardPatterns.inc" + /// Conversion patterns. template class BinaryOpConversion : public OpConversionPattern { @@ -35,20 +38,6 @@ } }; -class FromExtentTensorOpConversion - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(FromExtentTensorOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - FromExtentTensorOp::Adaptor transformed(operands); - rewriter.replaceOp(op.getOperation(), transformed.input()); - return success(); - } -}; - class IndexToSizeOpConversion : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -75,20 +64,6 @@ } }; -class ToExtentTensorOpConversion - : public OpConversionPattern { -public: - using OpConversionPattern::OpConversionPattern; - - LogicalResult - matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - ToExtentTensorOp::Adaptor transformed(operands); - rewriter.replaceOp(op.getOperation(), transformed.input()); - return success(); - } -}; - class ConstSizeOpConverter : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -122,6 +97,7 @@ /// Conversion pass. class ConvertShapeToStandardPass : public ConvertShapeToStandardBase { + void runOnOperation() override { // Setup type conversion. MLIRContext &ctx = getContext(); @@ -151,15 +127,14 @@ void mlir::populateShapeToStandardConversionPatterns( OwningRewritePatternList &patterns, MLIRContext *ctx) { + populateWithGenerated(ctx, &patterns); // clang-format off patterns.insert< BinaryOpConversion, BinaryOpConversion, ConstSizeOpConverter, - FromExtentTensorOpConversion, IndexToSizeOpConversion, - SizeToIndexOpConversion, - ToExtentTensorOpConversion>(ctx); + SizeToIndexOpConversion>(ctx); // clang-format on } diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td new file mode 100644 --- /dev/null +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -0,0 +1,12 @@ +include "mlir/Dialect/Shape/IR/ShapeOps.td" +include "mlir/Dialect/StandardOps/IR/Ops.td" + +// Convert `from_extent_tensor` and `to_extent_tensor` to no-ops as shapes will +// be represented as extent tensors. +def FromExtentTensorOpConversion : Pat< + (Shape_FromExtentTensorOp $input), + (replaceWithValue $input)>; +def ToExtentTensorOpConversion : Pat< + (Shape_ToExtentTensorOp $input), + (replaceWithValue $input)>; +