diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -530,9 +530,23 @@ /// This class is a utility diagnostic handler for use with llvm::SourceMgr. class SourceMgrDiagnosticHandler : public ScopedDiagnosticHandler { public: + /// This type represents a functor used to filter out locations when printing + /// a diagnostic. It should return true if the provided location is okay to + /// display, false otherwise. Note that when filtering if all locations in a + /// diagnostic are filtered out, the first location is used as the sole + /// location. + /// TODO: For now we only allow filtering on file locations, but we could + /// expand this to other locations in the future. This would just require some + /// special handling to avoid needing to unwrap the container locations(e.g. + /// if you wanted to check the name of a NameLoc, you wouldn't want to have to + /// manually recurse into the name). + using ShouldShowLocFn = llvm::unique_function; + + SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx, + raw_ostream &os, + ShouldShowLocFn &&shouldShowLocFn = {}); SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx, - raw_ostream &os); - SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, MLIRContext *ctx); + ShouldShowLocFn &&shouldShowLocFn = {}); ~SourceMgrDiagnosticHandler(); /// Emit the given diagnostic information with the held source manager. @@ -553,6 +567,10 @@ /// The output stream to use when printing diagnostics. raw_ostream &os; + /// A functor used when determining if a location for a diagnostic should be + /// shown. If null, all locations should be shown. + ShouldShowLocFn shouldShowLocFn; + private: /// Convert a location into the given memory buffer into an SMLoc. llvm::SMLoc convertLocToSMLoc(FileLineColLoc loc); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -409,17 +409,19 @@ llvm_unreachable("Unknown DiagnosticSeverity"); } -SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, - MLIRContext *ctx, - raw_ostream &os) +SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler( + llvm::SourceMgr &mgr, MLIRContext *ctx, raw_ostream &os, + ShouldShowLocFn &&shouldShowLocFn) : ScopedDiagnosticHandler(ctx), mgr(mgr), os(os), + shouldShowLocFn(std::move(shouldShowLocFn)), impl(new SourceMgrDiagnosticHandlerImpl()) { setHandler([this](Diagnostic &diag) { emitDiagnostic(diag); }); } -SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler(llvm::SourceMgr &mgr, - MLIRContext *ctx) - : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs()) {} +SourceMgrDiagnosticHandler::SourceMgrDiagnosticHandler( + llvm::SourceMgr &mgr, MLIRContext *ctx, ShouldShowLocFn &&shouldShowLocFn) + : SourceMgrDiagnosticHandler(mgr, ctx, llvm::errs(), + std::move(shouldShowLocFn)) {} SourceMgrDiagnosticHandler::~SourceMgrDiagnosticHandler() {} @@ -460,17 +462,28 @@ /// Emit the given diagnostic with the held source manager. void SourceMgrDiagnosticHandler::emitDiagnostic(Diagnostic &diag) { - // Emit the diagnostic. + SmallVector locationStack; + auto addLocToStack = [&](Location loc) { + // If we can grab a file location, check to see if we show it. + if (auto fileLoc = getFileLineColLoc(loc)) { + if (shouldShowLocFn && !shouldShowLocFn(*fileLoc)) + return; + loc = *fileLoc; + } + locationStack.push_back(loc); + }; + + // Add locations to display for this diagnostic. Location loc = diag.getLocation(); - emitDiagnostic(loc, diag.str(), diag.getSeverity()); + addLocToStack(loc); - // If the diagnostic location was a call site location, then print the call - // stack as well. + // If the diagnostic location was a call site location, add the call stack as + // well. if (auto callLoc = getCallSiteLoc(loc)) { // Print the call stack while valid, or until the limit is reached. loc = callLoc->getCaller(); for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) { - emitDiagnostic(loc, "called from", DiagnosticSeverity::Note); + addLocToStack(loc); if ((callLoc = getCallSiteLoc(loc))) loc = callLoc->getCaller(); else @@ -478,6 +491,17 @@ } } + // If the location stack is empty, use the initial location. + if (locationStack.empty()) { + emitDiagnostic(diag.getLocation(), diag.str(), diag.getSeverity()); + + // Otherwise, use the location stack. + } else { + emitDiagnostic(locationStack.front(), diag.str(), diag.getSeverity()); + for (Location loc : llvm::drop_begin(locationStack)) + emitDiagnostic(loc, "called from", DiagnosticSeverity::Note); + } + // Emit each of the notes. Only display the source code if the location is // different from the previous location. for (auto ¬e : diag.getNotes()) { diff --git a/mlir/test/IR/diagnostic-handler-filter.mlir b/mlir/test/IR/diagnostic-handler-filter.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/diagnostic-handler-filter.mlir @@ -0,0 +1,15 @@ +// RUN: mlir-opt %s -test-diagnostic-filter='filters=mysource1' -o - 2>&1 | FileCheck %s +// This test verifies that diagnostic handler can emit the call stack successfully. + +// CHECK-LABEL: Test 'test1' +// CHECK-NEXT: mysource2:1:0: error: test diagnostic +// CHECK-NEXT: mysource3:2:0: note: called from +func private @test1() attributes { + test.loc = loc(callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))) +} + +// CHECK-LABEL: Test 'test2' +// CHECK-NEXT: mysource1:0:0: error: test diagnostic +func private @test2() attributes { + test.loc = loc("mysource1":0:0) +} diff --git a/mlir/test/lib/IR/CMakeLists.txt b/mlir/test/lib/IR/CMakeLists.txt --- a/mlir/test/lib/IR/CMakeLists.txt +++ b/mlir/test/lib/IR/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRTestIR + TestDiagnostics.cpp TestDominance.cpp TestFunc.cpp TestInterfaces.cpp diff --git a/mlir/test/lib/IR/TestDiagnostics.cpp b/mlir/test/lib/IR/TestDiagnostics.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/IR/TestDiagnostics.cpp @@ -0,0 +1,60 @@ +//===- TestDiagnostics.cpp - Test Diagnostic Utilities --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file contains test passes for constructing and resolving dominance +// information. +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/SymbolTable.h" +#include "mlir/Pass/Pass.h" +#include "llvm/Support/SourceMgr.h" + +using namespace mlir; + +namespace { +struct TestDiagnosticFilterPass + : public PassWrapper> { + TestDiagnosticFilterPass() {} + TestDiagnosticFilterPass(const TestDiagnosticFilterPass &) {} + + void runOnOperation() override { + llvm::errs() << "Test '" << getOperation().getName() << "'\n"; + + // Build a diagnostic handler that has filtering capabilities. + auto filterFn = [&](FileLineColLoc loc) { + return llvm::none_of(filters, [&](StringRef filter) { + return loc.getFilename().strref().contains(filter); + }); + }; + llvm::SourceMgr sourceMgr; + SourceMgrDiagnosticHandler handler(sourceMgr, &getContext(), llvm::errs(), + filterFn); + + // Emit a diagnostic for every operation with a valid loc. + getOperation()->walk([&](Operation *op) { + if (LocationAttr locAttr = op->getAttrOfType("test.loc")) + emitError(locAttr, "test diagnostic"); + }); + } + + ListOption filters{ + *this, "filters", llvm::cl::MiscFlags::CommaSeparated, + llvm::cl::desc("Specifies the diagnostic file name filters.")}; +}; + +} // end anonymous namespace + +namespace mlir { +namespace test { +void registerTestDiagnosticsPass() { + PassRegistration( + "test-diagnostic-filter", "Test diagnostic filtering support."); +} +} // namespace test +} // namespace mlir diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -68,6 +68,7 @@ void registerTestGpuSerializeToHsacoPass(); void registerTestDataLayoutQuery(); void registerTestDecomposeCallGraphTypes(); +void registerTestDiagnosticsPass(); void registerTestDialect(DialectRegistry &); void registerTestDominancePass(); void registerTestDynamicPipelinePass(); @@ -140,6 +141,7 @@ test::registerTestAliasAnalysisPass(); test::registerTestCallGraphPass(); test::registerTestConstantFold(); + test::registerTestDiagnosticsPass(); #if MLIR_CUDA_CONVERSIONS_ENABLED test::registerTestGpuSerializeToCubinPass(); #endif