diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -237,6 +237,12 @@ /// Return true if the given op has a tensor result or a tensor operand. static bool hasTensorSemantics(Operation *op) { + if (auto funcOp = dyn_cast(op)) { + bool hasTensorArg = any_of(funcOp.getArgumentTypes(), isaTensor); + bool hasTensorResult = any_of(funcOp.getResultTypes(), isaTensor); + return hasTensorArg || hasTensorResult; + } + bool hasTensorResult = any_of(op->getResultTypes(), isaTensor); bool hasTensorOperand = any_of(op->getOperandTypes(), isaTensor); return hasTensorResult || hasTensorOperand; @@ -332,6 +338,12 @@ // Bufferize the op and its nested ops. RewritePatternSet patterns(op->getContext()); patterns.add(patterns.getContext(), bufferizationState); + FrozenRewritePatternSet frozenPatterns(std::move(patterns)); + + // Bufferize op without nested ops. + bool erased = false; + if (failed(applyOpPatternsAndFold(op, frozenPatterns, &erased))) + return failure(); // Bufferize ops top-to-bottom. When creating a new op, we should ideally // know the exact memref type of all operands. Otherwise, we have to use a @@ -345,13 +357,15 @@ GreedyRewriteConfig config; config.useTopDownTraversal = true; + // Bufferize all nested ops (unless the op was erased). // TODO: Perform a preorder walk instead of the greedy pattern rewriter. This // would be more efficient because every bufferization pattern is guaranteed // to apply only a single time (otherwise, an assertion would be triggered). // However, there are restrictions wrt. erasing ops during a preorder walk, // which would likely require a larger refactoring. - if (failed(applyPatternsAndFoldGreedily(op, std::move(patterns), config))) - return failure(); + if (!erased) + if (failed(applyPatternsAndFoldGreedily(op, frozenPatterns, config))) + return failure(); if (failed(checkBufferizationResult(op, bufferizationState.getOptions()))) return failure(); diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -6,7 +6,7 @@ // //===----------------------------------------------------------------------===// // -// Module Bufferization is an extension of Comprehensive Bufferize that +// Module Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` // implementations for FuncOp, CallOp and ReturnOp. // @@ -17,7 +17,7 @@ // // After analyzing a FuncOp, additional information about its bbArgs is // gathered through PostAnalysisStepFns and stored in -// `ModuleAnalysisState`. +// `FuncAnalysisState`. // // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs // for @@ -87,13 +87,16 @@ using namespace comprehensive_bufferize; using namespace mlir::bufferization; +/// A mapping of FuncOps to their callers. +using FuncCallerMap = DenseMap>; + namespace { /// The state of analysis of a FuncOp. enum class FuncOpAnalysisState { NotAnalyzed, InProgress, Analyzed }; /// Extra analysis state that is required for bufferization of function /// boundaries. -struct ModuleAnalysisState : public DialectAnalysisState { +struct FuncAnalysisState : public DialectAnalysisState { /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg /// indices. DenseMap> equivalentFuncArgs; @@ -113,35 +116,29 @@ /// Keep track of which FuncOps are fully analyzed or currently being /// analyzed. DenseMap analyzedFuncOps; - - // A list of functions in the order in which they are analyzed + bufferized. - SmallVector orderedFuncOps; - - // A mapping of FuncOps to their callers. - DenseMap> callerMap; }; } // namespace -/// Get ModuleAnalysisState. -static const ModuleAnalysisState & -getModuleAnalysisState(const AnalysisState &state) { - Optional maybeState = - state.getDialectState( +/// Get FuncAnalysisState. +static const FuncAnalysisState & +getFuncAnalysisState(const AnalysisState &state) { + Optional maybeState = + state.getDialectState( func::FuncDialect::getDialectNamespace()); - assert(maybeState.hasValue() && "ModuleAnalysisState does not exist"); + assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); return **maybeState; } -/// Get or create ModuleAnalysisState. -static ModuleAnalysisState &getModuleAnalysisState(AnalysisState &state) { - return state.getOrCreateDialectState( +/// Get or create FuncAnalysisState. +static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { + return state.getOrCreateDialectState( func::FuncDialect::getDialectNamespace()); } /// Return the state (phase) of analysis of the FuncOp. static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp) { - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &moduleState = getFuncAnalysisState(state); auto it = moduleState.analyzedFuncOps.find(funcOp); if (it == moduleState.analyzedFuncOps.end()) return FuncOpAnalysisState::NotAnalyzed; @@ -186,12 +183,12 @@ } /// Store function BlockArguments that are equivalent to/aliasing a returned -/// value in ModuleAnalysisState. +/// value in FuncAnalysisState. static LogicalResult aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { - ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + FuncAnalysisState &moduleState = getFuncAnalysisState(state); // Support only single return-terminated block in the function. auto funcOp = cast(op); @@ -287,7 +284,7 @@ funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { - ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + FuncAnalysisState &moduleState = getFuncAnalysisState(state); auto funcOp = cast(op); // Initialize data structure. @@ -330,8 +327,6 @@ } } // namespace -static bool isaTensor(Type t) { return t.isa(); } - /// If `value` is a memref::CastOp, return its source. Otherwise, return /// `value` directly. static Value getNonCastedValue(Value value) { @@ -359,26 +354,27 @@ SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } -/// Return the FunctionType with `argumentTypes` and `resultTypes` where each -/// tensor is replaced by the corresponding buffer type. -/// In order for all the callers to agree, this *must* bufferize to the most -/// dynamic buffer type supported. -/// A later pass across all CallOps in the module can decide whether to simplify -/// the types of to version according to some cost model. -static FunctionType -getBufferizedFunctionType(MLIRContext *ctx, TypeRange argumentTypes, - TypeRange resultTypes, - const BufferizationOptions &options) { - auto rewrite = [&](Type t) -> Type { - // TODO: non-zero address space. - // TODO: layout information if relevant. - if (auto tensorType = t.dyn_cast()) - return getMemRefType(tensorType, options); - return t; - }; - auto argTypes = llvm::to_vector<4>(llvm::map_range(argumentTypes, rewrite)); - auto retTypes = llvm::to_vector<4>(llvm::map_range(resultTypes, rewrite)); - return FunctionType::get(ctx, argTypes, retTypes); +/// Return the index-th bufferized function argument type. This assumes that the +/// specified argument is a tensor. If the tensor is ranked, a layout map may be +/// specified by the user. If no layout map is specified, a fully dynamic map is +/// used. +static BaseMemRefType +getBufferizedFunctionArgType(FuncOp funcOp, int64_t index, + const BufferizationOptions &options) { + auto tensorType = funcOp.getType().getInput(index).dyn_cast(); + assert(tensorType && "expected TensorType"); + BaseMemRefType memrefType = getMemRefType(tensorType, options); + + auto layoutAttr = funcOp.getArgAttrOfType( + index, BufferizableOpInterface::kBufferLayoutAttrName); + if (!layoutAttr) + return memrefType; + + auto rankedMemrefType = memrefType.dyn_cast(); + assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); + return MemRefType::get(memrefType.getShape(), memrefType.getElementType(), + layoutAttr.getValue(), + memrefType.getMemorySpaceAsInt()); } /// Gather equivalence info of CallOps. @@ -387,7 +383,7 @@ // TODO: This does not handle cyclic function call graphs etc. static void equivalenceAnalysis(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - ModuleAnalysisState &moduleState) { + FuncAnalysisState &moduleState) { funcOp->walk([&](func::CallOp callOp) { FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called FuncOp"); @@ -408,148 +404,22 @@ }); } -/// Rewrite the `funcOp` arguments analysis return values and terminator into -/// buffer form (using the canonical memref layout for now), according to the -/// inPlace-bufferizable information of the function arguments. -/// -/// This relies on a buffer equivalence analysis of each return operand. When a -/// result buffer is equivalent to a BlockArgument of `funcOp`, it can be -/// dropped from the return values and becomes inplaceable at all callers. This -/// assumes all CallOp perform the necessary work to clone operands so as to -/// make them inplaceable. Reliance on this logic will need to be relaxed in the -/// future. -/// -/// Note: Returning a memref currently fails bufferization. If such memrefs -/// originate from an op with an Alloc effect, they could be hoisted in the -/// future. -static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, - RewriterBase &rewriter, - BufferizationState &state) { - const ModuleAnalysisState &moduleState = - getModuleAnalysisState(state.getAnalysisState()); - - // If nothing to do then we are done. - if (!llvm::any_of(funcOp.getType().getInputs(), isaTensor) && - !llvm::any_of(funcOp.getType().getResults(), isaTensor)) - return success(); - - // Get the bufferized FunctionType for funcOp or construct it if not yet - // available. - // TODO: Atm we have 3 cases: - // 1. if a function is called from within the Module, it must have bufferized - // to inplaceable tensor results. - // 2. if it is bodiless, it must have bufferized and is not allowed to have - // result tensors. - // 3. if it is not called internally, it still must bufferize to inplaceable - // tensor results and we construct it now (e.g. top-level function called - // externally). - // -> Figure out a better layering. - TypeRange resultTypes; - - // Corner case: Bodiless FuncOp - // ============================ - // The body of such functions is assumed opaque and we can't know the - // bufferization contract they want to enforce atm. - // As a consequence, only support functions that don't return any tensor atm. - if (funcOp.getBody().empty()) { - if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) - return funcOp->emitError() << "cannot bufferize bodiless function that " - << "returns a tensor"; - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), funcOp.getType().getInputs(), - funcOp.getType().getResults(), state.getOptions()); - funcOp.setType(bufferizedFuncType); - return success(); - } - - // Support only single return-terminated block in the function. - func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - // 1. For each FuncOp result, keep track of which inplace argument it reuses. - SmallVector returnValues; - for (OpOperand &returnOperand : returnOp->getOpOperands()) { - Value returnVal = returnOperand.get(); - - // If not a renturn tensor type just forward it. - if (!returnVal.getType().isa()) { - returnValues.push_back(returnVal); - continue; - } - - // If return operand is equivalent to some bbArg, no need to return it. - auto funcOpIt = moduleState.equivalentFuncArgs.find(funcOp); - if (funcOpIt != moduleState.equivalentFuncArgs.end() && - funcOpIt->second.count(returnOperand.getOperandNumber())) - continue; - - // Cast values at the call site if necessary. - returnValues.push_back( - getNonCastedValue(*state.getBuffer(rewriter, returnOperand))); - } - - // 2. Rewrite the terminator without the inPlace bufferizable values. - ValueRange retValues{returnValues}; - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes(), - state.getOptions()); - OpBuilder b(returnOp); - b.create(returnOp.getLoc(), returnValues); - returnOp->erase(); - - // 3. Rewrite the bbArgs. - // Iterate on the original `numArgs` and replace them in order. - // This guarantees the argument order still matches after the rewrite. - Block &frontBlock = funcOp.getBody().front(); - unsigned numArgs = frontBlock.getNumArguments(); - for (unsigned idx = 0; idx < numArgs; ++idx) { - auto bbArg = frontBlock.getArgument(0); - auto tensorType = bbArg.getType().dyn_cast(); - // Non-tensor types are just forwarded. - if (!tensorType) { - frontBlock.addArgument(bbArg.getType(), bbArg.getLoc()); - bbArg.replaceAllUsesWith(frontBlock.getArguments().back()); - frontBlock.eraseArgument(0); - continue; - } - - // Get the buffer type from the bufferized function type. - Type memrefType = bufferizedFuncType.getInput(idx); - Value memref = frontBlock.addArgument(memrefType, bbArg.getLoc()); - OpBuilder b(funcOp->getContext()); - b.setInsertionPointToStart(&frontBlock); - // Replace all uses of bbArg through a ToMemRefOp. - for (auto &use : llvm::make_early_inc_range(bbArg.getUses())) { - if (auto toMemrefOp = - dyn_cast(use.getOwner())) { - if (memref.getType() != toMemrefOp.memref().getType()) { - // Type has changed, insert a cast. - assert(memref::CastOp::areCastCompatible( - memref.getType(), toMemrefOp.memref().getType()) && - "bufferizeFuncOpBoundary: cast incompatible"); - auto castOp = b.create( - funcOp.getLoc(), toMemrefOp.memref().getType(), memref); - toMemrefOp.memref().replaceAllUsesWith(castOp); - } else { - // Type did not change, replace directly. - toMemrefOp.memref().replaceAllUsesWith(memref); - } - } - } - // Replace all remaining uses by a to_tensor. - if (!bbArg.use_empty()) { - auto toTensorOp = - b.create(funcOp.getLoc(), memref); - bbArg.replaceAllUsesWith(toTensorOp); - } - frontBlock.eraseArgument(0); - // TODO: add support to erase aliasInfo entries if deemed necessary. - } +/// Return the index of the bbArg in the given FuncOp that is equivalent to the +/// specified return value (if any). +static Optional getEquivalentFuncArgIdx(FuncOp funcOp, + const FuncAnalysisState &state, + int64_t returnValIdx) { + auto funcOpIt = state.equivalentFuncArgs.find(funcOp); + if (funcOpIt == state.equivalentFuncArgs.end()) + // No equivalence info stores for funcOp. + return None; - // 4. Rewrite the FuncOp type to buffer form. - funcOp.setType(bufferizedFuncType); + auto retValIt = funcOpIt->getSecond().find(returnValIdx); + if (retValIt == funcOpIt->getSecond().end()) + // Return value has no equivalent bbArg. + return None; - return success(); + return retValIt->getSecond(); } /// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by @@ -560,7 +430,7 @@ static LogicalResult getFuncOpsOrderedByCalls(ModuleOp moduleOp, SmallVectorImpl &orderedFuncOps, - DenseMap> &callerMap) { + FuncCallerMap &callerMap) { // For each FuncOp, the set of functions called by it (i.e. the union of // symbols of all nested CallOpInterfaceOp). DenseMap> calledBy; @@ -609,9 +479,8 @@ return success(); } -static void -foreachCaller(const DenseMap> &callerMap, - FuncOp callee, llvm::function_ref doit) { +static void foreachCaller(const FuncCallerMap &callerMap, FuncOp callee, + llvm::function_ref doit) { auto itCallers = callerMap.find(callee); if (itCallers == callerMap.end()) return; @@ -619,116 +488,11 @@ doit(caller); } -/// Postprocess the linalg.buffer_layout annotation across function boundaries. -/// This is a purely mechanical process that may later become part of a -/// separate pass with its own layout assignment heuristic. -static void layoutPostProcessing(ModuleOp moduleOp) { - SmallVector orderedFuncOps; - DenseMap> callerMap; - auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap); - (void)res; - assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure"); - - for (FuncOp funcOp : orderedFuncOps) { - DenseMap> operandsPerCaller; - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - operandsPerCaller.try_emplace(caller, SmallVector()); - }); - - SmallVector argumentTypes; - // Iterate on each function argument and check it it was marked with a - // desired layout. - for (const auto &it : llvm::enumerate(funcOp.getType().getInputs())) { - int argNumber = it.index(); - Type inputType = it.value(); - auto memrefType = inputType.dyn_cast(); - auto layoutAttr = funcOp.getArgAttrOfType( - argNumber, BufferizableOpInterface::kBufferLayoutAttrName); - AffineMap desiredLayoutMap = - layoutAttr ? layoutAttr.getValue() : AffineMap(); - AffineMap currentLayoutMap = - memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap(); - if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) { - argumentTypes.push_back(inputType); - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - operandsPerCaller.find(caller)->getSecond().push_back( - caller->getOperand(argNumber)); - }); - continue; - } - - // Compute the buffer type with desired layout and add to input argument - // types. - MemRefType desiredMemrefType = MemRefType::get( - memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap); - argumentTypes.push_back(desiredMemrefType); - - // If funcOp's body is not empty, change the bbArg type and propagate. - if (!funcOp.getBody().empty()) { - BlockArgument bbArg = funcOp.getArgument(argNumber); - bbArg.setType(desiredMemrefType); - OpBuilder b(bbArg.getContext()); - b.setInsertionPointToStart(bbArg.getOwner()); - assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) && - "layoutPostProcessing: cast incompatible"); - // Cast back to the original memrefType and let it canonicalize. - Value cast = - b.create(funcOp.getLoc(), memrefType, bbArg); - bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp()); - } - - // Cast to desired buffer type on all callers to `funcOp`. - // TODO: on the callee side, this may even have to trigger a copy to - // change the layout. For now let the memref::CastOp fail to verify in - // such cases. - auto castArg = [&](Operation *caller) { - OpBuilder b(caller); - assert( - memref::CastOp::areCastCompatible( - caller->getOperand(argNumber).getType(), desiredMemrefType) && - "layoutPostProcessing.2: cast incompatible"); - Value newOperand = b.create( - funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber)); - operandsPerCaller.find(caller)->getSecond().push_back(newOperand); - }; - foreachCaller(callerMap, funcOp, castArg); - } - - // Set operands with cast buffer on all callers to `funcOp`. - foreachCaller(callerMap, funcOp, [&](Operation *caller) { - caller->setOperands(operandsPerCaller.lookup(caller)); - }); - - // Finally set the funcOp type to update the arguments. - auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes, - funcOp.getType().getResults()); - funcOp.setType(newFuncType); - } -} - namespace mlir { namespace linalg { namespace comprehensive_bufferize { namespace std_ext { -/// Return the index of the bbArg in the given FuncOp that is equivalent to the -/// specified return value (if any). -static Optional -getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleAnalysisState &state, - int64_t returnValIdx) { - auto funcOpIt = state.equivalentFuncArgs.find(funcOp); - if (funcOpIt == state.equivalentFuncArgs.end()) - // No equivalence info stores for funcOp. - return None; - - auto retValIt = funcOpIt->getSecond().find(returnValIdx); - if (retValIt == funcOpIt->getSecond().end()) - // Return value has no equivalent bbArg. - return None; - - return retValIt->getSecond(); -} - struct CallOpInterface : public BufferizableOpInterface::ExternalModel { @@ -738,7 +502,7 @@ FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &moduleState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is read. return true; @@ -755,7 +519,7 @@ FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &moduleState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is written. return true; @@ -771,7 +535,7 @@ func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &moduleState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) { // FuncOp not analyzed yet. Any OpResult may be aliasing. @@ -803,7 +567,7 @@ func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &moduleState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) { // FuncOp not analyzed yet. Any OpOperand may be aliasing. @@ -833,9 +597,8 @@ return BufferRelation::Equivalent; } - /// In a first approximation, all the function arguments of a FuncOp are - /// marked inplaceable. For now, it is the responsibility of the `callOp` - /// bufferization to allow FuncOp that are inplaceable to write inPlace. + /// All function arguments are writable. It is the responsibility of the + /// CallOp to insert buffer copies where necessary. LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { func::CallOp callOp = cast(op); @@ -843,8 +606,8 @@ unsigned numOperands = callOp->getNumOperands(); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = - getModuleAnalysisState(state.getAnalysisState()); + const FuncAnalysisState &moduleState = + getFuncAnalysisState(state.getAnalysisState()); const OneShotBufferizationOptions &options = static_cast(state.getOptions()); @@ -878,7 +641,7 @@ for (const auto &it : llvm::enumerate(callOp.getResultTypes())) { unsigned returnValIdx = it.index(); Type returnType = it.value(); - if (!isaTensor(returnType)) { + if (!returnType.isa()) { // Non-tensor values are returned. retValMapping[returnValIdx] = resultTypes.size(); resultTypes.push_back(returnType); @@ -910,11 +673,7 @@ } // 2. Compute bufferized FunctionType. - SmallVector argumentTypes{callOp->getOperandTypes()}; - // Get the bufferized FunctionType for funcOp or construct it if not yet - // available. - FunctionType bufferizedFuncType = getBufferizedFunctionType( - funcOp.getContext(), argumentTypes, resultTypes, options); + FunctionType bufferizedFuncType = funcOp.getType(); // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. for (OpOperand &opOperand : callOp->getOpOperands()) { @@ -999,6 +758,8 @@ assert(isa(returnOp->getParentOp()) && "only support FuncOp parent for ReturnOp"); #endif // NDEBUG + + // ReturnOps are bufferized as part of FuncOps. return failure(); } }; @@ -1007,7 +768,126 @@ : public BufferizableOpInterface::ExternalModel { LogicalResult bufferize(Operation *op, RewriterBase &rewriter, BufferizationState &state) const { - return failure(); + // Rewrite function bbArgs and return values into buffer form (using the + // canonical memref layout for now). + // + // This relies on a buffer equivalence analysis of each return operand. When + // a result buffer is equivalent to a function bbArg, it is dropped from the + // return values and becomes inplaceable at all callers. + // + // All function bbArgs are writable unless they are explicitly marked as + // read-only. Callers must insert copies when needed. + // + // Note: Returning a memref is possible, but corresponding CallOp + // bufferizations fail unless `allowReturnAllocs`. + auto funcOp = cast(op); + const FuncAnalysisState &moduleState = + getFuncAnalysisState(state.getAnalysisState()); + const BufferizationOptions &options = state.getOptions(); + + // Construct the bufferized function type. + SmallVector argTypes; + for (const auto &it : llvm::enumerate(funcOp.getType().getInputs())) { + Type argType = it.value(); + if (auto tensorType = argType.dyn_cast()) { + argTypes.push_back( + getBufferizedFunctionArgType(funcOp, it.index(), options)); + continue; + } + argTypes.push_back(argType); + } + + // Bodiless functions are assumed opaque and we cannot know the + // bufferization contract they want to enforce. As a consequence, only + // support functions that don't return any tensors atm. + if (funcOp.getBody().empty()) { + FunctionType funcType = funcOp.getType(); + SmallVector retTypes; + for (Type resultType : funcType.getResults()) { + if (resultType.isa()) + return funcOp->emitError() << "cannot bufferize bodiless function " + << "that returns a tensor"; + retTypes.push_back(resultType); + } + funcOp.setType(FunctionType::get(op->getContext(), argTypes, retTypes)); + return success(); + } + + // TODO: Support functions with multiple returns. + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + // 1. Rewrite the bbArgs. Turn every tensor bbArg into a memref bbArg. + Block &frontBlock = funcOp.getBody().front(); + for (BlockArgument &bbArg : frontBlock.getArguments()) { + auto tensorType = bbArg.getType().dyn_cast(); + // Non-tensor types stay the same. + if (!tensorType) + continue; + + // Collect all uses of the bbArg. + SmallVector bbArgUses; + for (OpOperand &use : bbArg.getUses()) + bbArgUses.push_back(&use); + + // Change the bbArg type to memref. + Type memrefType = + getBufferizedFunctionArgType(funcOp, bbArg.getArgNumber(), options); + bbArg.setType(memrefType); + + // Replace all uses of the original tensor bbArg. + rewriter.setInsertionPointToStart(&frontBlock); + if (!bbArgUses.empty()) { + // Insert to_tensor because the remaining function body has not been + // bufferized yet. + Value toTensorOp = + rewriter.create(funcOp.getLoc(), bbArg); + for (OpOperand *use : bbArgUses) + use->set(toTensorOp); + } + } + + // 2. For each result, keep track of which inplace argument it reuses. + SmallVector returnValues; + for (OpOperand &returnOperand : returnOp->getOpOperands()) { + Value returnVal = returnOperand.get(); + + // If not a tensor type just forward it. + if (!returnVal.getType().isa()) { + returnValues.push_back(returnVal); + continue; + } + + // If return operand is equivalent to some bbArg, no need to return it. + if (Optional equivBbArgIdx = getEquivalentFuncArgIdx( + funcOp, moduleState, returnOperand.getOperandNumber())) { + rewriter.setInsertionPoint(returnOp); + Location loc = returnOp.getLoc(); + Value toMemrefOp = rewriter.create( + loc, getMemRefType(returnVal.getType().cast(), options), + returnVal); + BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); + // Note: This copy will fold away. It must be inserted here to ensure + // that `returnVal` still has at least one use and does not fold away. + if (failed( + createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options))) + return funcOp->emitError("could not generate copy for bbArg"); + continue; + } + + // Cast values at the call site if necessary. + returnValues.push_back( + getNonCastedValue(*state.getBuffer(rewriter, returnOperand))); + } + + // 3. Rewrite the terminator without the in-place bufferizable values. + returnOp.operandsMutable().assign(returnValues); + + // 4. Rewrite the FuncOp type to buffer form. + funcOp.setType(FunctionType::get(op->getContext(), argTypes, + ValueRange(returnValues).getTypes())); + + return success(); } /// Return `true` if the given function argument is writable. @@ -1061,16 +941,43 @@ setInPlaceFuncArgument(bbArg, bufferizableOp.isWritable(bbArg, state)); } +/// Fold return values that are memref casts. +static void foldMemRefCasts(FuncOp funcOp) { + if (funcOp.getBody().empty()) + return; + + func::ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + SmallVector resultTypes; + + for (OpOperand &operand : returnOp->getOpOperands()) { + if (auto castOp = operand.get().getDefiningOp()) { + operand.set(castOp.source()); + resultTypes.push_back(castOp.source().getType()); + } else { + resultTypes.push_back(operand.get().getType()); + } + } + + auto newFuncType = FunctionType::get( + funcOp.getContext(), funcOp.getType().getInputs(), resultTypes); + funcOp.setType(newFuncType); +} + LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize( ModuleOp moduleOp, OneShotBufferizationOptions options) { IRRewriter rewriter(moduleOp.getContext()); OneShotAnalysisState analysisState(moduleOp, options); BufferizationState bufferizationState(analysisState); - ModuleAnalysisState &moduleState = getModuleAnalysisState(analysisState); + FuncAnalysisState &moduleState = getFuncAnalysisState(analysisState); BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); - if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps, - moduleState.callerMap))) + // A list of functions in the order in which they are analyzed + bufferized. + SmallVector orderedFuncOps; + + // A mapping of FuncOps to their callers. + FuncCallerMap callerMap; + + if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return failure(); // Collect bbArg/return value information after the analysis. @@ -1078,7 +985,7 @@ options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); // Analyze ops. - for (FuncOp funcOp : moduleState.orderedFuncOps) { + for (FuncOp funcOp : orderedFuncOps) { // No body => no analysis. if (funcOp.getBody().empty()) continue; @@ -1105,20 +1012,16 @@ return success(); // Bufferize functions. - for (FuncOp funcOp : moduleState.orderedFuncOps) { - // No body => no analysis. - if (!funcOp.getBody().empty()) - if (failed(bufferizeOp(funcOp, bufferizationState))) - return failure(); - + for (FuncOp funcOp : orderedFuncOps) { // Note: It would be good to apply cleanups here but we cannot as aliasInfo // would be invalidated. - if (failed(bufferizeFuncOpBoundary(funcOp, rewriter, bufferizationState))) + if (failed(bufferizeOp(funcOp, bufferizationState))) return failure(); + foldMemRefCasts(funcOp); } // Check result. - for (FuncOp funcOp : moduleState.orderedFuncOps) { + for (FuncOp funcOp : orderedFuncOps) { if (!options.allowReturnAllocs && llvm::any_of(funcOp.getType().getResults(), [](Type t) { return t.isa(); @@ -1132,10 +1035,6 @@ if (failed(finalizeBuffers(moduleOp, options))) return failure(); - // Perform a post-processing pass of layout modification at function boundary - // according to the kBufferLayoutAttrName. - layoutPostProcessing(moduleOp); - // Post-pass cleanup of inplaceable and buffer_layout attributes. moduleOp.walk([&](FuncOp op) { for (BlockArgument bbArg : op.getArguments()) diff --git a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir --- a/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir +++ b/mlir/test/Dialect/Linalg/comprehensive-module-bufferize-invalid.mlir @@ -11,6 +11,7 @@ // ----- +// expected-error @+2 {{op was not bufferized}} // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func private @foo() -> tensor @@ -212,6 +213,7 @@ // ----- +// expected-error @+2 {{op was not bufferized}} // expected-error @+1 {{cannot bufferize bodiless function that returns a tensor}} func private @foo(%t : tensor) -> (f32, tensor, f32) diff --git a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir --- a/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir +++ b/mlir/test/Dialect/Linalg/one-shot-module-bufferize.mlir @@ -468,15 +468,15 @@ // conflict. However, inside `entry`, the writes do cause a conflict because // %A, %B and %C are not inplaceable. This test case shows that this kind of // conflict detection has a "transitive" nature. -// CHECK: %[[ALLOC_C:.*]] = memref.alloc -// CHECK: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] -// CHECK: %[[ALLOC_B:.*]] = memref.alloc -// CHECK: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] -// CHECK: %[[ALLOC_A:.*]] = memref.alloc -// CHECK: memref.copy %[[A]], %[[ALLOC_A]] -// CHECK: memref.copy %[[B]], %[[ALLOC_B]] -// CHECK: memref.copy %[[C]], %[[ALLOC_C]] -// CHECK: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-DAG: %[[ALLOC_C:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_C:.*]] = memref.cast %[[ALLOC_C]] +// CHECK-DAG: %[[ALLOC_B:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_B:.*]] = memref.cast %[[ALLOC_B]] +// CHECK-DAG: %[[ALLOC_A:.*]] = memref.alloc +// CHECK-DAG: %[[CASTED_A:.*]] = memref.cast %[[ALLOC_A]] +// CHECK-DAG: memref.copy %[[A]], %[[ALLOC_A]] +// CHECK-DAG: memref.copy %[[B]], %[[ALLOC_B]] +// CHECK-DAG: memref.copy %[[C]], %[[ALLOC_C]] // CHECK-NEXT: call @callee(%[[CASTED_A]], %[[CASTED_B]], %[[CASTED_C]]) call @callee(%A, %B, %C) : (tensor, tensor, tensor) -> () return