diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -156,6 +156,10 @@ sets and inplace attributes will be set up accordingly before making any other bufferization decisions. This method will never be called on OpOperands that do not have a tensor type. + + Note: Unranked tensor OpOperands always bufferize in-place. This could + be extended in the future. Unranked tensors are used with external + functions only. }], /*retType=*/"bool", /*methodName=*/"mustBufferizeInPlace", @@ -163,7 +167,7 @@ "const ::mlir::bufferization::AnalysisState &":$state), /*methodBody=*/"", /*defaultImplementation=*/[{ - return false; + return opOperand.get().getType().isa(); }] >, InterfaceMethod< diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -107,6 +107,10 @@ tensor = shapedValue; } else if (shapedValue.getType().isa()) { tensor = b.create(loc, shapedValue); + } else if (shapedValue.getType().isa() || + shapedValue.getType().isa()) { + return getOwnerOfValue(shapedValue) + ->emitError("copying of unranked tensors is not implemented"); } else { llvm_unreachable("expected RankedTensorType or MemRefType"); } @@ -175,7 +179,7 @@ if (state.isInPlace(opOperand)) continue; if (operandType.isa()) - return op->emitError("copies of unranked tensors are not supported"); + return op->emitError("copying of unranked tensors is not implemented"); SmallVector aliasingOpResults = state.getAliasingOpResult(opOperand); @@ -189,11 +193,14 @@ if (aliasingOpResults.size() == 1 && !state.bufferizesToMemoryWrite(opOperand) && - state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { + state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1 && + !aliasingOpResults.front().getType().isa()) { // The op itself does not write but may create exactly one alias. Instead // of copying the OpOperand, copy the OpResult. The OpResult can sometimes // be smaller than the OpOperand (e.g., in the case of an extract_slice, - // where the result is usually a smaller part of the source). + // where the result is usually a smaller part of the source). Do not apply + // this optimization if the OpResult is an unranked tensor (because those + // cannot be copied at the moment). outOfPlaceOpResults.push_back(aliasingOpResults.front()); if (!state.canOmitTensorCopy(opOperand)) copiedOpResults.insert(aliasingOpResults.front()); diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir @@ -1283,16 +1283,16 @@ // ----- -// CHECK-LABEL: func.func private @ext_func(tensor<*xf32> {bufferization.access = "read-write"}) -func.func private @ext_func(%t: tensor<*xf32>) +// CHECK-LABEL: func.func private @ext_func(tensor {bufferization.access = "read-write"}) +func.func private @ext_func(%t: tensor) // CHECK: func.func @private_func_read_write(%{{.*}}: tensor<5xf32> {bufferization.access = "read"}) func.func @private_func_read_write(%t: tensor<5xf32>) -> f32 { %c0 = arith.constant 0 : index // Bufferizes out-of-place because `ext_func` may modify the buffer. // CHECK: tensor.cast {{.*}} {__inplace_operands_attr__ = ["false"]} - %0 = tensor.cast %t : tensor<5xf32> to tensor<*xf32> - func.call @ext_func(%0) : (tensor<*xf32>) -> () + %0 = tensor.cast %t : tensor<5xf32> to tensor + func.call @ext_func(%0) : (tensor) -> () %1 = tensor.extract %t[%c0] : tensor<5xf32> return %1 : f32 } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir @@ -315,3 +315,16 @@ %r = tensor.extract %2[%idx2] : tensor return %r : f32 } + +// ----- + +func.func @copy_of_unranked_tensor(%t: tensor<*xf32>) -> tensor<*xf32> { + // Unranked tensor OpOperands always bufferize in-place. With this limitation, + // there is no way to bufferize this IR correctly. + // expected-error @+1 {{input IR has RaW conflict}} + func.call @maybe_writing_func(%t) : (tensor<*xf32>) -> () + return %t : tensor<*xf32> +} + +// This function may write to buffer(%ptr). +func.func private @maybe_writing_func(%ptr : tensor<*xf32>) diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir @@ -607,3 +607,21 @@ // CHECK: return %[[RES]] : vector<4xf32> return %0 : vector<4xf32> } + +// ----- + +// CHECK-LABEL: func @main( +func.func @main() { + // CHECK: %[[const:.*]] = memref.get_global + %t = arith.constant dense<[1.0, 2.0, 3.0]> : tensor<3xf32> + // CHECK: %[[alloc:.*]] = memref.alloc + // CHECK: memref.copy %[[const]], %[[alloc]] + // CHECK: %[[casted:.*]] = memref.cast %[[alloc]] : memref<3xf32> to memref<*xf32> + %unranked = tensor.cast %t : tensor<3xf32> to tensor<*xf32> + // CHECK: call @maybe_writing_func(%[[casted]]) + func.call @maybe_writing_func(%unranked) : (tensor<*xf32>) -> () + return +} + +// This function may write to buffer(%ptr). +func.func private @maybe_writing_func(%ptr : tensor<*xf32>)