diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -17,6 +17,8 @@ #include "llvm/Support/TypeName.h" namespace mlir { +class AnalysisManager; + //===----------------------------------------------------------------------===// // Analysis Preservation and Concept Modeling //===----------------------------------------------------------------------===// @@ -129,17 +131,16 @@ /// Get an analysis for the current IR unit, computing it if necessary. template - AnalysisT &getAnalysis(PassInstrumentor *pi) { - return getAnalysisImpl(pi, ir); + AnalysisT &getAnalysis(PassInstrumentor *pi, AnalysisManager &am) { + return getAnalysisImpl(pi, ir, am); } /// Get an analysis for the current IR unit assuming it's of specific derived /// operation type. template - typename std::enable_if::value, - AnalysisT &>::type - getAnalysis(PassInstrumentor *pi) { - return getAnalysisImpl(pi, cast(ir)); + auto getAnalysis(PassInstrumentor *pi, AnalysisManager &am) + -> decltype(getAnalysisImpl(pi, cast(ir), am)) { + return getAnalysisImpl(pi, cast(ir), am); } /// Get a cached analysis instance if one exists, otherwise return null. @@ -170,20 +171,21 @@ private: template - AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op) { + AnalysisT &getAnalysisImpl(PassInstrumentor *pi, OpT op, + AnalysisManager &am) { TypeID id = TypeID::get(); - typename ConceptMap::iterator it; - bool wasInserted; - std::tie(it, wasInserted) = analyses.try_emplace(id); - + auto it = analyses.find(id); // If we don't have a cached analysis for this operation, compute it // directly and add it to the cache. - if (wasInserted) { + if (analyses.end() == it) { if (pi) pi->runBeforeAnalysis(getAnalysisName(), id, ir); - it->second = std::make_unique>(op); + bool wasInserted; + std::tie(it, wasInserted) = + analyses.try_emplace(id, constructAnalysis(am, op)); + assert(wasInserted); if (pi) pi->runAfterAnalysis(getAnalysisName(), id, ir); @@ -191,6 +193,20 @@ return static_cast &>(*it->second).analysis; } + template ::value> * = nullptr> + static auto constructAnalysis(AnalysisManager &am, OpT op) { + return std::make_unique>(op, am); + } + + template ::value> * = nullptr> + static auto constructAnalysis(AnalysisManager &, OpT op) { + return std::make_unique>(op); + } + Operation *ir; ConceptMap analyses; }; @@ -273,14 +289,15 @@ /// Query for the given analysis for the current operation. template AnalysisT &getAnalysis() { - return impl->analyses.getAnalysis(getPassInstrumentor()); + return impl->analyses.getAnalysis(getPassInstrumentor(), *this); } /// Query for the given analysis for the current operation of a specific /// derived operation type. template AnalysisT &getAnalysis() { - return impl->analyses.getAnalysis(getPassInstrumentor()); + return impl->analyses.getAnalysis(getPassInstrumentor(), + *this); } /// Query for a cached entry of the given analysis on the current operation. diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -159,4 +159,57 @@ EXPECT_TRUE(am.getCachedAnalysis().hasValue()); } +struct AnalysisWithDependency { + AnalysisWithDependency(Operation *, AnalysisManager &am) { + am.getAnalysis(); + } + + bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { + return !pa.isPreserved() || + !pa.isPreserved(); + } +}; + +TEST(AnalysisManagerTest, DependentAnalysis) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + am.getAnalysis(); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + + detail::PreservedAnalyses pa; + pa.preserve(); + am.invalidate(pa); + + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); +} + +struct AnalysisWith2Ctors { + AnalysisWith2Ctors(Operation *) { ctor1called = true; } + + AnalysisWith2Ctors(Operation *, AnalysisManager &) { ctor2called = true; } + + bool ctor1called = false; + bool ctor2called = false; +}; + +TEST(AnalysisManagerTest, DependentAnalysis2Ctors) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + auto &an = am.getAnalysis(); + EXPECT_FALSE(an.ctor1called); + EXPECT_TRUE(an.ctor2called); +} + } // end namespace