diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -13,10 +13,13 @@ #include "mlir/Pass/PassInstrumentation.h" #include "mlir/Support/LLVM.h" #include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/Support/TypeName.h" namespace mlir { +class AnalysisManager; + //===----------------------------------------------------------------------===// // Analysis Preservation and Concept Modeling //===----------------------------------------------------------------------===// @@ -59,6 +62,15 @@ bool isPreserved(TypeID id) const { return preservedIDs.count(id); } private: + /// Remove the analysis from preserved set + template + void unpreserve() { + preservedIDs.erase(TypeID::get()); + } + + template + friend struct AnalysisModel; + /// The set of analyses that are known to be preserved. SmallPtrSet preservedIDs; }; @@ -92,7 +104,7 @@ /// fine-tuned invalidation in cases where an analysis wasn't explicitly /// marked preserved, but may be preserved(or invalidated) based upon other /// properties such as analyses sets. - virtual bool isInvalidated(const PreservedAnalyses &pa) = 0; + virtual bool isInvalidated(PreservedAnalyses &pa) = 0; }; /// A derived analysis model used to hold a specific analysis object. @@ -102,8 +114,11 @@ : analysis(std::forward(args)...) {} /// A hook used to query analyses for invalidation. - bool isInvalidated(const PreservedAnalyses &pa) final { - return analysis_impl::isInvalidated(analysis, pa); + bool isInvalidated(PreservedAnalyses &pa) final { + bool result = analysis_impl::isInvalidated(analysis, pa); + if (result) + pa.unpreserve(); + return result; } /// The actual analysis object. @@ -114,7 +129,7 @@ /// computation, caching, and invalidation of analyses takes place here. class AnalysisMap { /// A mapping between an analysis id and an existing analysis instance. - using ConceptMap = DenseMap>; + using ConceptMap = llvm::MapVector>; /// Utility to return the name of the given analysis class. template static StringRef getAnalysisName() { @@ -129,17 +144,19 @@ /// Get an analysis for the current IR unit, computing it if necessary. template - AnalysisT &getAnalysis(PassInstrumentor *pi) { - return getAnalysisImpl(pi, ir); + AnalysisT &getAnalysis(PassInstrumentor *pi, AnalysisManager &am) { + return getAnalysisImpl(pi, ir, am); } /// Get an analysis for the current IR unit assuming it's of specific derived /// operation type. template - typename std::enable_if::value, - AnalysisT &>::type - getAnalysis(PassInstrumentor *pi) { - return getAnalysisImpl(pi, cast(ir)); + std::enable_if_t< + std::is_constructible::value || + std::is_constructible::value, + AnalysisT &> + getAnalysis(PassInstrumentor *pi, AnalysisManager &am) { + return getAnalysisImpl(pi, cast(ir), am); } /// Get a cached analysis instance if one exists, otherwise return null. @@ -160,30 +177,31 @@ /// Invalidate any cached analyses based upon the given set of preserved /// analyses. void invalidate(const PreservedAnalyses &pa) { + PreservedAnalyses paCopy(pa); // Remove any analyses that were invalidated. - for (auto it = analyses.begin(), e = analyses.end(); it != e;) { - auto curIt = it++; - if (curIt->second->isInvalidated(pa)) - analyses.erase(curIt); - } + // As we using MapVector, order of insertion is preserved and + // dependencies always go before users, so we need only one iteration + analyses.remove_if( + [&](auto &val) { return val.second->isInvalidated(paCopy); }); } private: template - AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op) { + AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op, + AnalysisManager &am) { TypeID id = TypeID::get(); - typename ConceptMap::iterator it; - bool wasInserted; - std::tie(it, wasInserted) = analyses.try_emplace(id); - + auto it = analyses.find(id); // If we don't have a cached analysis for this operation, compute it // directly and add it to the cache. - if (wasInserted) { + if (analyses.end() == it) { if (pi) pi->runBeforeAnalysis(getAnalysisName(), id, ir); - it->second = std::make_unique>(op); + bool wasInserted; + std::tie(it, wasInserted) = + analyses.insert({id, constructAnalysis(am, op)}); + assert(wasInserted); if (pi) pi->runAfterAnalysis(getAnalysisName(), id, ir); @@ -191,6 +209,22 @@ return static_cast &>(*it->second).analysis; } + /// Construct analysis using two arguments contructor (OpT, AnalysisManager) + template ::value> * = nullptr> + static auto constructAnalysis(AnalysisManager &am, OpT op) { + return std::make_unique>(op, am); + } + + /// Construct analysis using single argument contructor (OpT) + template ::value> * = nullptr> + static auto constructAnalysis(AnalysisManager &, OpT op) { + return std::make_unique>(op); + } + Operation *ir; ConceptMap analyses; }; @@ -273,14 +307,15 @@ /// Query for the given analysis for the current operation. template AnalysisT &getAnalysis() { - return impl->analyses.getAnalysis(getPassInstrumentor()); + return impl->analyses.getAnalysis(getPassInstrumentor(), *this); } /// Query for the given analysis for the current operation of a specific /// derived operation type. template AnalysisT &getAnalysis() { - return impl->analyses.getAnalysis(getPassInstrumentor()); + return impl->analyses.getAnalysis(getPassInstrumentor(), + *this); } /// Query for a cached entry of the given analysis on the current operation. diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -159,4 +159,91 @@ EXPECT_TRUE(am.getCachedAnalysis().hasValue()); } +struct AnalysisWithDependency { + AnalysisWithDependency(Operation *, AnalysisManager &am) { + am.getAnalysis(); + } + + bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { + return !pa.isPreserved() || + !pa.isPreserved(); + } +}; + +TEST(AnalysisManagerTest, DependentAnalysis) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + am.getAnalysis(); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + + detail::PreservedAnalyses pa; + pa.preserve(); + am.invalidate(pa); + + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); +} + +struct AnalysisWithNestedDependency { + AnalysisWithNestedDependency(Operation *, AnalysisManager &am) { + am.getAnalysis(); + } + + bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { + return !pa.isPreserved() || + !pa.isPreserved(); + } +}; + +TEST(AnalysisManagerTest, NestedDependentAnalysis) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + am.getAnalysis(); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + + detail::PreservedAnalyses pa; + pa.preserve(); + pa.preserve(); + am.invalidate(pa); + + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); +} + +struct AnalysisWith2Ctors { + AnalysisWith2Ctors(Operation *) { ctor1called = true; } + + AnalysisWith2Ctors(Operation *, AnalysisManager &) { ctor2called = true; } + + bool ctor1called = false; + bool ctor2called = false; +}; + +TEST(AnalysisManagerTest, DependentAnalysis2Ctors) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + auto &an = am.getAnalysis(); + EXPECT_FALSE(an.ctor1called); + EXPECT_TRUE(an.ctor2called); +} + } // end namespace