diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -71,6 +71,8 @@ SmallVector &newOps) = 0; }; +using PostAnalysisStepList = std::vector>; + /// Options for ComprehensiveBufferize. struct BufferizationOptions { BufferizationOptions(); @@ -107,7 +109,7 @@ bool testAnalysisOnly = false; /// Registered post analysis steps. - std::vector> postAnalysisSteps; + PostAnalysisStepList postAnalysisSteps; }; /// Specify fine-grain relationship between buffers to enable more analysis. diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -18,6 +18,7 @@ struct BufferizationOptions; struct BufferizationState; +struct PostAnalysisStep; /// Bufferize the given op. LogicalResult runComprehensiveBufferize(Operation *op, @@ -25,9 +26,10 @@ /// Bufferize the given function. Does not bufferize the function boundary. /// Reuses an existing BufferizationState object. -LogicalResult runComprehensiveBufferize(Operation *op, - const BufferizationOptions &options, - BufferizationState &state); +LogicalResult runComprehensiveBufferize( + Operation *op, const BufferizationOptions &options, + BufferizationState &state, + const std::vector> &extraSteps); } // namespace comprehensive_bufferize } // namespace linalg diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -721,12 +721,13 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( Operation *op, const BufferizationOptions &options) { BufferizationState state(op, options); - return runComprehensiveBufferize(op, options, state); + PostAnalysisStepList extraSteps; + return runComprehensiveBufferize(op, options, state, extraSteps); } LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize( Operation *op, const BufferizationOptions &options, - BufferizationState &state) { + BufferizationState &state, const PostAnalysisStepList &extraSteps) { DominanceInfo domInfo(op); BufferizationAliasInfo &aliasInfo = state.aliasInfo; @@ -740,16 +741,23 @@ return failure(); equivalenceAnalysis(op, aliasInfo); - for (const std::unique_ptr &step : - options.postAnalysisSteps) { - SmallVector newOps; - if (failed(step->run(op, state, newOps))) - return failure(); - // Analyze ops that were created by the PostAnalysisStep. - if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) - return failure(); - equivalenceAnalysis(newOps, aliasInfo); - } + auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) { + for (const std::unique_ptr &step : steps) { + SmallVector newOps; + if (failed(step->run(op, state, newOps))) + return failure(); + // Analyze ops that were created by the PostAnalysisStep. + if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) + return failure(); + equivalenceAnalysis(newOps, aliasInfo); + } + return success(); + }; + + if (failed(runPostAnalysisSteps(extraSteps))) + return failure(); + if (failed(runPostAnalysisSteps(options.postAnalysisSteps))) + return failure(); // Annotate operations if we only want to report the analysis. if (options.testAnalysisOnly) { diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp @@ -33,8 +33,9 @@ /// A map for looking up bufferized function types. DenseMap bufferizedFunctionTypes; - /// A mapping of return values to equivalent BlockArguments. - DenseMap equivalentReturnValToBBArg; + /// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg + /// indices. + DenseMap equivalentFuncArgs; }; } // namespace @@ -44,6 +45,47 @@ StandardOpsDialect::getDialectNamespace()); } +/// Return the unique ReturnOp that terminates `funcOp`. +/// Return nullptr if there is no such unique ReturnOp. +static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { + ReturnOp returnOp; + for (Block &b : funcOp.body()) { + if (auto candidateOp = dyn_cast(b.getTerminator())) { + if (returnOp) + return nullptr; + returnOp = candidateOp; + } + } + return returnOp; +} + +namespace { +/// Store function BlockArguments that are equivalent to a returned value in +/// ModuleBufferizationState. +struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep { + LogicalResult run(Operation *op, BufferizationState &state, + SmallVector &newOps) override { + ModuleBufferizationState &moduleState = getModuleBufferizationState(state); + auto funcOp = cast(op); + + // Support only single return-terminated block in the function. + ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); + assert(returnOp && "expected func with single return op"); + + for (OpOperand &returnVal : returnOp->getOpOperands()) + if (returnVal.get().getType().isa()) + for (BlockArgument bbArg : funcOp.getArguments()) + if (bbArg.getType().isa()) + if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(), + bbArg)) + moduleState.equivalentFuncArgs[returnVal.getOperandNumber()] = + bbArg.getArgNumber(); + + return success(); + } +}; +} // namespace + static bool isaTensor(Type t) { return t.isa(); } /// If `value` is a memref::CastOp, return its source. Otherwise, return @@ -73,20 +115,6 @@ SymbolTable::lookupNearestSymbolFrom(callOp, sym)); } -/// Return the unique ReturnOp that terminates `funcOp`. -/// Return nullptr if there is no such unique ReturnOp. -static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) { - ReturnOp returnOp; - for (Block &b : funcOp.body()) { - if (auto candidateOp = dyn_cast(b.getTerminator())) { - if (returnOp) - return nullptr; - returnOp = candidateOp; - } - } - return returnOp; -} - /// Return the FunctionType with `argumentTypes` and `resultTypes` where each /// tensor is replaced by the corresponding buffer type. /// In order for all the callers to agree, this *must* bufferize to the most @@ -128,24 +156,6 @@ return it2.first->second; } -/// Store function BlockArguments that are equivalent to a returned value in -/// the given ModuleBufferizationState. -static void populateEquivalentFuncOpBBArgs(FuncOp funcOp, - BufferizationState &state) { - ModuleBufferizationState &moduleState = getModuleBufferizationState(state); - - // Support only single return-terminated block in the function. - ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); - assert(returnOp && "expected func with single return op"); - - for (Value returnVal : returnOp.operands()) - if (returnVal.getType().isa()) - for (BlockArgument bbArg : funcOp.getArguments()) - if (bbArg.getType().isa()) - if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg)) - moduleState.equivalentReturnValToBBArg[returnVal] = bbArg; -} - /// Rewrite the `funcOp` arguments analysis return values and terminator into /// buffer form (using the canonical memref layout for now), according to the /// inPlace-bufferizable information of the function arguments. @@ -217,7 +227,7 @@ } // If return operand is equivalent to some bbArg, no need to return it. - if (moduleState.equivalentReturnValToBBArg.count(returnVal)) + if (moduleState.equivalentFuncArgs.count(returnOperand.getOperandNumber())) continue; // Cast values at the call site if necessary. @@ -503,12 +513,11 @@ } // If return operand is equivalent to some bbArg, no need to return it. - Value returnVal = returnOperand.get(); - if (moduleState.equivalentReturnValToBBArg.count(returnVal)) { - BlockArgument bbArg = - moduleState.equivalentReturnValToBBArg[returnVal]; + if (moduleState.equivalentFuncArgs.count( + returnOperand.getOperandNumber())) { + int64_t idx = + moduleState.equivalentFuncArgs[returnOperand.getOperandNumber()]; Value oldRes = callOp->getResult(returnOperand.getOperandNumber()); - int64_t idx = bbArg.getArgNumber(); Value buffer = state.lookupBuffer(callOp->getOperand(idx)); // Add CallOp operand/result equivalence: this is interprocedural // info. @@ -710,11 +719,14 @@ aliasInfo.setBufferizesToWritableMemory(bbArg); } + // Register extra post analysis steps. These cannot be stored in `options` + // because `options` is immutable. + PostAnalysisStepList extraSteps; + extraSteps.emplace_back(std::make_unique()); + // Analyze and bufferize funcOp. - if (failed(runComprehensiveBufferize(funcOp, options, state))) + if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps))) return failure(); - - populateEquivalentFuncOpBBArgs(funcOp, state); } if (options.testAnalysisOnly)