diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -370,7 +370,7 @@ //TODO(tpopp): Move the code below and witnesses to a different file. def Shape_AnyOp : Shape_Op<"any", - [NoSideEffect, DeclareOpInterfaceMethods]> { + [Commutative, NoSideEffect, DeclareOpInterfaceMethods]> { let summary = "Return any combination of the input shapes."; let description = [{ This operation takes multiple input shapes and returns some combination of @@ -390,6 +390,8 @@ let results = (outs Shape_ShapeType:$result); let assemblyFormat = "$inputs attr-dict"; + + let hasFolder = 1; } def Shape_AssumingAllOp : Shape_Op<"assuming_all", [Commutative, NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -101,6 +101,16 @@ // AnyOp //===----------------------------------------------------------------------===// +// TODO: Canonicalization should be implemented for shapes that can be +// determined through mixtures of the known dimensions of the inputs. +OpFoldResult AnyOp::fold(ArrayRef operands) { + // Only the last operand is checked because AnyOp is commutative. + if (operands.back()) + return operands.back(); + + return nullptr; +} + LogicalResult AnyOp::inferReturnTypes(MLIRContext *context, Optional location, ValueRange operands, DictionaryAttr attributes, diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -137,3 +137,25 @@ "consume.witness"(%2) : (!shape.witness) -> () return } + +// ----- +// any can be replaced with a constant input if it has one. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape) -> !shape.shape { + // CHECK-NEXT: %[[CS:.*]] = shape.const_shape + // CHECK-NEXT: return %[[CS]] + %0 = shape.const_shape [2, 3, 4] + %1 = shape.any %0, %arg0 + return %1 : !shape.shape +} + + +// ----- +// any is not yet replaced without a constant input. +// CHECK-LABEL: func @f +func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { + // CHECK-NEXT: shape.any + // CHECK-NEXT: return %[[CS]] + %1 = shape.any %arg0, %arg1 + return %1 : !shape.shape +}