diff --git a/mlir/include/mlir/Pass/Pass.h b/mlir/include/mlir/Pass/Pass.h --- a/mlir/include/mlir/Pass/Pass.h +++ b/mlir/include/mlir/Pass/Pass.h @@ -184,10 +184,18 @@ /// pipeline won't execute. virtual LogicalResult initialize(MLIRContext *context) { return success(); } - /// Indicate if the current pass can be scheduled on the given operation type. - /// This is useful for generic operation passes to add restrictions on the + /// Indicate if this pass can be scheduled on the given operation type. This + /// is useful for generic operation passes to add restrictions on the /// operations they operate on. - virtual bool canScheduleOn(RegisteredOperationName opName) const = 0; + /// + /// By default, this pass can be scheduled only on ops whose name matches the + /// name that was specified at the constructor. If no op name was specified, + /// this pass can be scheduled on any op. + virtual bool canScheduleOn(RegisteredOperationName opName) const { + if (!getOpName().has_value()) + return true; + return opName.getStringRef() == *getOpName(); + } /// Schedule an arbitrary pass pipeline on the provided operation. /// This can be invoke any time in a pass to dynamic schedule more passes. @@ -299,11 +307,11 @@ virtual void anchor(); /// Represents a unique identifier for the pass. - TypeID passID; + const TypeID passID; /// The name of the operation that this pass operates on, or std::nullopt if /// this is a generic OperationPass. - std::optional opName; + const std::optional opName; /// The current execution state for the pass. std::optional passState; @@ -358,11 +366,6 @@ return pass->getOpName() == OpT::getOperationName(); } - /// Indicate if the current pass can be scheduled on the given operation type. - bool canScheduleOn(RegisteredOperationName opName) const final { - return opName.getStringRef() == getOpName(); - } - /// Return the current operation being transformed. OpT getOperation() { return cast(Pass::getOperation()); } @@ -391,12 +394,6 @@ protected: OperationPass(TypeID passID) : Pass(passID) {} OperationPass(const OperationPass &) = default; - - /// Indicate if the current pass can be scheduled on the given operation type. - /// By default, generic operation passes can be scheduled on any operation. - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; - } }; /// Pass to transform an operation that implements the given interface. diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -90,9 +90,9 @@ InvalidPass() : Pass(TypeID::get(), StringRef("invalid_op")) {} StringRef getName() const override { return "Invalid Pass"; } - void runOnOperation() override {} - bool canScheduleOn(RegisteredOperationName opName) const override { - return true; + void runOnOperation() override { + assert(getOperation()->getName().getStringRef() == "invalid_op" && + "incorrect op type"); } /// A clone method to create a copy of this pass. @@ -137,6 +137,18 @@ result = pm.run(module.get()); EXPECT_TRUE(succeeded(result)); + // Create a pass manager that schedules arbitrary ops, but then run a pass + // that expects a specific op. + PassManager anyOpPm(&context); + // Adding the pass succeeds, because the we do not know yet on what op the + // pass will be scheduled. + anyOpPm.addPass(std::make_unique()); + result = anyOpPm.run(module.get()); + EXPECT_TRUE(failed(result)); + ASSERT_TRUE(diagnostic.get() != nullptr); + EXPECT_EQ(diagnostic->str(), "'builtin.module' op trying to schedule a pass " + "on an unsupported operation"); + // Check that adding the pass at the top-level triggers a fatal error. ASSERT_DEATH(pm.addPass(std::make_unique()), "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "