diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -94,6 +94,11 @@ bufferization::getBufferType(castOp.getResult(), options); if (failed(resultMemRefType)) return failure(); + if (isa(resultBuffer->getType()) && + isa(*resultMemRefType)) { + op->emitError("cannot bufferize unranked tensor cast"); + return failure(); + } // Replace the op with a memref.cast. assert(memref::CastOp::areCastCompatible(resultBuffer->getType(), diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-short-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-short-bufferize-invalid.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-short-bufferize-invalid.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s -one-shot-bufferize="allow-unknown-ops" -split-input-file -verify-diagnostics + +func.func @from_unranked_to_unranked(%arg0: tensor<*xi32>) { + // expected-error @+2 {{cannot bufferize unranked tensor cast}} + // expected-error @+1 {{failed to bufferize op}} + %0 = tensor.cast %arg0 : tensor<*xi32> to tensor<*xi32> + return +} \ No newline at end of file