diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -438,7 +438,7 @@ //===----------------------------------------------------------------------===// //TODO(tpopp): Move the code below and witnesses to a different file. -def Shape_AnyOp : Shape_Op<"any", [NoSideEffect]> { +def Shape_AnyOp : Shape_Op<"any", [Commutative, NoSideEffect]> { let summary = "Return any combination of the input shapes"; let description = [{ This operation takes multiple input shapes and returns some combination of @@ -458,6 +458,8 @@ let results = (outs Shape_ShapeType:$result); let assemblyFormat = "$inputs attr-dict"; + + let hasFolder = 1; } def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -101,6 +101,16 @@ // AnyOp //===----------------------------------------------------------------------===// +// TODO: Canonicalization should be implemented for shapes that can be +// determined through mixtures of the known dimensions of the inputs. +OpFoldResult AnyOp::fold(ArrayRef operands) { + // Only the last operand is checked because AnyOp is commutative. + if (operands.back()) + return operands.back(); + + return nullptr; +} + //===----------------------------------------------------------------------===// // AssumingOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -245,3 +245,25 @@ "consume.witness"(%2) : (!shape.witness) -> () return } + +// ----- +// any can be replaced with a constant input if it has one. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) -> !shape.shape { + // CHECK-NEXT: %[[CS:.*]] = shape.const_shape + // CHECK-NEXT: return %[[CS]] + %0 = shape.const_shape [2, 3, 4] + %1 = shape.any %0, %arg0 + return %1 : !shape.shape +} + + +// ----- +// Folding of any with partially constant operands is not yet implemented. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { + // CHECK-NEXT: shape.any + // CHECK-NEXT: return %[[CS]] + %1 = shape.any %arg0, %arg1 + return %1 : !shape.shape +}