diff --git a/mlir/include/mlir/Pass/AnalysisUtilities.h b/mlir/include/mlir/Pass/AnalysisUtilities.h new file mode 100644 --- /dev/null +++ b/mlir/include/mlir/Pass/AnalysisUtilities.h @@ -0,0 +1,38 @@ +//===- AnalysisUtilities.h - Analysis Management Utilities ------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_PASS_ANALYSISUTILITIES_H +#define MLIR_PASS_ANALYSISUTILITIES_H + +#include "mlir/Pass/AnalysisManager.h" + +namespace mlir { + +/// Helper class for analyses with dependencies. Auto-generates convenient +/// constructor (Op, Dependencies...) and isInvalidated check based on template +/// dependencies list. +template +class AnalysisWrapper : public Analysis { +public: + template + AnalysisWrapper(Op op, AnalysisManager &am) + : Analysis(op, am.getAnalysis()...) {} + + static bool isInvalidated(const AnalysisManager::PreservedAnalyses &pa) { + using SelfType = AnalysisWrapper; + for (bool preserved : + {pa.isPreserved(), pa.isPreserved()...}) + if (!preserved) + return true; + + return false; + } +}; +} // end namespace mlir + +#endif // MLIR_PASS_ANALYSISUTILITIES_H diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -9,6 +9,7 @@ #include "mlir/Pass/AnalysisManager.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/AnalysisUtilities.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" #include "gtest/gtest.h" @@ -246,4 +247,54 @@ EXPECT_TRUE(an.ctor2called); } +struct Analysis0Deps { + Analysis0Deps(Operation *) {} +}; + +using WrappedAnalysis0Deps = AnalysisWrapper; + +TEST(AnalysisManagerTest, AnalysisWrapper0Deps) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + am.getAnalysis(); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + + detail::PreservedAnalyses pa; + am.invalidate(pa); + + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); +} + +struct Analysis2Deps { + Analysis2Deps(Operation *, MyAnalysis &, OtherAnalysis &) {} +}; + +using WrappedAnalysis2Deps = + AnalysisWrapper; + +TEST(AnalysisManagerTest, AnalysisWrapper2Deps) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + am.getAnalysis(); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); + + detail::PreservedAnalyses pa; + pa.preserve(); + am.invalidate(pa); + + EXPECT_FALSE(am.getCachedAnalysis().hasValue()); +} + } // end namespace