diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -16,8 +16,7 @@ // respective callers. // // After analyzing a FuncOp, additional information about its bbArgs is -// gathered through PostAnalysisStepFns and stored in -// `ModuleAnalysisState`. +// gathered through PostAnalysisStepFns and stored in `FuncAnalysisState`. // // * `aliasingFuncOpBBArgsAnalysis` determines the equivalent/aliasing bbArgs // for @@ -93,7 +92,7 @@ /// Extra analysis state that is required for bufferization of function /// boundaries. -struct ModuleAnalysisState : public DialectAnalysisState { +struct FuncAnalysisState : public DialectAnalysisState { // Note: Function arguments and/or function return values may disappear during // bufferization. Functions and their CallOps are analyzed and bufferized // separately. To ensure that a CallOp analysis/bufferization can access an @@ -162,26 +161,26 @@ }; } // namespace -/// Get ModuleAnalysisState. -static const ModuleAnalysisState & -getModuleAnalysisState(const AnalysisState &state) { - Optional maybeState = - state.getDialectState( +/// Get FuncAnalysisState. +static const FuncAnalysisState & +getFuncAnalysisState(const AnalysisState &state) { + Optional maybeState = + state.getDialectState( func::FuncDialect::getDialectNamespace()); - assert(maybeState.hasValue() && "ModuleAnalysisState does not exist"); + assert(maybeState.hasValue() && "FuncAnalysisState does not exist"); return **maybeState; } -/// Get or create ModuleAnalysisState. -static ModuleAnalysisState &getModuleAnalysisState(AnalysisState &state) { - return state.getOrCreateDialectState( +/// Get or create FuncAnalysisState. +static FuncAnalysisState &getFuncAnalysisState(AnalysisState &state) { + return state.getOrCreateDialectState( func::FuncDialect::getDialectNamespace()); } /// Return the state (phase) of analysis of the FuncOp. static FuncOpAnalysisState getFuncOpAnalysisState(const AnalysisState &state, FuncOp funcOp) { - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &moduleState = getFuncAnalysisState(state); auto it = moduleState.analyzedFuncOps.find(funcOp); if (it == moduleState.analyzedFuncOps.end()) return FuncOpAnalysisState::NotAnalyzed; @@ -226,12 +225,12 @@ } /// Store function BlockArguments that are equivalent to/aliasing a returned -/// value in ModuleAnalysisState. +/// value in FuncAnalysisState. static LogicalResult aliasingFuncOpBBArgsAnalysis(Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { - ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + FuncAnalysisState &funcState = getFuncAnalysisState(state); // Support only single return-terminated block in the function. auto funcOp = cast(op); @@ -245,14 +244,13 @@ int64_t returnIdx = returnVal.getOperandNumber(); int64_t bbArgIdx = bbArg.getArgNumber(); if (aliasInfo.areEquivalentBufferizedValues(returnVal.get(), bbArg)) { - moduleState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; + funcState.equivalentFuncArgs[funcOp][returnIdx] = bbArgIdx; if (state.getOptions().testAnalysisOnly) annotateEquivalentReturnBbArg(returnVal, bbArg); } if (aliasInfo.areAliasingBufferizedValues(returnVal.get(), bbArg)) { - moduleState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); - moduleState.aliasingReturnVals[funcOp][bbArgIdx].push_back( - returnIdx); + funcState.aliasingFuncArgs[funcOp][returnIdx].push_back(bbArgIdx); + funcState.aliasingReturnVals[funcOp][bbArgIdx].push_back(returnIdx); } } @@ -311,15 +309,15 @@ funcOpBbArgReadWriteAnalysis(Operation *op, AnalysisState &state, BufferizationAliasInfo &aliasInfo, SmallVector &newOps) { - ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + FuncAnalysisState &funcState = getFuncAnalysisState(state); auto funcOp = cast(op); // If the function has no body, conservatively assume that all args are // read + written. if (funcOp.getBody().empty()) { for (BlockArgument bbArg : funcOp.getArguments()) { - moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); - moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); + funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); + funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); } return success(); @@ -333,9 +331,9 @@ if (state.getOptions().testAnalysisOnly) annotateFuncArgAccess(funcOp, bbArg, isRead, isWritten); if (isRead) - moduleState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); + funcState.readBbArgs[funcOp].insert(bbArg.getArgNumber()); if (isWritten) - moduleState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); + funcState.writtenBbArgs[funcOp].insert(bbArg.getArgNumber()); } return success(); @@ -399,16 +397,16 @@ // TODO: This does not handle cyclic function call graphs etc. static void equivalenceAnalysis(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - ModuleAnalysisState &moduleState) { + FuncAnalysisState &funcState) { funcOp->walk([&](func::CallOp callOp) { FuncOp calledFunction = getCalledFunction(callOp); assert(calledFunction && "could not retrieved called FuncOp"); // No equivalence info available for the called function. - if (!moduleState.equivalentFuncArgs.count(calledFunction)) + if (!funcState.equivalentFuncArgs.count(calledFunction)) return WalkResult::skip(); - for (auto it : moduleState.equivalentFuncArgs[calledFunction]) { + for (auto it : funcState.equivalentFuncArgs[calledFunction]) { int64_t returnIdx = it.first; int64_t bbargIdx = it.second; Value returnVal = callOp.getResult(returnIdx); @@ -437,8 +435,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, RewriterBase &rewriter, BufferizationState &state) { - const ModuleAnalysisState &moduleState = - getModuleAnalysisState(state.getAnalysisState()); + const FuncAnalysisState &funcState = + getFuncAnalysisState(state.getAnalysisState()); // If nothing to do then we are done. if (!llvm::any_of(funcOp.getFunctionType().getInputs(), isaTensor) && @@ -490,8 +488,8 @@ } // If return operand is equivalent to some bbArg, no need to return it. - auto funcOpIt = moduleState.equivalentFuncArgs.find(funcOp); - if (funcOpIt != moduleState.equivalentFuncArgs.end() && + auto funcOpIt = funcState.equivalentFuncArgs.find(funcOp); + if (funcOpIt != funcState.equivalentFuncArgs.end() && funcOpIt->second.count(returnOperand.getOperandNumber())) continue; @@ -726,9 +724,9 @@ /// Return the index of the bbArg in the given FuncOp that is equivalent to the /// specified return value (if any). -static Optional -getEquivalentFuncArgIdx(FuncOp funcOp, const ModuleAnalysisState &state, - int64_t returnValIdx) { +static Optional getEquivalentFuncArgIdx(FuncOp funcOp, + const FuncAnalysisState &state, + int64_t returnValIdx) { auto funcOpIt = state.equivalentFuncArgs.find(funcOp); if (funcOpIt == state.equivalentFuncArgs.end()) // No equivalence info stores for funcOp. @@ -751,12 +749,12 @@ FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is read. return true; - return moduleState.readBbArgs.lookup(funcOp).contains( + return funcState.readBbArgs.lookup(funcOp).contains( opOperand.getOperandNumber()); } @@ -766,12 +764,12 @@ FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) // FuncOp not analyzed yet. Assume that OpOperand is written. return true; - return moduleState.writtenBbArgs.lookup(funcOp).contains( + return funcState.writtenBbArgs.lookup(funcOp).contains( opOperand.getOperandNumber()); } @@ -780,7 +778,7 @@ func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) { // FuncOp not analyzed yet. Any OpResult may be aliasing. @@ -793,7 +791,7 @@ // Get aliasing results from state. auto aliasingReturnVals = - moduleState.aliasingReturnVals.lookup(funcOp).lookup( + funcState.aliasingReturnVals.lookup(funcOp).lookup( opOperand.getOperandNumber()); SmallVector result; for (int64_t resultIdx : aliasingReturnVals) @@ -807,7 +805,7 @@ func::CallOp callOp = cast(op); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = getModuleAnalysisState(state); + const FuncAnalysisState &funcState = getFuncAnalysisState(state); if (getFuncOpAnalysisState(state, funcOp) != FuncOpAnalysisState::Analyzed) { // FuncOp not analyzed yet. Any OpOperand may be aliasing. @@ -819,7 +817,7 @@ } // Get aliasing bbArgs from state. - auto aliasingFuncArgs = moduleState.aliasingFuncArgs.lookup(funcOp).lookup( + auto aliasingFuncArgs = funcState.aliasingFuncArgs.lookup(funcOp).lookup( opResult.getResultNumber()); SmallVector result; for (int64_t bbArgIdx : aliasingFuncArgs) @@ -842,8 +840,8 @@ unsigned numOperands = callOp->getNumOperands(); FuncOp funcOp = getCalledFunction(callOp); assert(funcOp && "expected CallOp to a FuncOp"); - const ModuleAnalysisState &moduleState = - getModuleAnalysisState(state.getAnalysisState()); + const FuncAnalysisState &funcState = + getFuncAnalysisState(state.getAnalysisState()); const OneShotBufferizationOptions &options = static_cast(state.getOptions()); @@ -885,7 +883,7 @@ } if (Optional bbArgIdx = - getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) { + getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) { // Return operands that are equivalent to some bbArg, are not // returned. FailureOr bufferOrFailure = @@ -1068,11 +1066,11 @@ IRRewriter rewriter(moduleOp.getContext()); OneShotAnalysisState analysisState(moduleOp, options); BufferizationState bufferizationState(analysisState); - ModuleAnalysisState &moduleState = getModuleAnalysisState(analysisState); + FuncAnalysisState &funcState = getFuncAnalysisState(analysisState); BufferizationAliasInfo &aliasInfo = analysisState.getAliasInfo(); - if (failed(getFuncOpsOrderedByCalls(moduleOp, moduleState.orderedFuncOps, - moduleState.callerMap))) + if (failed(getFuncOpsOrderedByCalls(moduleOp, funcState.orderedFuncOps, + funcState.callerMap))) return failure(); // Collect bbArg/return value information after the analysis. @@ -1080,23 +1078,23 @@ options.addPostAnalysisStep(funcOpBbArgReadWriteAnalysis); // Analyze ops. - for (FuncOp funcOp : moduleState.orderedFuncOps) { + for (FuncOp funcOp : funcState.orderedFuncOps) { // No body => no analysis. if (funcOp.getBody().empty()) continue; // Now analyzing function. - moduleState.startFunctionAnalysis(funcOp); + funcState.startFunctionAnalysis(funcOp); // Gather equivalence info for CallOps. - equivalenceAnalysis(funcOp, aliasInfo, moduleState); + equivalenceAnalysis(funcOp, aliasInfo, funcState); // Analyze funcOp. if (failed(analyzeOp(funcOp, analysisState))) return failure(); // Mark op as fully analyzed. - moduleState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; + funcState.analyzedFuncOps[funcOp] = FuncOpAnalysisState::Analyzed; // Add annotations to function arguments. if (options.testAnalysisOnly) @@ -1107,7 +1105,7 @@ return success(); // Bufferize functions. - for (FuncOp funcOp : moduleState.orderedFuncOps) { + for (FuncOp funcOp : funcState.orderedFuncOps) { // No body => no analysis. if (!funcOp.getBody().empty()) if (failed(bufferizeOp(funcOp, bufferizationState))) @@ -1120,7 +1118,7 @@ } // Check result. - for (FuncOp funcOp : moduleState.orderedFuncOps) { + for (FuncOp funcOp : funcState.orderedFuncOps) { if (!options.allowReturnAllocs && llvm::any_of(funcOp.getFunctionType().getResults(), [](Type t) { return t.isa();