diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -36,61 +36,72 @@ }; class FromExtentTensorOpConversion - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::FromExtentTensorOp op, ArrayRef operands, + matchAndRewrite(FromExtentTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - shape::FromExtentTensorOpOperandAdaptor transformed(operands); + FromExtentTensorOpOperandAdaptor transformed(operands); rewriter.replaceOp(op.getOperation(), transformed.input()); return success(); } }; -class IndexToSizeOpConversion - : public OpConversionPattern { +class IndexToSizeOpConversion : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::IndexToSizeOp op, ArrayRef operands, + matchAndRewrite(IndexToSizeOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - shape::IndexToSizeOpOperandAdaptor transformed(operands); + IndexToSizeOpOperandAdaptor transformed(operands); rewriter.replaceOp(op.getOperation(), transformed.arg()); return success(); } }; -class SizeToIndexOpConversion - : public OpConversionPattern { +class SizeToIndexOpConversion : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::SizeToIndexOp op, ArrayRef operands, + matchAndRewrite(SizeToIndexOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - shape::SizeToIndexOpOperandAdaptor transformed(operands); + SizeToIndexOpOperandAdaptor transformed(operands); rewriter.replaceOp(op.getOperation(), transformed.arg()); return success(); } }; class ToExtentTensorOpConversion - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(shape::ToExtentTensorOp op, ArrayRef operands, + matchAndRewrite(ToExtentTensorOp op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - shape::ToExtentTensorOpOperandAdaptor transformed(operands); + ToExtentTensorOpOperandAdaptor transformed(operands); rewriter.replaceOp(op.getOperation(), transformed.input()); return success(); } }; +class ConstSizeOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(ConstSizeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op.getOperation(), + op.value().getSExtValue()); + return success(); + } +}; + /// Type conversions. class ShapeTypeConverter : public TypeConverter { public: @@ -100,8 +111,8 @@ // Add default pass-through conversion. addConversion([&](Type type) { return type; }); - addConversion([ctx](shape::SizeType type) { return IndexType::get(ctx); }); - addConversion([ctx](shape::ShapeType type) { + addConversion([ctx](SizeType type) { return IndexType::get(ctx); }); + addConversion([ctx](ShapeType type) { return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); }); @@ -111,9 +122,7 @@ /// Conversion pass. class ConvertShapeToStandardPass : public ConvertShapeToStandardBase { - void runOnOperation() override { - // Setup type conversion. MLIRContext &ctx = getContext(); ShapeTypeConverter typeConverter(&ctx); @@ -146,6 +155,7 @@ patterns.insert< BinaryOpConversion, BinaryOpConversion, + ConstSizeOpConverter, FromExtentTensorOpConversion, IndexToSizeOpConversion, SizeToIndexOpConversion, diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -75,3 +75,14 @@ // CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index return } + +// ----- + +// Convert `const_size` to `constant` op. +// CHECK-LABEL: @size_const +func @size_const() -> !shape.size { + %c1 = shape.const_size 1 + return %c1 : !shape.size +} +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: return %[[C1]] : index