diff --git a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td --- a/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td +++ b/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td @@ -192,6 +192,32 @@ llvm_unreachable("bufferRelation not implemented"); }] >, + InterfaceMethod< + /*desc=*/[{ + Resolve all inplacability conflicts by inserting explicit + `bufferization.alloc_tensor` ops. Examples of inplacability conflicts + are read-after-write conflicts or writes into non-writable buffers. + + This method should rewrite the IR in such a way that for each tensor + OpOperand t, buffer(t) can be directly used when during bufferization. + The bufferization does no longer have to care about inplacability + conflicts. + + This method can query analysis information from the given analysis + state. + }], + /*retType=*/"LogicalResult", + /*methodName=*/"resolveConflicts", + /*args=*/(ins "RewriterBase &":$rewriter, + "const AnalysisState &":$state), + /*methodBody=*/"", + /*defaultImplementation=*/[{ + auto bufferizableOp = + cast($_op.getOperation()); + return bufferizableOp.resolveTensorOpOperandConflicts( + rewriter, state); + }] + >, InterfaceMethod< /*desc=*/[{ Bufferize this op, i.e., rewrite it into a memref-based equivalent. @@ -301,6 +327,11 @@ ]; let extraClassDeclaration = [{ + /// Resolve out-of-place tensor OpOperands with explicit allocations in the + /// form of `bufferization.alloc_tensor` ops. + LogicalResult resolveTensorOpOperandConflicts( + RewriterBase &rewriter, const AnalysisState &state); + /// Return `true` if the given OpOperand creates an alias but does neither /// read nor write. This implies that `bufferizesToMemoryRead` and /// `bufferizesToMemoryWrite` must return `false`. This method will never diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp --- a/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp @@ -18,6 +18,10 @@ #include "mlir/IR/Value.h" #include "llvm/Support/Debug.h" +//===----------------------------------------------------------------------===// +// BufferizableOpInterface +//===----------------------------------------------------------------------===// + namespace mlir { namespace bufferization { @@ -38,6 +42,31 @@ constexpr const ::llvm::StringLiteral bufferization::BufferizableOpInterface::kInplaceableAttrName; +LogicalResult BufferizableOpInterface::resolveTensorOpOperandConflicts( + RewriterBase &rewriter, const AnalysisState &state) { + Operation *op = getOperation(); + for (OpOperand &opOperand : op->getOpOperands()) { + Type operandType = opOperand.get().getType(); + if (!operandType.isa()) + continue; + if (state.isInPlace(opOperand)) + continue; + if (operandType.isa()) + return op->emitError("copies of unranked tensors are not supported"); + auto tensorType = operandType.dyn_cast(); + if (!tensorType) + continue; + SmallVector aliasingOpResults = + state.getAliasingOpResult(opOperand); + bool escape = llvm::any_of( + aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); + Value copy = rewriter.create( + op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); + rewriter.updateRootInPlace(op, [&]() { opOperand.set(copy); }); + } + return success(); +} + //===----------------------------------------------------------------------===// // OpFilter //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -43,7 +43,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(Operation *op, const AnalysisState &state) { - OpBuilder builder(op->getContext()); + IRRewriter rewriter(op->getContext()); WalkResult result = op->walk([&](Operation *op) { auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op); if (!bufferizableOp) @@ -55,31 +55,15 @@ if (allocTensorOp.escape()) return WalkResult::advance(); bool escape = state.isTensorYielded(allocTensorOp.result()); - allocTensorOp.escapeAttr(builder.getBoolAttr(escape)); + allocTensorOp.escapeAttr(rewriter.getBoolAttr(escape)); return WalkResult::advance(); } - // Find out-of-place tensor OpOperands and resolve them with an explicit - // tensor copy in the form of an AllocTensorOp. - builder.setInsertionPoint(op); - for (OpOperand &opOperand : op->getOpOperands()) { - if (opOperand.get().getType().isa()) { - op->emitError("copies of unranked tensors are not supported"); - return WalkResult::interrupt(); - } - auto tensorType = opOperand.get().getType().dyn_cast(); - if (!tensorType) - continue; - if (state.isInPlace(opOperand)) - continue; - SmallVector aliasingOpResults = - state.getAliasingOpResult(opOperand); - bool escape = llvm::any_of( - aliasingOpResults, [&](Value v) { return state.isTensorYielded(v); }); - Value copy = builder.create( - op->getLoc(), tensorType, ValueRange(), opOperand.get(), escape); - opOperand.set(copy); - } + // Find inplacability conflicts and resolve them. (Typically with explicit + // tensor copies in the form of AllocTensorOps.) + rewriter.setInsertionPoint(op); + if (failed(bufferizableOp.resolveConflicts(rewriter, state))) + return WalkResult::interrupt(); return WalkResult::advance(); });