diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -24,9 +24,6 @@ /// Extra bufferization state that is required for bufferization of function /// boundaries. struct ModuleBufferizationState : public DialectBufferizationState { - /// A map for looking up bufferized function types. - DenseMap bufferizedFunctionTypes; - /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg /// indices. DenseMap> equivalentFuncArgs; @@ -161,23 +158,6 @@ return FunctionType::get(ctx, argTypes, retTypes); } -/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return -/// it. Otherwise, construct a new entry based on `argumentTypes` and -/// `resultTypes`. -// TODO: improve the layering. -static FunctionType getOrCreateBufferizedFunctionType( - FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes, - DenseMap &bufferizedFunctionTypes) { - auto it = bufferizedFunctionTypes.find(funcOp); - if (it != bufferizedFunctionTypes.end()) - return it->second; - - auto it2 = bufferizedFunctionTypes.try_emplace( - funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes, - resultTypes)); - return it2.first->second; -} - /// Gather equivalence info of CallOps. /// Note: This only adds new equivalence info if `funcOp` was already analyzed. // TODO: This does not handle cyclic function call graphs etc. @@ -250,9 +230,8 @@ if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) return funcOp->emitError() << "cannot bufferize bodiless function that " << "returns a tensor"; - FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( - funcOp, funcOp.getType().getInputs(), TypeRange{}, - moduleState.bufferizedFunctionTypes); + FunctionType bufferizedFuncType = getBufferizedFunctionType( + funcOp.getContext(), funcOp.getType().getInputs(), TypeRange{}); funcOp.setType(bufferizedFuncType); return success(); } @@ -284,9 +263,8 @@ // 2. Rewrite the terminator without the inPlace bufferizable values. ValueRange retValues{returnValues}; - FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType( - funcOp, funcOp.getType().getInputs(), retValues.getTypes(), - moduleState.bufferizedFunctionTypes); + FunctionType bufferizedFuncType = getBufferizedFunctionType( + funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes()); OpBuilder b(returnOp); b.create(returnOp.getLoc(), returnValues); returnOp->erase(); @@ -590,9 +568,8 @@ SmallVector argumentTypes{callOp->getOperandTypes()}; // Get the bufferized FunctionType for funcOp or construct it if not yet // available. - FunctionType bufferizedFuncType = - getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes, - moduleState.bufferizedFunctionTypes); + FunctionType bufferizedFuncType = getBufferizedFunctionType( + funcOp.getContext(), argumentTypes, resultTypes); // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. for (OpOperand &opOperand : callOp->getOpOperands()) {