diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h --- a/mlir/include/mlir/Pass/PassInstrumentation.h +++ b/mlir/include/mlir/Pass/PassInstrumentation.h @@ -9,12 +9,12 @@ #ifndef MLIR_PASS_PASSINSTRUMENTATION_H_ #define MLIR_PASS_PASSINSTRUMENTATION_H_ +#include "mlir/IR/Identifier.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" namespace mlir { class Operation; -class OperationName; class Pass; namespace detail { @@ -43,13 +43,13 @@ /// A callback to run before a pass pipeline is executed. This function takes /// the name of the operation type being operated on, and information related /// to the parent that spawned this pipeline. - virtual void runBeforePipeline(const OperationName &name, + virtual void runBeforePipeline(Identifier name, const PipelineParentInfo &parentInfo) {} /// A callback to run after a pass pipeline has executed. This function takes /// the name of the operation type being operated on, and information related /// to the parent that spawned this pipeline. - virtual void runAfterPipeline(const OperationName &name, + virtual void runAfterPipeline(Identifier name, const PipelineParentInfo &parentInfo) {} /// A callback to run before a pass is executed. This function takes a pointer @@ -90,12 +90,12 @@ /// See PassInstrumentation::runBeforePipeline for details. void - runBeforePipeline(const OperationName &name, + runBeforePipeline(Identifier name, const PassInstrumentation::PipelineParentInfo &parentInfo); /// See PassInstrumentation::runAfterPipeline for details. void - runAfterPipeline(const OperationName &name, + runAfterPipeline(Identifier name, const PassInstrumentation::PipelineParentInfo &parentInfo); /// See PassInstrumentation::runBeforePass for details. diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -26,9 +26,9 @@ namespace mlir { class AnalysisManager; +class Identifier; class MLIRContext; class ModuleOp; -class OperationName; class Operation; class Pass; class PassInstrumentation; @@ -47,7 +47,7 @@ /// other OpPassManagers or the top-level PassManager. class OpPassManager { public: - OpPassManager(OperationName name, bool verifyPasses); + OpPassManager(Identifier name, MLIRContext *context, bool verifyPasses); OpPassManager(OpPassManager &&rhs); OpPassManager(const OpPassManager &rhs); ~OpPassManager(); @@ -70,10 +70,10 @@ /// Nest a new operation pass manager for the given operation kind under this /// pass manager. - OpPassManager &nest(const OperationName &nestedName); + OpPassManager &nest(Identifier nestedName); OpPassManager &nest(StringRef nestedName); template OpPassManager &nest() { - return nest(OpT::getOperationName()); + return nest(Identifier::get(OpT::getOperationName(), getContext())); } /// Add the given pass to this pass manager. If this pass has a concrete @@ -93,7 +93,7 @@ MLIRContext *getContext() const; /// Return the operation name that this pass manager operates on. - const OperationName &getOpName() const; + Identifier getOpName() const; /// Returns the internal implementation instance. detail::OpPassManagerImpl &getImpl(); diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -92,17 +92,17 @@ namespace mlir { namespace detail { struct OpPassManagerImpl { - OpPassManagerImpl(OperationName name, bool verifyPasses) - : name(name), verifyPasses(verifyPasses) {} + OpPassManagerImpl(Identifier name, MLIRContext *ctx, bool verifyPasses) + : name(name), context(ctx), verifyPasses(verifyPasses) {} /// Merge the passes of this pass manager into the one provided. void mergeInto(OpPassManagerImpl &rhs); /// Nest a new operation pass manager for the given operation kind under this /// pass manager. - OpPassManager &nest(const OperationName &nestedName); + OpPassManager &nest(Identifier nestedName); OpPassManager &nest(StringRef nestedName) { - return nest(OperationName(nestedName, getContext())); + return nest(Identifier::get(nestedName, getContext())); } /// Add the given pass to this pass manager. If this pass has a concrete @@ -118,12 +118,13 @@ void splitAdaptorPasses(); /// Return an instance of the context. - MLIRContext *getContext() const { - return name.getAbstractOperation()->dialect.getContext(); - } + MLIRContext *getContext() const { return context; } /// The name of the operation that passes of this pass manager operate on. - OperationName name; + Identifier name; + + /// The current context for this pass manager + MLIRContext *context; /// Flag that specifies if the IR should be verified after each pass has run. bool verifyPasses : 1; @@ -141,8 +142,8 @@ passes.clear(); } -OpPassManager &OpPassManagerImpl::nest(const OperationName &nestedName) { - OpPassManager nested(nestedName, verifyPasses); +OpPassManager &OpPassManagerImpl::nest(Identifier nestedName) { + OpPassManager nested(nestedName, getContext(), verifyPasses); auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); @@ -152,7 +153,7 @@ // If this pass runs on a different operation than this pass manager, then // implicitly nest a pass manager for this operation. auto passOpName = pass->getOpName(); - if (passOpName && passOpName != name.getStringRef()) + if (passOpName && passOpName != name.strref()) return nest(*passOpName).addPass(std::move(pass)); passes.emplace_back(std::move(pass)); @@ -239,19 +240,14 @@ // OpPassManager //===----------------------------------------------------------------------===// -OpPassManager::OpPassManager(OperationName name, bool verifyPasses) - : impl(new OpPassManagerImpl(name, verifyPasses)) { - assert(name.getAbstractOperation() && - "OpPassManager can only operate on registered operations"); - assert(name.getAbstractOperation()->hasProperty( - OperationProperty::IsolatedFromAbove) && - "OpPassManager only supports operating on operations marked as " - "'IsolatedFromAbove'"); -} +OpPassManager::OpPassManager(Identifier name, MLIRContext *context, + bool verifyPasses) + : impl(new OpPassManagerImpl(name, context, verifyPasses)) {} OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {} OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; } OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) { - impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->verifyPasses)); + impl.reset(new OpPassManagerImpl(rhs.impl->name, rhs.impl->getContext(), + rhs.impl->verifyPasses)); for (auto &pass : rhs.impl->passes) impl->passes.emplace_back(pass->clone()); return *this; @@ -275,7 +271,7 @@ /// Nest a new operation pass manager for the given operation kind under this /// pass manager. -OpPassManager &OpPassManager::nest(const OperationName &nestedName) { +OpPassManager &OpPassManager::nest(Identifier nestedName) { return impl->nest(nestedName); } OpPassManager &OpPassManager::nest(StringRef nestedName) { @@ -298,7 +294,7 @@ MLIRContext *OpPassManager::getContext() const { return impl->getContext(); } /// Return the operation name that this pass manager operates on. -const OperationName &OpPassManager::getOpName() const { return impl->name; } +Identifier OpPassManager::getOpName() const { return impl->name; } /// Prints out the given passes as the textual representation of a pipeline. static void printAsTextualPipeline(ArrayRef> passes, @@ -336,6 +332,14 @@ LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, AnalysisManager am) { + if (!op->getName().getAbstractOperation()) + return op->emitOpError() + << "trying to schedule a pass on an unregistered operation"; + if (!op->getName().getAbstractOperation()->hasProperty( + OperationProperty::IsolatedFromAbove)) + return op->emitOpError() << "trying to schedule a pass on an operation not " + "marked as 'IsolatedFromAbove'"; + pass->passState.emplace(op, am); // Instrument before the pass has run. @@ -385,7 +389,7 @@ /// Find an operation pass manager that can operate on an operation of the given /// type, or nullptr if one does not exist. static OpPassManager *findPassManagerFor(MutableArrayRef mgrs, - const OperationName &name) { + Identifier name) { auto it = llvm::find_if( mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; }); return it == mgrs.end() ? nullptr : &*it; @@ -417,8 +421,8 @@ // After coalescing, sort the pass managers within rhs by name. llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(), [](const OpPassManager *lhs, const OpPassManager *rhs) { - return lhs->getOpName().getStringRef().compare( - rhs->getOpName().getStringRef()); + return lhs->getOpName().strref().compare( + rhs->getOpName().strref()); }); } @@ -450,7 +454,7 @@ for (auto ®ion : getOperation()->getRegions()) { for (auto &block : region) { for (auto &op : block) { - auto *mgr = findPassManagerFor(mgrs, op.getName()); + auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier()); if (!mgr) continue; @@ -494,8 +498,8 @@ for (auto ®ion : getOperation()->getRegions()) { for (auto &block : region) { for (auto &op : block) { - // Add this operation iff the name matches the any of the pass managers. - if (findPassManagerFor(mgrs, op.getName())) + // Add this operation iff the name matches any of the pass managers. + if (findPassManagerFor(mgrs, op.getName().getIdentifier())) opAMPairs.emplace_back(&op, am.nest(&op)); } } @@ -531,7 +535,8 @@ // Get the pass manager for this operation and execute it. auto &it = opAMPairs[nextID]; - auto *pm = findPassManagerFor(pms, it.first->getName()); + auto *pm = + findPassManagerFor(pms, it.first->getName().getIdentifier()); assert(pm && "expected valid pass manager for operation"); if (instrumentor) @@ -732,7 +737,7 @@ //===----------------------------------------------------------------------===// PassManager::PassManager(MLIRContext *ctx, bool verifyPasses) - : OpPassManager(OperationName(ModuleOp::getOperationName(), ctx), + : OpPassManager(Identifier::get(ModuleOp::getOperationName(), ctx), ctx, verifyPasses), passTiming(false), localReproducer(false) {} @@ -870,7 +875,7 @@ /// See PassInstrumentation::runBeforePipeline for details. void PassInstrumentor::runBeforePipeline( - const OperationName &name, + Identifier name, const PassInstrumentation::PipelineParentInfo &parentInfo) { llvm::sys::SmartScopedLock instrumentationLock(impl->mutex); for (auto &instr : impl->instrumentations) @@ -879,7 +884,7 @@ /// See PassInstrumentation::runAfterPipeline for details. void PassInstrumentor::runAfterPipeline( - const OperationName &name, + Identifier name, const PassInstrumentation::PipelineParentInfo &parentInfo) { llvm::sys::SmartScopedLock instrumentationLock(impl->mutex); for (auto &instr : llvm::reverse(impl->instrumentations)) diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp --- a/mlir/lib/Pass/PassStatistics.cpp +++ b/mlir/lib/Pass/PassStatistics.cpp @@ -116,7 +116,7 @@ // Print each of the children passes. for (OpPassManager &mgr : mgrs) { - auto name = ("'" + mgr.getOpName().getStringRef() + "' Pipeline").str(); + auto name = ("'" + mgr.getOpName().strref() + "' Pipeline").str(); printPassEntry(os, indent, name); for (Pass &pass : mgr.getPasses()) printPass(indent + 2, &pass); diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -165,9 +165,9 @@ ~PassTiming() override { print(); } /// Setup the instrumentation hooks. - void runBeforePipeline(const OperationName &name, + void runBeforePipeline(Identifier name, const PipelineParentInfo &parentInfo) override; - void runAfterPipeline(const OperationName &name, + void runAfterPipeline(Identifier name, const PipelineParentInfo &parentInfo) override; void runBeforePass(Pass *pass, Operation *) override { startPassTimer(pass); } void runAfterPass(Pass *pass, Operation *) override; @@ -242,15 +242,15 @@ }; } // end anonymous namespace -void PassTiming::runBeforePipeline(const OperationName &name, +void PassTiming::runBeforePipeline(Identifier name, const PipelineParentInfo &parentInfo) { // We don't actually want to time the pipelines, they gather their total // from their held passes. getTimer(name.getAsOpaquePointer(), TimerKind::Pipeline, - [&] { return ("'" + name.getStringRef() + "' Pipeline").str(); }); + [&] { return ("'" + name.strref() + "' Pipeline").str(); }); } -void PassTiming::runAfterPipeline(const OperationName &name, +void PassTiming::runAfterPipeline(Identifier name, const PipelineParentInfo &parentInfo) { // Pop the timer for the pipeline. auto tid = llvm::get_threadid(); diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -74,4 +74,47 @@ } } +namespace { +struct InvalidPass : Pass { + InvalidPass() : Pass(TypeID::get(), StringRef("invalid_op")) {} + StringRef getName() const override { return "Invalid Pass"; } + void runOnOperation() override {} + + /// A clone method to create a copy of this pass. + std::unique_ptr clonePass() const override { + return std::make_unique( + *static_cast(this)); + } +}; +} // anonymous namespace + +TEST(PassManagerTest, InvalidPass) { + MLIRContext context; + + // Create a module + OwningModuleRef module(ModuleOp::create(UnknownLoc::get(&context))); + + // Add a single "invalid_op" operation + OpBuilder builder(&module->getBodyRegion()); + OperationState state(UnknownLoc::get(&context), "invalid_op"); + builder.insert(Operation::create(state)); + + // Register a diagnostic handler to capture the diagnostic so that we can + // check it later. + std::unique_ptr diagnostic; + context.getDiagEngine().registerHandler([&](Diagnostic &diag) { + diagnostic.reset(new Diagnostic(std::move(diag))); + }); + + // Instantiate and run our pass. + PassManager pm(&context); + pm.addPass(std::make_unique()); + LogicalResult result = pm.run(module.get()); + EXPECT_TRUE(failed(result)); + ASSERT_TRUE(diagnostic.get() != nullptr); + EXPECT_EQ( + diagnostic->str(), + "'invalid_op' op trying to schedule a pass on an unregistered operation"); +} + } // end namespace