diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -523,6 +523,43 @@ MLIRContext *ctx; }; +//===----------------------------------------------------------------------===// +// ScopedDiagInspectorHandler +//===----------------------------------------------------------------------===// + +/// This is a RAII class for inspecting the diagnostic. It can be used to +/// collect the information at a certain program scope. It's mainly for +/// debugging purpose. +class ScopedDiagInspectorHandler { +public: + /// An inspector is only supposed to read the diagnostic so it shouldn't + /// update it. + using InspectorTy = std::function; + + ScopedDiagInspectorHandler(MLIRContext *ctx, InspectorTy inspector) + : ctx(ctx), inspector(inspector) { + auto &diagEngine = ctx->getDiagEngine(); + handlerID = + diagEngine.registerHandler([this](Diagnostic &diag) -> LogicalResult { + (void)this->inspector(diag); + // Inspector is supposed to only read the diagnostic, return failure + // to let other handlers do their job. + return failure(); + }); + } + ~ScopedDiagInspectorHandler(); + +private: + /// The context to erase the handler from. + MLIRContext *ctx; + + /// The inspector registered by the user. + InspectorTy inspector; + + /// The handler id returned by the DiagEngine. + DiagnosticEngine::HandlerID handlerID; +}; + //===----------------------------------------------------------------------===// // SourceMgrDiagnosticHandler //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -337,6 +337,13 @@ ctx->getDiagEngine().eraseHandler(handlerID); } +//===----------------------------------------------------------------------===// +// ScopedDiagInspectorHandler +//===----------------------------------------------------------------------===// +ScopedDiagInspectorHandler::~ScopedDiagInspectorHandler() { + ctx->getDiagEngine().eraseHandler(handlerID); +} + //===----------------------------------------------------------------------===// // SourceMgrDiagnosticHandler //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1298,6 +1298,8 @@ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; + // Emit to the DiagEngine so that the inspector can get notified. + loc->getContext()->getDiagEngine().emit(std::move(diag)); }); return failure(); } diff --git a/mlir/test/IR/diagnostic-inspector.mlir b/mlir/test/IR/diagnostic-inspector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/diagnostic-inspector.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt -test-diagnostic-inspector %s +// This test verifies that diagnostic inspector can capture the diagnostics. + +func @test_diagnostic_inspector() { + "test.dummy_0"() : () -> () + "test.dummy_1"() : () -> () + return +} diff --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp --- a/mlir/test/lib/IR/TestDiagnostics.cpp +++ b/mlir/test/lib/IR/TestDiagnostics.cpp @@ -11,12 +11,48 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/SourceMgr.h" +#include + using namespace mlir; namespace { +struct TestDiagnosticInspectorPass + : public PassWrapper> { + StringRef getArgument() const final { return "test-diagnostic-inspector"; } + StringRef getDescription() const final { + return "Test diagnostic inspector support."; + } + + TestDiagnosticInspectorPass() {} + TestDiagnosticInspectorPass(const TestDiagnosticInspectorPass &) {} + + void runOnOperation() override { + llvm::SmallVector opNames; + ScopedDiagInspectorHandler inspector( + &getContext(), [&](const Diagnostic &diag) -> LogicalResult { + opNames.push_back(diag.str()); + return success(); + }); + + getOperation()->walk([&](Operation *op) { + // Emit the operation name for inspector. + op->emitRemark() << op->getName().getStringRef(); + }); + + // Verify the operation names are captured by the inspector. + auto nameIter = opNames.begin(); + getOperation()->walk([&](Operation *op) { + assert(op->getName().getStringRef().equals(*nameIter)); + ++nameIter; + }); + } +}; + struct TestDiagnosticFilterPass : public PassWrapper> { StringRef getArgument() const final { return "test-diagnostic-filter"; } @@ -63,6 +99,7 @@ namespace test { void registerTestDiagnosticsPass() { PassRegistration{}; + PassRegistration{}; } } // namespace test } // namespace mlir