diff --git a/mlir/docs/Bufferization.md b/mlir/docs/Bufferization.md --- a/mlir/docs/Bufferization.md +++ b/mlir/docs/Bufferization.md @@ -103,8 +103,8 @@ To simplify this problem, One-Shot Bufferize was designed for ops that are in *destination-passing style*. For every tensor result, such ops have a tensor -operand, who's buffer could be for storing the result of the op in the absence -of other conflicts. We call such tensor operands the *destination*. +operand, whose buffer could be utilized for storing the result of the op in the +absence of other conflicts. We call such tensor operands the *destination*. As an example, consider the following op: `%0 = tensor.insert %cst into %t[%idx] : tensor` diff --git a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td --- a/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td +++ b/mlir/include/mlir/Dialect/Arith/IR/ArithOps.td @@ -1381,6 +1381,9 @@ condition operand. If an i1 is provided as the condition, the entire vector or tensor is chosen. + Zero ranked tensors (`tensor`) are disallowed as condition. Such cases + should extract i1 first. + Example: ```mlir @@ -1398,7 +1401,7 @@ ``` }]; - let arguments = (ins BoolLike:$condition, + let arguments = (ins ScalarConditionOrMatchingShape:$condition, AnyType:$true_value, AnyType:$false_value); let results = (outs AnyType:$result); diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -1001,6 +1001,15 @@ TensorOf<[allowedType]>.predicate]>, name>; +// Temporary constraint to disallow zero ranked tensors in conditions in +// arith.select +class TypeOrContainerOfScalarConditionOrMatchingShape : + TypeConstraint.predicate, + Non0RankedTensorOf<[allowedType]>.predicate, + UnrankedTensorOf<[allowedType]>.predicate]>, + name>; // Type constraint for bool-like types: bools, vectors of bools, tensors of // bools. @@ -1008,6 +1017,10 @@ def BoolLikeOfAnyRank : TypeOrContainerOfAnyRank; +// Similar to bool-like, but disallows zero ranked tensors +def ScalarConditionOrMatchingShape : + TypeOrContainerOfScalarConditionOrMatchingShape; + // Type constraint for signless-integer-like types: signless integers, indices, // vectors of signless integers or indices, tensors of signless integers. def SignlessIntegerLike : TypeOrContainer, %arg1 : tensor, %arg2 : tensor) -> tensor { + // expected-error @+1 {{'arith.select' op operand #0 must be bool-like, but got 'tensor'}} + %0 = arith.select %arg0, %arg1, %arg2 : tensor, tensor + return %0 : tensor +} + +// ----- + +func.func @disallow_zero_rank_tensor_with_ranked_tensor(%arg0 : tensor, %arg1 : tensor<2xi64>, %arg2 : tensor<2xi64>) -> tensor<2xi64> { + // expected-error @+1 {{'arith.select' op operand #0 must be bool-like, but got 'tensor'}} + %0 = arith.select %arg0, %arg1, %arg2 : tensor, tensor<2xi64> + return %0 : tensor<2xi64> +} + +// ----- + +func.func @disallow_zero_rank_tensor_with_unranked_tensor(%arg0 : tensor, %arg1 : tensor<2x?xi64>, %arg2 : tensor<2x?xi64>) -> tensor<2x?xi64> { + // expected-error @+1 {{'arith.select' op operand #0 must be bool-like, but got 'tensor'}} + %0 = arith.select %arg0, %arg1, %arg2 : tensor, tensor<2x?xi64> + return %0 : tensor<2x?xi64> +} + +// ----- + +func.func @disallow_zero_rank_tensor_with_vector(%arg0 : tensor, %arg1 : vector, %arg2 : vector) -> vector { + // expected-error @+1 {{'arith.select' op operand #0 must be bool-like, but got 'tensor'}} + %0 = arith.select %arg0, %arg1, %arg2 : tensor, vector + return %0 : vector +}