diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -159,7 +159,10 @@ cached to avoid unnecessary recomputation. An analysis in MLIR must adhere to the following: -* Provide a valid constructor taking an `Operation*`. +* Provide a valid constructor taking either an `Operation*` or `Operation*` + and `AnalysisManager &`. + * The provided `AnalysisManager &` should be used to query any necessary + analysis dependencies. * Must not modify the given operation. An analysis may provide additional hooks to control various behavior: @@ -169,7 +172,9 @@ Given a preserved analysis set, the analysis returns true if it should truly be invalidated. This allows for more fine-tuned invalidation in cases where an analysis wasn't explicitly marked preserved, but may be preserved (or -invalidated) based upon other properties such as analyses sets. +invalidated) based upon other properties such as analyses sets. If the analysis +uses any other analysis as a dependency, it must also check if the dependency +was invalidated. ### Querying Analyses @@ -200,6 +205,20 @@ MyOperationAnalysis(Operation *op); }; +struct MyOperationAnalysisWithDependency { + MyOperationAnalysisWithDependency(Operation *op, AnalysisManager &am) { + // Request other analysis as dependency + MyOperationAnalysis &otherAnalysis = am.getAnalysis(); + ... + } + + bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { + // Check if analysis or its dependency were invalidated + return !pa.isPreserved() || + !pa.isPreserved(); + } +}; + void MyOperationPass::runOnOperation() { // Query MyOperationAnalysis for the current operation. MyOperationAnalysis &myAnalysis = getAnalysis(); @@ -899,6 +918,10 @@ executed, `runAfterPass` will *not* be. * `runBeforeAnalysis` * This callback is run just before an analysis is computed. + * If the analysis requested another analysis as a dependency, the + `runBeforeAnalysis`/`runAfterAnalysis` pair for the dependency can be + called from inside of the current `runBeforeAnalysis`/`runAfterAnalysis` + pair. * `runAfterAnalysis` * This callback is run right after an analysis is computed. diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -13,10 +13,13 @@ #include "mlir/Pass/PassInstrumentation.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/TypeName.h" namespace mlir { +class AnalysisManager; + //===----------------------------------------------------------------------===// // Analysis Preservation and Concept Modeling //===----------------------------------------------------------------------===// @@ -59,6 +62,16 @@ bool isPreserved(TypeID id) const { return preservedIDs.count(id); } private: + /// Remove the analysis from preserved set. + template + void unpreserve() { + preservedIDs.erase(TypeID::get()); + } + + /// AnalysisModel need access to unpreserve(). + template + friend struct AnalysisModel; + /// The set of analyses that are known to be preserved. SmallPtrSet preservedIDs; }; @@ -91,8 +104,9 @@ /// set, returns true if it should truly be invalidated. This allows for more /// fine-tuned invalidation in cases where an analysis wasn't explicitly /// marked preserved, but may be preserved(or invalidated) based upon other - /// properties such as analyses sets. - virtual bool isInvalidated(const PreservedAnalyses &pa) = 0; + /// properties such as analyses sets. Invalidated analyses must also be + /// removed from pa. + virtual bool invalidate(PreservedAnalyses &pa) = 0; }; /// A derived analysis model used to hold a specific analysis object. @@ -101,9 +115,13 @@ explicit AnalysisModel(Args &&...args) : analysis(std::forward(args)...) {} - /// A hook used to query analyses for invalidation. - bool isInvalidated(const PreservedAnalyses &pa) final { - return analysis_impl::isInvalidated(analysis, pa); + /// A hook used to query analyses for invalidation. Removes invalidated + /// analyses from pa. + bool invalidate(PreservedAnalyses &pa) final { + bool result = analysis_impl::isInvalidated(analysis, pa); + if (result) + pa.unpreserve(); + return result; } /// The actual analysis object. @@ -114,7 +132,7 @@ /// computation, caching, and invalidation of analyses takes place here. class AnalysisMap { /// A mapping between an analysis id and an existing analysis instance. - using ConceptMap = DenseMap>; + using ConceptMap = llvm::MapVector>; /// Utility to return the name of the given analysis class. template static StringRef getAnalysisName() { @@ -129,17 +147,19 @@ /// Get an analysis for the current IR unit, computing it if necessary. template - AnalysisT &getAnalysis(PassInstrumentor *pi) { - return getAnalysisImpl(pi, ir); + AnalysisT &getAnalysis(PassInstrumentor *pi, AnalysisManager &am) { + return getAnalysisImpl(pi, ir, am); } /// Get an analysis for the current IR unit assuming it's of specific derived /// operation type. template - typename std::enable_if::value, - AnalysisT &>::type - getAnalysis(PassInstrumentor *pi) { - return getAnalysisImpl(pi, cast(ir)); + std::enable_if_t< + std::is_constructible::value || + std::is_constructible::value, + AnalysisT &> + getAnalysis(PassInstrumentor *pi, AnalysisManager &am) { + return getAnalysisImpl(pi, cast(ir), am); } /// Get a cached analysis instance if one exists, otherwise return null. @@ -160,30 +180,31 @@ /// Invalidate any cached analyses based upon the given set of preserved /// analyses. void invalidate(const PreservedAnalyses &pa) { + PreservedAnalyses paCopy(pa); // Remove any analyses that were invalidated. - for (auto it = analyses.begin(), e = analyses.end(); it != e;) { - auto curIt = it++; - if (curIt->second->isInvalidated(pa)) - analyses.erase(curIt); - } + // As we are using MapVector, order of insertion is preserved and + // dependencies always go before users, so we need only one iteration. + analyses.remove_if( + [&](auto &val) { return val.second->invalidate(paCopy); }); } private: template - AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op) { + AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op, + AnalysisManager &am) { TypeID id = TypeID::get(); - typename ConceptMap::iterator it; - bool wasInserted; - std::tie(it, wasInserted) = analyses.try_emplace(id); - + auto it = analyses.find(id); // If we don't have a cached analysis for this operation, compute it // directly and add it to the cache. - if (wasInserted) { + if (analyses.end() == it) { if (pi) pi->runBeforeAnalysis(getAnalysisName(), id, ir); - it->second = std::make_unique>(op); + bool wasInserted; + std::tie(it, wasInserted) = + analyses.insert({id, constructAnalysis(am, op)}); + assert(wasInserted); if (pi) pi->runAfterAnalysis(getAnalysisName(), id, ir); @@ -191,6 +212,22 @@ return static_cast &>(*it->second).analysis; } + /// Construct analysis using two arguments contructor (OpT, AnalysisManager) + template ::value> * = nullptr> + static auto constructAnalysis(AnalysisManager &am, OpT op) { + return std::make_unique>(op, am); + } + + /// Construct analysis using single argument contructor (OpT) + template ::value> * = nullptr> + static auto constructAnalysis(AnalysisManager &, OpT op) { + return std::make_unique>(op); + } + Operation *ir; ConceptMap analyses; }; @@ -273,14 +310,15 @@ /// Query for the given analysis for the current operation. template AnalysisT &getAnalysis() { - return impl->analyses.getAnalysis(getPassInstrumentor()); + return impl->analyses.getAnalysis(getPassInstrumentor(), *this); } /// Query for the given analysis for the current operation of a specific /// derived operation type. template AnalysisT &getAnalysis() { - return impl->analyses.getAnalysis(getPassInstrumentor()); + return impl->analyses.getAnalysis(getPassInstrumentor(), + *this); } /// Query for a cached entry of the given analysis on the current operation. diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -159,4 +159,91 @@ EXPECT_TRUE(am.getCachedAnalysis().hasValue()); } +struct AnalysisWithDependency { + AnalysisWithDependency(Operation *, AnalysisManager &am) { + am.getAnalysis(); + } + + bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { + return !pa.isPreserved() || + !pa.isPreserved(); + } +}; + +TEST(AnalysisManagerTest, DependentAnalysis) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + am.getAnalysis(); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + + detail::PreservedAnalyses pa; + pa.preserve(); + am.invalidate(pa); + + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); +} + +struct AnalysisWithNestedDependency { + AnalysisWithNestedDependency(Operation *, AnalysisManager &am) { + am.getAnalysis(); + } + + bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { + return !pa.isPreserved() || + !pa.isPreserved(); + } +}; + +TEST(AnalysisManagerTest, NestedDependentAnalysis) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + am.getAnalysis(); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + + detail::PreservedAnalyses pa; + pa.preserve(); + pa.preserve(); + am.invalidate(pa); + + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); +} + +struct AnalysisWith2Ctors { + AnalysisWith2Ctors(Operation *) { ctor1called = true; } + + AnalysisWith2Ctors(Operation *, AnalysisManager &) { ctor2called = true; } + + bool ctor1called = false; + bool ctor2called = false; +}; + +TEST(AnalysisManagerTest, DependentAnalysis2Ctors) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + auto &an = am.getAnalysis(); + EXPECT_FALSE(an.ctor1called); + EXPECT_TRUE(an.ctor2called); +} + } // end namespace