diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -18,6 +18,8 @@ namespace mlir { class BlockAndValueMapping; +class DominanceInfo; +class FuncOp; namespace linalg { namespace comprehensive_bufferize { @@ -266,6 +268,20 @@ /// bufferization is necessary. Value getResultBuffer(OpBuilder &b, OpResult result, BufferizationState &state); +/// PostAnalysisSteps can be registered with `BufferizationOptions` and are +/// executed after the analysis, but before bufferization. They can be used +/// implement custom dialect-specific optimizations. +struct PostAnalysisStep { + virtual ~PostAnalysisStep() {} + + /// Run the post analysis step. This function may modify the IR, but must keep + /// `aliasInfo` consistent. Newly created operations and operations that + /// should be re-analyzed must be stored in `newOps`. + virtual LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo, + SmallVector &newOps) = 0; +}; + } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.h @@ -63,40 +63,59 @@ /// Register external models implemented for the `BufferizableOpInterface`. void registerBufferizableOpInterfaceExternalModels(DialectRegistry ®istry); -/// Try to eliminate InitTensorOps inside `funcOp`. -/// -/// * `rewriteFunc` generates the replacement for the InitTensorOp. -/// * Only InitTensorOps that are anchored on a matching OpOperand as per -/// `anchorMatchFunc` are considered. "Anchored" means that there is a path on -/// the reverse SSA use-def chain, starting from the OpOperand and always -/// following the aliasing OpOperand, that eventually ends at a single -/// InitTensorOp. -/// * The result of `rewriteFunc` must usually be analyzed for inplacability. -/// This analysis can be skipped with `skipAnalysis`. -LogicalResult initTensorElimination( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, - std::function anchorMatchFunc, - std::function rewriteFunc, - bool skipAnalysis = false); - -/// Try to eliminate InitTensorOps inside funcOp that are anchored on an -/// InsertSliceOp, i.e., if it is eventually inserted into another tensor -/// (and some other conditions are met). -LogicalResult eliminateInsertSliceAnchoredInitTensorOps( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo); - struct BufferizationOptions { BufferizationOptions(); + /// Register a "post analysis" step. Such steps are executed after the + /// analysis, but before bufferization. + template + void addPostAnalysisStep(Args... args) { + postAnalysisSteps.emplace_back( + std::make_unique(std::forward(args)...)); + } + std::unique_ptr allocationFns; bool allowReturnMemref = false; unsigned analysisFuzzerSeed = 0; bool testAnalysisOnly = false; + std::vector> postAnalysisSteps; }; LogicalResult runComprehensiveBufferize(ModuleOp moduleOp, const BufferizationOptions &options); +namespace linalg_ext { + +struct InitTensorEliminationStep : public PostAnalysisStep { + /// Try to eliminate InitTensorOps inside `funcOp`. + /// + /// * `rewriteFunc` generates the replacement for the InitTensorOp. + /// * Only InitTensorOps that are anchored on a matching OpOperand as per + /// `anchorMatchFunc` are considered. "Anchored" means that there is a path + /// on the reverse SSA use-def chain, starting from the OpOperand and always + /// following the aliasing OpOperand, that eventually ends at a single + /// InitTensorOp. + /// * The result of `rewriteFunc` must usually be analyzed for inplacability. + /// This analysis can be skipped with `skipAnalysis`. + LogicalResult eliminateInitTensors( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, + std::function anchorMatchFunc, + std::function rewriteFunc, + SmallVector &newOps); +}; + +/// Try to eliminate InitTensorOps inside funcOp that are anchored on an +/// InsertSliceOp, i.e., if it is eventually inserted into another tensor +/// (and some other conditions are met). +struct InsertSliceAnchoredInitTensorEliminationStep + : public InitTensorEliminationStep { + LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo, + SmallVector &newOps) override; +}; + +} // namespace linalg_ext + } // namespace comprehensive_bufferize } // namespace linalg } // namespace mlir diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -1595,11 +1595,13 @@ /// OpOperand. "Anchored" means that there is a path on the reverse SSA use-def /// chain, starting from the OpOperand and always following the aliasing /// OpOperand, that eventually ends at a single InitTensorOp. -LogicalResult mlir::linalg::comprehensive_bufferize::initTensorElimination( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, - std::function anchorMatchFunc, - std::function rewriteFunc, - bool skipAnalysis) { +LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: + InitTensorEliminationStep::eliminateInitTensors( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo, + std::function anchorMatchFunc, + std::function rewriteFunc, + SmallVector &newOps) { OpBuilder b(funcOp->getContext()); WalkResult status = funcOp->walk([&](Operation *op) { @@ -1647,14 +1649,9 @@ aliasInfo.unionAliasSets(initTensor, replacement); aliasInfo.unionEquivalenceClasses(initTensor, replacement); - // Run analysis on the newly created op. - if (auto opResult = replacement.dyn_cast()) { - if (!skipAnalysis) { - SmallVector ops(1, replacement.getDefiningOp()); - if (failed(inPlaceAnalysis(ops, aliasInfo, domInfo))) - return WalkResult::interrupt(); - } - } + // Register replacement ops. + if (Operation *newOp = replacement.getDefiningOp()) + newOps.push_back(newOp); } // Advance to the next operation. @@ -1692,11 +1689,11 @@ /// /// Note that the newly inserted ExtractSliceOp may have to bufferize /// out-of-place due to RaW conflicts. -LogicalResult mlir::linalg::comprehensive_bufferize:: - eliminateInsertSliceAnchoredInitTensorOps(FuncOp funcOp, - BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo) { - return initTensorElimination( +LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: + InsertSliceAnchoredInitTensorEliminationStep::run( + FuncOp funcOp, BufferizationAliasInfo &aliasInfo, + DominanceInfo &domInfo, SmallVector &newOps) { + return eliminateInitTensors( funcOp, aliasInfo, domInfo, [&](OpOperand &operand) { auto insertSliceOp = dyn_cast(operand.getOwner()); @@ -1713,7 +1710,8 @@ loc, insertSliceOp.dest(), insertSliceOp.getMixedOffsets(), insertSliceOp.getMixedSizes(), insertSliceOp.getMixedStrides()); return extractOp.result(); - }); + }, + newOps); } #ifndef NDEBUG @@ -1793,11 +1791,15 @@ options.analysisFuzzerSeed))) return failure(); - // Try to eliminate InitTensorOps to avoid new allocations during the - // bufferization phase. - if (failed(eliminateInsertSliceAnchoredInitTensorOps(funcOp, aliasInfo, - domInfo))) - return failure(); + for (const std::unique_ptr &step : + options.postAnalysisSteps) { + SmallVector newOps; + if (failed(step->run(funcOp, aliasInfo, domInfo, newOps))) + return failure(); + // Analyze ops that were created by the PostAnalysisStep. + if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) + return failure(); + } // Bufferization phase. if (!options.testAnalysisOnly) { diff --git a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp --- a/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferizePass.cpp @@ -71,6 +71,10 @@ options.analysisFuzzerSeed = analysisFuzzerSeed; options.testAnalysisOnly = testAnalysisOnly; + // Enable InitTensorOp elimination. + options.addPostAnalysisStep< + linalg_ext::InsertSliceAnchoredInitTensorEliminationStep>(); + ModuleOp moduleOp = getOperation(); applyEnablingTransformations(moduleOp);