diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -230,6 +230,13 @@ MemCpyFn memCpyFn; }; +/// Dialect-specific bufferization state. Analysis/bufferization information +/// that is specific to ops from a certain dialect can be stored in derived +/// variants of this struct. +struct DialectBufferizationState { + virtual ~DialectBufferizationState() = default; +}; + /// BufferizationState keeps track of bufferization state and provides access to /// the results of the analysis. struct BufferizationState { @@ -271,6 +278,14 @@ /// Erase all ops that were marked obsolete. void eraseObsoleteOps(); + /// Return dialect-specific bufferization state. + template StateT &getDialectState(StringRef name) { + // Create state if it does not exist yet. + if (!dialectState.count(name)) + dialectState[name] = std::make_unique(); + return static_cast(*dialectState[name]); + } + /// `aliasInfo` keeps track of aliasing and equivalent values. BufferizationAliasInfo aliasInfo; @@ -284,6 +299,9 @@ /// Obsolete ops that should be deleted after bufferization. SmallVector obsoleteOps; + + /// Dialect-specific bufferization state. + DenseMap> dialectState; }; /// Return the result buffer (memref) for a given OpResult (tensor). Allocate diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -27,11 +27,9 @@ using namespace comprehensive_bufferize; namespace { -/// A specialization of BufferizationState that keeps track of additional -/// state required for bufferization of function boundaries. -struct ModuleBufferizationState : public BufferizationState { - using BufferizationState::BufferizationState; - +/// Extra bufferization state that is required for bufferization of function +/// boundaries. +struct ModuleBufferizationState : public DialectBufferizationState { /// A map for looking up bufferized function types. DenseMap bufferizedFunctionTypes; @@ -40,6 +38,12 @@ }; } // namespace +static ModuleBufferizationState & +getModuleBufferizationState(BufferizationState &state) { + return state.getDialectState( + StandardOpsDialect::getDialectNamespace()); +} + static bool isaTensor(Type t) { return t.isa(); } /// If `value` is a memref::CastOp, return its source. Otherwise, return @@ -127,7 +131,9 @@ /// Store function BlockArguments that are equivalent to a returned value in /// the given ModuleBufferizationState. static void populateEquivalentFuncOpBBArgs(FuncOp funcOp, - ModuleBufferizationState &state) { + BufferizationState &state) { + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + // Support only single return-terminated block in the function. ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); assert(returnOp && "expected func with single return op"); @@ -137,7 +143,7 @@ for (BlockArgument bbArg : funcOp.getArguments()) if (bbArg.getType().isa()) if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg)) - state.equivalentReturnValToBBArg[returnVal] = bbArg; + moduleState.equivalentReturnValToBBArg[returnVal] = bbArg; } /// Rewrite the `funcOp` arguments analysis return values and terminator into @@ -155,8 +161,9 @@ /// originate from an op with an Alloc effect, they could be hoisted in the /// future. static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp, - ModuleBufferizationState &state) { + BufferizationState &state) { LLVM_DEBUG(DBGS() << "Begin bufferizeFuncOpBoundary:\n" << funcOp << "\n"); + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); BufferizationAliasInfo &aliasInfo = state.aliasInfo; // If nothing to do then we are done. @@ -188,7 +195,7 @@ << "returns a tensor"; FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( funcOp, funcOp.getType().getInputs(), TypeRange{}, - state.bufferizedFunctionTypes); + moduleState.bufferizedFunctionTypes); funcOp.setType(bufferizedFuncType); LLVM_DEBUG(DBGS() << "End bufferizeFuncOpBoundary no fun body: " << funcOp); return success(); @@ -210,7 +217,7 @@ } // If return operand is equivalent to some bbArg, no need to return it. - if (state.equivalentReturnValToBBArg.count(returnVal)) + if (moduleState.equivalentReturnValToBBArg.count(returnVal)) continue; // Cast values at the call site if necessary. @@ -221,7 +228,7 @@ ValueRange retValues{returnValues}; FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( funcOp, funcOp.getType().getInputs(), retValues.getTypes(), - state.bufferizedFunctionTypes); + moduleState.bufferizedFunctionTypes); OpBuilder b(returnOp); b.create(returnOp.getLoc(), returnValues); returnOp->erase(); @@ -474,7 +481,7 @@ FuncOp funcOp = getCalledFunction(callOp); assert(isa(callOp.getOperation()) && funcOp && "expected Callop to a FuncOp"); - auto &moduleState = static_cast(state); + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); // Take a guard before anything else. OpBuilder::InsertionGuard g(b); @@ -649,7 +656,7 @@ if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap))) return failure(); - ModuleBufferizationState state(moduleOp, *options.allocationFns); + BufferizationState state(moduleOp, *options.allocationFns); BufferizationAliasInfo &aliasInfo = state.aliasInfo; // Interestingly, all function args that are not visible outside of a module