diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -103,8 +103,8 @@ To simplify this problem, One-Shot Bufferize was designed for ops that are in *destination-passing style*. For every tensor result, such ops have a tensor -operand, who's buffer could be for storing the result of the op in the absence -of other conflicts. We call such tensor operands the *destination*. +operand, whose buffer could be utilized for storing the result of the op in the +absence of other conflicts. We call such tensor operands the *destination*. As an example, consider the following op: `%0 = tensor.insert %cst into %t[%idx] : tensor` diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1366,6 +1366,7 @@ def SelectOp : Arith_Op<"select", [Pure, AllTypesMatch<["true_value", "false_value", "result"]>, + ScalarConditionOrMatchingShape<["condition", "result"]>, DeclareOpInterfaceMethods, ] # ElementwiseMappable.traits> { let summary = "select operation"; diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -2548,6 +2548,12 @@ class ElementType : StrFunc<"getElementTypeOrSelf($" # name # ")">; +class AnyPred values> : + CPred; + class AllMatchPred values> : CPred values = names; } +class AnyMatchOperatorPred names, string operator> : + AnyPred; + +class AnyMatchOperatorTrait names, string operator, + string summary> : + PredOpTrait< + "any of {" # !interleave(names, ", ") # "} has " # summary, + AnyMatchOperatorPred> { + list values = names; +} + class AllElementCountsMatch names> : AllMatchSameOperatorTrait.result, "element count">; @@ -2695,4 +2712,16 @@ "[this](unsigned i) { return getElementTypeOrSelf(this->getOperand(i)); " "}))">; +class AnyScalarTypeMatch names> : + AnyMatchOperatorTrait; + +class ScalarConditionOrMatchingShape names> : + PredOpTrait< + !head(names) # " is scalar or has matching shape", + Or<[AnyScalarTypeMatch<[!head(names)]>.predicate, + AllShapesMatch.predicate]>> { + list values = names; +} + #endif // OP_BASE diff --git a/mlir/test/Dialect/Arith/invalid.mlir b/mlir/test/Dialect/Arith/invalid.mlir --- a/mlir/test/Dialect/Arith/invalid.mlir +++ b/mlir/test/Dialect/Arith/invalid.mlir @@ -753,3 +753,19 @@ %x = arith.constant 1 : i32 } + +// ----- + +func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor, %arg1 : tensor<2xi64>, %arg2 : tensor<2xi64>) -> tensor<2xi64> { + // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}} + %0 = arith.select %arg0, %arg1, %arg2 : tensor, tensor<2xi64> + return %0 : tensor<2xi64> +} + +// ----- + +func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor, %arg1 : tensor<2x?xi64>, %arg2 : tensor<2x?xi64>) -> tensor<2x?xi64> { + // expected-error @+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}} + %0 = arith.select %arg0, %arg1, %arg2 : tensor, tensor<2x?xi64> + return %0 : tensor<2x?xi64> +} diff --git a/mlir/test/IR/invalid-ops.mlir b/mlir/test/IR/invalid-ops.mlir --- a/mlir/test/IR/invalid-ops.mlir +++ b/mlir/test/IR/invalid-ops.mlir @@ -82,7 +82,7 @@ func.func @func_with_ops(vector<12xi1>, vector<42xi32>, vector<42xi32>) { ^bb0(%cond : vector<12xi1>, %t : vector<42xi32>, %f : vector<42xi32>): - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} + // expected-error@+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}} %r = "arith.select"(%cond, %t, %f) : (vector<12xi1>, vector<42xi32>, vector<42xi32>) -> vector<42xi32> } @@ -90,7 +90,7 @@ func.func @func_with_ops(tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) { ^bb0(%cond : tensor<12xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>): - // expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}} + // expected-error@+1 {{'arith.select' op failed to verify that condition is scalar or has matching shape}} %r = "arith.select"(%cond, %t, %f) : (tensor<12xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32> }