diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -187,15 +187,14 @@ /// relation is "equivalent" (TODO: can be relaxed if needed). /// * The reverse use-def chain has exactly one end, which is the /// tensor::EmptyOp. -LogicalResult -mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep( +template +static LogicalResult insertSliceLikeAnchoredEmptyTensorEliminationStep( RewriterBase &rewriter, Operation *op, AnalysisState &state) { return eliminateEmptyTensors( rewriter, op, state, /*anchorMatchFunc=*/ [&](OpOperand &operand, SmallVector &neededValues) { - auto insertSliceOp = - dyn_cast(operand.getOwner()); + auto insertSliceOp = dyn_cast(operand.getOwner()); if (!insertSliceOp) return false; if (&operand != &insertSliceOp->getOpOperand(0) /*source*/) @@ -214,7 +213,7 @@ }, /*rewriteFunc=*/ [](OpBuilder &b, Location loc, OpOperand &operand) { - auto insertOp = cast(operand.getOwner()); + auto insertOp = cast(operand.getOwner()); auto extractOp = b.create( loc, insertOp.getSourceType(), insertOp.getDest(), insertOp.getMixedOffsets(), insertOp.getMixedSizes(), @@ -223,6 +222,18 @@ }); } +LogicalResult +mlir::bufferization::insertSliceAnchoredEmptyTensorEliminationStep( + RewriterBase &rewriter, Operation *op, AnalysisState &state) { + if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< + tensor::InsertSliceOp>(rewriter, op, state))) + return failure(); + if (failed(insertSliceLikeAnchoredEmptyTensorEliminationStep< + tensor::ParallelInsertSliceOp>(rewriter, op, state))) + return failure(); + return success(); +} + namespace { struct EmptyTensorElimination : public bufferization::impl::EmptyTensorEliminationBase< diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir --- a/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-empty-tensor-elimination.mlir @@ -137,3 +137,35 @@ : tensor<1x1x128xf32> into tensor<5x6x128xf32> return %3 : tensor<5x6x128xf32> } + +// ----- + +// CHECK: func @parallel_insert_slice( +// CHECK-SAME: %[[FUNC_ARG:[0-9a-zA-Z]*]]: memref +// CHECK-SAME: %[[sz:[0-9a-zA-Z]*]]: index +func.func @parallel_insert_slice( + %t: tensor {bufferization.buffer_layout = affine_map<(d0) -> (d0)>, bufferization.writable = true}, + %sz: index) + -> (tensor) +{ + %f0 = arith.constant 0.0: f32 + %c512 = arith.constant 512 : index + + %r1 = scf.foreach_thread (%iv) in (%c512) shared_outs(%o = %t) -> (tensor) { + // tensor.empty itself does not alloc but forwards to the insert_slice. + // EmptyTensorOpElimination replaces the tensor.empty with an inplace + // extract_slice. + // CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1] + %a = tensor.empty(%sz) : tensor + + // CHECK: linalg.fill ins({{.*}} : f32) outs(%[[T_SUBVIEW]] : memref) -> tensor + + // Self-copy canonicalizes away later. + scf.foreach_thread.perform_concurrently { + tensor.parallel_insert_slice %f into %o[42][%sz][1]: tensor into tensor + } + } + + return %r1: tensor +}