diff --git a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp @@ -134,6 +134,13 @@ auto selectOp = cast(op); Location loc = selectOp.getLoc(); + // Elementwise conditions are not supported yet. To bufferize such an op, + // it could be lowered to an elementwise "linalg.generic" with a new + // "tensor.empty" out tensor, followed by "empty tensor elimination". Such + // IR will bufferize. + if (!selectOp.getCondition().getType().isInteger(1)) + return op->emitOpError("only i1 condition values are supported"); + // TODO: It would be more efficient to copy the result of the `select` op // instead of its OpOperands. In the worst case, 2 copies are inserted at // the moment (one for each tensor). When copying the op result, only one diff --git a/mlir/test/Dialect/Arith/bufferize.mlir b/mlir/test/Dialect/Arith/bufferize.mlir --- a/mlir/test/Dialect/Arith/bufferize.mlir +++ b/mlir/test/Dialect/Arith/bufferize.mlir @@ -1,5 +1,5 @@ -// RUN: mlir-opt %s -arith-bufferize -split-input-file | FileCheck %s -// RUN: mlir-opt %s -arith-bufferize=alignment=64 -split-input-file | FileCheck --check-prefix=ALIGNED %s +// RUN: mlir-opt %s -arith-bufferize -split-input-file -verify-diagnostics | FileCheck %s +// RUN: mlir-opt %s -arith-bufferize=alignment=64 -split-input-file -verify-diagnostics | FileCheck --check-prefix=ALIGNED %s // CHECK-LABEL: func @index_cast( // CHECK-SAME: %[[TENSOR:.*]]: tensor, %[[SCALAR:.*]]: i32 @@ -96,3 +96,12 @@ %0 = arith.select %arg0, %arg1, %arg2 : tensor return %0 : tensor } + +// ----- + +func.func @elementwise_select(%arg0: tensor<5xi1>, %arg1: tensor<5xi32>, %arg2: tensor<5xi32>) -> tensor<5xi32> { + // expected-error @below{{only i1 condition values are supported}} + // expected-error @below{{failed to bufferize op}} + %0 = arith.select %arg0, %arg1, %arg2 : tensor<5xi1>, tensor<5xi32> + return %0 : tensor<5xi32> +}