diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -104,6 +104,21 @@ } }; +class SizeEqOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SizeEqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + + SizeEqOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(op, CmpIPredicate::eq, + transformed.lhs(), transformed.rhs()); + return success(); + } +}; + /// Type conversions. class ShapeTypeConverter : public TypeConverter { public: @@ -162,7 +177,8 @@ BinaryOpConversion, ConstSizeOpConverter, RankOpConverter, - ShapeOfOpConversion>(ctx); + ShapeOfOpConversion, + SizeEqOpConverter>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -127,3 +127,15 @@ %rank = shape.rank %shape return %rank : !shape.size } + +// ----- + +// Lower `size_eq` to `cmpi`. +// CHECK-LABEL: @size_eq +// CHECK-SAME: (%[[A:.*]]: index, %[[B:.*]]: index) -> i1 +func @size_eq(%a : !shape.size, %b : !shape.size) -> i1 { + // CHECK: %[[RESULT:.*]] = cmpi "eq", %[[A]], %[[B]] + // CHECK: return %[[RESULT]] : i1 + %result = shape.size_eq %a, %b + return %result : i1 +}