diff --git a/mlir/docs/Diagnostics.md b/mlir/docs/Diagnostics.md --- a/mlir/docs/Diagnostics.md +++ b/mlir/docs/Diagnostics.md @@ -243,6 +243,45 @@ SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context); ``` +#### Filtering Locations + +In some situations, a diagnostic may be emitted with a callsite location in a +very deep call stack in which many frames are unrelated to the user source code. +These situations often arise when the user source code is intertwined with that +of a large framework or library. The context of the diagnostic in these cases is +often obfuscated by the unrelated framework source locations. To help alleviate +this obfuscation, the `SourceMgrDiagnosticHandler` provides support for +filtering which locations are shown to the user. To enable filtering, a user +must simply provide a filter function to the `SourceMgrDiagnosticHandler` on +construction that indicates which locations should be shown. A quick example is +shown below: + +```c++ +// Here we define the functor that controls which locations are shown to the +// user. This functor should return true when a location should be shown, and +// false otherwise. When filtering a container location, such as a NameLoc, this +// function should not recurse into the child location. Recursion into nested +// location is performed as necessary by the caller. +auto shouldShowFn = [](Location loc) -> bool { + FileLineColLoc fileLoc = loc.dyn_cast(); + + // We don't perform any filtering on non-file locations. + // Reminder: The caller will recurse into any necessary child locations. + if (!fileLoc) + return true; + + // Don't show file locations that contain our framework code. + return !fileLoc.getFilename().strref().contains("my/framework/source/"); +}; + +SourceMgr sourceMgr; +MLIRContext context; +SourceMgrDiagnosticHandler sourceMgrHandler(sourceMgr, &context, shouldShowFn); +``` + +Note: In the case where all locations are filtered out, the first location in +the stack will still be shown. + ### SourceMgr Diagnostic Verifier Handler This handler is a wrapper around a llvm::SourceMgr that is used to verify that diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -530,9 +530,20 @@ /// This class is a utility diagnostic handler for use with llvm::SourceMgr. class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler { public: + /// This type represents a functor used to filter out locations when printing + /// a diagnostic. It should return true if the provided location is okay to + /// display, false otherwise. If all locations in a diagnostic are filtered + /// out, the first location is used as the sole location. When deciding + /// whether or not to filter a location, this function should not recurse into + /// any nested location. This recursion is handled automatically by the + /// caller. + using ShouldShowLocFn = llvm::unique_function; + + SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx, + raw_ostream &os, + ShouldShowLocFn &&shouldShowLocFn = {}); SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx, - raw_ostream &os); - SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx); + ShouldShowLocFn &&shouldShowLocFn = {}); ~SourceMgrDiagnosticHandler(); /// Emit the given diagnostic information with the held source manager. @@ -547,12 +558,19 @@ /// available. const llvm::MemoryBuffer *getBufferForFile(StringRef filename); + /// Return true if the given location should be shown, false otherwise. + bool shouldShowLocation(Location loc); + /// The source manager that we are wrapping. llvm::SourceMgr &mgr; /// The output stream to use when printing diagnostics. raw_ostream &os; + /// A functor used when determining if a location for a diagnostic should be + /// shown. If null, all locations should be shown. + ShouldShowLocFn shouldShowLocFn; + private: /// Convert a location into the given memory buffer into an SMLoc. llvm::SMLoc convertLocToSMLoc(FileLineColLoc loc); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallString.h" #include "llvm/ADT/StringMap.h" +#include "llvm/ADT/TypeSwitch.h" #include "llvm/Support/Mutex.h" #include "llvm/Support/PrettyStackTrace.h" #include "llvm/Support/Regex.h" @@ -409,17 +410,19 @@ llvm_unreachable("Unknown DiagnosticSeverity"); } -SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, - MLIRContext *ctx, - raw_ostream &os) +SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler( + llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os, + ShouldShowLocFn &&shouldShowLocFn) : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os), + shouldShowLocFn(std::move(shouldShowLocFn)), impl(new SourceMgrDiagnosticHandlerImpl()) { setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); }); } -SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, - MLIRContext *ctx) - : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs()) {} +SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler( + llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn) + : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(), + std::move(shouldShowLocFn)) {} SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() {} @@ -460,17 +463,23 @@ /// Emit the given diagnostic with the held source manager. void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) { - // Emit the diagnostic. + SmallVector> locationStack; + auto addLocToStack = [&](Location loc, StringRef locContext) { + if (shouldShowLocation(loc)) + locationStack.emplace_back(loc, locContext); + }; + + // Add locations to display for this diagnostic. Location loc = diag.getLocation(); - emitDiagnostic(loc, diag.str(), diag.getSeverity()); + addLocToStack(loc, /*locContext=*/{}); - // If the diagnostic location was a call site location, then print the call - // stack as well. + // If the diagnostic location was a call site location, add the call stack as + // well. if (auto callLoc = getCallSiteLoc(loc)) { // Print the call stack while valid, or until the limit is reached. loc = callLoc->getCaller(); for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) { - emitDiagnostic(loc, "called from", DiagnosticSeverity::Note); + addLocToStack(loc, "called from"); if ((callLoc = getCallSiteLoc(loc))) loc = callLoc->getCaller(); else @@ -478,6 +487,17 @@ } } + // If the location stack is empty, use the initial location. + if (locationStack.empty()) { + emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity()); + + // Otherwise, use the location stack. + } else { + emitDiagnostic(locationStack.front().first, diag.str(), diag.getSeverity()); + for (auto &it : llvm::drop_begin(locationStack)) + emitDiagnostic(it.first, it.second, DiagnosticSeverity::Note); + } + // Emit each of the notes. Only display the source code if the location is // different from the previous location. for (auto ¬e : diag.getNotes()) { @@ -495,6 +515,39 @@ return nullptr; } +/// Return true if the given location should be shown, false otherwise. +bool SourceMgrDiagnosticHandler::shouldShowLocation(Location loc) { + if (!shouldShowLocFn) + return true; + + SmallVector filterStack(1, loc); + while (!filterStack.empty()) { + Location loc = filterStack.pop_back_val(); + if (!shouldShowLocFn(loc)) + return false; + + // Recurse into the child locations of some of location types. + TypeSwitch(loc) + .Case([&](CallSiteLoc callLoc) { + // We only recurse into the callee of a call site, as the caller will + // be emitted in a different note on the main diagnostic. + filterStack.push_back(callLoc.getCallee()); + }) + .Case([&](FusedLoc fusedLoc) { + ArrayRef children = fusedLoc.getLocations(); + filterStack.append(children.begin(), children.end()); + }) + .Case([&](NameLoc nameLoc) { + filterStack.push_back(nameLoc.getChildLoc()); + }) + .Case([&](OpaqueLoc opaqueLoc) { + filterStack.push_back(opaqueLoc.getFallbackLocation()); + }); + } + + return true; +} + /// Get a memory buffer for the given file, or the main file of the source /// manager if one doesn't exist. This always returns non-null. llvm::SMLoc SourceMgrDiagnosticHandler::convertLocToSMLoc(FileLineColLoc loc) { diff --git a/mlir/test/IR/diagnostic-handler-filter.mlir b/mlir/test/IR/diagnostic-handler-filter.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/diagnostic-handler-filter.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -test-diagnostic-filter='filters=mysource1' -o - 2>&1 | FileCheck %s +// This test verifies that diagnostic handler can emit the call stack successfully. + +// CHECK-LABEL: Test 'test1' +// CHECK-NEXT: mysource2:1:0: error: test diagnostic +// CHECK-NEXT: mysource3:2:0: note: called from +func private @test1() attributes { + test.loc = loc(callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))) +} + +// CHECK-LABEL: Test 'test2' +// CHECK-NEXT: mysource1:0:0: error: test diagnostic +func private @test2() attributes { + test.loc = loc("mysource1":0:0) +} diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestDiagnostics.cpp TestDominance.cpp TestFunc.cpp TestInterfaces.cpp diff --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestDiagnostics.cpp @@ -0,0 +1,66 @@ +//===- TestDiagnostics.cpp - Test Diagnostic Utilities --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for constructing and resolving dominance +// information. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/SourceMgr.h" + +using namespace mlir; + +namespace { +struct TestDiagnosticFilterPass + : public PassWrapper> { + TestDiagnosticFilterPass() {} + TestDiagnosticFilterPass(const TestDiagnosticFilterPass &) {} + + void runOnOperation() override { + llvm::errs() << "Test '" << getOperation().getName() << "'\n"; + + // Build a diagnostic handler that has filtering capabilities. + auto filterFn = [&](Location loc) { + // Ignore non-file locations. + FileLineColLoc fileLoc = loc.dyn_cast(); + if (!fileLoc) + return true; + + // Don't show file locations if their name contains a filter. + return llvm::none_of(filters, [&](StringRef filter) { + return fileLoc.getFilename().strref().contains(filter); + }); + }; + llvm::SourceMgr sourceMgr; + SourceMgrDiagnosticHandler handler(sourceMgr, &getContext(), llvm::errs(), + filterFn); + + // Emit a diagnostic for every operation with a valid loc. + getOperation()->walk([&](Operation *op) { + if (LocationAttr locAttr = op->getAttrOfType("test.loc")) + emitError(locAttr, "test diagnostic"); + }); + } + + ListOption filters{ + *this, "filters", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Specifies the diagnostic file name filters.")}; +}; + +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestDiagnosticsPass() { + PassRegistration( + "test-diagnostic-filter", "Test diagnostic filtering support."); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -68,6 +68,7 @@ void registerTestGpuSerializeToHsacoPass(); void registerTestDataLayoutQuery(); void registerTestDecomposeCallGraphTypes(); +void registerTestDiagnosticsPass(); void registerTestDialect(DialectRegistry &); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); @@ -140,6 +141,7 @@ test::registerTestAliasAnalysisPass(); test::registerTestCallGraphPass(); test::registerTestConstantFold(); + test::registerTestDiagnosticsPass(); #if MLIR_CUDA_CONVERSIONS_ENABLED test::registerTestGpuSerializeToCubinPass(); #endif