diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandard.cpp @@ -23,6 +23,22 @@ #include "ShapeToStandardPatterns.inc" /// Conversion patterns. +class AnyOpConversion : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(AnyOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + AnyOp::Adaptor transformed(operands); + + // Replace `any` with its first operand. + // Any operand would be a valid substitution. + rewriter.replaceOp(op, {transformed.inputs().front()}); + return success(); + } +}; + template class BinaryOpConversion : public OpConversionPattern { public: @@ -181,6 +197,7 @@ populateWithGenerated(ctx, &patterns); // clang-format off patterns.insert< + AnyOpConversion, BinaryOpConversion, BinaryOpConversion, ConstSizeOpConverter, diff --git a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td --- a/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td +++ b/mlir/lib/Conversion/ShapeToStandard/ShapeToStandardPatterns.td @@ -18,4 +18,3 @@ def SizeToIndexOpConversion : Pat< (Shape_SizeToIndexOp $arg), (replaceWithValue $arg)>; - diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -158,3 +158,26 @@ return %result : !shape.size } +// ----- + +// Lower `any` to its first operand. +// CHECK-LABEL: @any_of_three +// CHECK-SAME: (%[[A:.*]]: tensor, %[[B:.*]]: tensor, %[[C:.*]]: tensor) -> tensor +func @any_of_three(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape) + -> !shape.shape { + // CHECK: return %[[A]] : tensor + %result = shape.any %a, %b, %c + return %result : !shape.shape +} + +// ----- + +// Lower `any` to its first operand. +// CHECK-LABEL: @any_of_one +// CHECK-SAME: (%[[A:.*]]: tensor) -> tensor +func @any_of_one(%a : !shape.shape) -> !shape.shape { + // CHECK: return %[[A]] : tensor + %result = shape.any %a + return %result : !shape.shape +} +