diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -581,6 +581,9 @@ let arguments = (ins Variadic:$inputs); let results = (outs Shape_ShapeOrExtentTensorType:$result); + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; + let hasFolder = 1; } diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -152,6 +152,27 @@ return nullptr; } +static ParseResult parseAnyOp(OpAsmParser &parser, OperationState &result) { + auto loc = parser.getCurrentLocation(); + SmallVector operands; + SmallVector operandTypes; + Type resultTy; + if (parser.parseOperandList(operands) || + parser.parseOptionalAttrDict(result.attributes) || parser.parseColon() || + parser.parseTypeList(operandTypes) || + parser.resolveOperands(operands, operandTypes, loc, result.operands) || + parser.parseArrow() || parser.parseType(resultTy)) + return failure(); + result.addTypes(resultTy); + return success(); +} + +static void print(OpAsmPrinter &p, AnyOp op) { + p << AnyOp::getOperationName() << " " << op.getOperands(); + p.printOptionalAttrDict(op.getAttrs()); + p << " : " << op.getOperandTypes() << " -> " << op.getType(); +} + //===----------------------------------------------------------------------===// // AssumingOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -428,7 +428,7 @@ // CHECK-NEXT: %[[CS:.*]] = shape.const_shape // CHECK-NEXT: return %[[CS]] %0 = shape.const_shape [2, 3, 4] : !shape.shape - %1 = "shape.any"(%0, %arg) : (!shape.shape, !shape.shape) -> !shape.shape + %1 = shape.any %0, %arg : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } @@ -440,7 +440,7 @@ // CHECK-NEXT: %[[CS:.*]] = shape.const_shape [2, 3, 4] : tensor // CHECK-NEXT: return %[[CS]] : tensor %0 = shape.const_shape [2, 3, 4] : tensor - %1 = "shape.any"(%0, %arg) : (tensor, tensor) -> tensor + %1 = shape.any %0, %arg : tensor, tensor -> tensor return %1 : tensor } @@ -449,9 +449,9 @@ // Folding of any with partially constant operands is not yet implemented. // CHECK-LABEL: func @f func @f(%arg0 : !shape.shape, %arg1 : !shape.shape) -> !shape.shape { - // CHECK-NEXT: %[[CS:.*]] = "shape.any" + // CHECK-NEXT: %[[CS:.*]] = shape.any // CHECK-NEXT: return %[[CS]] - %1 = "shape.any"(%arg0, %arg1) : (!shape.shape, !shape.shape) -> !shape.shape + %1 = shape.any %arg0, %arg1 : !shape.shape, !shape.shape -> !shape.shape return %1 : !shape.shape } diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -235,3 +235,26 @@ %2 = call @shape_equal_shapes(%a, %1) : (!shape.value_shape, !shape.value_shape) -> !shape.shape return %2 : !shape.shape } + +func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape) + -> !shape.shape { + %result = shape.any %a, %b, %c + : !shape.shape, !shape.shape, !shape.shape -> !shape.shape + return %result : !shape.shape +} + +func @any_on_mixed(%a : tensor, + %b : tensor, + %c : !shape.shape) -> !shape.shape { + %result = shape.any %a, %b, %c + : tensor, tensor, !shape.shape -> !shape.shape + return %result : !shape.shape +} + +func @any_on_extent_tensors(%a : tensor, + %b : tensor, + %c : tensor) -> tensor { + %result = shape.any %a, %b, %c + : tensor, tensor, tensor -> tensor + return %result : tensor +}