diff --git a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Bufferize.cpp @@ -53,6 +53,10 @@ auto tensorShape = tensorType.getShape(); auto memrefType = MemRefType::get(tensorShape, tensorType.getElementType()); Value resultTensor = outputs[resultIndex]; + if (isa(resultTensor.getDefiningOp())) { + resultBuffers.push_back(resultTensor); + continue; + } // Clone output buffers whose value is actually used. OpOperand *tiedOpOperand = linalgOp.getOutputOperand(resultIndex); diff --git a/mlir/test/Dialect/Linalg/bufferize.mlir b/mlir/test/Dialect/Linalg/bufferize.mlir --- a/mlir/test/Dialect/Linalg/bufferize.mlir +++ b/mlir/test/Dialect/Linalg/bufferize.mlir @@ -45,8 +45,9 @@ // CHECK: #map = affine_map<(d0) -> (d0)> // CHECK-LABEL: func @init_tensor( // CHECK-SAME: %[[IN:.*]]: tensor, %[[SIZE:.*]]: index) -// CHECK: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref -// CHECK: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) : memref +// CHECK-DAG: %[[MEMREF:.*]] = memref.buffer_cast %[[IN]] : memref +// CHECK-DAG: %[[OUT_BUF:.*]] = memref.alloc(%[[SIZE]]) : memref +// CHECK-NOT: memref.alloc // CHECK: linalg.generic // CHECK-SAME: ins(%[[MEMREF]] : memref) // CHECK-SAME: outs(%[[OUT_BUF]] : memref) {