diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp @@ -405,8 +405,9 @@ if (auto bufferizableOp = options.dynCastBufferizableOp(lastWrite)) if (!bufferizableOp.isMemoryWrite(lastWrite.cast(), *this)) skipCopy = true; - // Do not copy if the copied data is never read. - if (!isValueRead(result)) + // Do not copy if the copied data is never read. (Neither by this op nor by + // any following op.) + if (!bufferizesToMemoryRead(*opOperand) && !isValueRead(result)) skipCopy = true; // Do not copy if this op does not read the data, but writes it. if (bufferizesToMemoryWrite(*opOperand) && diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -1189,3 +1189,39 @@ return %r : tensor } +// ----- + +#accesses = [ + affine_map<(i) -> (i)> +] +#trait = { + indexing_maps = #accesses, + iterator_types = ["parallel"] +} + +// CHECK-LABEL: func @op_is_reading_but_following_ops_are_not +// CHECK-SAME: %[[t0:.*]]: memref {linalg.inplaceable = false}, + %cst : f32) + -> tensor +{ + // Make sure that a copy is inserted here. + // CHECK: %[[ALLOC:.*]] = memref.alloc + // CHECK: linalg.copy(%[[t0]], %[[ALLOC]]) + // CHECK: linalg.generic {{.*}} outs(%[[ALLOC]] : memref + %r0 =linalg.generic #trait outs (%t0 : tensor) { + ^bb(%0: f32) : + %a = arith.addf %cst, %0 : f32 + linalg.yield %a : f32 + } -> (tensor) + + // CHECK: linalg.generic {{.*}} outs(%[[ALLOC]] : memref + %r1 = linalg.generic #trait outs (%r0 : tensor) { + ^bb(%0: f32) : + linalg.yield %cst : f32 + } -> (tensor) + + // CHECK: return %[[ALLOC]] + return %r1 : tensor +}