diff --git a/mlir/include/mlir/Pass/AnalysisManager.h b/mlir/include/mlir/Pass/AnalysisManager.h --- a/mlir/include/mlir/Pass/AnalysisManager.h +++ b/mlir/include/mlir/Pass/AnalysisManager.h @@ -128,25 +128,20 @@ explicit AnalysisMap(Operation *ir) : ir(ir) {} /// Get an analysis for the current IR unit, computing it if necessary. - template AnalysisT &getAnalysis(PassInstrumentor *pi) { - TypeID id = TypeID::get(); - - typename ConceptMap::iterator it; - bool wasInserted; - std::tie(it, wasInserted) = analyses.try_emplace(id); - - // If we don't have a cached analysis for this function, compute it directly - // and add it to the cache. - if (wasInserted) { - if (pi) - pi->runBeforeAnalysis(getAnalysisName(), id, ir); - - it->second = std::make_unique>(ir); + template + typename std::enable_if::value, + AnalysisT &>::type + getAnalysis(PassInstrumentor *pi) { + return getAnalysisImpl(pi, ir); + } - if (pi) - pi->runAfterAnalysis(getAnalysisName(), id, ir); - } - return static_cast &>(*it->second).analysis; + /// Get an analysis for the current IR unit assumig its of specific derived + // operation type. + template + typename std::enable_if::value, + AnalysisT &>::type + getAnalysis(PassInstrumentor *pi) { + return getAnalysisImpl(pi, cast(ir)); } /// Get a cached analysis instance if one exists, otherwise return null. @@ -175,6 +170,30 @@ } } +private: + template + AnalysisT &getAnalysisImpl(PassInstrumentor *pi, Args &&...args) { + TypeID id = TypeID::get(); + + typename ConceptMap::iterator it; + bool wasInserted; + std::tie(it, wasInserted) = analyses.try_emplace(id); + + // If we don't have a cached analysis for this function, compute it directly + // and add it to the cache. + if (wasInserted) { + if (pi) + pi->runBeforeAnalysis(getAnalysisName(), id, ir); + + it->second = std::make_unique>( + std::forward(args)...); + + if (pi) + pi->runAfterAnalysis(getAnalysisName(), id, ir); + } + return static_cast &>(*it->second).analysis; + } + private: Operation *ir; ConceptMap analyses; @@ -230,11 +249,26 @@ return None; } + // Query for a cached analysis on the given parent operation of a specific + // derived operation type. + template + Optional> + getCachedParentAnalysis(OpT parentOp) const { + return getCachedParentAnalysis(parentOp.getOperation()); + } + // Query for the given analysis for the current operation. template AnalysisT &getAnalysis() { return impl->analyses.getAnalysis(getPassInstrumentor()); } + // Query for the given analysis for the current operation of a specific + // derived operation type. + template + AnalysisT &getAnalysis() { + return impl->analyses.getAnalysis(getPassInstrumentor()); + } + // Query for a cached entry of the given analysis on the current operation. template Optional> getCachedAnalysis() const { @@ -246,6 +280,13 @@ return slice(op).template getAnalysis(); } + /// Query for an analysis of a child operation of a specifc derived operation + // type, constructing it if necessary. + template + AnalysisT &getChildAnalysis(OpT child) { + return slice(child.getOperation()).template getAnalysis(); + } + /// Query for a cached analysis of a child operation, or return null. template Optional> @@ -257,6 +298,14 @@ return it->second->analyses.getCachedAnalysis(); } + /// Query for a cached analysis of a child operation of specific derived + /// operation type, or return null. + template + Optional> + getCachedChildAnalysis(OpT child) const { + return getCachedChildAnalysis(child.getOperation()); + } + /// Get an analysis manager for the given child operation. AnalysisManager slice(Operation *op); diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -167,6 +167,13 @@ return getAnalysisManager().getAnalysis(); } + /// Query an analysis for the current ir unit of a specific derived operation + /// type. + template + AnalysisT &getAnalysis() { + return getAnalysisManager().getAnalysis(); + } + /// Query a cached instance of an analysis for the current ir unit if one /// exists. template @@ -187,12 +194,22 @@ getPassState().preservedAnalyses.preserve(id); } - /// Returns the analysis for the parent operation if it exists. + /// Returns the analysis for the given parent operation if it exists. template Optional> getCachedParentAnalysis(Operation *parent) { return getAnalysisManager().getCachedParentAnalysis(parent); } + + /// Returns the analysis for the given parent operation of specific derived + /// operation type if it exists. + template + Optional> + getCachedParentAnalysis(OpT parent) { + return getAnalysisManager().getCachedParentAnalysis(parent); + } + + /// Returns the analysis for the parent operation if it exists. template Optional> getCachedParentAnalysis() { return getAnalysisManager().getCachedParentAnalysis( @@ -206,12 +223,27 @@ return getAnalysisManager().getCachedChildAnalysis(child); } + /// Returns the analysis for the given child operation of specific derived + /// operation type if it exists. + template + Optional> + getCachedChildAnalysis(OpT child) { + return getAnalysisManager().getCachedChildAnalysis(child); + } + /// Returns the analysis for the given child operation, or creates it if it /// doesn't exist. template AnalysisT &getChildAnalysis(Operation *child) { return getAnalysisManager().getChildAnalysis(child); } + /// Returns the analysis for the given child operation of specific derived + /// operation type, or creates it if it doesn't exist. + template + AnalysisT &getChildAnalysis(OpTy child) { + return getAnalysisManager().getChildAnalysis(child); + } + /// Returns the current analysis manager. AnalysisManager getAnalysisManager() { return getPassState().analysisManager; @@ -286,6 +318,13 @@ /// Return the current operation being transformed. OpT getOperation() { return cast(Pass::getOperation()); } + + /// Query an analysis for the current operation of the specific derived + /// operation type. + template + AnalysisT &getAnalysis() { + return Pass::getAnalysis(); + } }; /// Pass to transform an operation. diff --git a/mlir/unittests/Pass/AnalysisManagerTest.cpp b/mlir/unittests/Pass/AnalysisManagerTest.cpp --- a/mlir/unittests/Pass/AnalysisManagerTest.cpp +++ b/mlir/unittests/Pass/AnalysisManagerTest.cpp @@ -9,6 +9,8 @@ #include "mlir/Pass/AnalysisManager.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" #include "gtest/gtest.h" using namespace mlir; @@ -22,6 +24,9 @@ struct OtherAnalysis { OtherAnalysis(Operation *) {} }; +struct OpSpecificAnalysis { + OpSpecificAnalysis(ModuleOp) {} +}; TEST(AnalysisManagerTest, FineGrainModuleAnalysisPreservation) { MLIRContext context; @@ -138,4 +143,18 @@ am.invalidate(pa); EXPECT_TRUE(am.getCachedAnalysis().hasValue()); } + +TEST(AnalysisManagerTest, OpSpecificAnalysis) { + MLIRContext context; + + // Create a module. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + ModuleAnalysisManager mam(*module, /*passInstrumentor=*/nullptr); + AnalysisManager am = mam; + + // Query the op specific analysis for the module and verify that its cached. + am.getAnalysis(); + EXPECT_TRUE(am.getCachedAnalysis().hasValue()); +} + } // end namespace diff --git a/mlir/unittests/Pass/CMakeLists.txt b/mlir/unittests/Pass/CMakeLists.txt --- a/mlir/unittests/Pass/CMakeLists.txt +++ b/mlir/unittests/Pass/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_unittest(MLIRPassTests AnalysisManagerTest.cpp + PassManagerTest.cpp ) target_link_libraries(MLIRPassTests PRIVATE diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp new file mode 100644 --- /dev/null +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -0,0 +1,77 @@ +//===- PassManagerTest.cpp - PassManager unit tests -----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Pass/PassManager.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Function.h" +#include "mlir/Pass/Pass.h" +#include "gtest/gtest.h" + +using namespace mlir; +using namespace mlir::detail; + +namespace { +/// Analysis that operates on any operation. +struct GenericAnalysis { + GenericAnalysis(Operation *op) : is_func(isa(op)) {} + const bool is_func; +}; + +/// Analysis that operates on a specific operation. +struct OpSpecificAnalysis { + OpSpecificAnalysis(FuncOp op) : is_secret(op.getName() == "secret") {} + const bool is_secret; +}; + +/// Simple pass to annotate a FuncOp with the results of analysis. +/// Note: not using FunctionPass as it skip external functions. +struct AnnotateFunctionPass + : public PassWrapper> { + void runOnOperation() override { + FuncOp op = getOperation(); + Builder builder(op.getParentOfType()); + + auto &ga = getAnalysis(); + auto &sa = getAnalysis(); + + op.setAttr("is_func", builder.getBoolAttr(ga.is_func)); + op.setAttr("is_secret", builder.getBoolAttr(sa.is_secret)); + } +}; + +TEST(PassManagerTest, OpSpecificAnalysis) { + MLIRContext context; + Builder builder(&context); + + // Create a module with 2 functions. + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + for (StringRef name : {"secret", "not_secret"}) { + FuncOp func = + FuncOp::create(builder.getUnknownLoc(), name, + builder.getFunctionType(llvm::None, llvm::None)); + module->push_back(func); + } + + // Instantiate and run our pass. + PassManager pm(&context); + pm.addNestedPass(std::make_unique()); + LogicalResult result = pm.run(module.get()); + EXPECT_TRUE(succeeded(result)); + + // Verify that each function got annotated with expected attributes. + for (FuncOp func : module->getOps()) { + ASSERT_TRUE(func.getAttr("is_func").isa()); + EXPECT_TRUE(func.getAttr("is_func").cast().getValue()); + + bool is_secret = func.getName() == "secret"; + ASSERT_TRUE(func.getAttr("is_secret").isa()); + EXPECT_EQ(func.getAttr("is_secret").cast().getValue(), is_secret); + } +} + +} // end namespace