diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -193,6 +193,32 @@ llvm_unreachable("bufferRelation not implemented"); }] >, + InterfaceMethod< + /*desc=*/[{ + Resolve all inplacability conflicts by inserting explicit + `bufferization.alloc_tensor` ops. Examples of inplacability conflicts + are read-after-write conflicts or writes into non-writable buffers. + + This method should rewrite the IR in such a way that for each tensor + OpOperand t, buffer(t) can be directly used when during bufferization. + The bufferization does no longer have to care about inplacability + conflicts. + + This method can query analysis information from the given analysis + state. + }], + /*retType=*/"LogicalResult", + /*methodName=*/"resolveConflicts", + /*args=*/(ins "RewriterBase &":$rewriter, + "const AnalysisState &":$state), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto bufferizableOp = + cast($_op.getOperation()); + return bufferizableOp.resolveTensorOpOperandConflicts( + rewriter, state); + }] + >, InterfaceMethod< /*desc=*/[{ Bufferize this op, i.e., rewrite it into a memref-based equivalent. @@ -302,6 +328,11 @@ ]; let extraClassDeclaration = [{ + /// Resolve out-of-place tensor OpOperands with explicit allocations in the + /// form of `bufferization.alloc_tensor` ops. + LogicalResult resolveTensorOpOperandConflicts( + RewriterBase &rewriter, const AnalysisState &state); + /// Return `true` if the given OpOperand creates an alias but does neither /// read nor write. This implies that `bufferizesToMemoryRead` and /// `bufferizesToMemoryWrite` must return `false`. This method will never diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -18,6 +18,10 @@ #include "mlir/IR/Value.h" #include "llvm/Support/Debug.h" +//===----------------------------------------------------------------------===// +// BufferizableOpInterface +//===----------------------------------------------------------------------===// + namespace mlir { namespace bufferization { @@ -38,6 +42,28 @@ constexpr const ::llvm::StringLiteral bufferization::BufferizableOpInterface::kInplaceableAttrName; +LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( + RewriterBase &rewriter, const AnalysisState &state) { + Operation *op = getOperation(); + for (OpOperand &opOperand : op->getOpOperands()) { + if (opOperand.get().getType().isa()) + return op->emitError("copies of unranked tensors are not supported"); + auto tensorType = opOperand.get().getType().dyn_cast(); + if (!tensorType) + continue; + if (state.isInPlace(opOperand)) + continue; + SmallVector aliasingOpResults = + state.getAliasingOpResult(opOperand); + bool escape = llvm::any_of( + aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); + Value copy = rewriter.create( + op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); + rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); }); + } + return success(); +} + //===----------------------------------------------------------------------===// // OpFilter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -32,7 +32,7 @@ return failure(); } - OpBuilder builder(op->getContext()); + IRRewriter rewriter(op->getContext()); WalkResult result = op->walk([&](Operation *op) { auto bufferizableOp = options.dynCastBufferizableOp(op); if (!bufferizableOp) @@ -48,33 +48,15 @@ op->emitError("illegal return of allocation detected"); return WalkResult::interrupt(); } - allocTensorOp.escapeAttr(builder.getBoolAttr(escape)); + allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape)); return WalkResult::advance(); } - // Find out-of-place tensor OpOperands and resolve them with an explicit - // tensor copy in the form of an AllocTensorOp. - builder.setInsertionPoint(op); - for (OpOperand &opOperand : op->getOpOperands()) { - if (opOperand.get().getType().isa()) { - op->emitError("copies of unranked tensors are not supported"); - return WalkResult::interrupt(); - } - auto tensorType = opOperand.get().getType().dyn_cast(); - if (!tensorType) - continue; - if (state.isInPlace(opOperand)) - continue; - SmallVector aliasingOpResults = - state.getAliasingOpResult(opOperand); - bool escape = llvm::any_of( - aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); - assert((!escape || options.allowReturnAllocs) && - "analysis should have detected illegal alloc return"); - Value copy = builder.create( - op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); - opOperand.set(copy); - } + // Find inplacability conflicts and resolve them. (Typically with explicit + // tensor copies in the form of AllocTensorOps.) + rewriter.setInsertionPoint(op); + if (failed(bufferizableOp.resolveConflicts(rewriter, state))) + return WalkResult::interrupt(); return WalkResult::advance(); });