diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -323,9 +323,23 @@ newBlockArgs); // Replace previous terminator with a new one that does not yield anything. - Operation *oldTerminator = newTiledLoopOp.getBody()->getTerminator(); + auto oldTerminator = + cast(newTiledLoopOp.getBody()->getTerminator()); rewriter.setInsertionPointToEnd(newTiledLoopOp.getBody()); - rewriter.create(oldTerminator->getLoc()); + auto newTerminator = + rewriter.create(oldTerminator->getLoc()); + + // Copy buffer of yielded tensor to output buffer. If everything bufferized + // inplace, this copy will fold away. + rewriter.setInsertionPoint(newTerminator); + for (auto it : llvm::zip(oldTerminator.values(), newOutputs)) { + Value output = std::get<1>(it); + Value toMemrefOp = rewriter.create( + newTerminator.getLoc(), output.getType(), std::get<0>(it)); + state.createMemCpy(rewriter, newTerminator.getLoc(), toMemrefOp, output); + } + + // Erase old terminator. rewriter.eraseOp(oldTerminator); // Replace results and delete old op. diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -576,6 +576,7 @@ %0 = tensor.dim %A, %c0 : tensor // CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} %[[A]]{{.*}}%[[B]]{{.*}}outs{{.*}}%[[c]] + // CHECK-NOT: copy %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3) ins (%arg4 = %A: tensor, %use = %effecting : memref, %arg5 = %B: tensor) outs (%arg6 = %c: tensor) @@ -655,6 +656,40 @@ // ----- +// CHECK: func @tiled_loop_yield_out_of_place( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref, +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref +func @tiled_loop_yield_out_of_place( + %A: tensor {linalg.inplaceable = true}, + %B: tensor {linalg.inplaceable = true}) + -> tensor +{ + %c3 = arith.constant 3 : index + %c0 = arith.constant 0 : index + %f0 = arith.constant 0.0 : f32 + + // CHECK: %[[M:.*]] = memref.dim %[[A]], {{.*}} : memref + %0 = tensor.dim %A, %c0 : tensor + + // CHECK: linalg.tiled_loop {{.*}} to (%[[M]]) {{.*}} outs{{.*}}%[[A]] + %1 = linalg.tiled_loop (%arg3) = (%c0) to (%0) step (%c3) + outs (%arg1 = %A: tensor) + iterators["parallel"] + { + // CHECK-NOT: alloc + // CHECK: linalg.copy(%[[B]], %[[A]]) + linalg.yield %B : tensor + // CHECK: linalg.yield + // CHECK-NOT: tensor + } + + // CHECK: return + // CHECK-NOT: tensor + return %1 : tensor +} + +// ----- + // CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> // CHECK: func private @external_func(memref)