diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -217,6 +217,32 @@ return success(); } +namespace { +class SizeEqOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(SizeEqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override; +}; +} // namespace + +LogicalResult +SizeEqOpConverter::matchAndRewrite(SizeEqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const { + + // For now, only error-free types are supported by this lowering. + if (!op.lhs().getType().isa() || + !op.rhs().getType().isa()) + return failure(); + + SizeEqOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(op, CmpIPredicate::eq, transformed.lhs(), + transformed.rhs()); + return success(); +} + namespace { class RankOpConverter : public OpConversionPattern { public: @@ -278,6 +304,7 @@ GetExtentOpConverter, RankOpConverter, ShapeOfOpConversion, + SizeEqOpConverter, ToExtentTensorOpConversion>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -190,3 +190,16 @@ // CHECK: return %[[RES]] return %casted : tensor<3xindex> } + +// ----- + +// Lower `size_eq` to `cmpi`. +// CHECK-LABEL: @size_eq +// CHECK-SAME: (%[[A:.*]]: index, %[[B:.*]]: index) -> i1 +func @size_eq(%a : index, %b : index) -> i1 { + // CHECK: %[[RESULT:.*]] = cmpi "eq", %[[A]], %[[B]] + // CHECK: return %[[RESULT]] : i1 + %result = shape.size_eq %a, %b : index, index + return %result : i1 +} +