diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -113,7 +113,9 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/BufferUtils.h" +#include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/EquivalenceClasses.h" @@ -135,6 +137,8 @@ // Generic helpers. //===----------------------------------------------------------------------===// +static bool isaTensor(Type t) { return t.isa(); }; + /// Return the FuncOp called by `callOp`. static FuncOp getCalledFunction(CallOpInterface callOp) { SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); @@ -144,6 +148,20 @@ SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } +/// Return the unique ReturnOp that terminates `funcOp`. +/// Return nullptr if there is no such unique ReturnOp. +static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { + ReturnOp returnOp; + for (Block &b : funcOp.body()) { + if (auto candidateOp = dyn_cast(b.getTerminator())) { + if (returnOp) + return nullptr; + returnOp = candidateOp; + } + } + return returnOp; +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// @@ -162,7 +180,7 @@ } /// Wrapper for better debugging. -static Value lookup(BlockAndValueMapping &bvm, Value key) { +static Value lookup(const BlockAndValueMapping &bvm, Value key) { // TODO: if key comes from bbArg, forward. assert(key.getType().isa()); Value v = bvm.lookupOrNull(key); @@ -334,10 +352,8 @@ VectorTransferOpInterface, scf::YieldOp>(op) // clang-format on - || (none_of(op->getResultTypes(), - [](Type t) { return t.isa(); }) && - none_of(op->getOperandTypes(), - [](Type t) { return t.isa(); })); + || (none_of(op->getResultTypes(), isaTensor) && + none_of(op->getOperandTypes(), isaTensor)); } /// Return the OpResult that may bufferize into the same buffer as `opOperand` @@ -564,14 +580,22 @@ /// beginning the alias and equivalence sets only contain `v` itself. void createAliasInfoEntry(Value v); + /// Insert an info entry for `newValue` and merge its alias set with that of + /// `alias`. + void insertNewBufferAlias(Value newValue, Value alias); + + /// Insert an info entry for `newValue` and merge its alias set with that of + /// `alias`. Additionally, merge their equivalence classes. + void insertNewBufferEquivalence(Value newValue, Value alias); + /// Return true if the buffer to which `operand` would bufferize aliases a /// buffer that is known to not be writeable. This implies that the matching /// OpResult cannot be bufferized inplace. bool aliasesNonWriteableBuffer(OpOperand &operand) const; /// Return true if the buffer to which `operand` would bufferize is equivalent - /// to some use that would bufferize to a write to a buffer. - bool aliasesInPlaceWrite(ExtractSliceOp extractSliceOp) const; + /// to some buffer write. + bool aliasesInPlaceWrite(Value v) const; /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. @@ -606,6 +630,9 @@ bool isSourceEquivalentToAMatchingExtractSliceOp( InsertSliceOp insertSliceOp) const; + /// Apply `fun` to all the members of the equivalence class of `v`. + void applyOnEquivalenceClass(Value v, function_ref fun) const; + /// Print to `os`. void print(raw_ostream &os) const; @@ -613,8 +640,9 @@ void dump() const { print(llvm::errs()); } private: - /// Check aliasInfo for `v` exists and return a reference to it. + /// Check that aliasInfo for `v` exists and return a reference to it. DenseSet &getAliasInfoRef(Value v); + const DenseSet &getAliasInfoRef(Value v) const { return const_cast(this)->getAliasInfoRef(v); } @@ -727,6 +755,29 @@ equivalentInfo.insert(v); } +/// Insert an info entry for `newValue` and merge its alias set with that of +/// `alias`. +void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { + assert(aliasInfo.find(alias) != aliasInfo.end() && "Missing alias entry"); + createAliasInfoEntry(newValue); + mergeAliases(newValue, alias); + mergeAliasesToFixedPoint(); + if (auto memrefCastOp = newValue.getDefiningOp()) + if (aliasInfo.find(memrefCastOp.source()) == aliasInfo.end()) + insertNewBufferAlias(memrefCastOp.source(), newValue); +} + +/// Insert an info entry for `newValue` and merge its alias set with that of +/// `alias`. Additionally, merge their equivalence classes. +void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, + Value alias) { + insertNewBufferAlias(newValue, alias); + equivalentInfo.unionSets(newValue, alias); + if (auto memrefCastOp = newValue.getDefiningOp()) + if (aliasInfo.find(memrefCastOp.source()) == aliasInfo.end()) + insertNewBufferEquivalence(memrefCastOp.source(), newValue); +} + /// Return true if the buffer to which `operand` would bufferize aliases a /// buffer that is known to not be writeable. This implies that the matching /// OpResult cannot be bufferized inplace. @@ -742,13 +793,13 @@ LDBG("-----------bbArg is writeable -> skip: " << bbArg << '\n'); continue; } - LDBG("-----------notWriteable: " << v << '\n'); + LDBG("-----------notWriteable\n"); return true; } if (Operation *op = v.getDefiningOp()) { if (isa(op) || !hasKnownBufferizationAliasingBehavior(op)) { - LDBG("-----------notWriteable: " << v << '\n'); + LDBG("-----------notWriteable\n"); return true; } } @@ -758,12 +809,11 @@ } /// Return true if the buffer to which `operand` would bufferize is equivalent -/// to some use that would bufferize to a write to a buffer. -bool BufferizationAliasInfo::aliasesInPlaceWrite( - ExtractSliceOp extractSliceOp) const { +/// to some buffer write. +bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const { LDBG("----Start aliasesInPlaceWrite\n"); - LDBG("-------for op: " << *extractSliceOp.getOperation() << '\n'); - for (Value v : getAliasInfoRef(extractSliceOp.result())) { + LDBG("-------for : " << value << '\n'); + for (Value v : getAliasInfoRef(value)) { for (auto &use : v.getUses()) { if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { LDBG("-----------wants to bufferize to inPlace write: " @@ -772,7 +822,7 @@ } } } - LDBG("----------->extract_slice does not alias an inplace write"); + LDBG("----------->does not alias an inplace write\n"); return false; } @@ -911,6 +961,16 @@ return false; } +/// Apply `fun` to all the members of the equivalence class of `v`. +void BufferizationAliasInfo::applyOnEquivalenceClass( + Value v, function_ref fun) const { + for (auto it = equivalentInfo.findLeader(v), + eit = equivalentInfo.member_end(); + it != eit; ++it) { + fun(v); + } +} + void BufferizationAliasInfo::print(raw_ostream &os) const { os << "\n/========================== AliasInfo " "==========================\n"; @@ -1097,6 +1157,21 @@ return existsInterleavedValueClobber(aliasingRead, aliasingWrite, domInfo); } +//===----------------------------------------------------------------------===// +// Forward declarations. +//===----------------------------------------------------------------------===// + +/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such +/// an op. Return null otherwise. +static Operation *getEquivalentAlloc(Value value, + const BufferizationAliasInfo &aliasInfo); + +/// Return the first argument of the enclosing FuncOp that is equivalent to `v`. +/// Return null if no such bbArg can be found. +static BlockArgument +getEquivalentEnclosingFuncBBArg(Value v, + const BufferizationAliasInfo &aliasInfo); + //===----------------------------------------------------------------------===// // Bufferization-specific MemRefType support. //===----------------------------------------------------------------------===// @@ -1150,8 +1225,10 @@ /// Create an Allocop/DeAllocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -static Value createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc, - Value shapedValue) { +static Value +createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc, + Value shapedValue, + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1181,9 +1258,12 @@ b.create(loc, shapedValue, dim.index())); Value allocated = b.create(loc, allocMemRefType, dynShape); + aliasInfo.createAliasInfoEntry(allocated); Value casted = allocated; - if (memRefType != allocMemRefType) + if (memRefType != allocMemRefType) { casted = b.create(loc, memRefType, allocated); + aliasInfo.insertNewBufferEquivalence(casted, allocated); + } b.setInsertionPoint(allocated.getParentBlock()->getTerminator()); b.create(loc, allocated); return casted; @@ -1204,7 +1284,8 @@ static LogicalResult allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, SmallVectorImpl &resultBuffers, - BlockAndValueMapping &bvm) { + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1228,7 +1309,8 @@ // Otherwise, `op` is not inplaceable and we need to allocate its result. Value dimTensor = bvm.lookupOrDefault(output); - Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor); + Value alloc = + createNewAllocDeallocPairForShapedValue(b, loc, dimTensor, aliasInfo); b.setInsertionPointAfter(alloc.getDefiningOp()); resultBuffers.push_back(alloc); @@ -1250,7 +1332,7 @@ /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferize(OpBuilder &b, LinalgOp op, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1259,8 +1341,6 @@ if (!op.hasTensorSemantics()) return failure(); - LDBG("bufferize: " << *op << '\n'); - b.setInsertionPoint(op); Location loc = op.getLoc(); SmallVector newInputBuffers; @@ -1276,7 +1356,8 @@ } SmallVector newOutputBuffers; // Try to allocate new buffers depending on op's inplace semantics. - if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm))) + if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, + aliasInfo))) return failure(); // Clone the newly bufferized op. @@ -1295,11 +1376,125 @@ return success(); } +/// In a first approximation, all the function arguments of a FuncOp are marked +/// inplaceable. For now, it is the responsibility of the `callOp` bufferization +/// to allow FuncOp that are inplaceable to write inPlace. +static LogicalResult bufferize(OpBuilder &b, CallOpInterface callOp, + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { + FuncOp funcOp = getCalledFunction(callOp); + if (!funcOp) + return success(); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(callOp); + + // Some operands must be writeable inplace in the callee. + SmallVector inPlaceWriteOperands; + int64_t operandNumber = 0; + llvm::copy_if(callOp->getOperands(), inPlaceWriteOperands.end(), + [&](Value v) { + BlockArgument bbArg = funcOp.getArgument(operandNumber++); + return bbArg.getType().isa() && + aliasInfo.aliasesInPlaceWrite(bbArg); + }); + + // 1. Rewrite tensor operands as memrefs. + // TODO: pull new allocated returned tensors out as new args when possible. + SmallVector newOperands; + newOperands.reserve(callOp->getNumOperands()); + for (OpOperand &opOperand : callOp->getOpOperands()) { + Value tensorOperand = opOperand.get(); + // Non-tensor operands are just copied. + if (!tensorOperand.getType().isa()) { + newOperands.push_back(tensorOperand); + continue; + } + + // Tensor operands that have not been bufferized trigger an error. + int64_t idx = opOperand.getOperandNumber(); + Value buffer = bvm.lookupOrNull(tensorOperand); + if (!buffer) + return callOp->emitError() << " missing buffer for operand #" << idx; + + // Caller / callee type mistmatch is handled with a CastOp. + auto funcOp = getCalledFunction(callOp); + auto memRefType = funcOp.getType().getInput(idx).dyn_cast(); + // Since we don't yet have a clear layout story, buffer_cast may + // conservatively turn tensors into more dynamic memref than + // necessary. If the memref type of the callee fails, introduce an + // extra memref.cast that will either canonicalize away or fail + // compilation until we can do something better. + if (memRefType && buffer.getType() != memRefType) { + Value castBuffer = + b.create(callOp.getLoc(), memRefType, buffer); + aliasInfo.insertNewBufferEquivalence(castBuffer, buffer); + map(bvm, tensorOperand, castBuffer); + buffer = castBuffer; + } + newOperands.push_back(buffer); + } + + // 2. Bodiless CallOp bufferization. + // If CallOp a FuncOp without a body, only the arguments have been + // bufferized. Such a function cannot return a tensor and its results do not + // change as a result of bufferization. + if (funcOp.body().empty()) { + if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) + return callOp->emitError() << "Bodiless callee cannot return a tensor"; + callOp->setOperands(newOperands); + LLVM_DEBUG(DBGS() << "Map: inplace CallOp " << *callOp << "\n"); + return success(); + } + + // 3. CallOp bufferization. + // Clone the CallOp with its attributes, its results may change as a + // result of bufferization. + if (!isa(callOp.getOperation())) + return callOp->emitError() << "expected a CallOp"; + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + if (!returnOp) + return funcOp->emitError() << "no unique ReturnOp"; + + // For each FuncOp result, keep track of which inplace argument it reuses. + SmallVector returnTypes; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + // If return operand is equivalent to some bbArg, no need to return it. + Value returnVal = returnOperand.get(); + if (BlockArgument bbArg = + getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) { + Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); + Value buffer = newOperands[bbArg.getArgNumber()]; + aliasInfo.insertNewBufferEquivalence(oldRes, buffer); + map(bvm, oldRes, buffer); + // Add TensorLoadOp that must fold away so we can eventually erase callOp. + Value tensorLoad = + b.create(callOp.getLoc(), buffer); + // Replace all uses of the CallOp results so we can erase the CallOp. + oldRes.replaceAllUsesWith(tensorLoad); + aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer); + map(bvm, tensorLoad, buffer); + continue; + } + // TODO: Need to hoist above function boundary. If this is not possible due + // to data-depedent sizes, we need a better type than memref. + if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) + return allocOp->emitError() << " needs hoist across function boundary\n"; + int64_t returnIdx = returnOperand.getOperandNumber(); + return returnOp->emitError() << " bufferize result #" << returnIdx << "\n"; + } + Operation *newCallOp = b.create(callOp.getLoc(), funcOp.sym_name(), + returnTypes, newOperands); + newCallOp->setAttrs(callOp->getAttrs()); + return success(); +} + /// DimOp tensor operand is modified inplace. This allows leaving dead /// tensors behind that will get DCE'd. static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { if (dimOp.memrefOrTensor().getType().isa()) { Value v = lookup(bvm, dimOp.memrefOrTensor()); if (!v) @@ -1311,13 +1506,11 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); Location loc = forOp.getLoc(); - LLVM_DEBUG(DBGS() << "bufferize: " << *forOp << "\n"); - // If inPlace, just forward the buffer. // Otherwise alloc and copy. b.setInsertionPoint(forOp); @@ -1331,7 +1524,8 @@ Value operandBuffer = lookup(bvm, operand); Value resultBuffer = operandBuffer; if (getInPlace(opResult) != InPlaceSpec::True) { - resultBuffer = createNewAllocDeallocPairForShapedValue(b, loc, operand); + resultBuffer = + createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo); // If the tensor comes from `linalg::InitTensorOp`, the value is // unitialized and we do not need to copy. // TODO: if the matching bbArg does not bufferize to a read is more @@ -1350,7 +1544,7 @@ /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); @@ -1364,9 +1558,10 @@ Type memRefType = rankedTensorType ? getDynamicMemRefType(rankedTensorType) : getContiguousOrUnrankedMemRefType(tensorType); - Value tensorToMemref = + Value bufferCast = b.create(funcOp.getLoc(), memRefType, bbArg); - map(bvm, bbArg, tensorToMemref); + aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg); + map(bvm, bbArg, bufferCast); } return success(); } @@ -1374,7 +1569,7 @@ /// ReturnOp always creates memref::TensorLoadOp. static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(returnOp); @@ -1388,7 +1583,10 @@ Value v = lookup(bvm, operand.get()); if (!v) return failure(); - operand.set(b.create(returnOp.getLoc(), v)); + Value returnTensor = b.create(returnOp.getLoc(), v); + operand.set(returnTensor); + aliasInfo.insertNewBufferEquivalence(returnTensor, v); + map(bvm, returnTensor, v); } return success(); } @@ -1400,7 +1598,7 @@ /// isolation. static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { LDBG("bufferize: " << *extractSliceOp << '\n'); // Take a guard before anything else. @@ -1420,8 +1618,8 @@ Value alloc; auto inPlace = getInPlace(extractSliceOp->getResult(0)); if (inPlace != InPlaceSpec::True) { - alloc = createNewAllocDeallocPairForShapedValue(b, loc, - extractSliceOp.result()); + alloc = createNewAllocDeallocPairForShapedValue( + b, loc, extractSliceOp.result(), aliasInfo); b.setInsertionPointAfter(alloc.getDefiningOp()); } @@ -1435,6 +1633,8 @@ Value subView = b.create( loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); + // Insert new alias. + aliasInfo.insertNewBufferAlias(subView, srcMemref); /// If not inplaceable, copy. if (alloc) { @@ -1448,7 +1648,7 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { LDBG("bufferize: " << *insertSliceOp << '\n'); // Take a guard before anything else. @@ -1466,8 +1666,8 @@ // cloning the whole tensor on every single iteration and is a symptom // of a catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. - Value newDstMemref = - createNewAllocDeallocPairForShapedValue(b, loc, insertSliceOp.result()); + Value newDstMemref = createNewAllocDeallocPairForShapedValue( + b, loc, insertSliceOp.result(), aliasInfo); b.setInsertionPointAfter(newDstMemref.getDefiningOp()); b.create(insertSliceOp.getLoc(), dstMemref, newDstMemref); dstMemref = newDstMemref; @@ -1497,6 +1697,8 @@ Value subView = b.create( loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + // Insert new alias. + aliasInfo.insertNewBufferAlias(subView, dstMemref); b.create(insertSliceOp.getLoc(), srcMemref, subView); } @@ -1507,7 +1709,7 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -1534,8 +1736,8 @@ // If transfer_write is not inPlace, allocate a new buffer. Value newInputBuffer; if (inPlace != InPlaceSpec::True) { - newInputBuffer = - createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result()); + newInputBuffer = createNewAllocDeallocPairForShapedValue( + b, loc, writeOp.result(), aliasInfo); b.setInsertionPointAfter(newInputBuffer.getDefiningOp()); map(bvm, writeOp.result(), newInputBuffer); } else { @@ -1561,7 +1763,7 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(yieldOp); @@ -1612,7 +1814,7 @@ // If `extractSliceOp` were to be bufferized inplace, it cannot end up // aliasing a write into a non-writeable buffer. bool wouldCreateAliasingWriteToNonWriteableBuffer = - aliasInfo.aliasesInPlaceWrite(extractSliceOp) && + aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0)); if (wouldCreateAliasingWriteToNonWriteableBuffer) @@ -1736,7 +1938,6 @@ return extractSliceOps.push_back(extractSliceOp); if (auto insertSliceOp = dyn_cast(op)) return insertSliceOps.push_back(insertSliceOp); - auto isaTensor = [](Type t) { return t.isa(); }; // No tensors => no buffers. if (none_of(op->getOperandTypes(), isaTensor) && none_of(op->getResultTypes(), isaTensor)) @@ -1787,12 +1988,12 @@ } //===----------------------------------------------------------------------===// -// Bufferization entry-point. +// Bufferization entry-point for functions. //===----------------------------------------------------------------------===// static LogicalResult bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); @@ -1800,42 +2001,53 @@ if (failed(bufferize(b, funcOp, bvm, aliasInfo))) return failure(); // Walk in PreOrder to ensure ops with regions are handled before their body. - WalkResult result = funcOp.walk([&](Operation *op) { - LogicalResult status = - TypeSwitch(op) - // Skip BufferCast and TensorLoad ops. - // clang-format off - .Case( - [&](auto) { return success(); }) - .Case( - [&](auto op) { - LDBG("Begin buferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo); - }) - // clang-format on - .Default([&](Operation *op) { - auto isaTensor = [](Type t) { return t.isa(); }; - if (any_of(op->getOperandTypes(), isaTensor) || - any_of(op->getResultTypes(), isaTensor)) - return failure(); - return success(); - }); - if (failed(status)) { - op->emitError("Failed bufferization"); - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); + // Since walk has to be PreOrder, we need to erase ops that require it + // separately: this is the case for CallOp + SmallVector toErase; + WalkResult result = + funcOp.walk([&](Operation *op) -> WalkResult { + WalkResult result = + TypeSwitch(op) + // Skip BufferCast and TensorLoad ops. + // clang-format off + .Case( + [&](auto) { return success(); }) + .Case( + [&](auto op) { + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo); + }) + // clang-format on + .Default([&](Operation *op) { + auto isaTensor = [](Type t) { return t.isa(); }; + if (any_of(op->getOperandTypes(), isaTensor) || + any_of(op->getResultTypes(), isaTensor)) + return failure(); + return success(); + }); + + // Register post-walk erasure, if necessary. + if (isa(op)) + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getResultTypes(), isaTensor)) + toErase.push_back(op); + + return result; + }); LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); + for (Operation *op : toErase) + op->erase(); + return failure(result.wasInterrupted()); } @@ -1884,6 +2096,164 @@ // Bufferization entry-point for modules. //===----------------------------------------------------------------------===// +/// Modify the `funcOp` arg and return types to replace tensor types by buffer +/// types. +// TODO: Generalize the use of contiguous MemRef at the function boundary. +static void bufferizeFunctionArgAndReturnTypes(FuncOp funcOp, + TypeRange returnTypes) { + auto argTypes = llvm::to_vector<4>( + llvm::map_range(funcOp.getType().getInputs(), [](Type t) -> Type { + // TODO: non-zero address space. + // TODO: layout information if relevant. + if (auto tensorType = t.dyn_cast()) + return getContiguousOrUnrankedMemRefType(tensorType); + return t; + })); + funcOp.setType( + FunctionType::get(funcOp->getContext(), argTypes, returnTypes)); +}; + +/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such +/// an op. Return null otherwise. +static Operation *getEquivalentAlloc(Value value, + const BufferizationAliasInfo &aliasInfo) { + Operation *res; + aliasInfo.applyOnEquivalenceClass(value, [&](Value v) { + if (!res) + if (auto interface = + dyn_cast_or_null(v.getDefiningOp())) + if (auto effect = + interface.getEffectOnValue(value)) + res = v.getDefiningOp(); + }); + return res; +} + +/// Return the first argument of the enclosing FuncOp that is equivalent to `v`. +/// Return null if no such bbArg can be found. +static BlockArgument +getEquivalentEnclosingFuncBBArg(Value v, + const BufferizationAliasInfo &aliasInfo) { + Operation *op = v.getParentBlock()->getParentOp(); + FuncOp funcOp = llvm::dyn_cast(op); + if (!funcOp) + funcOp = op->getParentOfType(); + for (BlockArgument bbArg : funcOp.getArguments()) + if (aliasInfo.areEquivalentBufferizedValues(v, bbArg)) + return bbArg; + return nullptr; +} + +/// Search `funcOp` for the following pattern for each result to determine +/// whether it can fold onto an argument: +/// ``` +/// func @foo(%A: tensor<...>, ..., %Z: tensor<...>) -> +/// (tensor<...>, ..., tensor<...>) +/// { +/// %p = memref.buffer_cast(%some_arg): ... +/// ... // uses of %p (read or writes) +/// %t = memref.tensor_load %p: ... +/// return ..., %t, ...: ..., tensor<...>, ... +/// } +/// ``` +/// Rewrite the `funcOp` arguments analysis return values and terminator into +/// buffer form (using the canonical memref layout for now), according to the +/// inPlace-bufferizable information of the function arguments. +static LogicalResult +bufferizeFuncOpBoundary(FuncOp funcOp, BufferizationAliasInfo &aliasInfo) { + LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n"); + + // Corner case: Bodiless FuncOp + // ============================ + // The body of such functions is assumed opaque and we can't know the + // bufferization contract they want to enforce atm. + // As a consequence, only support functions that don't return any tensor atm. + if (funcOp.getBody().empty()) { + if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) + return funcOp->emitError() << "Cannot bufferize bodiless function that " + << "returns a tensor"; + bufferizeFunctionArgAndReturnTypes(funcOp, funcOp.getType().getResults()); + LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp); + return success(); + } + + // Support only single return-terminated block in the function. + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + if (!returnOp) + return funcOp->emitError() << "no unique ReturnOp"; + + // 1. For each FuncOp result, keep track of which inplace argument it reuses. + SmallVector returnValues; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + // If return operand is equivalent to some bbArg, no need to return it. + Value returnVal = returnOperand.get(); + if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) + continue; + // TODO: Need to hoist above function boundary. If this is not possible due + // to data-depedent sizes, we need a better type than memref. + if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) + return allocOp->emitError() << " needs hoist across function boundary\n"; + int64_t returnIdx = returnOperand.getOperandNumber(); + return returnOp->emitError() << " bufferize result #" << returnIdx << "\n"; + } + + // 2. Rewrite the terminator without the inPlace bufferizable values. + OpBuilder(returnOp).create(returnOp.getLoc(), returnValues); + returnOp->erase(); + + // 3. Rewrite the bbArgs. + // Iterate on the original `numArgs` and replace them in order. + // This guarantees the argument order still matches after the rewrite. + Block &frontBlock = funcOp.body().front(); + unsigned numArgs = frontBlock.getNumArguments(); + for (unsigned idx = 0; idx < numArgs; ++idx) { + auto bbArg = frontBlock.getArgument(0); + auto tensorType = bbArg.getType().dyn_cast(); + if (!tensorType) { + frontBlock.addArgument(bbArg.getType()); + bbArg.replaceAllUsesWith(frontBlock.getArguments().back()); + frontBlock.eraseArgument(0); + continue; + } + + // TODO: non-zero address space. + // TODO: layout information if relevant. + Value memref = + frontBlock.addArgument(getContiguousOrUnrankedMemRefType(tensorType)); + OpBuilder b(funcOp->getContext()); + b.setInsertionPointToStart(&frontBlock); + // Replace all uses of bbArg through a BufferCastOp by a memref::CastOp. + for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) { + if (auto bufferCastOp = dyn_cast(use.getOwner())) { + auto castOp = b.create( + funcOp.getLoc(), bufferCastOp.memref().getType(), memref); + bufferCastOp.memref().replaceAllUsesWith(castOp); + aliasInfo.insertNewBufferEquivalence(castOp.dest(), + bufferCastOp.memref()); + continue; + } + } + // Replace all remaining uses by a tensor_load; these *must* canonicalize + // away. + if (!bbArg.use_empty()) { + auto tensorLoadOp = + b.create(funcOp.getLoc(), memref); + aliasInfo.insertNewBufferEquivalence(tensorLoadOp, bbArg); + bbArg.replaceAllUsesWith(tensorLoadOp); + } + frontBlock.eraseArgument(0); + // TODO: erase aliasInfo entries. + } + + // 4. Rewrite the FuncOp type to buffer form. + bufferizeFunctionArgAndReturnTypes(funcOp, + ValueRange{returnValues}.getTypes()); + + LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp); + + return success(); +} + /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by /// callee-caller order (i.e. callees without callers first). /// Store the map of FuncOp to all its callers in `callerMap`. @@ -1978,15 +2348,38 @@ return; } - // TODO: Bufferization phase. + // Bufferization phase. + if (!testAnalysisOnly) { + BlockAndValueMapping tensorToBufferMap; + if (failed( + bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo))) { + signalPassFailure(); + return; + } + } } // Don't drop the attributes if we only want to report the analysis. if (testAnalysisOnly) return; + for (FuncOp funcOp : orderedFuncOps) { + // Note: It would be good to apply cleanups here but we cannot as domInfo + // and aliasInfo would be invalidated. + if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo))) { + signalPassFailure(); + return; + } + } + // Post-pass cleanup of inplaceable attributes. moduleOp.walk( [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); }); + + OpPassManager cleanupPipeline(OpPassManager("module")); + cleanupPipeline.addPass(createCanonicalizerPass()); + cleanupPipeline.addPass(createCSEPass()); + cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); + (void)runPipeline(cleanupPipeline, moduleOp); } std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -0,0 +1,34 @@ +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file -verify-each=0 + +// ----- + +func @scf_for_with_tensor.insert_slice( + %A : tensor, %B : tensor, %C : tensor<4xf32>, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + %ttA = tensor.insert_slice %C into %tA[0][4][1] : tensor<4xf32> into tensor + %ttB = tensor.insert_slice %C into %tB[0][4][1] : tensor<4xf32> into tensor + scf.yield %ttA, %ttB : tensor, tensor + } + + // Swaparoo requires bufferizing the whole function to figure out who's who. + return %r0#1, %r0#0: tensor, tensor +} + +func @bar( + %A : tensor {linalg.inplaceable = true}, + %B : tensor {linalg.inplaceable = true}, + %C : tensor<4xf32> {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + %r0:2 = call @scf_for_with_tensor.insert_slice(%A, %B, %C, %lb, %ub, %step) : + (tensor, tensor, tensor<4xf32>, index, index, index) + -> (tensor, tensor) + + return %r0#0, %r0#1: tensor, tensor +}