diff --git a/mlir/test/Pass/invalid-ir.mlir b/mlir/test/Pass/invalid-ir.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/Pass/invalid-ir.mlir @@ -0,0 +1,7 @@ +// RUN: mlir-opt %s -pass-pipeline='func.func(test-pass-create-invalid-ir{signal-pass-failure=true})' -mlir-print-ir-after-failure -mlir-print-custom-assembly-after-failure| FileCheck %s +// RUN: mlir-opt %s -pass-pipeline='func.func(test-pass-create-invalid-ir{signal-pass-failure=false})' -mlir-print-ir-after-failure -mlir-print-custom-assembly-after-failure| FileCheck %s + +// Test whether we print generically or not on pass failure. +func @TestCreateInvalidCallInPass() { + return + } diff --git a/mlir/test/lib/Pass/TestPassManager.cpp b/mlir/test/lib/Pass/TestPassManager.cpp --- a/mlir/test/lib/Pass/TestPassManager.cpp +++ b/mlir/test/lib/Pass/TestPassManager.cpp @@ -9,8 +9,11 @@ #include "TestDialect.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/BuiltinOps.h" +#include "mlir/IR/Location.h" +#include "mlir/IR/OperationSupport.h" #include "mlir/Pass/Pass.h" #include "mlir/Pass/PassManager.h" +#include "mlir/Pass/PassRegistry.h" using namespace mlir; @@ -131,6 +134,37 @@ } }; +/// A test pass that creates an invalid operation in a function body. +struct TestInvalidIRPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestInvalidIRPass) + + TestInvalidIRPass() = default; + TestInvalidIRPass(const TestInvalidIRPass &other) { + signalFailure = other.signalFailure; + } + + StringRef getArgument() const final { return "test-pass-create-invalid-ir"; } + StringRef getDescription() const final { + return "Test pass that add an invalid operation in a function body"; + } + void getDependentDialects(DialectRegistry ®istry) const final { + registry.insert(); + } + void runOnOperation() final { + FunctionOpInterface op = getOperation(); + OpBuilder b(getOperation().getBody()); + OperationState state(UnknownLoc::get(&getContext()), + "test.any_attr_of_i32_str"); + b.create(state); + if (signalFailure) + signalPassFailure(); + } + Option signalFailure{*this, "signal-pass-failure", + llvm::cl::desc("Trigger a pass failure as well")}; +}; + /// A test pass that contains a statistic. struct TestStatisticPass : public PassWrapper> { @@ -180,6 +214,7 @@ PassRegistration(); PassRegistration(); PassRegistration(); + PassRegistration(); PassRegistration();