diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp --- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp @@ -1058,6 +1058,21 @@ *srcBuffer, subview))) return failure(); + // In case the source was allocated in the same block, make sure that the + // deallocation op (if any) appears after the memcpy. By default, deallocs + // are placed before the terminator, but this does not work for ForallOp + // because the terminator does more than just yielding a value. + // + // Note: This is not a problem for the destination buffer because these are + // assumed to always bufferize in-place. + for (Operation *user : srcBuffer->getUsers()) { + if (hasEffect(user)) { + if (user->getBlock() == parallelCombiningParent->getBlock()) + user->moveBefore(user->getBlock()->getTerminator()); + break; + } + } + // Delete the op. rewriter.eraseOp(op); return success(); diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir --- a/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize.mlir @@ -335,7 +335,7 @@ // CHECK-LABEL: func @dim_not_reading( // CHECK-SAME: %[[t:.*]]: memref, %f: f32, %pos: index) +func.func @dim_not_reading(%t: tensor, %f: f32, %pos: index) -> (tensor, index) { %c0 = arith.constant 0 : index @@ -370,3 +370,31 @@ // in the caller. return %casted, %slice : tensor<10xf32>, tensor } + +// ----- + +// CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place +func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) { + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %num_threads = arith.constant 50 : index + + // CHECK: scf.forall {{.*}} { + %result = scf.forall (%thread_idx) in (%num_threads) shared_outs (%o = %out) -> tensor<100xf32> { + // The tensor.insert must bufferize out-of-place. + // CHECK: memref.alloc + // CHECK: memref.store + %insert = tensor.insert %f into %in[%c0] : tensor<1xf32> + %r = tensor.extract %in[%c0] : tensor<1xf32> + vector.print %r : f32 + + // CHECK: memref.copy + // CHECK: memref.dealloc + scf.forall.in_parallel { + tensor.parallel_insert_slice %insert into %o[%thread_idx][1][1] : + tensor<1xf32> into tensor<100xf32> + } + } + // CHECK: } + return +}