diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -47,7 +47,8 @@ /// other OpPassManagers or the top-level PassManager. class OpPassManager { public: - OpPassManager(Identifier name, MLIRContext *context, bool verifyPasses); + OpPassManager(Identifier name, bool verifyPasses); + OpPassManager(StringRef name, bool verifyPasses); OpPassManager(OpPassManager &&rhs); OpPassManager(const OpPassManager &rhs); ~OpPassManager(); @@ -73,7 +74,7 @@ OpPassManager &nest(Identifier nestedName); OpPassManager &nest(StringRef nestedName); template OpPassManager &nest() { - return nest(Identifier::get(OpT::getOperationName(), getContext())); + return nest(OpT::getOperationName()); } /// Add the given pass to this pass manager. If this pass has a concrete @@ -89,11 +90,11 @@ /// Returns the number of passes held by this manager. size_t size() const; - /// Return an instance of the context. - MLIRContext *getContext() const; + /// Return the operation name that this pass manager operates on. + Identifier getOpName(MLIRContext &context) const; /// Return the operation name that this pass manager operates on. - Identifier getOpName() const; + StringRef getOpName() const; /// Returns the internal implementation instance. detail::OpPassManagerImpl &getImpl(); @@ -151,6 +152,9 @@ LLVM_NODISCARD LogicalResult run(ModuleOp module); + /// Return an instance of the context. + MLIRContext *getContext() const { return context; } + /// Enable support for the pass manager to generate a reproducer on the event /// of a crash or a pass failure. `outputFile` is a .mlir filename used to /// write the generated reproducer. If `genLocalReproducer` is true, the pass @@ -304,6 +308,8 @@ runWithCrashRecovery(MutableArrayRef> passes, ModuleOp module, AnalysisManager am); + MLIRContext *context; + /// Flag that specifies if pass statistics should be dumped. Optional passStatisticsMode; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -92,8 +92,10 @@ namespace mlir { namespace detail { struct OpPassManagerImpl { - OpPassManagerImpl(Identifier name, MLIRContext *ctx, bool verifyPasses) - : name(name), context(ctx), verifyPasses(verifyPasses) {} + OpPassManagerImpl(Identifier identifier, bool verifyPasses) + : name(identifier), identifier(identifier), verifyPasses(verifyPasses) {} + OpPassManagerImpl(StringRef name, bool verifyPasses) + : name(name), verifyPasses(verifyPasses) {} /// Merge the passes of this pass manager into the one provided. void mergeInto(OpPassManagerImpl &rhs); @@ -101,9 +103,7 @@ /// Nest a new operation pass manager for the given operation kind under this /// pass manager. OpPassManager &nest(Identifier nestedName); - OpPassManager &nest(StringRef nestedName) { - return nest(Identifier::get(nestedName, getContext())); - } + OpPassManager &nest(StringRef nestedName); /// Add the given pass to this pass manager. If this pass has a concrete /// operation type, it must be the same type as this pass manager. @@ -117,14 +117,18 @@ /// pass. void splitAdaptorPasses(); - /// Return an instance of the context. - MLIRContext *getContext() const { return context; } + Identifier getOpName(MLIRContext &context) { + if (!identifier) + identifier = Identifier::get(name, &context); + return *identifier; + } /// The name of the operation that passes of this pass manager operate on. - Identifier name; + StringRef name; - /// The current context for this pass manager - MLIRContext *context; + /// The cached identifier (internalized in the context) for the name of the + /// operation that passes of this pass manager operate on. + Optional identifier; /// Flag that specifies if the IR should be verified after each pass has run. bool verifyPasses : 1; @@ -143,7 +147,14 @@ } OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) { - OpPassManager nested(nestedName, getContext(), verifyPasses); + OpPassManager nested(nestedName, verifyPasses); + auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); + addPass(std::unique_ptr(adaptor)); + return adaptor->getPassManagers().front(); +} + +OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) { + OpPassManager nested(nestedName, verifyPasses); auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); @@ -153,7 +164,7 @@ // If this pass runs on a different operation than this pass manager, then // implicitly nest a pass manager for this operation. auto passOpName = pass->getOpName(); - if (passOpName && passOpName != name.strref()) + if (passOpName && passOpName != name) return nest(*passOpName).addPass(std::move(pass)); passes.emplace_back(std::move(pass)); @@ -240,14 +251,14 @@ // OpPassManager //===----------------------------------------------------------------------===// -OpPassManager::OpPassManager(Identifier name, MLIRContext *context, - bool verifyPasses) - : impl(new OpPassManagerImpl(name, context, verifyPasses)) {} +OpPassManager::OpPassManager(Identifier name, bool verifyPasses) + : impl(new OpPassManagerImpl(name, verifyPasses)) {} +OpPassManager::OpPassManager(StringRef name, bool verifyPasses) + : impl(new OpPassManagerImpl(name, verifyPasses)) {} OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {} OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; } OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) { - impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->getContext(), - rhs.impl->verifyPasses)); + impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses)); for (auto &pass : rhs.impl->passes) impl->passes.emplace_back(pass->clone()); return *this; @@ -290,11 +301,13 @@ /// Returns the internal implementation instance. OpPassManagerImpl &OpPassManager::getImpl() { return *impl; } -/// Return an instance of the context. -MLIRContext *OpPassManager::getContext() const { return impl->getContext(); } +/// Return the operation name that this pass manager operates on. +StringRef OpPassManager::getOpName() const { return impl->name; } /// Return the operation name that this pass manager operates on. -Identifier OpPassManager::getOpName() const { return impl->name; } +Identifier OpPassManager::getOpName(MLIRContext &context) const { + return impl->getOpName(context); +} /// Prints out the given passes as the textual representation of a pipeline. static void printAsTextualPipeline(ArrayRef> passes, @@ -389,12 +402,22 @@ /// Find an operation pass manager that can operate on an operation of the given /// type, or nullptr if one does not exist. static OpPassManager *findPassManagerFor(MutableArrayRef mgrs, - Identifier name) { + StringRef name) { auto it = llvm::find_if( mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; }); return it == mgrs.end() ? nullptr : &*it; } +/// Find an operation pass manager that can operate on an operation of the given +/// type, or nullptr if one does not exist. +static OpPassManager *findPassManagerFor(MutableArrayRef mgrs, + Identifier name, + MLIRContext &context) { + auto it = llvm::find_if( + mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; }); + return it == mgrs.end() ? nullptr : &*it; +} + OpToOpPassAdaptor::OpToOpPassAdaptor(OpPassManager &&mgr) { mgrs.emplace_back(std::move(mgr)); } @@ -421,8 +444,7 @@ // After coalescing, sort the pass managers within rhs by name. llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(), [](const OpPassManager *lhs, const OpPassManager *rhs) { - return lhs->getOpName().strref().compare( - rhs->getOpName().strref()); + return lhs->getOpName().compare(rhs->getOpName()); }); } @@ -457,13 +479,14 @@ auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier()); if (!mgr) continue; + Identifier opName = mgr->getOpName(*getOperation()->getContext()); // Run the held pipeline over the current operation. if (instrumentor) - instrumentor->runBeforePipeline(mgr->getOpName(), parentInfo); + instrumentor->runBeforePipeline(opName, parentInfo); auto result = runPipeline(mgr->getPasses(), &op, am.nest(&op)); if (instrumentor) - instrumentor->runAfterPipeline(mgr->getOpName(), parentInfo); + instrumentor->runAfterPipeline(opName, parentInfo); if (failed(result)) return signalPassFailure(); @@ -539,12 +562,13 @@ findPassManagerFor(pms, it.first->getName().getIdentifier()); assert(pm && "expected valid pass manager for operation"); + Identifier opName = pm->getOpName(*getOperation()->getContext()); if (instrumentor) - instrumentor->runBeforePipeline(pm->getOpName(), parentInfo); + instrumentor->runBeforePipeline(opName, parentInfo); auto pipelineResult = runPipeline(pm->getPasses(), it.first, it.second); if (instrumentor) - instrumentor->runAfterPipeline(pm->getOpName(), parentInfo); + instrumentor->runAfterPipeline(opName, parentInfo); // Drop this thread from being tracked by the diagnostic handler. // After this task has finished, the thread may be used outside of @@ -737,9 +761,9 @@ //===----------------------------------------------------------------------===// PassManager::PassManager(MLIRContext *ctx, bool verifyPasses) - : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), ctx, + : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), verifyPasses), - passTiming(false), localReproducer(false) {} + context(ctx), passTiming(false), localReproducer(false) {} PassManager::~PassManager() {} diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp --- a/mlir/lib/Pass/PassStatistics.cpp +++ b/mlir/lib/Pass/PassStatistics.cpp @@ -116,7 +116,7 @@ // Print each of the children passes. for (OpPassManager &mgr : mgrs) { - auto name = ("'" + mgr.getOpName().strref() + "' Pipeline").str(); + auto name = ("'" + mgr.getOpName() + "' Pipeline").str(); printPassEntry(os, indent, name); for (Pass &pass : mgr.getPasses()) printPass(indent + 2, &pass);