diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -581,6 +581,8 @@ let arguments = (ins Variadic:$inputs); let results = (outs Shape_ShapeOrExtentTensorType:$result); + let assemblyFormat = "$inputs attr-dict `:` type($inputs) `->` type($result)"; + let hasFolder = 1; } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -428,7 +428,7 @@ // CHECK-NEXT: %[[CS:.*]] = shape.const_shape // CHECK-NEXT: return %[[CS]] %0 = shape.const_shape [2, 3, 4] : !shape.shape - %1 = "shape.any"(%0, %arg) : (!shape.shape, !shape.shape) -> !shape.shape + %1 = shape.any %0, %arg : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } @@ -440,7 +440,7 @@ // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor // CHECK-NEXT: return %[[CS]] : tensor %0 = shape.const_shape [2, 3, 4] : tensor - %1 = "shape.any"(%0, %arg) : (tensor, tensor) -> tensor + %1 = shape.any %0, %arg : tensor, tensor -> tensor return %1 : tensor } @@ -449,9 +449,9 @@ // Folding of any with partially constant operands is not yet implemented. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { - // CHECK-NEXT: %[[CS:.*]] = "shape.any" + // CHECK-NEXT: %[[CS:.*]] = shape.any // CHECK-NEXT: return %[[CS]] - %1 = "shape.any"(%arg0, %arg1) : (!shape.shape, !shape.shape) -> !shape.shape + %1 = shape.any %arg0, %arg1 : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -235,3 +235,26 @@ %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape return %2 : !shape.shape } + +func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape) + -> !shape.shape { + %result = shape.any %a, %b, %c + : !shape.shape, !shape.shape, !shape.shape -> !shape.shape + return %result : !shape.shape +} + +func @any_on_mixed(%a : tensor, + %b : tensor, + %c : !shape.shape) -> !shape.shape { + %result = shape.any %a, %b, %c + : tensor, tensor, !shape.shape -> !shape.shape + return %result : !shape.shape +} + +func @any_on_extent_tensors(%a : tensor, + %b : tensor, + %c : tensor) -> tensor { + %result = shape.any %a, %b, %c + : tensor, tensor, tensor -> tensor + return %result : tensor +}