diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -44,7 +44,12 @@ LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( RewriterBase &rewriter, const AnalysisState &state) { + OpBuilder::InsertionGuard g(rewriter); Operation *op = getOperation(); + SmallVector outOfPlaceOpOperands; + SmallVector outOfPlaceOpResults; + + // Find all out-of-place OpOperands. for (OpOperand &opOperand : op->getOpOperands()) { Type operandType = opOperand.get().getType(); if (!operandType.isa()) @@ -53,17 +58,52 @@ continue; if (operandType.isa()) return op->emitError("copies of unranked tensors are not supported"); - auto tensorType = operandType.dyn_cast(); - if (!tensorType) - continue; + SmallVector aliasingOpResults = state.getAliasingOpResult(opOperand); + if (aliasingOpResults.size() == 1 && + !state.bufferizesToMemoryWrite(opOperand) && + state.getAliasingOpOperand(aliasingOpResults.front()).size() == 1) { + // The op itself does not write but may create exactly one alias. Instead + // of copying the OpOperand, copy the OpResult. The OpResult can sometimes + // be smaller than the OpOperand (e.g., in the case of an extract_slice, + // where the result is usually a smaller part of the source). + outOfPlaceOpResults.push_back(aliasingOpResults.front()); + } else { + // In all other cases, make a copy of the OpOperand. + outOfPlaceOpOperands.push_back(&opOperand); + } + } + + // Insert copies of OpOperands. + rewriter.setInsertionPoint(op); + for (OpOperand *opOperand : outOfPlaceOpOperands) { + auto tensorType = opOperand->get().getType().cast(); + SmallVector aliasingOpResults = + state.getAliasingOpResult(*opOperand); bool escape = llvm::any_of( aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); Value copy = rewriter.create( - op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); - rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); }); + op->getLoc(), tensorType, ValueRange(), opOperand->get(), escape); + rewriter.updateRootInPlace(op, [&]() { opOperand->set(copy); }); + } + + // Insert copies of OpResults. + rewriter.setInsertionPointAfter(op); + for (OpResult opResult : outOfPlaceOpResults) { + auto tensorType = opResult.getType().cast(); + bool escape = state.isTensorYielded(opResult); + Value copy = rewriter.create(op->getLoc(), tensorType, + ValueRange(), opResult, escape); + SmallVector uses = llvm::to_vector(llvm::map_range( + opResult.getUses(), [](OpOperand &use) { return &use; })); + for (OpOperand *use : uses) { + // Do not update the alloc_tensor op that we just created. + if (use->getOwner() != copy.getDefiningOp()) + rewriter.updateRootInPlace(use->getOwner(), [&]() { use->set(copy); }); + } } + return success(); } diff --git a/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir b/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Tensor/one-shot-bufferize-tensor-copy-insertion.mlir @@ -0,0 +1,18 @@ +// RUN: mlir-opt %s -tensor-copy-insertion -split-input-file | FileCheck %s +// RUN: mlir-opt %s -tensor-copy-insertion="bufferize-function-boundaries allow-return-allocs" -split-input-file | FileCheck %s --check-prefix=CHECK-FUNC + +// CHECK-LABEL: func @extract_slice( +// CHECK-SAME: %[[t:.*]]: tensor +// CHECK-FUNC-LABEL: func @extract_slice( +func.func @extract_slice(%t: tensor, %idx: index, %f: f32) + -> (tensor<5xf32>, tensor) +{ + // CHECK: %[[extract_slice:.*]] = tensor.extract_slice %[[t]][10] [5] [1] + %0 = tensor.extract_slice %t[10][5][1] : tensor to tensor<5xf32> + // CHECK: %[[alloc:.*]] = bufferization.alloc_tensor() copy(%[[extract_slice]]) {escape = false} : tensor<5xf32> + // CHECK-FUNC: bufferization.alloc_tensor() copy(%{{.*}}) {escape = true} : tensor<5xf32> + // CHECK: %[[insert:.*]] = tensor.insert %{{.*}} into %[[alloc]] + %1 = tensor.insert %f into %0[%idx] : tensor<5xf32> + // CHECK: return %[[insert]], %[[t]] + return %1, %t : tensor<5xf32>, tensor +}