diff --git a/mlir/docs/PassManagement.md b/mlir/docs/PassManagement.md --- a/mlir/docs/PassManagement.md +++ b/mlir/docs/PassManagement.md @@ -293,28 +293,31 @@ ### OpPassManager -An `OpPassManager` is essentially a collection of passes to execute on an -operation of a specific type. This operation type must adhere to the following -requirement: +An `OpPassManager` is essentially a collection of passes anchored to execute on +operations at a given level of nesting. A pass manager, similarly to passes, can be +`op-specific` (anchored on a specific operation type), or `op-agnostic` (not restricted +to any specific operation, and executed on any viable operation type). Operation +types that anchor pass managers must adhere to the following requirement: * Must be registered and marked [`IsolatedFromAbove`](Traits.md/#isolatedfromabove). - * Passes are expected to not modify operations at or above the current + * Passes are expected not to modify operations at or above the current operation being processed. If the operation is not isolated, it may inadvertently modify or traverse the SSA use-list of an operation it is not supposed to. Passes can be added to a pass manager via `addPass`. The pass must either be an `op-specific` pass operating on the same operation type as `OpPassManager`, or -an `op-agnostic` pass. +an `op-agnostic` pass. Pass managers that are `op-agnostic` can only have +`op-agnostic` passes added to them. An `OpPassManager` is generally created by explicitly nesting a pipeline within -another existing `OpPassManager` via the `nest<>` method. This method takes the -operation type that the nested pass manager will operate on. At the top-level, a -`PassManager` acts as an `OpPassManager`. Nesting in this sense, corresponds to -the [structural](Tutorials/UnderstandingTheIRStructure.md) nesting within -[Regions](LangRef.md/#regions) of the IR. +another existing `OpPassManager` via the `nest` or `nestAny` methods. The +former method takes the operation type that the nested pass manager will operate on. +The latter method nests an op-agnostic pass manager, that may run on any viable +operation type. Nesting in this sense, corresponds to the [structural](Tutorials/UnderstandingTheIRStructure.md) +nesting within [Regions](LangRef.md/#regions) of the IR. For example, the following `.mlir`: @@ -359,6 +362,12 @@ OpPassManager &nestedFunctionPM = nestedModulePM.nest(); nestedFunctionPM.addPass(std::make_unique()); +// Nest an op-agnostic pass manager. This will operate on any viable +// operation, e.g. func.func, spv.func, spv.module, builtin.module, etc. +OpPassManager &nestedAnyPM = nestedModulePM.nestAny(); +nestedFunctionPM.addPass(createCanonicalizePass()); +nestedFunctionPM.addPass(createCSEPass()); + // Run the pass manager on the top-level module. ModuleOp m = ...; if (failed(pm.run(m))) @@ -374,6 +383,9 @@ MySPIRVModulePass OpPassManager MyFunctionPass + OpPassManager<> + Canonicalizer + CSE ``` These pipelines are then run over a single operation at a time. This means that, @@ -652,14 +664,17 @@ nested pipeline description. The syntax for this specification is as follows: ```ebnf -pipeline ::= op-name `(` pipeline-element (`,` pipeline-element)* `)` +pipeline ::= op-anchor `(` pipeline-element (`,` pipeline-element)* `)` pipeline-element ::= pipeline | (pass-name | pass-pipeline-name) options? options ::= '{' (key ('=' value)?)+ '}' ``` -* `op-name` - * This corresponds to the mnemonic name of an operation to run passes on, - e.g. `func.func` or `builtin.module`. +* `op-anchor` + * This corresponds to the mnemonic name that anchors the execution of the + pass manager. This is either the name of an operation to run passes on, + e.g. `func.func` or `builtin.module`, or `any`, for op-agnostic pass + managers that execute on any viable operation (i.e. any operation that + can be used to anchor a pass manager). * `pass-name` | `pass-pipeline-name` * This corresponds to the argument of a registered pass or pass pipeline, e.g. `cse` or `canonicalize`. @@ -678,7 +693,11 @@ Can also be specified as (via the `-pass-pipeline` flag): ```shell +# Anchor the cse and canonicalize passes on the `func.func` operation. $ mlir-opt foo.mlir -pass-pipeline='func.func(cse,canonicalize),convert-func-to-llvm{use-bare-ptr-memref-call-conv=1}' + +# Anchor the cse and canonicalize passes on "any" viable root operation. +$ mlir-opt foo.mlir -pass-pipeline='any(cse,canonicalize),convert-func-to-llvm{use-bare-ptr-memref-call-conv=1}' ``` In order to support round-tripping a pass to the textual representation using diff --git a/mlir/include/mlir/Pass/PassInstrumentation.h b/mlir/include/mlir/Pass/PassInstrumentation.h --- a/mlir/include/mlir/Pass/PassInstrumentation.h +++ b/mlir/include/mlir/Pass/PassInstrumentation.h @@ -9,11 +9,11 @@ #ifndef MLIR_PASS_PASSINSTRUMENTATION_H_ #define MLIR_PASS_PASSINSTRUMENTATION_H_ -#include "mlir/IR/BuiltinAttributes.h" #include "mlir/Support/LLVM.h" #include "mlir/Support/TypeID.h" namespace mlir { +class OperationName; class Operation; class Pass; @@ -41,16 +41,18 @@ virtual ~PassInstrumentation() = 0; /// A callback to run before a pass pipeline is executed. This function takes - /// the name of the operation type being operated on, and information related - /// to the parent that spawned this pipeline. - virtual void runBeforePipeline(StringAttr name, - const PipelineParentInfo &parentInfo) {} + /// the name of the operation type being operated on, or None if the pipeline + /// is op-agnostic, and information related to the parent that spawned this + /// pipeline. + virtual void runBeforePipeline(Optional name, + const PipelineParentInfo &parentInfo); /// A callback to run after a pass pipeline has executed. This function takes - /// the name of the operation type being operated on, and information related - /// to the parent that spawned this pipeline. - virtual void runAfterPipeline(StringAttr name, - const PipelineParentInfo &parentInfo) {} + /// the name of the operation type being operated on, or None if the pipeline + /// is op-agnostic, and information related to the parent that spawned this + /// pipeline. + virtual void runAfterPipeline(Optional name, + const PipelineParentInfo &parentInfo); /// A callback to run before a pass is executed. This function takes a pointer /// to the pass to be executed, as well as the current operation being @@ -90,12 +92,12 @@ /// See PassInstrumentation::runBeforePipeline for details. void - runBeforePipeline(StringAttr name, + runBeforePipeline(Optional name, const PassInstrumentation::PipelineParentInfo &parentInfo); /// See PassInstrumentation::runAfterPipeline for details. void - runAfterPipeline(StringAttr name, + runAfterPipeline(Optional name, const PassInstrumentation::PipelineParentInfo &parentInfo); /// See PassInstrumentation::runBeforePass for details. diff --git a/mlir/include/mlir/Pass/PassManager.h b/mlir/include/mlir/Pass/PassManager.h --- a/mlir/include/mlir/Pass/PassManager.h +++ b/mlir/include/mlir/Pass/PassManager.h @@ -32,7 +32,6 @@ class Pass; class PassInstrumentation; class PassInstrumentor; -class StringAttr; namespace detail { struct OpPassManagerImpl; @@ -45,14 +44,32 @@ // OpPassManager //===----------------------------------------------------------------------===// -/// This class represents a pass manager that runs passes on a specific -/// operation type. This class is not constructed directly, but nested within -/// other OpPassManagers or the top-level PassManager. +/// This class represents a pass manager that runs passes on either a specific +/// operation type, or any isolated operation. This class is generally not +/// constructed directly, but nested within other OpPassManagers or the +/// top-level PassManager. class OpPassManager { public: - enum class Nesting { Implicit, Explicit }; - OpPassManager(StringAttr name, Nesting nesting = Nesting::Explicit); + /// This enum represents the nesting behavior of the pass manager. + enum class Nesting { + /// Implicit nesting behavior. This allows for adding passes operating on + /// operations different from this pass manager, in which case a new pass + /// manager is implicitly nested for the operation type of the new pass. + Implicit, + /// Explicit nesting behavior. This requires that any passes added to this + /// pass manager support its operation type. + Explicit + }; + + /// Construct a new op-agnostic ("any") pass manager with the given operation + /// type and nesting behavior. This is the same as invoking: + /// `OpPassManager(getAnyOpAnchorName(), nesting)`. + OpPassManager(Nesting nesting = Nesting::Explicit); + + /// Construct a new pass manager with the given anchor operation type and + /// nesting behavior. OpPassManager(StringRef name, Nesting nesting = Nesting::Explicit); + OpPassManager(OperationName name, Nesting nesting = Nesting::Explicit); OpPassManager(OpPassManager &&rhs); OpPassManager(const OpPassManager &rhs); ~OpPassManager(); @@ -78,12 +95,16 @@ /// Nest a new operation pass manager for the given operation kind under this /// pass manager. - OpPassManager &nest(StringAttr nestedName); + OpPassManager &nest(OperationName nestedName); OpPassManager &nest(StringRef nestedName); template OpPassManager &nest() { return nest(OpT::getOperationName()); } + /// Nest a new op-agnostic ("any") pass manager under this pass manager. + /// Note: This is the same as invoking `nest(getAnyOpAnchorName())`. + OpPassManager &nestAny(); + /// Add the given pass to this pass manager. If this pass has a concrete /// operation type, it must be the same type as this pass manager. void addPass(std::unique_ptr pass); @@ -100,11 +121,22 @@ /// Returns the number of passes held by this manager. size_t size() const; - /// Return the operation name that this pass manager operates on. - OperationName getOpName(MLIRContext &context) const; + /// Return the operation name that this pass manager operates on, or None if + /// this is an op-agnostic pass manager. + Optional getOpName(MLIRContext &context) const; + + /// Return the operation name that this pass manager operates on, or None if + /// this is an op-agnostic pass manager. + Optional getOpName() const; + + /// Return the name used to anchor this pass manager. This is either the name + /// of an operation, or the result of `getAnyOpAnchorName()` in the case of an + /// op-agnostic pass manager. + StringRef getOpAnchorName() const; - /// Return the operation name that this pass manager operates on. - StringRef getOpName() const; + /// Return the string name used to anchor op-agnostic pass managers that + /// operate generically on any viable operation. + static StringRef getAnyOpAnchorName() { return "any"; } /// Returns the internal implementation instance. detail::OpPassManagerImpl &getImpl(); @@ -177,6 +209,8 @@ /// Create a new pass manager under the given context with a specific nesting /// style. The created pass manager can schedule operations that match /// `operationName`. + /// FIXME: We should make the specification of `builtin.module` explicit here, + /// so that we can have top-level op-agnostic pass managers. PassManager(MLIRContext *ctx, Nesting nesting = Nesting::Explicit, StringRef operationName = "builtin.module"); PassManager(MLIRContext *ctx, StringRef operationName) diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -58,7 +58,7 @@ llvm::interleave( adaptor->getPassManagers(), [&](OpPassManager &pm) { - os << pm.getOpName() << "("; + os << pm.getOpAnchorName() << "("; pm.printAsTextualPipeline(os); os << ")"; }, @@ -84,18 +84,39 @@ namespace detail { struct OpPassManagerImpl { OpPassManagerImpl(OperationName opName, OpPassManager::Nesting nesting) - : name(opName.getStringRef()), opName(opName), + : name(opName.getStringRef().str()), opName(opName), initializationGeneration(0), nesting(nesting) {} OpPassManagerImpl(StringRef name, OpPassManager::Nesting nesting) - : name(name), initializationGeneration(0), nesting(nesting) {} + : name(name == OpPassManager::getAnyOpAnchorName() ? "" : name.str()), + initializationGeneration(0), nesting(nesting) {} + OpPassManagerImpl(OpPassManager::Nesting nesting) + : name(""), initializationGeneration(0), nesting(nesting) {} + OpPassManagerImpl(const OpPassManagerImpl &rhs) + : name(rhs.name), opName(rhs.opName), + initializationGeneration(rhs.initializationGeneration), + nesting(rhs.nesting) { + for (auto &pass : rhs.passes) { + auto newPass = pass->clone(); + newPass->threadingSibling = pass.get(); + passes.push_back(std::move(newPass)); + } + } /// Merge the passes of this pass manager into the one provided. void mergeInto(OpPassManagerImpl &rhs); /// Nest a new operation pass manager for the given operation kind under this /// pass manager. - OpPassManager &nest(StringAttr nestedName); - OpPassManager &nest(StringRef nestedName); + OpPassManager &nest(OperationName nestedName) { + return nest(OpPassManager(nestedName, nesting)); + } + OpPassManager &nest(StringRef nestedName) { + return nest(OpPassManager(nestedName, nesting)); + } + OpPassManager &nestAny() { return nest(OpPassManager(nesting)); } + + /// Nest the given pass manager under this pass manager. + OpPassManager &nest(OpPassManager &&nested); /// Add the given pass to this pass manager. If this pass has a concrete /// operation type, it must be the same type as this pass manager. @@ -111,12 +132,26 @@ LogicalResult finalizePassList(MLIRContext *ctx); /// Return the operation name of this pass manager. - OperationName getOpName(MLIRContext &context) { - if (!opName) + Optional getOpName(MLIRContext &context) { + if (!name.empty() && !opName) opName = OperationName(name, &context); - return *opName; + return opName; + } + Optional getOpName() const { + return name.empty() ? Optional() : Optional(name); + } + + /// Return the name used to anchor this pass manager. This is either the name + /// of an operation, or the result of `getAnyOpAnchorName()` in the case of an + /// op-agnostic pass manager. + StringRef getOpAnchorName() const { + return getOpName().getValueOr(OpPassManager::getAnyOpAnchorName()); } + /// Indicate if the current pass manager can be scheduled on the given + /// operation type. + bool canScheduleOn(MLIRContext &context, OperationName opName); + /// The name of the operation that passes of this pass manager operate on. std::string name; @@ -145,15 +180,7 @@ passes.clear(); } -OpPassManager &OpPassManagerImpl::nest(StringAttr nestedName) { - OpPassManager nested(nestedName, nesting); - auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); - addPass(std::unique_ptr(adaptor)); - return adaptor->getPassManagers().front(); -} - -OpPassManager &OpPassManagerImpl::nest(StringRef nestedName) { - OpPassManager nested(nestedName, nesting); +OpPassManager &OpPassManagerImpl::nest(OpPassManager &&nested) { auto *adaptor = new OpToOpPassAdaptor(std::move(nested)); addPass(std::unique_ptr(adaptor)); return adaptor->getPassManagers().front(); @@ -168,8 +195,8 @@ return nest(*passOpName).addPass(std::move(pass)); llvm::report_fatal_error(llvm::Twine("Can't add pass '") + pass->getName() + "' restricted to '" + *passOpName + - "' on a PassManager intended to run on '" + name + - "', did you intend to nest?"); + "' on a PassManager intended to run on '" + + getOpAnchorName() + "', did you intend to nest?"); } passes.emplace_back(std::move(pass)); @@ -178,6 +205,13 @@ void OpPassManagerImpl::clear() { passes.clear(); } LogicalResult OpPassManagerImpl::finalizePassList(MLIRContext *ctx) { + auto finalizeAdaptor = [ctx](OpToOpPassAdaptor *adaptor) { + for (auto &pm : adaptor->getPassManagers()) + if (failed(pm.getImpl().finalizePassList(ctx))) + return failure(); + return success(); + }; + // Walk the pass list and merge adjacent adaptors. OpToOpPassAdaptor *lastAdaptor = nullptr; for (auto &pass : passes) { @@ -190,61 +224,77 @@ continue; } - // Otherwise, merge into the existing adaptor and delete the current one. - currentAdaptor->mergeInto(*lastAdaptor); - pass.reset(); + // Otherwise, try to merge into the existing adaptor and delete the + // current one. If merging fails, just remember this as the last adaptor. + if (succeeded(currentAdaptor->tryMergeInto(ctx, *lastAdaptor))) + pass.reset(); + else + lastAdaptor = currentAdaptor; } else if (lastAdaptor) { - // If this pass is not an adaptor, then finalize and forget any existing - // adaptor. - for (auto &pm : lastAdaptor->getPassManagers()) - if (failed(pm.getImpl().finalizePassList(ctx))) - return failure(); + // If this pass isn't an adaptor, finalize it and forget the last adaptor. + if (failed(finalizeAdaptor(lastAdaptor))) + return failure(); lastAdaptor = nullptr; } } // If there was an adaptor at the end of the manager, finalize it as well. - if (lastAdaptor) { - for (auto &pm : lastAdaptor->getPassManagers()) - if (failed(pm.getImpl().finalizePassList(ctx))) - return failure(); - } + if (lastAdaptor && failed(finalizeAdaptor(lastAdaptor))) + return failure(); // Now that the adaptors have been merged, erase any empty slots corresponding // to the merged adaptors that were nulled-out in the loop above. - Optional opName = - getOpName(*ctx).getRegisteredInfo(); llvm::erase_if(passes, std::logical_not>()); - // Verify that all of the passes are valid for the operation. + // If this is a op-agnostic pass manager, there is nothing left to do. + Optional rawOpName = getOpName(*ctx); + if (!rawOpName) + return success(); + + // Otherwise, verify that all of the passes are valid for the current + // operation anchor. + Optional opName = rawOpName->getRegisteredInfo(); for (std::unique_ptr &pass : passes) { if (opName && !pass->canScheduleOn(*opName)) { return emitError(UnknownLoc::get(ctx)) << "unable to schedule pass '" << pass->getName() - << "' on a PassManager intended to run on '" << name << "'!"; + << "' on a PassManager intended to run on '" << getOpAnchorName() + << "'!"; } } return success(); } +bool OpPassManagerImpl::canScheduleOn(MLIRContext &context, + OperationName opName) { + // If this pass manager is op-specific, we simply check if the provided + // operation name is the same as this one. + Optional pmOpName = getOpName(context); + if (pmOpName) + return pmOpName == opName; + + // Otherwise, this is an op-agnostic pass manager. In this case, we simply + // check if this operation can be scheduled on any pass manager. Any + // additional filtering will happen during pass execution. + Optional registeredInfo = opName.getRegisteredInfo(); + return registeredInfo && + registeredInfo->hasTrait(); +} + //===----------------------------------------------------------------------===// // OpPassManager //===----------------------------------------------------------------------===// -OpPassManager::OpPassManager(StringAttr name, Nesting nesting) - : impl(new OpPassManagerImpl(name, nesting)) {} +OpPassManager::OpPassManager(Nesting nesting) + : impl(new OpPassManagerImpl(nesting)) {} OpPassManager::OpPassManager(StringRef name, Nesting nesting) : impl(new OpPassManagerImpl(name, nesting)) {} +OpPassManager::OpPassManager(OperationName name, Nesting nesting) + : impl(new OpPassManagerImpl(name, nesting)) {} OpPassManager::OpPassManager(OpPassManager &&rhs) : impl(std::move(rhs.impl)) {} OpPassManager::OpPassManager(const OpPassManager &rhs) { *this = rhs; } OpPassManager &OpPassManager::operator=(const OpPassManager &rhs) { - impl = std::make_unique(rhs.impl->name, rhs.impl->nesting); - impl->initializationGeneration = rhs.impl->initializationGeneration; - for (auto &pass : rhs.impl->passes) { - auto newPass = pass->clone(); - newPass->threadingSibling = pass.get(); - impl->passes.push_back(std::move(newPass)); - } + impl = std::make_unique(*rhs.impl); return *this; } @@ -266,12 +316,13 @@ /// Nest a new operation pass manager for the given operation kind under this /// pass manager. -OpPassManager &OpPassManager::nest(StringAttr nestedName) { +OpPassManager &OpPassManager::nest(OperationName nestedName) { return impl->nest(nestedName); } OpPassManager &OpPassManager::nest(StringRef nestedName) { return impl->nest(nestedName); } +OpPassManager &OpPassManager::nestAny() { return impl->nestAny(); } /// Add the given pass to this pass manager. If this pass has a concrete /// operation type, it must be the same type as this pass manager. @@ -288,13 +339,19 @@ OpPassManagerImpl &OpPassManager::getImpl() { return *impl; } /// Return the operation name that this pass manager operates on. -StringRef OpPassManager::getOpName() const { return impl->name; } +Optional OpPassManager::getOpName() const { + return impl->getOpName(); +} /// Return the operation name that this pass manager operates on. -OperationName OpPassManager::getOpName(MLIRContext &context) const { +Optional OpPassManager::getOpName(MLIRContext &context) const { return impl->getOpName(context); } +StringRef OpPassManager::getOpAnchorName() const { + return impl->getOpAnchorName(); +} + /// Prints out the given passes as the textual representation of a pipeline. static void printAsTextualPipeline(ArrayRef> passes, raw_ostream &os) { @@ -359,12 +416,14 @@ //===----------------------------------------------------------------------===// LogicalResult OpToOpPassAdaptor::run(Pass *pass, Operation *op, - AnalysisManager am, bool verifyPasses, + AnalysisManager am, + bool checkOpNameForSkip, bool verifyPasses, unsigned parentInitGeneration) { - if (!op->isRegistered()) + Optional opInfo = op->getRegisteredInfo(); + if (!opInfo) return op->emitOpError() << "trying to schedule a pass on an unregistered operation"; - if (!op->hasTrait()) + if (!opInfo->hasTrait()) return op->emitOpError() << "trying to schedule a pass on an operation not " "marked as 'IsolatedFromAbove'"; @@ -380,7 +439,8 @@ << "Trying to schedule a dynamic pipeline on an " "operation that isn't " "nested under the current operation the pass is processing"; - assert(pipeline.getOpName() == root->getName().getStringRef()); + assert( + pipeline.getImpl().canScheduleOn(*op->getContext(), root->getName())); // Before running, finalize the passes held by the pipeline. if (failed(pipeline.getImpl().finalizePassList(root->getContext()))) @@ -390,7 +450,7 @@ if (failed(pipeline.initialize(root->getContext(), parentInitGeneration))) return failure(); AnalysisManager nestedAm = root == op ? am : am.nest(root); - return OpToOpPassAdaptor::runPipeline(pipeline.getPasses(), root, nestedAm, + return OpToOpPassAdaptor::runPipeline(pipeline, root, nestedAm, verifyPasses, parentInitGeneration, pi, &parentInfo); }; @@ -401,10 +461,22 @@ pi->runBeforePass(pass, op); // Invoke the virtual runOnOperation method. - if (auto *adaptor = dyn_cast(pass)) + if (auto *adaptor = dyn_cast(pass)) { adaptor->runOnOperation(verifyPasses); - else + } else { + // Check to see if we should skip the execution of this pass. + // TODO: There are various other ways we may want to inject skipping pass + // execution in the future, such additional hooks should be added here. + if (checkOpNameForSkip && !pass->canScheduleOn(*opInfo)) { + // TODO: Should we have a `runAfterPassSkipped` instrumentation callback? + if (pi) + pi->runAfterPass(pass, op); + return success(); + } + + // The pass wasn't skipped, so execute it normally. pass->runOnOperation(); + } bool passFailed = pass->passState->irAndPassFailed.getInt(); // Invalidate any non preserved analyses. @@ -448,9 +520,8 @@ /// Run the given operation and analysis manager on a provided op pass manager. LogicalResult OpToOpPassAdaptor::runPipeline( - iterator_range passes, Operation *op, - AnalysisManager am, bool verifyPasses, unsigned parentInitGeneration, - PassInstrumentor *instrumentor, + OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses, + unsigned parentInitGeneration, PassInstrumentor *instrumentor, const PassInstrumentation::PipelineParentInfo *parentInfo) { assert((!instrumentor || parentInfo) && "expected parent info if instrumentor is provided"); @@ -463,22 +534,33 @@ }); // Run the pipeline over the provided operation. - if (instrumentor) - instrumentor->runBeforePipeline(op->getName().getIdentifier(), *parentInfo); - for (Pass &pass : passes) - if (failed(run(&pass, op, am, verifyPasses, parentInitGeneration))) + if (instrumentor) { + instrumentor->runBeforePipeline(pm.getOpName(*op->getContext()), + *parentInfo); + } + + // If this is a generic pass manager, dynamically check to see if each pass + // can be scheduled on the current operation. + bool checkOpNameForSkip = !pm.getOpName(); + for (Pass &pass : pm.getPasses()) { + if (failed(run(&pass, op, am, checkOpNameForSkip, verifyPasses, + parentInitGeneration))) return failure(); - if (instrumentor) - instrumentor->runAfterPipeline(op->getName().getIdentifier(), *parentInfo); + } + + if (instrumentor) { + instrumentor->runAfterPipeline(pm.getOpName(*op->getContext()), + *parentInfo); + } return success(); } -/// Find an operation pass manager that can operate on an operation of the given -/// type, or nullptr if one does not exist. -static OpPassManager *findPassManagerFor(MutableArrayRef mgrs, - StringRef name) { +/// Find an operation pass manager with the given anchor name, or nullptr if one +/// does not exist. +static OpPassManager * +findPassManagerWithAnchor(MutableArrayRef mgrs, StringRef name) { auto *it = llvm::find_if( - mgrs, [&](OpPassManager &mgr) { return mgr.getOpName() == name; }); + mgrs, [&](OpPassManager &mgr) { return mgr.getOpAnchorName() == name; }); return it == mgrs.end() ? nullptr : &*it; } @@ -487,8 +569,9 @@ static OpPassManager *findPassManagerFor(MutableArrayRef mgrs, OperationName name, MLIRContext &context) { - auto *it = llvm::find_if( - mgrs, [&](OpPassManager &mgr) { return mgr.getOpName(context) == name; }); + auto *it = llvm::find_if(mgrs, [&](OpPassManager &mgr) { + return mgr.getImpl().canScheduleOn(context, name); + }); return it == mgrs.end() ? nullptr : &*it; } @@ -501,12 +584,47 @@ pm.getDependentDialects(dialects); } -/// Merge the current pass adaptor into given 'rhs'. -void OpToOpPassAdaptor::mergeInto(OpToOpPassAdaptor &rhs) { +LogicalResult OpToOpPassAdaptor::tryMergeInto(MLIRContext *ctx, + OpToOpPassAdaptor &rhs) { + // Functor used to check if a pass manager is generic, i.e. op-agnostic. + auto isGenericPM = [&](OpPassManager &pm) { return !pm.getOpName(); }; + + // Functor used to detect if the given generic pass manager will have a + // potential schedule conflict with the given `otherPMs`. + auto hasScheduleConflictWith = [&](OpPassManager &genericPM, + MutableArrayRef otherPMs) { + return llvm::any_of(otherPMs, [&](OpPassManager &pm) { + // A conflict will arise if a non-generic pass manager's operation name + // can be scheduled on one of the generic passes. Generic pass managers + // aren't important here to check against though, given that we merge them + // together. + if (Optional pmOpName = pm.getOpName(*ctx)) { + for (Pass &pass : genericPM.getPasses()) + if (pass.canScheduleOn(*pmOpName->getRegisteredInfo())) + return true; + } + return false; + }); + }; + + // Check that if either adaptor has a generic pass manager, that pm is + // compatible within any non-generic pass managers. + /// Check the current adaptor. + auto *genericPMIt = llvm::find_if(mgrs, isGenericPM); + if (genericPMIt != mgrs.end() && + hasScheduleConflictWith(*genericPMIt, rhs.mgrs)) + return failure(); + /// Check the rhs adaptor. + genericPMIt = llvm::find_if(rhs.mgrs, isGenericPM); + if (genericPMIt != rhs.mgrs.end() && + hasScheduleConflictWith(*genericPMIt, mgrs)) + return failure(); + for (auto &pm : mgrs) { // If an existing pass manager exists, then merge the given pass manager // into it. - if (auto *existingPM = findPassManagerFor(rhs.mgrs, pm.getOpName())) { + if (auto *existingPM = + findPassManagerWithAnchor(rhs.mgrs, pm.getOpAnchorName())) { pm.getImpl().mergeInto(existingPM->getImpl()); } else { // Otherwise, add the given pass manager to the list. @@ -516,10 +634,19 @@ mgrs.clear(); // After coalescing, sort the pass managers within rhs by name. - llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(), - [](const OpPassManager *lhs, const OpPassManager *rhs) { - return lhs->getOpName().compare(rhs->getOpName()); - }); + auto compareFn = [](const OpPassManager *lhs, const OpPassManager *rhs) { + // Always order op-agnostic pass managers last, this simplifies the work + // necessary when picking a pass manager for an operation during execution + // (op-specific pass managers will always be chosen first). + if (Optional lhsName = lhs->getOpName()) { + if (Optional rhsName = rhs->getOpName()) + return lhsName->compare(*rhsName); + return -1; // lhs(op-specific) < rhs(op-agnostic) + } + return 1; // lhs(op-agnostic) > rhs(op-specific) + }; + llvm::array_pod_sort(rhs.mgrs.begin(), rhs.mgrs.end(), compareFn); + return success(); } /// Returns the adaptor pass name. @@ -527,7 +654,7 @@ std::string name = "Pipeline Collection : ["; llvm::raw_string_ostream os(name); llvm::interleaveComma(getPassManagers(), os, [&](OpPassManager &pm) { - os << '\'' << pm.getOpName() << '\''; + os << '\'' << pm.getOpAnchorName() << '\''; }); os << ']'; return os.str(); @@ -561,9 +688,8 @@ // Run the held pipeline over the current operation. unsigned initGeneration = mgr->impl->initializationGeneration; - if (failed(runPipeline(mgr->getPasses(), &op, am.nest(&op), - verifyPasses, initGeneration, instrumentor, - &parentInfo))) + if (failed(runPipeline(*mgr, &op, am.nest(&op), verifyPasses, + initGeneration, instrumentor, &parentInfo))) return signalPassFailure(); } } @@ -626,8 +752,8 @@ unsigned initGeneration = pm->impl->initializationGeneration; LogicalResult pipelineResult = - runPipeline(pm->getPasses(), opPMPair.first, opPMPair.second, - verifyPasses, initGeneration, instrumentor, &parentInfo); + runPipeline(*pm, opPMPair.first, opPMPair.second, verifyPasses, + initGeneration, instrumentor, &parentInfo); // Reset the active bit for this pass manager. activePMs[pmIndex].store(false); @@ -645,7 +771,7 @@ PassManager::PassManager(MLIRContext *ctx, Nesting nesting, StringRef operationName) - : OpPassManager(StringAttr::get(ctx, operationName), nesting), context(ctx), + : OpPassManager(OperationName(operationName, ctx), nesting), context(ctx), initializationKey(DenseMapInfo::getTombstoneKey()), passTiming(false), verifyPasses(true) {} @@ -708,7 +834,7 @@ } LogicalResult PassManager::runPasses(Operation *op, AnalysisManager am) { - return OpToOpPassAdaptor::runPipeline(getPasses(), op, am, verifyPasses, + return OpToOpPassAdaptor::runPipeline(*this, op, am, verifyPasses, impl->initializationGeneration); } @@ -788,6 +914,12 @@ PassInstrumentation::~PassInstrumentation() = default; +void PassInstrumentation::runBeforePipeline( + Optional name, const PipelineParentInfo &parentInfo) {} + +void PassInstrumentation::runAfterPipeline( + Optional name, const PipelineParentInfo &parentInfo) {} + //===----------------------------------------------------------------------===// // PassInstrumentor //===----------------------------------------------------------------------===// @@ -809,7 +941,7 @@ /// See PassInstrumentation::runBeforePipeline for details. void PassInstrumentor::runBeforePipeline( - StringAttr name, + Optional name, const PassInstrumentation::PipelineParentInfo &parentInfo) { llvm::sys::SmartScopedLock instrumentationLock(impl->mutex); for (auto &instr : impl->instrumentations) @@ -818,7 +950,7 @@ /// See PassInstrumentation::runAfterPipeline for details. void PassInstrumentor::runAfterPipeline( - StringAttr name, + Optional name, const PassInstrumentation::PipelineParentInfo &parentInfo) { llvm::sys::SmartScopedLock instrumentationLock(impl->mutex); for (auto &instr : llvm::reverse(impl->instrumentations)) diff --git a/mlir/lib/Pass/PassDetail.h b/mlir/lib/Pass/PassDetail.h --- a/mlir/lib/Pass/PassDetail.h +++ b/mlir/lib/Pass/PassDetail.h @@ -29,8 +29,8 @@ void runOnOperation(bool verifyPasses); void runOnOperation() override; - /// Merge the current pass adaptor into given 'rhs'. - void mergeInto(OpToOpPassAdaptor &rhs); + /// Try to merge the current pass adaptor into 'rhs'. + LogicalResult tryMergeInto(MLIRContext *ctx, OpToOpPassAdaptor &rhs); /// Returns the pass managers held by this adaptor. MutableArrayRef getPassManagers() { return mgrs; } @@ -59,16 +59,16 @@ /// manager, and is used to initialize any dynamic pass pipelines run by the /// given pass. static LogicalResult run(Pass *pass, Operation *op, AnalysisManager am, - bool verifyPasses, unsigned parentInitGeneration); + bool checkOpNameForSkip, bool verifyPasses, + unsigned parentInitGeneration); /// Run the given operation and analysis manager on a provided op pass /// manager. `parentInitGeneration` is the initialization generation of the /// parent pass manager, and is used to initialize any dynamic pass pipelines /// run by the given passes. static LogicalResult runPipeline( - iterator_range passes, Operation *op, - AnalysisManager am, bool verifyPasses, unsigned parentInitGeneration, - PassInstrumentor *instrumentor = nullptr, + OpPassManager &pm, Operation *op, AnalysisManager am, bool verifyPasses, + unsigned parentInitGeneration, PassInstrumentor *instrumentor = nullptr, const PassInstrumentation::PipelineParentInfo *parentInfo = nullptr); /// A set of adaptors to run. diff --git a/mlir/lib/Pass/PassRegistry.cpp b/mlir/lib/Pass/PassRegistry.cpp --- a/mlir/lib/Pass/PassRegistry.cpp +++ b/mlir/lib/Pass/PassRegistry.cpp @@ -43,7 +43,7 @@ return errorHandler(llvm::Twine("Can't add pass '") + pass->getName() + "' restricted to '" + *pass->getOpName() + "' on a PassManager intended to run on '" + - pm.getOpName() + "', did you intend to nest?"); + pm.getOpAnchorName() + "', did you intend to nest?"); pm.addPass(std::move(pass)); return result; }; diff --git a/mlir/lib/Pass/PassStatistics.cpp b/mlir/lib/Pass/PassStatistics.cpp --- a/mlir/lib/Pass/PassStatistics.cpp +++ b/mlir/lib/Pass/PassStatistics.cpp @@ -120,7 +120,7 @@ // Print each of the children passes. for (OpPassManager &mgr : mgrs) { - auto name = ("'" + mgr.getOpName() + "' Pipeline").str(); + auto name = ("'" + mgr.getOpAnchorName() + "' Pipeline").str(); printPassEntry(os, indent, name); for (Pass &pass : mgr.getPasses()) printPass(indent + 2, &pass); diff --git a/mlir/lib/Pass/PassTiming.cpp b/mlir/lib/Pass/PassTiming.cpp --- a/mlir/lib/Pass/PassTiming.cpp +++ b/mlir/lib/Pass/PassTiming.cpp @@ -52,7 +52,7 @@ // Pipeline //===--------------------------------------------------------------------===// - void runBeforePipeline(StringAttr name, + void runBeforePipeline(Optional name, const PipelineParentInfo &parentInfo) override { auto tid = llvm::get_threadid(); auto &activeTimers = activeThreadTimers[tid]; @@ -68,12 +68,15 @@ } else { parentScope = &activeTimers.back(); } - activeTimers.push_back(parentScope->nest(name.getAsOpaquePointer(), [name] { - return ("'" + name.strref() + "' Pipeline").str(); + + const void *timerId = name ? name->getAsOpaquePointer() : nullptr; + activeTimers.push_back(parentScope->nest(timerId, [name] { + return ("'" + (name ? name->getStringRef() : "any") + "' Pipeline").str(); })); } - void runAfterPipeline(StringAttr, const PipelineParentInfo &) override { + void runAfterPipeline(Optional, + const PipelineParentInfo &) override { auto &activeTimers = activeThreadTimers[llvm::get_threadid()]; assert(!activeTimers.empty() && "expected active timer"); activeTimers.pop_back(); diff --git a/mlir/lib/Transforms/Inliner.cpp b/mlir/lib/Transforms/Inliner.cpp --- a/mlir/lib/Transforms/Inliner.cpp +++ b/mlir/lib/Transforms/Inliner.cpp @@ -734,6 +734,7 @@ return failure(); // Initialize the default pipeline builder to use the option string. + // TODO: Use a generic pass manager for default pipelines, and remove this. if (!defaultPipelineStr.empty()) { std::string defaultPipelineCopy = defaultPipelineStr; defaultPipeline = [=](OpPassManager &pm) { @@ -747,7 +748,7 @@ llvm::StringMap pipelines; for (OpPassManager pipeline : opPipelineList) if (!pipeline.empty()) - pipelines.try_emplace(pipeline.getOpName(), pipeline); + pipelines.try_emplace(pipeline.getOpAnchorName(), pipeline); opPipelines.assign({std::move(pipelines)}); return success(); diff --git a/mlir/test/Pass/generic-pipeline.mlir b/mlir/test/Pass/generic-pipeline.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/generic-pipeline.mlir @@ -0,0 +1,24 @@ +// RUN: mlir-opt %s -verify-diagnostics -pass-pipeline='any(cse, test-interface-pass)' -allow-unregistered-dialect -o /dev/null + +// Test that we execute generic pipelines correctly. The `cse` pass is fully generic and should execute +// on both the module and the func. The `test-interface-pass` filters based on FunctionOpInterface and +// should only execute on the func. + +// expected-remark@below {{Executing interface pass on operation}} +func @main() -> (i1, i1) { + // CHECK-LABEL: func @main + // CHECK-NEXT: arith.constant true + // CHECK-NEXT: return + %true = arith.constant true + %true1 = arith.constant true + return %true, %true1 : i1, i1 +} + +module @module { + // CHECK-LABEL: module @main + // CHECK-NEXT: arith.constant true + // CHECK-NEXT: foo.op + %true = arith.constant true + %true1 = arith.constant true + "foo.op"(%true, %true1) : (i1, i1) -> () +} diff --git a/mlir/test/Pass/pipeline-parsing.mlir b/mlir/test/Pass/pipeline-parsing.mlir --- a/mlir/test/Pass/pipeline-parsing.mlir +++ b/mlir/test/Pass/pipeline-parsing.mlir @@ -1,16 +1,19 @@ // RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(test-module-pass,func.func(test-function-pass)),func.func(test-function-pass)' -pass-pipeline="func.func(cse,canonicalize)" -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s // RUN: mlir-opt %s -mlir-disable-threading -test-textual-pm-nested-pipeline -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=TEXTUAL_CHECK +// RUN: mlir-opt %s -mlir-disable-threading -pass-pipeline='builtin.module(test-module-pass),any(test-interface-pass),any(test-interface-pass),func.func(test-function-pass),any(canonicalize),func.func(cse)' -verify-each=false -mlir-timing -mlir-timing-display=tree 2>&1 | FileCheck %s --check-prefix=GENERIC_MERGE_CHECK // RUN: not mlir-opt %s -pass-pipeline='builtin.module(test-module-pass' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_1 %s // RUN: not mlir-opt %s -pass-pipeline='builtin.module(test-module-pass))' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_2 %s // RUN: not mlir-opt %s -pass-pipeline='builtin.module()(' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_3 %s // RUN: not mlir-opt %s -pass-pipeline=',' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_4 %s // RUN: not mlir-opt %s -pass-pipeline='func.func(test-module-pass)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_5 %s +// RUN: not mlir-opt %s -pass-pipeline='any(test-module-pass)' 2>&1 | FileCheck --check-prefix=CHECK_ERROR_6 %s // CHECK_ERROR_1: encountered unbalanced parentheses while parsing pipeline // CHECK_ERROR_2: encountered extra closing ')' creating unbalanced parentheses while parsing pipeline // CHECK_ERROR_3: expected ',' after parsing pipeline // CHECK_ERROR_4: does not refer to a registered pass or pass pipeline // CHECK_ERROR_5: Can't add pass '{{.*}}TestModulePass' restricted to 'builtin.module' on a PassManager intended to run on 'func.func', did you intend to nest? +// CHECK_ERROR_6: Can't add pass '{{.*}}TestModulePass' restricted to 'builtin.module' on a PassManager intended to run on 'any', did you intend to nest? func @foo() { return } @@ -39,3 +42,19 @@ // TEXTUAL_CHECK-NEXT: TestModulePass // TEXTUAL_CHECK-NEXT: 'func.func' Pipeline // TEXTUAL_CHECK-NEXT: TestFunctionPass + +// Check that generic pass pipelines are only merged when they aren't +// going to overlap with op-specific pipelines. +// GENERIC_MERGE_CHECK: Pipeline Collection : ['builtin.module', 'any'] +// GENERIC_MERGE_CHECK-NEXT: 'any' Pipeline +// GENERIC_MERGE_CHECK-NEXT: (anonymous namespace)::TestInterfacePass +// GENERIC_MERGE_CHECK-NEXT: (anonymous namespace)::TestInterfacePass +// GENERIC_MERGE_CHECK-NEXT: 'builtin.module' Pipeline +// GENERIC_MERGE_CHECK-NEXT: (anonymous namespace)::TestModulePass +// GENERIC_MERGE_CHECK-NEXT: 'func.func' Pipeline +// GENERIC_MERGE_CHECK-NEXT: (anonymous namespace)::TestFunctionPass +// GENERIC_MERGE_CHECK-NEXT: 'any' Pipeline +// GENERIC_MERGE_CHECK-NEXT: Canonicalizer +// GENERIC_MERGE_CHECK-NEXT: 'func.func' Pipeline +// GENERIC_MERGE_CHECK-NEXT: CSE +// GENERIC_MERGE_CHECK-NEXT: (A) DominanceInfo