diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h @@ -330,10 +330,9 @@ virtual ~PostAnalysisStep() {} /// Run the post analysis step. This function may modify the IR, but must keep - /// `aliasInfo` consistent. Newly created operations and operations that - /// should be re-analyzed must be stored in `newOps`. - virtual LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo, + /// `aliasInfo` (inside `state`) consistent. Newly created operations and + /// operations that should be re-analyzed must be stored in `newOps`. + virtual LogicalResult run(FuncOp funcOp, BufferizationState &state, SmallVector &newOps) = 0; }; diff --git a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h --- a/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h +++ b/mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.h @@ -34,7 +34,7 @@ /// * The result of `rewriteFunc` must usually be analyzed for inplacability. /// This analysis can be skipped with `skipAnalysis`. LogicalResult eliminateInitTensors( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, DominanceInfo &domInfo, + FuncOp funcOp, BufferizationState &state, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps); @@ -45,8 +45,7 @@ /// (and some other conditions are met). struct InsertSliceAnchoredInitTensorEliminationStep : public InitTensorEliminationStep { - LogicalResult run(FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo, + LogicalResult run(FuncOp funcOp, BufferizationState &state, SmallVector &newOps) override; }; diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp @@ -910,7 +910,7 @@ for (const std::unique_ptr &step : options.postAnalysisSteps) { SmallVector newOps; - if (failed(step->run(funcOp, aliasInfo, domInfo, newOps))) + if (failed(step->run(funcOp, state, newOps))) return failure(); // Analyze ops that were created by the PostAnalysisStep. if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo))) diff --git a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp --- a/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp +++ b/mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp @@ -407,12 +407,12 @@ /// OpOperand, that eventually ends at a single InitTensorOp. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InitTensorEliminationStep::eliminateInitTensors( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo, + FuncOp funcOp, BufferizationState &state, std::function anchorMatchFunc, std::function rewriteFunc, SmallVector &newOps) { OpBuilder b(funcOp->getContext()); + BufferizationAliasInfo &aliasInfo = state.aliasInfo; WalkResult status = funcOp->walk([&](Operation *op) { for (OpOperand &operand : op->getOpOperands()) { @@ -501,17 +501,17 @@ /// out-of-place due to RaW conflicts. LogicalResult mlir::linalg::comprehensive_bufferize::linalg_ext:: InsertSliceAnchoredInitTensorEliminationStep::run( - FuncOp funcOp, BufferizationAliasInfo &aliasInfo, - DominanceInfo &domInfo, SmallVector &newOps) { + FuncOp funcOp, BufferizationState &state, + SmallVector &newOps) { return eliminateInitTensors( - funcOp, aliasInfo, domInfo, + funcOp, state, [&](OpOperand &operand) { auto insertSliceOp = dyn_cast(operand.getOwner()); if (!insertSliceOp) return false; // Only inplace bufferized InsertSliceOps are eligible. - if (!aliasInfo.isInPlace(insertSliceOp->getOpResult(0))) + if (!state.aliasInfo.isInPlace(insertSliceOp->getOpResult(0))) return false; return &operand == &insertSliceOp->getOpOperand(0) /*source*/; },