diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -90,6 +90,28 @@ } }; +class EqOpConverter : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(EqOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + auto operandTy = op.lhs().getType(); + + // Shape equality is lowered to `scf` and the pattern can be found in the + // corresponding pass. + if (!operandTy.isa()) + return failure(); + + // Lower size equality. + EqOp::Adaptor transformed(operands); + rewriter.replaceOpWithNewOp(op, CmpIPredicate::eq, + transformed.lhs(), transformed.rhs()); + return success(); + } +}; + /// Type conversions. class ShapeTypeConverter : public TypeConverter { public: @@ -147,6 +169,7 @@ BinaryOpConversion, BinaryOpConversion, ConstSizeOpConverter, + EqOpConverter, ShapeOfOpConversion>(ctx); // clang-format on } diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -115,3 +115,16 @@ %shape = shape.shape_of %arg : tensor<1x5x?xf32> return } + +// ----- + +// Lower `eq` for sizes. +// CHECK-LABEL: @eq_size +// CHECK-SAME: (%[[A:.*]]: index, %[[B:.*]]: index) -> i1 +func @eq_size(%a : !shape.size, %b : !shape.size) -> i1 { + // CHECK: %[[RESULT:.*]] = cmpi "eq", %[[A]], %[[B]] + // CHECK: return %[[RESULT]] : i1 + %result = shape.eq %a, %b : !shape.size + return %result : i1 +} +