diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -20,6 +20,7 @@ namespace mlir { namespace detail { class OpToOpPassAdaptor; +struct OpPassManagerImpl; /// The state for a single execution of a pass. This provides a unified /// interface for accessing and initializing necessary state for pass execution. @@ -184,6 +185,11 @@ /// pipeline won't execute. virtual LogicalResult initialize(MLIRContext *context) { return success(); } + /// Indicate if the current pass can be scheduled on the given operation type. + /// This is useful for generic operation passes to add restrictions on the + /// operations they operate on. + virtual bool canScheduleOn(RegisteredOperationName opName) const = 0; + /// Schedule an arbitrary pass pipeline on the provided operation. /// This can be invoke any time in a pass to dynamic schedule more passes. /// The provided operation must be the current one or one nested below. @@ -313,6 +319,9 @@ /// Allow access to 'clone'. friend class OpPassManager; + /// Allow access to 'canScheduleOn'. + friend detail::OpPassManagerImpl; + /// Allow access to 'passState'. friend detail::OpToOpPassAdaptor; @@ -346,6 +355,11 @@ return pass->getOpName() == OpT::getOperationName(); } + /// Indicate if the current pass can be scheduled on the given operation type. + bool canScheduleOn(RegisteredOperationName opName) const final { + return opName.getStringRef() == getOpName(); + } + /// Return the current operation being transformed. OpT getOperation() { return cast(Pass::getOperation()); } @@ -373,6 +387,46 @@ protected: OperationPass(TypeID passID) : Pass(passID) {} OperationPass(const OperationPass &) = default; + + /// Indicate if the current pass can be scheduled on the given operation type. + /// By default, generic operation passes can be scheduled on any operation. + bool canScheduleOn(RegisteredOperationName opName) const override { + return true; + } +}; + +/// Pass to transform an operation that implements the given interface. +/// +/// Interface passes must not: +/// - modify any other operations within the parent region, as other threads +/// may be manipulating them concurrently. +/// - modify any state within the parent operation, this includes adding +/// additional operations. +/// +/// Derived interface passes are expected to provide the following: +/// - A 'void runOnOperation()' method. +/// - A 'StringRef getName() const' method. +/// - A 'std::unique_ptr clonePass() const' method. +template +class InterfacePass : public OperationPass<> { +protected: + using OperationPass::OperationPass; + + /// Indicate if the current pass can be scheduled on the given operation type. + /// For an InterfacePass, this checks if the operation implements the given + /// interface. + bool canScheduleOn(RegisteredOperationName opName) const final { + return opName.hasInterface(); + } + + /// Return the current operation being transformed. + InterfaceT getOperation() { return cast(Pass::getOperation()); } + + /// Query an analysis for the current operation. + template + AnalysisT &getAnalysis() { + return Pass::getAnalysis(); + } }; /// This class provides a CRTP wrapper around a base pass class to define diff --git a/mlir/include/mlir/Pass/PassBase.td b/mlir/include/mlir/Pass/PassBase.td --- a/mlir/include/mlir/Pass/PassBase.td +++ b/mlir/include/mlir/Pass/PassBase.td @@ -92,4 +92,8 @@ class Pass : PassBase">; +// This class represents an mlir::InterfacePass. +class InterfacePass + : PassBase">; + #endif // MLIR_PASS_PASSBASE diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -98,7 +98,7 @@ size_t size() const; /// Return the operation name that this pass manager operates on. - StringAttr getOpName(MLIRContext &context) const; + OperationName getOpName(MLIRContext &context) const; /// Return the operation name that this pass manager operates on. StringRef getOpName() const; diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -80,8 +80,8 @@ namespace mlir { namespace detail { struct OpPassManagerImpl { - OpPassManagerImpl(StringAttr identifier, OpPassManager::Nesting nesting) - : name(identifier.str()), identifier(identifier), + OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting) + : name(opName.getStringRef()), opName(opName), initializationGeneration(0), nesting(nesting) {} OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting) : name(name), initializationGeneration(0), nesting(nesting) {} @@ -102,23 +102,24 @@ /// preserved. void clear(); - /// Coalesce adjacent AdaptorPasses into one large adaptor. This runs - /// recursively through the pipeline graph. - void coalesceAdjacentAdaptorPasses(); + /// Finalize the pass list in preparation for execution. This includes + /// coalescing adjacent pass managers when possible, verifying scheduled + /// passes, etc. + LogicalResult finalizePassList(MLIRContext *ctx); - /// Return the operation name of this pass manager as an identifier. - StringAttr getOpName(MLIRContext &context) { - if (!identifier) - identifier = StringAttr::get(&context, name); - return *identifier; + /// Return the operation name of this pass manager. + OperationName getOpName(MLIRContext &context) { + if (!opName) + opName = OperationName(name, &context); + return *opName; } /// The name of the operation that passes of this pass manager operate on. std::string name; - /// The cached identifier (internalized in the context) for the name of the + /// The cached OperationName (internalized in the context) for the name of the /// operation that passes of this pass manager operate on. - Optional identifier; + Optional opName; /// The set of passes to run as part of this pass manager. std::vector> passes; @@ -173,18 +174,12 @@ void OpPassManagerImpl::clear() { passes.clear(); } -void OpPassManagerImpl::coalesceAdjacentAdaptorPasses() { - // Bail out early if there are no adaptor passes. - if (llvm::none_of(passes, [](std::unique_ptr &pass) { - return isa(pass.get()); - })) - return; - +LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) { // Walk the pass list and merge adjacent adaptors. OpToOpPassAdaptor *lastAdaptor = nullptr; - for (auto &passe : passes) { + for (auto &pass : passes) { // Check to see if this pass is an adaptor. - if (auto *currentAdaptor = dyn_cast(passe.get())) { + if (auto *currentAdaptor = dyn_cast(pass.get())) { // If it is the first adaptor in a possible chain, remember it and // continue. if (!lastAdaptor) { @@ -194,25 +189,39 @@ // Otherwise, merge into the existing adaptor and delete the current one. currentAdaptor->mergeInto(*lastAdaptor); - passe.reset(); + pass.reset(); } else if (lastAdaptor) { - // If this pass is not an adaptor, then coalesce and forget any existing + // If this pass is not an adaptor, then finalize and forget any existing // adaptor. for (auto &pm : lastAdaptor->getPassManagers()) - pm.getImpl().coalesceAdjacentAdaptorPasses(); + if (failed(pm.getImpl().finalizePassList(ctx))) + return failure(); lastAdaptor = nullptr; } } - // If there was an adaptor at the end of the manager, coalesce it as well. + // If there was an adaptor at the end of the manager, finalize it as well. if (lastAdaptor) { for (auto &pm : lastAdaptor->getPassManagers()) - pm.getImpl().coalesceAdjacentAdaptorPasses(); + if (failed(pm.getImpl().finalizePassList(ctx))) + return failure(); } - // Now that the adaptors have been merged, erase the empty slot corresponding + // Now that the adaptors have been merged, erase any empty slots corresponding // to the merged adaptors that were nulled-out in the loop above. + Optional opName = + getOpName(*ctx).getRegisteredInfo(); llvm::erase_if(passes, std::logical_not>()); + + // Verify that all of the passes are valid for the operation. + for (std::unique_ptr &pass : passes) { + if (opName && !pass->canScheduleOn(*opName)) { + return emitError(UnknownLoc::get(ctx)) + << "unable to schedule pass '" << pass->getName() + << "' on a PassManager intended to run on '" << name << "'!"; + } + } + return success(); } //===----------------------------------------------------------------------===// @@ -279,7 +288,7 @@ StringRef OpPassManager::getOpName() const { return impl->name; } /// Return the operation name that this pass manager operates on. -StringAttr OpPassManager::getOpName(MLIRContext &context) const { +OperationName OpPassManager::getOpName(MLIRContext &context) const { return impl->getOpName(context); } @@ -367,9 +376,9 @@ "nested under the current operation the pass is processing"; assert(pipeline.getOpName() == root->getName().getStringRef()); - // Before running, make sure to coalesce any adjacent pass adaptors in the - // pipeline. - pipeline.getImpl().coalesceAdjacentAdaptorPasses(); + // Before running, finalize the passes held by the pipeline. + if (failed(pipeline.getImpl().finalizePassList(root->getContext()))) + return failure(); // Initialize the user provided pipeline and execute the pipeline. if (failed(pipeline.initialize(root->getContext(), parentInitGeneration))) @@ -468,7 +477,7 @@ /// Find an operation pass manager that can operate on an operation of the given /// type, or nullptr if one does not exist. static OpPassManager *findPassManagerFor(MutableArrayRef mgrs, - StringAttr name, + OperationName name, MLIRContext &context) { auto *it = llvm::find_if( mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; }); @@ -538,8 +547,7 @@ for (auto ®ion : getOperation()->getRegions()) { for (auto &block : region) { for (auto &op : block) { - auto *mgr = findPassManagerFor(mgrs, op.getName().getIdentifier(), - *op.getContext()); + auto *mgr = findPassManagerFor(mgrs, op.getName(), *op.getContext()); if (!mgr) continue; @@ -581,7 +589,7 @@ for (auto &block : region) { for (auto &op : block) { // Add this operation iff the name matches any of the pass managers. - if (findPassManagerFor(mgrs, op.getName().getIdentifier(), *context)) + if (findPassManagerFor(mgrs, op.getName(), *context)) opAMPairs.emplace_back(&op, am.nest(&op)); } } @@ -604,9 +612,8 @@ unsigned pmIndex = it - activePMs.begin(); // Get the pass manager for this operation and execute it. - auto *pm = - findPassManagerFor(asyncExecutors[pmIndex], - opPMPair.first->getName().getIdentifier(), *context); + auto *pm = findPassManagerFor(asyncExecutors[pmIndex], + opPMPair.first->getName(), *context); assert(pm && "expected valid pass manager for operation"); unsigned initGeneration = pm->impl->initializationGeneration; @@ -641,14 +648,10 @@ /// Run the passes within this manager on the provided operation. LogicalResult PassManager::run(Operation *op) { MLIRContext *context = getContext(); - assert(op->getName().getIdentifier() == getOpName(*context) && + assert(op->getName() == getOpName(*context) && "operation has a different name than the PassManager or is from a " "different context"); - // Before running, make sure to coalesce any adjacent pass adaptors in the - // pipeline. - getImpl().coalesceAdjacentAdaptorPasses(); - // Register all dialects for the current pipeline. DialectRegistry dependentDialects; getDependentDialects(dependentDialects); @@ -656,6 +659,10 @@ for (StringRef name : dependentDialects.getDialectNames()) context->getOrLoadDialect(name); + // Before running, make sure to finalize the pipeline pass list. + if (failed(getImpl().finalizePassList(context))) + return failure(); + // Initialize all of the passes within the pass manager with a new generation. llvm::hash_code newInitKey = context->getRegistryHash(); if (newInitKey != initializationKey) { diff --git a/mlir/test/Pass/interface-pass.mlir b/mlir/test/Pass/interface-pass.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/interface-pass.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt %s -verify-diagnostics -pass-pipeline='builtin.func(test-interface-pass)' -o /dev/null + +// Test that we run the interface pass on the function. + +// expected-remark@below {{Executing interface pass on operation}} +func @main() { + return +} diff --git a/mlir/test/Pass/invalid-interface-pass.mlir b/mlir/test/Pass/invalid-interface-pass.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/invalid-interface-pass.mlir @@ -0,0 +1,9 @@ +// RUN: not mlir-opt %s -pass-pipeline='test-interface-pass' 2>&1 | FileCheck %s + +// Test that we emit an error when an interface pass is added to a pass manager it can't be scheduled on. + +// CHECK: unable to schedule pass '{{.*}}' on a PassManager intended to run on 'builtin.module'! + +func @main() { + return +} diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -29,6 +29,18 @@ return "Test a function pass in the pass manager"; } }; +class TestInterfacePass + : public PassWrapper> { + void runOnOperation() final { + getOperation()->emitRemark() << "Executing interface pass on operation"; + } + StringRef getArgument() const final { return "test-interface-pass"; } + StringRef getDescription() const final { + return "Test an interface pass (running on FunctionOpInterface) in the " + "pass manager"; + } +}; class TestOptionsPass : public PassWrapper> { public: @@ -128,6 +140,8 @@ PassRegistration(); + PassRegistration(); + PassRegistration(); PassRegistration(); diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -81,6 +81,9 @@ InvalidPass() : Pass(TypeID::get(), StringRef("invalid_op")) {} StringRef getName() const override { return "Invalid Pass"; } void runOnOperation() override {} + bool canScheduleOn(RegisteredOperationName opName) const override { + return true; + } /// A clone method to create a copy of this pass. std::unique_ptr clonePass() const override {