diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -184,10 +184,18 @@ /// pipeline won't execute. virtual LogicalResult initialize(MLIRContext *context) { return success(); } - /// Indicate if the current pass can be scheduled on the given operation type. - /// This is useful for generic operation passes to add restrictions on the + /// Indicate if this pass can be scheduled on the given operation type. This + /// is useful for generic operation passes to add restrictions on the /// operations they operate on. - virtual bool canScheduleOn(RegisteredOperationName opName) const = 0; + /// + /// By default, this pass can be scheduled only on ops whose name matches the + /// name that was specified at the constructor. If no op name was specified, + /// this pass can be scheduled on any op. + virtual bool canScheduleOn(RegisteredOperationName opName) const { + if (!getOpName().has_value()) + return true; + return opName.getStringRef() == *getOpName(); + } /// Schedule an arbitrary pass pipeline on the provided operation. /// This can be invoke any time in a pass to dynamic schedule more passes. @@ -299,11 +307,11 @@ virtual void anchor(); /// Represents a unique identifier for the pass. - TypeID passID; + const TypeID passID; /// The name of the operation that this pass operates on, or std::nullopt if /// this is a generic OperationPass. - std::optional opName; + const std::optional opName; /// The current execution state for the pass. std::optional passState; @@ -358,11 +366,6 @@ return pass->getOpName() == OpT::getOperationName(); } - /// Indicate if the current pass can be scheduled on the given operation type. - bool canScheduleOn(RegisteredOperationName opName) const final { - return opName.getStringRef() == getOpName(); - } - /// Return the current operation being transformed. OpT getOperation() { return cast(Pass::getOperation()); } @@ -391,12 +394,6 @@ protected: OperationPass(TypeID passID) : Pass(passID) {} OperationPass(const OperationPass &) = default; - - /// Indicate if the current pass can be scheduled on the given operation type. - /// By default, generic operation passes can be scheduled on any operation. - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; - } }; /// Pass to transform an operation that implements the given interface. diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -94,9 +94,6 @@ assert(getOperation()->getName().getStringRef() == "invalid_op" && "incorrect op type"); } - bool canScheduleOn(RegisteredOperationName opName) const override { - return opName.getStringRef() == "invalid_op"; - } /// A clone method to create a copy of this pass. std::unique_ptr clonePass() const override {