diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -45,6 +45,13 @@ Remark, }; +/// Determine if this diagnostic will group the following diagnostics for +/// handling. +enum class DiagnosticKind { + Single, + Group, +}; + //===----------------------------------------------------------------------===// // DiagnosticArgument //===----------------------------------------------------------------------===// @@ -154,31 +161,34 @@ /// to the DiagnosticEngine. It should generally not be constructed directly, /// and instead used transitively via InFlightDiagnostic. class Diagnostic { - using NoteVector = std::vector>; + using DiagVector = std::vector>; - /// This class implements a wrapper iterator around NoteVector::iterator to + /// This class implements a wrapper iterator around DiagVector::iterator to /// implicitly dereference the unique_ptr. - template - class NoteIteratorImpl - : public llvm::mapped_iterator { - static ResultTy &unwrap(NotePtrTy note) { return *note; } + class DiagIteratorImpl + : public llvm::mapped_iterator { + static ResultTy &unwrap(DiagPtrTy note) { return *note; } public: - NoteIteratorImpl(IteratorTy it) - : llvm::mapped_iterator(it, + DiagIteratorImpl(IteratorTy it) + : llvm::mapped_iterator(it, &unwrap) {} }; public: Diagnostic(Location loc, DiagnosticSeverity severity) - : loc(loc), severity(severity) {} + : loc(loc), severity(severity), kind(DiagnosticKind::Single) {} Diagnostic(Diagnostic &&) = default; Diagnostic &operator=(Diagnostic &&) = default; /// Returns the severity of this diagnostic. DiagnosticSeverity getSeverity() const { return severity; } + /// Returns the kind of this diagnostic. + DiagnosticKind getKind() const { return kind; } + /// Returns the source location for this diagnostic. Location getLocation() const { return loc; } @@ -265,30 +275,54 @@ /// diagnostic. Notes may not be attached to other notes. Diagnostic &attachNote(Optional noteLoc = llvm::None); - using note_iterator = NoteIteratorImpl; - using const_note_iterator = NoteIteratorImpl; + /// Addpend the diagnostic which is created after this diagnositc. + void appendNestedDiagnostic(Diagnostic nested); + + using diag_iterator = DiagIteratorImpl; + using const_diag_iterator = DiagIteratorImpl; /// Returns the notes held by this diagnostic. - iterator_range getNotes() { + iterator_range getNotes() { return {notes.begin(), notes.end()}; } - iterator_range getNotes() const { + iterator_range getNotes() const { return {notes.begin(), notes.end()}; } + /// Check if it has nested diagnostics. + bool hasNested() const { return nested.hasValue(); } + + /// Returns the nested diagnostics grouped with this diagnostic. + iterator_range getNested() { + assert(hasNested()); + return {nested->begin(), nested->end()}; + } + iterator_range getNested() const { + assert(hasNested()); + return {nested->begin(), nested->end()}; + } + /// Allow a diagnostic to be converted to 'failure'. operator LogicalResult() const; private: + Diagnostic(Location loc, DiagnosticSeverity severity, DiagnosticKind kind) + : loc(loc), severity(severity), kind(kind) {} Diagnostic(const Diagnostic &rhs) = delete; Diagnostic &operator=(const Diagnostic &rhs) = delete; + // Allow access to the constructor which can assign the diagnostic kind. + friend DiagnosticEngine; + /// The source location. Location loc; /// The severity of this diagnostic. DiagnosticSeverity severity; + /// The kind of this diagnostic. + DiagnosticKind kind; + /// The current list of arguments. SmallVector arguments; @@ -297,7 +331,11 @@ std::vector> strings; /// A list of attached notes. - NoteVector notes; + DiagVector notes; + + /// A list of diagnostics that are nested under this. This is used when it's + /// DiagnosticsKind::Group. + Optional nested; }; inline raw_ostream &operator<<(raw_ostream &os, const Diagnostic &diag) { @@ -437,11 +475,19 @@ /// Erase the registered diagnostic handler with the given identifier. void eraseHandler(HandlerID id); - /// Create a new inflight diagnostic with the given location and severity. - InFlightDiagnostic emit(Location loc, DiagnosticSeverity severity) { + /// Create a new inflight diagnostic with the given location, severity and a + /// default kind. + InFlightDiagnostic emit(Location loc, DiagnosticSeverity severity, + DiagnosticKind kind) { assert(severity != DiagnosticSeverity::Note && "notes should not be emitted directly"); - return InFlightDiagnostic(this, Diagnostic(loc, severity)); + InFlightDiagnostic inFlight(this, Diagnostic(loc, severity, kind)); + if (kind == DiagnosticKind::Group) { + assert(severity == DiagnosticSeverity::Remark && + "Group can be only Remark kind"); + notifyGroup(*inFlight.impl); + } + return inFlight; } /// Emit a diagnostic using the registered issue handler if present, or with @@ -449,6 +495,8 @@ void emit(Diagnostic diag); private: + void notifyGroup(Diagnostic &group); + friend class MLIRContextImpl; DiagnosticEngine(); @@ -468,6 +516,13 @@ InFlightDiagnostic emitRemark(Location loc); InFlightDiagnostic emitRemark(Location loc, const Twine &message); +/// Utility method to emit a remark group using this location. The following +/// diagnostics will be attached to this group diagnostic. Note that this will +/// delay the handling of a reported diagnostic. Please ensure the side effect +/// before using it. +InFlightDiagnostic emitRemarkGroup(Location loc); +InFlightDiagnostic emitRemarkGroup(Location loc, const Twine &message); + /// Overloads of the above emission functions that take an optionally null /// location. If the location is null, no diagnostic is emitted and a failure is /// returned. Given that the provided location may be null, these methods take @@ -526,6 +581,43 @@ MLIRContext *ctx; }; +//===----------------------------------------------------------------------===// +// ScopedDiagInspectorHandler +//===----------------------------------------------------------------------===// + +/// This is a RAII class for inspecting the diagnostic. It can be used to +/// collect the information at a certain program scope. It's mainly for +/// debugging purpose. +class ScopedDiagInspectorHandler { +public: + /// An inspector is only supposed to read the diagnostic so it shouldn't + /// update it. + using InspectorTy = std::function; + + ScopedDiagInspectorHandler(MLIRContext *ctx, InspectorTy inspector) + : ctx(ctx), inspector(inspector) { + auto &diagEngine = ctx->getDiagEngine(); + handlerID = + diagEngine.registerHandler([this](Diagnostic &diag) -> LogicalResult { + (void)this->inspector(diag); + // Inspector is supposed to only read the diagnostic, return failure + // to let other handlers do their job. + return failure(); + }); + } + ~ScopedDiagInspectorHandler(); + +private: + /// The context to erase the handler from. + MLIRContext *ctx; + + /// The inspector registered by the user. + InspectorTy inspector; + + /// The handler id returned by the DiagEngine. + DiagnosticEngine::HandlerID handlerID; +}; + //===----------------------------------------------------------------------===// // SourceMgrDiagnosticHandler //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -22,6 +22,7 @@ #include "llvm/Support/Regex.h" #include "llvm/Support/Signals.h" #include "llvm/Support/SourceMgr.h" +#include "llvm/Support/ThreadLocal.h" #include "llvm/Support/raw_ostream.h" using namespace mlir; @@ -177,6 +178,12 @@ return *notes.back(); } +void Diagnostic::appendNestedDiagnostic(Diagnostic diag) { + if (!nested.hasValue()) + nested.emplace(); + nested->emplace_back(std::make_unique(std::move(diag))); +} + /// Allow a diagnostic to be converted to 'failure'. Diagnostic::operator LogicalResult() const { return failure(); } @@ -202,7 +209,11 @@ } /// Abandons this diagnostic. -void InFlightDiagnostic::abandon() { owner = nullptr; } +void InFlightDiagnostic::abandon() { + assert((!impl.hasValue() || impl->getKind() != DiagnosticKind::Group) && + "Can't abandon group diagnostic"); + owner = nullptr; +} //===----------------------------------------------------------------------===// // DiagnosticEngineImpl @@ -226,6 +237,9 @@ /// This is a unique identifier counter for diagnostic handlers in the /// context. This id starts at 1 to allow for 0 to be used as a sentinel. DiagnosticEngine::HandlerID uniqueHandlerId = 1; + + /// The group diagnostic list. + llvm::sys::ThreadLocal group; }; } // namespace detail } // namespace mlir @@ -233,6 +247,15 @@ /// Emit a diagnostic using the registered issue handle if present, or with /// the default behavior if not. void DiagnosticEngineImpl::emit(Diagnostic diag) { + // If there is a group diagnostic, nested the diag to it. All the nested + // diagnostics will be handled along with the group diagnostic. + if (diag.getKind() == DiagnosticKind::Group) { + group.erase(); + } else if (group.get() != nullptr) { + group.get()->appendNestedDiagnostic(std::move(diag)); + return; + } + llvm::sys::SmartScopedLock lock(mutex); // Try to process the given diagnostic on one of the registered handlers. @@ -253,6 +276,12 @@ // The default behavior for errors is to emit them to stderr. os << diag << '\n'; + + // Emit nested diagnostics. + if (diag.hasNested()) + for (auto &nested : diag.getNested()) + emit(std::move(nested)); + os.flush(); } @@ -279,6 +308,14 @@ impl->handlers.erase(handlerID); } +void DiagnosticEngine::notifyGroup(Diagnostic &group) { + // `group` is thread-local thus we don't need the lock. + + // Currently don't support nested group. + assert(impl->group.get() == nullptr); + impl->group.set(&group); +} + /// Emit a diagnostic using the registered issue handler if present, or with /// the default behavior if not. void DiagnosticEngine::emit(Diagnostic diag) { @@ -290,11 +327,12 @@ /// Helper function used to emit a diagnostic with an optionally empty twine /// message. If the message is empty, then it is not inserted into the /// diagnostic. -static InFlightDiagnostic -emitDiag(Location location, DiagnosticSeverity severity, const Twine &message) { +static InFlightDiagnostic emitDiag(Location location, + DiagnosticSeverity severity, + DiagnosticKind kind, const Twine &message) { MLIRContext *ctx = location->getContext(); auto &diagEngine = ctx->getDiagEngine(); - auto diag = diagEngine.emit(location, severity); + auto diag = diagEngine.emit(location, severity, kind); if (!message.isTriviallyEmpty()) diag << message; @@ -315,7 +353,8 @@ /// Emit an error message using this location. InFlightDiagnostic mlir::emitError(Location loc) { return emitError(loc, {}); } InFlightDiagnostic mlir::emitError(Location loc, const Twine &message) { - return emitDiag(loc, DiagnosticSeverity::Error, message); + return emitDiag(loc, DiagnosticSeverity::Error, DiagnosticKind::Single, + message); } /// Emit a warning message using this location. @@ -323,7 +362,8 @@ return emitWarning(loc, {}); } InFlightDiagnostic mlir::emitWarning(Location loc, const Twine &message) { - return emitDiag(loc, DiagnosticSeverity::Warning, message); + return emitDiag(loc, DiagnosticSeverity::Warning, DiagnosticKind::Single, + message); } /// Emit a remark message using this location. @@ -331,7 +371,18 @@ return emitRemark(loc, {}); } InFlightDiagnostic mlir::emitRemark(Location loc, const Twine &message) { - return emitDiag(loc, DiagnosticSeverity::Remark, message); + return emitDiag(loc, DiagnosticSeverity::Remark, DiagnosticKind::Single, + message); +} + +/// Emit a remark group using this location. The following diagnostics will be +/// attached to this group diagnostic. +InFlightDiagnostic mlir::emitRemarkGroup(Location loc) { + return emitRemarkGroup(loc, {}); +} +InFlightDiagnostic mlir::emitRemarkGroup(Location loc, const Twine &message) { + return emitDiag(loc, DiagnosticSeverity::Remark, DiagnosticKind::Group, + message); } //===----------------------------------------------------------------------===// @@ -343,6 +394,13 @@ ctx->getDiagEngine().eraseHandler(handlerID); } +//===----------------------------------------------------------------------===// +// ScopedDiagInspectorHandler +//===----------------------------------------------------------------------===// +ScopedDiagInspectorHandler::~ScopedDiagInspectorHandler() { + ctx->getDiagEngine().eraseHandler(handlerID); +} + //===----------------------------------------------------------------------===// // SourceMgrDiagnosticHandler //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp --- a/mlir/lib/Transforms/Utils/DialectConversion.cpp +++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp @@ -1482,6 +1482,8 @@ Diagnostic diag(loc, DiagnosticSeverity::Remark); reasonCallback(diag); logger.startLine() << "** Failure : " << diag.str() << "\n"; + // Emit to the DiagEngine so that the inspector can get notified. + loc->getContext()->getDiagEngine().emit(std::move(diag)); }); return failure(); } diff --git a/mlir/test/IR/diagnostic-inspector.mlir b/mlir/test/IR/diagnostic-inspector.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/diagnostic-inspector.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-opt -test-diagnostic-inspector %s +// This test verifies that diagnostic inspector can capture the diagnostics. + +func @test_diagnostic_inspector() { + "test.dummy_0"() : () -> () + "test.dummy_1"() : () -> () + return +} diff --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp --- a/mlir/test/lib/IR/TestDiagnostics.cpp +++ b/mlir/test/lib/IR/TestDiagnostics.cpp @@ -11,12 +11,60 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/Diagnostics.h" #include "mlir/Pass/Pass.h" +#include "llvm/ADT/SmallVector.h" #include "llvm/Support/SourceMgr.h" +#include + using namespace mlir; namespace { +struct TestDiagnosticInspectorPass + : public PassWrapper> { + StringRef getArgument() const final { return "test-diagnostic-inspector"; } + StringRef getDescription() const final { + return "Test diagnostic inspector support."; + } + + TestDiagnosticInspectorPass() {} + TestDiagnosticInspectorPass(const TestDiagnosticInspectorPass &) {} + + void collectOpNames(llvm::SmallVector &opNames, Operation *op) { + InFlightDiagnostic groupDiag = + emitRemarkGroup(getOperation()->getLoc(), "Test OpName Group"); + op->walk([&](Operation *op) { + // Emit the operation name for inspector. + op->emitRemark() << op->getName().getStringRef(); + }); + } + + void runOnOperation() override { + llvm::SmallVector opNames; + ScopedDiagInspectorHandler inspector( + &getContext(), [&](const Diagnostic &diag) -> LogicalResult { + // All the diagnostics will be grouped into one, thus it's supposed + // to have only one diagnostic reported. + assert(opNames.empty()); + assert(diag.getSeverity() == DiagnosticSeverity::Remark); + assert(diag.hasNested()); + for (auto &nested : diag.getNested()) + opNames.push_back(nested.str()); + return success(); + }); + + collectOpNames(opNames, getOperation()); + + // Verify the operation names are captured by the inspector. + auto nameIter = opNames.begin(); + getOperation()->walk([&](Operation *op) { + assert(op->getName().getStringRef().equals(*nameIter)); + ++nameIter; + }); + } +}; + struct TestDiagnosticFilterPass : public PassWrapper> { StringRef getArgument() const final { return "test-diagnostic-filter"; } @@ -63,6 +111,7 @@ namespace test { void registerTestDiagnosticsPass() { PassRegistration{}; + PassRegistration{}; } } // namespace test } // namespace mlir