diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -91,6 +91,19 @@ } }; +class ConstSizeOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(shape::ConstSizeOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp(op.getOperation(), + op.value().getSExtValue()); + return success(); + } +}; + /// Type conversions. class ShapeTypeConverter : public TypeConverter { public: @@ -111,9 +124,7 @@ /// Conversion pass. class ConvertShapeToStandardPass : public ConvertShapeToStandardBase { - void runOnOperation() override { - // Setup type conversion. MLIRContext &ctx = getContext(); ShapeTypeConverter typeConverter(&ctx); @@ -146,6 +157,7 @@ patterns.insert< BinaryOpConversion, BinaryOpConversion, + ConstSizeOpConverter, FromExtentTensorOpConversion, IndexToSizeOpConversion, SizeToIndexOpConversion, diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -75,3 +75,14 @@ // CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index return } + +// ----- + +// Convert `const_size` to `constant` op. +// CHECK-LABEL: @size_const +func @size_const() -> !shape.size { + %c1 = shape.const_size 1 + return %c1 : !shape.size +} +// CHECK: %[[C1:.*]] = constant 1 : index +// CHECK: return %[[C1]] : index