diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -56,13 +56,12 @@ TypeID getTypeID() const { return passID; } /// Returns the pass info for the specified pass class or null if unknown. - static const PassInfo *lookupPassInfo(TypeID passID); - template static const PassInfo *lookupPassInfo() { - return lookupPassInfo(TypeID::get()); - } + static const PassInfo *lookupPassInfo(StringRef passArg); - /// Returns the pass info for this pass. - const PassInfo *lookupPassInfo() const { return lookupPassInfo(getTypeID()); } + /// Returns the pass info for this pass, or null if unknown. + const PassInfo *lookupPassInfo() const { + return lookupPassInfo(getArgument()); + } /// Returns the derived pass name. virtual StringRef getName() const = 0; @@ -76,11 +75,7 @@ /// Returns the command line argument used when registering this pass. Return /// an empty string if one does not exist. - virtual StringRef getArgument() const { - if (const PassInfo *passInfo = lookupPassInfo()) - return passInfo->getPassArgument(); - return ""; - } + virtual StringRef getArgument() const { return ""; } /// Returns the name of the operation that this pass operates on, or None if /// this is a generic OperationPass. diff --git a/mlir/include/mlir/Pass/PassRegistry.h b/mlir/include/mlir/Pass/PassRegistry.h --- a/mlir/include/mlir/Pass/PassRegistry.h +++ b/mlir/include/mlir/Pass/PassRegistry.h @@ -108,7 +108,7 @@ public: /// PassInfo constructor should not be invoked directly, instead use /// PassRegistration or registerPass. - PassInfo(StringRef arg, StringRef description, TypeID passID, + PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator); }; diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -19,7 +19,11 @@ using namespace detail; /// Static mapping of all of the registered passes. -static llvm::ManagedStatic> passRegistry; +static llvm::ManagedStatic> passRegistry; + +/// A mapping of the above pass registry entries to the corresponding TypeID +/// of the pass that they generate. +static llvm::ManagedStatic> passRegistryTypeIDs; /// Static mapping of all of the registered pass pipelines. static llvm::ManagedStatic> @@ -94,7 +98,7 @@ // PassInfo //===----------------------------------------------------------------------===// -PassInfo::PassInfo(StringRef arg, StringRef description, TypeID passID, +PassInfo::PassInfo(StringRef arg, StringRef description, const PassAllocatorFunction &allocator) : PassRegistryEntry( arg, description, buildDefaultRegistryFn(allocator), @@ -105,18 +109,23 @@ void mlir::registerPass(StringRef arg, StringRef description, const PassAllocatorFunction &function) { - // TODO: We should use the 'arg' as the lookup key instead of the pass id. - TypeID passID = function()->getTypeID(); - PassInfo passInfo(arg, description, passID, function); - passRegistry->try_emplace(passID, passInfo); + PassInfo passInfo(arg, description, function); + passRegistry->try_emplace(arg, passInfo); + + // Verify that the registered pass has the same ID as any registered to this + // arg before it. + TypeID entryTypeID = function()->getTypeID(); + auto it = passRegistryTypeIDs->try_emplace(arg, entryTypeID).first; + if (it->second != entryTypeID) { + llvm_unreachable("pass allocator creates a different pass than previously " + "registered"); + } } -/// Returns the pass info for the specified pass class or null if unknown. -const PassInfo *mlir::Pass::lookupPassInfo(TypeID passID) { - auto it = passRegistry->find(passID); - if (it == passRegistry->end()) - return nullptr; - return &it->getSecond(); +/// Returns the pass info for the specified pass argument or null if unknown. +const PassInfo *mlir::Pass::lookupPassInfo(StringRef passArg) { + auto it = passRegistry->find(passArg); + return it == passRegistry->end() ? nullptr : &it->second; } //===----------------------------------------------------------------------===// @@ -433,12 +442,8 @@ } // If not, then this must be a specific pass name. - for (auto &passIt : *passRegistry) { - if (passIt.second.getPassArgument() == element.name) { - element.registryEntry = &passIt.second; - return success(); - } - } + if ((element.registryEntry = Pass::lookupPassInfo(element.name))) + return success(); // Emit an error for the unknown pass. auto *rawLoc = element.name.data(); diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -16,9 +16,11 @@ struct TestModulePass : public PassWrapper> { void runOnOperation() final {} + StringRef getArgument() const final { return "test-module-pass"; } }; struct TestFunctionPass : public PassWrapper { void runOnFunction() final {} + StringRef getArgument() const final { return "test-function-pass"; } }; class TestOptionsPass : public PassWrapper { public: @@ -41,6 +43,7 @@ } void runOnFunction() final {} + StringRef getArgument() const final { return "test-options-pass"; } ListOption listOption{*this, "list", llvm::cl::MiscFlags::CommaSeparated, llvm::cl::desc("Example list option")}; @@ -56,6 +59,7 @@ class TestCrashRecoveryPass : public PassWrapper> { void runOnOperation() final { abort(); } + StringRef getArgument() const final { return "test-pass-crash"; } }; /// A test pass that always fails to enable testing the failure recovery