diff --git a/mlir/include/mlir/Analysis/DataFlowFramework.h b/mlir/include/mlir/Analysis/DataFlowFramework.h --- a/mlir/include/mlir/Analysis/DataFlowFramework.h +++ b/mlir/include/mlir/Analysis/DataFlowFramework.h @@ -235,12 +235,6 @@ /// dependent work items to the back of the queue. void propagateIfChanged(AnalysisState *state, ChangeResult changed); - /// Add a dependency to an analysis state on a child analysis and program - /// point. If the state is updated, the child analysis must be invoked on the - /// given program point again. - void addDependency(AnalysisState *state, DataFlowAnalysis *analysis, - ProgramPoint point); - private: /// The solver's work queue. Work items can be inserted to the front of the /// queue to be processed greedily, speeding up computations that otherwise @@ -294,13 +288,30 @@ /// Print the contents of the analysis state. virtual void print(raw_ostream &os) const = 0; + /// Add a dependency to this analysis state on a program point and an + /// analysis. If this state is updated, the analysis will be invoked on the + /// given program point again (in onUpdate()). + void addDependency(ProgramPoint dependent, DataFlowAnalysis *analysis); + protected: /// This function is called by the solver when the analysis state is updated - /// to optionally enqueue more work items. For example, if a state tracks - /// dependents through the IR (e.g. use-def chains), this function can be - /// implemented to push those dependents on the worklist. - virtual void onUpdate(DataFlowSolver *solver) const {} + /// to enqueue more work items. For example, if a state tracks dependents + /// through the IR (e.g. use-def chains), this function can be implemented to + /// push those dependents on the worklist. + virtual void onUpdate(DataFlowSolver *solver) const { + for (const DataFlowSolver::WorkItem &item : dependents) + solver->enqueue(item); + } + + /// The program point to which the state belongs. + ProgramPoint point; + +#if LLVM_ENABLE_ABI_BREAKING_CHECKS + /// When compiling with debugging, keep a name for the analysis state. + StringRef debugName; +#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS +private: /// The dependency relations originating from this analysis state. An entry /// `state -> (analysis, point)` is created when `analysis` queries `state` /// when updating `point`. @@ -312,14 +323,6 @@ /// Store the dependents on the analysis state for efficiency. SetVector dependents; - /// The program point to which the state belongs. - ProgramPoint point; - -#if LLVM_ENABLE_ABI_BREAKING_CHECKS - /// When compiling with debugging, keep a name for the analysis state. - StringRef debugName; -#endif // LLVM_ENABLE_ABI_BREAKING_CHECKS - /// Allow the framework to access the dependents. friend class DataFlowSolver; }; diff --git a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/DeadCodeAnalysis.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" #include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Interfaces/CallInterfaces.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include @@ -31,6 +32,8 @@ } void Executable::onUpdate(DataFlowSolver *solver) const { + AnalysisState::onUpdate(solver); + if (auto *block = llvm::dyn_cast_if_present(point)) { // Re-invoke the analyses on the block itself. for (DataFlowAnalysis *analysis : subscribers) diff --git a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp --- a/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/SparseAnalysis.cpp @@ -8,6 +8,7 @@ #include "mlir/Analysis/DataFlow/SparseAnalysis.h" #include "mlir/Analysis/DataFlow/DeadCodeAnalysis.h" +#include "mlir/Analysis/DataFlowFramework.h" #include "mlir/Interfaces/CallInterfaces.h" using namespace mlir; @@ -18,6 +19,8 @@ //===----------------------------------------------------------------------===// void AbstractSparseLattice::onUpdate(DataFlowSolver *solver) const { + AnalysisState::onUpdate(solver); + // Push all users of the value to the queue. for (Operation *user : point.get().getUsers()) for (DataFlowAnalysis *analysis : useDefSubscribers) diff --git a/mlir/lib/Analysis/DataFlowFramework.cpp b/mlir/lib/Analysis/DataFlowFramework.cpp --- a/mlir/lib/Analysis/DataFlowFramework.cpp +++ b/mlir/lib/Analysis/DataFlowFramework.cpp @@ -30,6 +30,19 @@ AnalysisState::~AnalysisState() = default; +void AnalysisState::addDependency(ProgramPoint dependent, + DataFlowAnalysis *analysis) { + auto inserted = dependents.insert({dependent, analysis}); + (void)inserted; + DATAFLOW_DEBUG({ + if (inserted) { + llvm::dbgs() << "Creating dependency between " << debugName << " of " + << point << "\nand " << debugName << " on " << dependent + << "\n"; + } + }); +} + //===----------------------------------------------------------------------===// // ProgramPoint //===----------------------------------------------------------------------===// @@ -97,26 +110,10 @@ DATAFLOW_DEBUG(llvm::dbgs() << "Propagating update to " << state->debugName << " of " << state->point << "\n" << "Value: " << *state << "\n"); - for (const WorkItem &item : state->dependents) - enqueue(item); state->onUpdate(this); } } -void DataFlowSolver::addDependency(AnalysisState *state, - DataFlowAnalysis *analysis, - ProgramPoint point) { - auto inserted = state->dependents.insert({point, analysis}); - (void)inserted; - DATAFLOW_DEBUG({ - if (inserted) { - llvm::dbgs() << "Creating dependency between " << state->debugName - << " of " << state->point << "\nand " << analysis->debugName - << " on " << point << "\n"; - } - }); -} - //===----------------------------------------------------------------------===// // DataFlowAnalysis //===----------------------------------------------------------------------===// @@ -126,7 +123,7 @@ DataFlowAnalysis::DataFlowAnalysis(DataFlowSolver &solver) : solver(solver) {} void DataFlowAnalysis::addDependency(AnalysisState *state, ProgramPoint point) { - solver.addDependency(state, this, point); + state->addDependency(point, this); } void DataFlowAnalysis::propagateIfChanged(AnalysisState *state,