diff --git a/mlir/lib/Pass/Pass.cpp b/mlir/lib/Pass/Pass.cpp --- a/mlir/lib/Pass/Pass.cpp +++ b/mlir/lib/Pass/Pass.cpp @@ -438,6 +438,9 @@ if (!opInfo->hasTrait()) return op->emitOpError() << "trying to schedule a pass on an operation not " "marked as 'IsolatedFromAbove'"; + if (!pass->canScheduleOn(*op->getName().getRegisteredInfo())) + return op->emitOpError() + << "trying to schedule a pass on an unsupported operation"; // Initialize the pass state with a callback for the pass to dynamically // execute a pipeline on the currently visited operation. diff --git a/mlir/test/Dialect/Transform/test-pass-application.mlir b/mlir/test/Dialect/Transform/test-pass-application.mlir --- a/mlir/test/Dialect/Transform/test-pass-application.mlir +++ b/mlir/test/Dialect/Transform/test-pass-application.mlir @@ -70,3 +70,22 @@ %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op transform.apply_registered_pass "canonicalize" to %1 {options = "top-down=false"} : (!transform.any_op) -> !transform.any_op } + +// ----- + +module { + // expected-error @below {{trying to schedule a pass on an unsupported operation}} + // expected-note @below {{target op}} + func.func @invalid_target_op_type() { + return + } + + transform.sequence failures(propagate) { + ^bb1(%arg1: !transform.any_op): + %1 = transform.structured.match ops{["func.func"]} in %arg1 : (!transform.any_op) -> !transform.any_op + + // func-bufferize can be applied only to ModuleOps. + // expected-error @below {{pass pipeline failed}} + transform.apply_registered_pass "func-bufferize" to %1 : (!transform.any_op) -> !transform.any_op + } +} diff --git a/mlir/unittests/Pass/PassManagerTest.cpp b/mlir/unittests/Pass/PassManagerTest.cpp --- a/mlir/unittests/Pass/PassManagerTest.cpp +++ b/mlir/unittests/Pass/PassManagerTest.cpp @@ -90,9 +90,12 @@ InvalidPass() : Pass(TypeID::get(), StringRef("invalid_op")) {} StringRef getName() const override { return "Invalid Pass"; } - void runOnOperation() override {} + void runOnOperation() override { + assert(getOperation()->getName().getStringRef() == "invalid_op" && + "incorrect op type"); + } bool canScheduleOn(RegisteredOperationName opName) const override { - return true; + return opName.getStringRef() == "invalid_op"; } /// A clone method to create a copy of this pass. @@ -137,6 +140,18 @@ result = pm.run(module.get()); EXPECT_TRUE(succeeded(result)); + // Create a pass manager that schedules arbitrary ops, but then run a pass + // that expects a specific op. + PassManager anyOpPm(&context); + // Adding the pass succeeds, because the we do not know yet on what op the + // pass will be scheduled. + anyOpPm.addPass(std::make_unique()); + result = anyOpPm.run(module.get()); + EXPECT_TRUE(failed(result)); + ASSERT_TRUE(diagnostic.get() != nullptr); + EXPECT_EQ(diagnostic->str(), "'builtin.module' op trying to schedule a pass " + "on an unsupported operation"); + // Check that adding the pass at the top-level triggers a fatal error. ASSERT_DEATH(pm.addPass(std::make_unique()), "Can't add pass 'Invalid Pass' restricted to 'invalid_op' on a "