diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp @@ -114,7 +114,9 @@ #include "mlir/Dialect/Vector/VectorOps.h" #include "mlir/IR/Operation.h" #include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "mlir/Transforms/BufferUtils.h" +#include "mlir/Transforms/Passes.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/EquivalenceClasses.h" @@ -136,6 +138,8 @@ // Generic helpers. //===----------------------------------------------------------------------===// +static bool isaTensor(Type t) { return t.isa(); }; + /// Return the FuncOp called by `callOp`. static FuncOp getCalledFunction(CallOpInterface callOp) { SymbolRefAttr sym = callOp.getCallableForCallee().dyn_cast(); @@ -145,6 +149,20 @@ SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } +/// Return the unique ReturnOp that terminates `funcOp`. +/// Return nullptr if there is no such unique ReturnOp. +static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { + ReturnOp returnOp; + for (Block &b : funcOp.body()) { + if (auto candidateOp = dyn_cast(b.getTerminator())) { + if (returnOp) + return nullptr; + returnOp = candidateOp; + } + } + return returnOp; +} + //===----------------------------------------------------------------------===// // Bufferization-specific BlockAndValueMapping support with debugging. //===----------------------------------------------------------------------===// @@ -163,7 +181,7 @@ } /// Wrapper for better debugging. -static Value lookup(BlockAndValueMapping &bvm, Value key) { +static Value lookup(const BlockAndValueMapping &bvm, Value key) { // TODO: if key comes from bbArg, forward. assert(key.getType().isa()); Value v = bvm.lookupOrNull(key); @@ -338,10 +356,8 @@ VectorTransferOpInterface, scf::YieldOp>(op) // clang-format on - || (none_of(op->getResultTypes(), - [](Type t) { return t.isa(); }) && - none_of(op->getOperandTypes(), - [](Type t) { return t.isa(); })); + || (none_of(op->getResultTypes(), isaTensor) && + none_of(op->getOperandTypes(), isaTensor)); } /// Return the OpResult that may bufferize into the same buffer as `opOperand` @@ -568,14 +584,22 @@ /// beginning the alias and equivalence sets only contain `v` itself. void createAliasInfoEntry(Value v); + /// Insert an info entry for `newValue` and merge its alias set with that of + /// `alias`. + void insertNewBufferAlias(Value newValue, Value alias); + + /// Insert an info entry for `newValue` and merge its alias set with that of + /// `alias`. Additionally, merge their equivalence classes. + void insertNewBufferEquivalence(Value newValue, Value alias); + /// Return true if the buffer to which `operand` would bufferize aliases a /// buffer that is known to not be writeable. This implies that the matching /// OpResult cannot be bufferized inplace. bool aliasesNonWriteableBuffer(OpOperand &operand) const; /// Return true if the buffer to which `operand` would bufferize is equivalent - /// to some use that would bufferize to a write to a buffer. - bool aliasesInPlaceWrite(ExtractSliceOp extractSliceOp) const; + /// to some buffer write. + bool aliasesInPlaceWrite(Value v) const; /// Set the inPlace bufferization spec to true. /// Merge result's and operand's aliasing sets and iterate to a fixed point. @@ -610,6 +634,9 @@ bool isSourceEquivalentToAMatchingExtractSliceOp( InsertSliceOp insertSliceOp) const; + /// Apply `fun` to all the members of the equivalence class of `v`. + void applyOnEquivalenceClass(Value v, function_ref fun) const; + /// Print to `os`. void print(raw_ostream &os) const; @@ -617,8 +644,9 @@ void dump() const { print(llvm::errs()); } private: - /// Check aliasInfo for `v` exists and return a reference to it. + /// Check that aliasInfo for `v` exists and return a reference to it. DenseSet &getAliasInfoRef(Value v); + const DenseSet &getAliasInfoRef(Value v) const { return const_cast(this)->getAliasInfoRef(v); } @@ -731,6 +759,29 @@ equivalentInfo.insert(v); } +/// Insert an info entry for `newValue` and merge its alias set with that of +/// `alias`. +void BufferizationAliasInfo::insertNewBufferAlias(Value newValue, Value alias) { + assert(aliasInfo.find(alias) != aliasInfo.end() && "Missing alias entry"); + createAliasInfoEntry(newValue); + mergeAliases(newValue, alias); + mergeAliasesToFixedPoint(); + if (auto memrefCastOp = newValue.getDefiningOp()) + if (aliasInfo.find(memrefCastOp.source()) == aliasInfo.end()) + insertNewBufferAlias(memrefCastOp.source(), newValue); +} + +/// Insert an info entry for `newValue` and merge its alias set with that of +/// `alias`. Additionally, merge their equivalence classes. +void BufferizationAliasInfo::insertNewBufferEquivalence(Value newValue, + Value alias) { + insertNewBufferAlias(newValue, alias); + equivalentInfo.unionSets(newValue, alias); + if (auto memrefCastOp = newValue.getDefiningOp()) + if (aliasInfo.find(memrefCastOp.source()) == aliasInfo.end()) + insertNewBufferEquivalence(memrefCastOp.source(), newValue); +} + /// Return true if the buffer to which `operand` would bufferize aliases a /// buffer that is known to not be writeable. This implies that the matching /// OpResult cannot be bufferized inplace. @@ -746,13 +797,13 @@ LDBG("-----------bbArg is writeable -> skip: " << bbArg << '\n'); continue; } - LDBG("-----------notWriteable: " << v << '\n'); + LDBG("-----------notWriteable\n"); return true; } if (Operation *op = v.getDefiningOp()) { if (isa(op) || !hasKnownBufferizationAliasingBehavior(op)) { - LDBG("-----------notWriteable: " << v << '\n'); + LDBG("-----------notWriteable\n"); return true; } } @@ -762,12 +813,11 @@ } /// Return true if the buffer to which `operand` would bufferize is equivalent -/// to some use that would bufferize to a write to a buffer. -bool BufferizationAliasInfo::aliasesInPlaceWrite( - ExtractSliceOp extractSliceOp) const { +/// to some buffer write. +bool BufferizationAliasInfo::aliasesInPlaceWrite(Value value) const { LDBG("----Start aliasesInPlaceWrite\n"); - LDBG("-------for op: " << *extractSliceOp.getOperation() << '\n'); - for (Value v : getAliasInfoRef(extractSliceOp.result())) { + LDBG("-------for : " << value << '\n'); + for (Value v : getAliasInfoRef(value)) { for (auto &use : v.getUses()) { if (bufferizesToMemoryWrite(use, InPlaceSpec::True)) { LDBG("-----------wants to bufferize to inPlace write: " @@ -776,7 +826,7 @@ } } } - LDBG("----------->extract_slice does not alias an inplace write"); + LDBG("----------->does not alias an inplace write\n"); return false; } @@ -799,10 +849,6 @@ setInPlaceOpResult(result, InPlaceSpec::False); } -static bool isTheSSADef(OpOperand &write, OpOperand &read) { - return getInplaceableOpResult(write) == read.get(); -} - /// Return true if it is possible to find an inplace write W among the uses of /// aliasInfo[result], and a read R among the uses of aliasInfo[result], /// such that W and R interfere. @@ -915,6 +961,16 @@ return false; } +/// Apply `fun` to all the members of the equivalence class of `v`. +void BufferizationAliasInfo::applyOnEquivalenceClass( + Value v, function_ref fun) const { + for (auto it = equivalentInfo.findLeader(v), + eit = equivalentInfo.member_end(); + it != eit; ++it) { + fun(v); + } +} + void BufferizationAliasInfo::print(raw_ostream &os) const { os << "\n/========================== AliasInfo " "==========================\n"; @@ -1101,6 +1157,21 @@ return existsInterleavedValueClobber(aliasingRead, aliasingWrite, domInfo); } +//===----------------------------------------------------------------------===// +// Forward declarations. +//===----------------------------------------------------------------------===// + +/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such +/// an op. Return null otherwise. +static Operation *getEquivalentAlloc(Value value, + const BufferizationAliasInfo &aliasInfo); + +/// Return the first argument of the enclosing FuncOp that is equivalent to `v`. +/// Return null if no such bbArg can be found. +static BlockArgument +getEquivalentEnclosingFuncBBArg(Value v, + const BufferizationAliasInfo &aliasInfo); + //===----------------------------------------------------------------------===// // Bufferization-specific MemRefType support. //===----------------------------------------------------------------------===// @@ -1147,6 +1218,48 @@ stridedLayout, addressSpace); } +/// Return the FunctionType with `argumentTypes` and `resultTypes` where each +/// tensor is replaced by the corresponding buffer type. +/// In order for all the callers to agree, this *must* bufferize to the most +/// dynamic buffer type supported. +/// A later pass across all CallOps in the module can decide whether to simplify +/// the types of to version according to some cost model. +static FunctionType getBufferizedFunctionType(MLIRContext *ctx, + TypeRange argumentTypes, + TypeRange resultTypes) { + auto rewrite = [](Type t) -> Type { + // TODO: non-zero address space. + // TODO: layout information if relevant. + if (auto rankedTensorType = t.dyn_cast()) + return getDynamicMemRefType(rankedTensorType); + if (auto tensorType = t.dyn_cast()) + return getContiguousOrUnrankedMemRefType(tensorType); + return t; + }; + auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite)); + auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite)); + return FunctionType::get(ctx, argTypes, retTypes); +}; + +/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return +/// it. Otherwise, construct a new entry based on `argumentTypes` and +/// `resultTypes`. +// TODO: improve the layering. +static FunctionType getOrCreateBufferizedFunctionType( + FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes, + DenseMap &bufferizedFunctionTypes) { + auto it = bufferizedFunctionTypes.find(funcOp); + if (it != bufferizedFunctionTypes.end()) + return it->second; + + auto it2 = bufferizedFunctionTypes.try_emplace( + funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes, + resultTypes)); + assert(it2.second && "try_emplace failed"); + LDBG("FT: " << funcOp.getType() << " -> " << it2.first->second << "\n"); + return it2.first->second; +} + //===----------------------------------------------------------------------===// // Bufferization-specific scoped alloc/dealloc insertion support. //===----------------------------------------------------------------------===// @@ -1154,8 +1267,10 @@ /// Create an Allocop/DeAllocOp pair, where the AllocOp is after /// `shapedValue.getDefiningOp` (or at the top of the block in case of a /// bbArg) and the DeallocOp is at the end of the block. -static Value createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc, - Value shapedValue) { +static Value +createNewAllocDeallocPairForShapedValue(OpBuilder &b, Location loc, + Value shapedValue, + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1185,9 +1300,12 @@ b.create(loc, shapedValue, dim.index())); Value allocated = b.create(loc, allocMemRefType, dynShape); + aliasInfo.createAliasInfoEntry(allocated); Value casted = allocated; - if (memRefType != allocMemRefType) + if (memRefType != allocMemRefType) { casted = b.create(loc, memRefType, allocated); + aliasInfo.insertNewBufferEquivalence(casted, allocated); + } b.setInsertionPoint(allocated.getParentBlock()->getTerminator()); b.create(loc, allocated); return casted; @@ -1208,7 +1326,8 @@ static LogicalResult allocateBuffersForResults(OpBuilder &b, Location loc, LinalgOp op, SmallVectorImpl &resultBuffers, - BlockAndValueMapping &bvm) { + BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1232,7 +1351,8 @@ // Otherwise, `op` is not inplaceable and we need to allocate its result. Value dimTensor = bvm.lookupOrDefault(output); - Value alloc = createNewAllocDeallocPairForShapedValue(b, loc, dimTensor); + Value alloc = + createNewAllocDeallocPairForShapedValue(b, loc, dimTensor, aliasInfo); b.setInsertionPointAfter(alloc.getDefiningOp()); resultBuffers.push_back(alloc); @@ -1254,7 +1374,7 @@ /// Generic conversion for any LinalgOp on tensors. static LogicalResult bufferize(OpBuilder &b, LinalgOp op, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -1263,8 +1383,6 @@ if (!op.hasTensorSemantics()) return failure(); - LDBG("bufferize: " << *op << '\n'); - b.setInsertionPoint(op); Location loc = op.getLoc(); SmallVector newInputBuffers; @@ -1280,7 +1398,8 @@ } SmallVector newOutputBuffers; // Try to allocate new buffers depending on op's inplace semantics. - if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm))) + if (failed(allocateBuffersForResults(b, loc, op, newOutputBuffers, bvm, + aliasInfo))) return failure(); // Clone the newly bufferized op. @@ -1299,11 +1418,150 @@ return success(); } +/// In a first approximation, all the function arguments of a FuncOp are marked +/// inplaceable. For now, it is the responsibility of the `callOp` bufferization +/// to allow FuncOp that are inplaceable to write inPlace. +static LogicalResult +bufferize(OpBuilder &b, CallOpInterface callOp, BlockAndValueMapping &bvm, + BufferizationAliasInfo &aliasInfo, + DenseMap &bufferizedFunctionTypes) { + FuncOp funcOp = getCalledFunction(callOp); + if (!funcOp) + return success(); + + // Take a guard before anything else. + OpBuilder::InsertionGuard g(b); + b.setInsertionPoint(callOp); + + // Only support CallOp for now. + if (!isa(callOp.getOperation())) + return callOp->emitError() << "expected a CallOp"; + + // 1. Filter return types: + // - if the callee is bodiless / external, we cannot inspect it and we + // cannot assume anything. We can just assert that it does not return a + // tensor as this would have to bufferize to "return a memref", whose + // semantics is ill-defined. + // - if the callee has a body, we perform inter-procedural equivalence + // analysis. When successful, a result folds onto an operand. When + // unsuccessful, additional work is needed to either: + // * hoist a result into an inplaceable operand or + // * devise a better representation to truly return a buffer. + SmallVector resultTypes; + SmallVector hoistedArguments; + if (funcOp.body().empty()) { + if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) + return callOp->emitError() << "Bodiless callee cannot return a tensor"; + } else { + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + if (!returnOp) + return funcOp->emitError() << "no unique ReturnOp"; + + // For each FuncOp result, keep track of which inplace argument it reuses. + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Type returnType = returnOperand.get().getType(); + if (!isaTensor(returnType)) { + resultTypes.push_back(returnType); + continue; + } + + // If return operand is equivalent to some bbArg, no need to return it. + Value returnVal = returnOperand.get(); + if (BlockArgument bbArg = + getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) { + Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); + int64_t idx = bbArg.getArgNumber(); + Value buffer = bvm.lookupOrNull(callOp->getOperand(idx)); + if (!buffer) + return callOp->emitError() << "operand #" << idx << " not bufferized"; + // Add CallOp operand/result equivalence: this is interprocedural info. + aliasInfo.insertNewBufferEquivalence(oldRes, buffer); + map(bvm, oldRes, buffer); + // Add a TensorLoadOp to kill all uses of the CallOp return. + // Replace all uses of the CallOp results so we can erase the CallOp. + // This TensorLoadOp must fold/DCE away or bufferization should be + // considered failed. + Value tensorLoad = + b.create(callOp.getLoc(), buffer); + oldRes.replaceAllUsesWith(tensorLoad); + // Add new op equivalence info. + aliasInfo.insertNewBufferEquivalence(tensorLoad, buffer); + map(bvm, tensorLoad, buffer); + continue; + } + + // TODO: Need to hoist above function boundary and add to + // `hoistedArgumentTypes`. + if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) + return allocOp->emitError() + << " needs hoist across function boundary\n"; + + // Other cases legitimately need to return a tensor, this is currently not + // supported. For instance, if hoisting across function boundary has + // failed, it may be due to e.g. data-dependent sizes. In such a case, we + // would we need a better type than memref. + resultTypes.push_back(returnType); + + int64_t returnIdx = returnOperand.getOperandNumber(); + return returnOp->emitError() + << " bufferize result #" << returnIdx << "\n"; + } + } + + // 2. Compute bufferized FunctionType. + SmallVector argumentTypes{callOp->getOperandTypes()}; + llvm::append_range(argumentTypes, ValueRange{hoistedArguments}.getTypes()); + // Get the bufferized FunctionType for funcOp or construct it if not yet + // available. + FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( + funcOp, argumentTypes, resultTypes, bufferizedFunctionTypes); + + // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. + SmallVector newOperands; + newOperands.reserve(callOp->getNumOperands()); + for (OpOperand &opOperand : callOp->getOpOperands()) { + Value tensorOperand = opOperand.get(); + // Non-tensor operands are just copied. + if (!tensorOperand.getType().isa()) { + newOperands.push_back(tensorOperand); + continue; + } + + // Tensor operands are guaranteed to have been buferized. + int64_t idx = opOperand.getOperandNumber(); + Value buffer = bvm.lookupOrNull(tensorOperand); + assert(buffer && " missing buffer for operand"); + + // Caller / callee type mistmatch is handled with a CastOp. + auto memRefType = bufferizedFuncType.getInput(idx); + // Since we don't yet have a clear layout story, buffer_cast may + // conservatively turn tensors into more dynamic memref than necessary. + // If the memref type of the callee fails, introduce an extra memref.cast + // that will either canonicalize away or fail compilation until we can do + // something better. + if (buffer.getType() != memRefType) { + Value castBuffer = + b.create(callOp.getLoc(), memRefType, buffer); + // Add new op equivalence info. + aliasInfo.insertNewBufferEquivalence(castBuffer, buffer); + map(bvm, tensorOperand, castBuffer); + buffer = castBuffer; + } + newOperands.push_back(buffer); + } + + // 4. Create the new CallOp. + Operation *newCallOp = b.create(callOp.getLoc(), funcOp.sym_name(), + resultTypes, newOperands); + newCallOp->setAttrs(callOp->getAttrs()); + return success(); +} + /// DimOp tensor operand is modified inplace. This allows leaving dead /// tensors behind that will get DCE'd. static LogicalResult bufferize(OpBuilder &b, memref::DimOp dimOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { if (dimOp.memrefOrTensor().getType().isa()) { Value v = lookup(bvm, dimOp.memrefOrTensor()); if (!v) @@ -1315,13 +1573,11 @@ static LogicalResult bufferize(OpBuilder &b, scf::ForOp forOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); Location loc = forOp.getLoc(); - LLVM_DEBUG(DBGS() << "bufferize: " << *forOp << "\n"); - // If inPlace, just forward the buffer. // Otherwise alloc and copy. b.setInsertionPoint(forOp); @@ -1335,11 +1591,12 @@ Value operandBuffer = lookup(bvm, operand); Value resultBuffer = operandBuffer; if (getInPlace(opResult) != InPlaceSpec::True) { - resultBuffer = createNewAllocDeallocPairForShapedValue(b, loc, operand); + resultBuffer = + createNewAllocDeallocPairForShapedValue(b, loc, operand, aliasInfo); // If the tensor comes from `linalg::InitTensorOp`, the value is // unitialized and we do not need to copy. - // TODO: if the matching bbArg does not bufferize to a read is more - // general. + // TODO: "matching bbArg does not bufferize to a read" is a more general + // check. if (!operand.getDefiningOp()) b.create(forOp.getLoc(), operandBuffer, resultBuffer); } @@ -1354,7 +1611,7 @@ /// FuncOp always creates TensorToMemRef ops. static LogicalResult bufferize(OpBuilder &b, FuncOp funcOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPointToStart(&funcOp.body().front()); @@ -1368,9 +1625,10 @@ Type memRefType = rankedTensorType ? getDynamicMemRefType(rankedTensorType) : getContiguousOrUnrankedMemRefType(tensorType); - Value tensorToMemref = + Value bufferCast = b.create(funcOp.getLoc(), memRefType, bbArg); - map(bvm, bbArg, tensorToMemref); + aliasInfo.insertNewBufferEquivalence(bufferCast, bbArg); + map(bvm, bbArg, bufferCast); } return success(); } @@ -1378,7 +1636,7 @@ /// ReturnOp always creates memref::TensorLoadOp. static LogicalResult bufferize(OpBuilder &b, ReturnOp returnOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(returnOp); @@ -1392,7 +1650,10 @@ Value v = lookup(bvm, operand.get()); if (!v) return failure(); - operand.set(b.create(returnOp.getLoc(), v)); + Value returnTensor = b.create(returnOp.getLoc(), v); + operand.set(returnTensor); + aliasInfo.insertNewBufferEquivalence(returnTensor, v); + map(bvm, returnTensor, v); } return success(); } @@ -1404,7 +1665,7 @@ /// isolation. static LogicalResult bufferize(OpBuilder &b, ExtractSliceOp extractSliceOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { LDBG("bufferize: " << *extractSliceOp << '\n'); // Take a guard before anything else. @@ -1424,8 +1685,8 @@ Value alloc; auto inPlace = getInPlace(extractSliceOp->getResult(0)); if (inPlace != InPlaceSpec::True) { - alloc = createNewAllocDeallocPairForShapedValue(b, loc, - extractSliceOp.result()); + alloc = createNewAllocDeallocPairForShapedValue( + b, loc, extractSliceOp.result(), aliasInfo); b.setInsertionPointAfter(alloc.getDefiningOp()); } @@ -1439,6 +1700,8 @@ Value subView = b.create( loc, subviewMemRefType, srcMemref, extractSliceOp.getMixedOffsets(), extractSliceOp.getMixedSizes(), extractSliceOp.getMixedStrides()); + // Insert new alias. + aliasInfo.insertNewBufferAlias(subView, srcMemref); /// If not inplaceable, copy. if (alloc) { @@ -1452,7 +1715,7 @@ static LogicalResult bufferize(OpBuilder &b, InsertSliceOp insertSliceOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { LDBG("bufferize: " << *insertSliceOp << '\n'); // Take a guard before anything else. @@ -1470,8 +1733,8 @@ // cloning the whole tensor on every single iteration and is a symptom // of a catastrophically bad scheduling decision. // TODO: be very loud about it or even consider failing the pass. - Value newDstMemref = - createNewAllocDeallocPairForShapedValue(b, loc, insertSliceOp.result()); + Value newDstMemref = createNewAllocDeallocPairForShapedValue( + b, loc, insertSliceOp.result(), aliasInfo); b.setInsertionPointAfter(newDstMemref.getDefiningOp()); b.create(insertSliceOp.getLoc(), dstMemref, newDstMemref); dstMemref = newDstMemref; @@ -1501,6 +1764,8 @@ Value subView = b.create( loc, subviewMemRefType, dstMemref, insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); + // Insert new alias. + aliasInfo.insertNewBufferAlias(subView, dstMemref); b.create(insertSliceOp.getLoc(), srcMemref, subView); } @@ -1511,7 +1776,7 @@ static LogicalResult bufferize(OpBuilder &b, VectorTransferOpInterface op, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(op); @@ -1520,8 +1785,6 @@ if (op.getShapedType().isa()) return failure(); - LDBG("bufferize: " << *op << '\n'); - /// transfer_read from buffer always reads from the bufferized /// op.source(). if (auto readOp = dyn_cast(op.getOperation())) { @@ -1538,8 +1801,8 @@ // If transfer_write is not inPlace, allocate a new buffer. Value newInputBuffer; if (inPlace != InPlaceSpec::True) { - newInputBuffer = - createNewAllocDeallocPairForShapedValue(b, loc, writeOp.result()); + newInputBuffer = createNewAllocDeallocPairForShapedValue( + b, loc, writeOp.result(), aliasInfo); b.setInsertionPointAfter(newInputBuffer.getDefiningOp()); map(bvm, writeOp.result(), newInputBuffer); } else { @@ -1565,7 +1828,7 @@ static LogicalResult bufferize(OpBuilder &b, scf::YieldOp yieldOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { + BufferizationAliasInfo &aliasInfo) { // Take a guard before anything else. OpBuilder::InsertionGuard g(b); b.setInsertionPoint(yieldOp); @@ -1616,7 +1879,7 @@ // If `extractSliceOp` were to be bufferized inplace, it cannot end up // aliasing a write into a non-writeable buffer. bool wouldCreateAliasingWriteToNonWriteableBuffer = - aliasInfo.aliasesInPlaceWrite(extractSliceOp) && + aliasInfo.aliasesInPlaceWrite(extractSliceOp.result()) && aliasInfo.aliasesNonWriteableBuffer(extractSliceOp->getOpOperand(0)); if (wouldCreateAliasingWriteToNonWriteableBuffer) @@ -1740,7 +2003,6 @@ return extractSliceOps.push_back(extractSliceOp); if (auto insertSliceOp = dyn_cast(op)) return insertSliceOps.push_back(insertSliceOp); - auto isaTensor = [](Type t) { return t.isa(); }; // No tensors => no buffers. if (none_of(op->getOperandTypes(), isaTensor) && none_of(op->getResultTypes(), isaTensor)) @@ -1789,12 +2051,12 @@ } //===----------------------------------------------------------------------===// -// Bufferization entry-point. +// Bufferization entry-point for functions. //===----------------------------------------------------------------------===// -static LogicalResult -bufferizeFuncOpInternals(FuncOp funcOp, BlockAndValueMapping &bvm, - const BufferizationAliasInfo &aliasInfo) { +static LogicalResult bufferizeFuncOpInternals( + FuncOp funcOp, BlockAndValueMapping &bvm, BufferizationAliasInfo &aliasInfo, + DenseMap &bufferizedFunctionTypes) { LLVM_DEBUG(llvm::dbgs() << "\n\n"); LDBG("Begin BufferizeFuncOpInternals:\n" << funcOp << '\n'); OpBuilder b(funcOp->getContext()); @@ -1802,42 +2064,54 @@ if (failed(bufferize(b, funcOp, bvm, aliasInfo))) return failure(); // Walk in PreOrder to ensure ops with regions are handled before their body. - WalkResult result = funcOp.walk([&](Operation *op) { - LogicalResult status = - TypeSwitch(op) - // Skip BufferCast and TensorLoad ops. - // clang-format off - .Case( - [&](auto) { return success(); }) - .Case( - [&](auto op) { - LDBG("Begin buferize:\n" << op << '\n'); - return bufferize(b, op, bvm, aliasInfo); - }) - // clang-format on - .Default([&](Operation *op) { - auto isaTensor = [](Type t) { return t.isa(); }; - if (any_of(op->getOperandTypes(), isaTensor) || - any_of(op->getResultTypes(), isaTensor)) - return failure(); - return success(); - }); - if (failed(status)) { - op->emitError("Failed bufferization"); - return WalkResult::interrupt(); - } - return WalkResult::advance(); + // Since walk has to be PreOrder, we need to erase ops that require it + // separately: this is the case for CallOp + SmallVector toErase; + WalkResult result = funcOp.walk([&](Operation *op) + -> WalkResult { + // clang-format off + WalkResult result = + TypeSwitch(op) + // Skip BufferCast and TensorLoad ops. + .Case([&](auto) { return success(); }) + .Case([&](auto op) { + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo); + }) + .Case([&](CallOpInterface op) { + LDBG("Begin bufferize:\n" << op << '\n'); + return bufferize(b, op, bvm, aliasInfo, bufferizedFunctionTypes); + }) + .Default([&](Operation *op) { + auto isaTensor = [](Type t) { return t.isa(); }; + if (any_of(op->getOperandTypes(), isaTensor) || + any_of(op->getResultTypes(), isaTensor)) + return failure(); + return success(); + }); + // clang-format on + + // Register post-walk erasure, if necessary. + if (isa(op)) + if (llvm::any_of(op->getOperandTypes(), isaTensor) || + llvm::any_of(op->getResultTypes(), isaTensor)) + toErase.push_back(op); + + return result; }); LDBG("End BufferizeFuncOpInternals:\n" << funcOp << '\n'); + for (Operation *op : toErase) + op->erase(); + return failure(result.wasInterrupted()); } @@ -1871,7 +2145,9 @@ // Bufferization phase. BlockAndValueMapping bvm; - if (failed(bufferizeFuncOpInternals(funcOp, bvm, aliasInfo))) + DenseMap bufferizedFunctionTypes; + if (failed(bufferizeFuncOpInternals(funcOp, bvm, aliasInfo, + bufferizedFunctionTypes))) signalPassFailure(); // Post-pass cleanup of inplaceable attributes. @@ -1886,6 +2162,162 @@ // Bufferization entry-point for modules. //===----------------------------------------------------------------------===// +/// Return the op with Allocate MemoryEffect if `v` is equivalent to an such +/// an op. Return null otherwise. +static Operation *getEquivalentAlloc(Value value, + const BufferizationAliasInfo &aliasInfo) { + Operation *res; + aliasInfo.applyOnEquivalenceClass(value, [&](Value v) { + if (!res) + if (auto interface = + dyn_cast_or_null(v.getDefiningOp())) + if (auto effect = + interface.getEffectOnValue(value)) + res = v.getDefiningOp(); + }); + return res; +} + +/// Return the first argument of the enclosing FuncOp that is equivalent to `v`. +/// Return null if no such bbArg can be found. +static BlockArgument +getEquivalentEnclosingFuncBBArg(Value v, + const BufferizationAliasInfo &aliasInfo) { + Operation *op = v.getParentBlock()->getParentOp(); + FuncOp funcOp = llvm::dyn_cast(op); + if (!funcOp) + funcOp = op->getParentOfType(); + for (BlockArgument bbArg : funcOp.getArguments()) + if (aliasInfo.areEquivalentBufferizedValues(v, bbArg)) + return bbArg; + return nullptr; +} + +/// Rewrite the `funcOp` arguments analysis return values and terminator into +/// buffer form (using the canonical memref layout for now), according to the +/// inPlace-bufferizable information of the function arguments. +/// This relies on a buffer equivalence analysis of each return operand. When a +/// result buffer is equivalent to: +/// 1. a BlockArgument of `funcOp`, it can be dropped from the return values +/// and becomes inplaceable at all callers. This assumes all CallOp perform +/// the necessary work to clone operands so as to make them inplaceable. +// Reliance on this logic will need to be relaxed in thefuture. +/// 2. an op with an Alloc effect, this currently fails bufferization but is a +/// candidate for hoisting and creating a new inplace operand at all caller +/// sites. +/// 3. if such a hoisting for 2. is not possible (e.g. data-dependent that +/// prevents hoisting), this is currently unsupported and will require a +/// refcounted buffer type. +static LogicalResult bufferizeFuncOpBoundary( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + DenseMap &bufferizedFunctionTypes) { + LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n"); + + // Get the bufferized FunctionType for funcOp or construct it if not yet + // available. + // TODO: Atm we have 3 cases: + // 1. if a function is called from within the Module, it must have bufferized + // to inplaceable tensor results. + // 2. if it is bodiless, it must have bufferized and is not allowed to have + // result tensors. + // 3. if it is not called internally, it still must bufferize to inplaceable + // tensor results and we construct it now (e.g. top-level function called + // externally). + // -> Figure out a better layering. + TypeRange resultTypes; + FunctionType bufferizedFuncType = + getOrCreateBufferizedFunctionType(funcOp, funcOp.getType().getInputs(), + resultTypes, bufferizedFunctionTypes); + + // Corner case: Bodiless FuncOp + // ============================ + // The body of such functions is assumed opaque and we can't know the + // bufferization contract they want to enforce atm. + // As a consequence, only support functions that don't return any tensor atm. + if (funcOp.getBody().empty()) { + if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) + return funcOp->emitError() << "Cannot bufferize bodiless function that " + << "returns a tensor"; + funcOp.setType(bufferizedFuncType); + LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp); + return success(); + } + + // Support only single return-terminated block in the function. + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + if (!returnOp) + return funcOp->emitError() << "no unique ReturnOp"; + + // 1. For each FuncOp result, keep track of which inplace argument it reuses. + SmallVector returnValues; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + // If return operand is equivalent to some bbArg, no need to return it. + Value returnVal = returnOperand.get(); + if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo)) + continue; + // TODO: Need to hoist above function boundary. If this is not possible due + // to data-depedent sizes, we need a better type than memref. + if (Operation *allocOp = getEquivalentAlloc(returnVal, aliasInfo)) + return allocOp->emitError() << " needs hoist across function boundary\n"; + int64_t returnIdx = returnOperand.getOperandNumber(); + return returnOp->emitError() << " bufferize result #" << returnIdx << "\n"; + } + + // 2. Rewrite the terminator without the inPlace bufferizable values. + OpBuilder(returnOp).create(returnOp.getLoc(), returnValues); + returnOp->erase(); + + // 3. Rewrite the bbArgs. + // Iterate on the original `numArgs` and replace them in order. + // This guarantees the argument order still matches after the rewrite. + Block &frontBlock = funcOp.body().front(); + unsigned numArgs = frontBlock.getNumArguments(); + for (unsigned idx = 0; idx < numArgs; ++idx) { + auto bbArg = frontBlock.getArgument(0); + auto tensorType = bbArg.getType().dyn_cast(); + // Non-tensor types are just forwarded. + if (!tensorType) { + frontBlock.addArgument(bbArg.getType()); + bbArg.replaceAllUsesWith(frontBlock.getArguments().back()); + frontBlock.eraseArgument(0); + continue; + } + + // Get the buffer type from the bufferized function type. + Type memrefType = bufferizedFuncType.getInput(idx); + Value memref = frontBlock.addArgument(memrefType); + OpBuilder b(funcOp->getContext()); + b.setInsertionPointToStart(&frontBlock); + // Replace all uses of bbArg through a BufferCastOp by a memref::CastOp. + for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) { + if (auto bufferCastOp = dyn_cast(use.getOwner())) { + auto castOp = b.create( + funcOp.getLoc(), bufferCastOp.memref().getType(), memref); + bufferCastOp.memref().replaceAllUsesWith(castOp); + aliasInfo.insertNewBufferEquivalence(castOp.dest(), + bufferCastOp.memref()); + continue; + } + } + // Replace all remaining uses by a tensor_load. + if (!bbArg.use_empty()) { + auto tensorLoadOp = + b.create(funcOp.getLoc(), memref); + aliasInfo.insertNewBufferEquivalence(tensorLoadOp, bbArg); + bbArg.replaceAllUsesWith(tensorLoadOp); + } + frontBlock.eraseArgument(0); + // TODO: add support to erase aliasInfo entries if deemed necessary. + } + + // 4. Rewrite the FuncOp type to buffer form. + funcOp.setType(bufferizedFuncType); + + LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary:\n" << funcOp); + + return success(); +} + /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by /// callee-caller order (i.e. callees without callers first). /// Store the map of FuncOp to all its callers in `callerMap`. @@ -1948,6 +2380,7 @@ SmallVector orderedFuncOps; DenseMap> callerMap; + DenseMap bufferizedFunctionTypes; if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return signalPassFailure(); @@ -1980,15 +2413,39 @@ return; } - // TODO: Bufferization phase. + // Bufferization phase. + if (!testAnalysisOnly) { + BlockAndValueMapping tensorToBufferMap; + if (failed(bufferizeFuncOpInternals(funcOp, tensorToBufferMap, aliasInfo, + bufferizedFunctionTypes))) { + signalPassFailure(); + return; + } + } } // Don't drop the attributes if we only want to report the analysis. if (testAnalysisOnly) return; + for (FuncOp funcOp : orderedFuncOps) { + // Note: It would be good to apply cleanups here but we cannot as aliasInfo + // would be invalidated. + if (failed(bufferizeFuncOpBoundary(funcOp, aliasInfo, + bufferizedFunctionTypes))) { + signalPassFailure(); + return; + } + } + // Post-pass cleanup of inplaceable attributes. moduleOp.walk( [&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); }); + + OpPassManager cleanupPipeline(OpPassManager("module")); + cleanupPipeline.addPass(createCanonicalizerPass()); + cleanupPipeline.addPass(createCSEPass()); + cleanupPipeline.addPass(createLoopInvariantCodeMotionPass()); + (void)runPipeline(cleanupPipeline, moduleOp); } std::unique_ptr mlir::createLinalgComprehensiveModuleBufferizePass() { diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir @@ -0,0 +1,62 @@ +// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s + +// ----- + +// CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)> + +// CHECK: func private @some_external_func(memref) +func private @some_external_func(tensor) + +// CHECK: func @scf_for_with_tensor_insert_slice( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref {linalg.inplaceable = true} +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref {linalg.inplaceable = true} +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$DYN_1D_MAP]]> {linalg.inplaceable = true} +func @scf_for_with_tensor_insert_slice( + %A : tensor, %B : tensor, %C : tensor<4xf32>, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ + // CHECK-NEXT: scf.for + %r0:2 = scf.for %i = %lb to %ub step %step iter_args(%tA = %A, %tB = %B) + -> (tensor, tensor) + { + // CHECK-NEXT: %[[SVA:.*]] = memref.subview %[[A]] + // CHECK-NEXT: linalg.copy(%[[C]], %[[SVA]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]> + %ttA = tensor.insert_slice %C into %tA[%i][4][1] : tensor<4xf32> into tensor + + // CHECK-NEXT: %[[SVB:.*]] = memref.subview %[[B]] + // CHECK-NEXT: linalg.copy(%[[C]], %[[SVB]]) : memref<4xf32, #[[$DYN_1D_MAP]]>, memref<4xf32, #[[$DYN_1D_MAP]]> + %ttB = tensor.insert_slice %C into %tB[%i][4][1] : tensor<4xf32> into tensor + + // scf.yield is empty and is elided + // CHECK-NOT: scf.yield + scf.yield %ttA, %ttB : tensor, tensor + } + + // Swaparoo requires bufferizing the whole function to figure out who's who. + return %r0#1, %r0#0: tensor, tensor +} + +// CHECK: func @bar( +// CHECK-SAME: %[[A:[a-zA-Z0-9]*]]: memref {linalg.inplaceable = true} +// CHECK-SAME: %[[B:[a-zA-Z0-9]*]]: memref {linalg.inplaceable = true} +// CHECK-SAME: %[[C:[a-zA-Z0-9]*]]: memref<4xf32, #[[$DYN_1D_MAP]]> {linalg.inplaceable = true} +func @bar( + %A : tensor {linalg.inplaceable = true}, + %B : tensor {linalg.inplaceable = true}, + %C : tensor<4xf32> {linalg.inplaceable = true}, + %lb : index, %ub : index, %step : index) + -> (tensor, tensor) +{ +// CHECK-NEXT: call @scf_for_with_tensor_insert_slice(%[[A]], %[[B]], %[[C]] + %r0:2 = call @scf_for_with_tensor_insert_slice(%A, %B, %C, %lb, %ub, %step) : + (tensor, tensor, tensor<4xf32>, index, index, index) + -> (tensor, tensor) + + // %r0#0 is actually %B after inplaceable results are swapped in the callee. +// CHECK-NEXT: call @some_external_func(%[[B]]) : (memref) -> () + call @some_external_func(%r0#0) : (tensor) -> () + +// CHECK-NEXT: return + return %r0#0, %r0#1: tensor, tensor +}