diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -549,10 +549,10 @@ /// Emit the given diagnostic information with the held source manager. void emitDiagnostic(Location loc, Twine message, DiagnosticSeverity kind); -protected: /// Emit the given diagnostic with the held source manager. void emitDiagnostic(Diagnostic &diag); +protected: /// Get a memory buffer for the given file, or nullptr if no file is /// available. const llvm::MemoryBuffer *getBufferForFile(StringRef filename); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -373,6 +373,22 @@ case StandardAttributes::CallSiteLocation: // Process the callee of a callsite location. return getFileLineColLoc(loc.cast().getCallee()); + case StandardAttributes::FusedLocation: + return getFileLineColLoc(loc.cast().getLocations().front()); + default: + return llvm::None; + } +} + +/// Return a processable CallSiteLoc from the given location. +static Optional getCallSiteLoc(Location loc) { + switch (loc->getKind()) { + case StandardAttributes::NameLocation: + return getCallSiteLoc(loc.cast().getChildLoc()); + case StandardAttributes::CallSiteLocation: + return loc.cast(); + case StandardAttributes::FusedLocation: + return getCallSiteLoc(loc.cast().getLocations().front()); default: return llvm::None; } @@ -442,15 +458,15 @@ // If the diagnostic location was a call site location, then print the call // stack as well. - if (auto callLoc = loc.dyn_cast()) { + if (auto callLoc = getCallSiteLoc(loc)) { // Print the call stack while valid, or until the limit is reached. - Location callerLoc = callLoc.getCaller(); + Location callerLoc = callLoc->getCaller(); for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) { emitDiagnostic(callerLoc, "called from", DiagnosticSeverity::Note); - if ((callLoc = callerLoc.dyn_cast())) - callerLoc = callLoc.getCaller(); - else + callLoc = getCallSiteLoc(callerLoc); + if (!callLoc) break; + callerLoc = callLoc->getCaller(); } } diff --git a/mlir/test/IR/diagnostic-handler.mlir b/mlir/test/IR/diagnostic-handler.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/diagnostic-handler.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -test-diagnostic-handler | FileCheck %s +// This test verifies that diagnostic handler can emit the call stack successfully. + +func @call_site_loc_in_fused() -> i32 { + //CHECK: mysource1: note: + //CHECK: mysource2: note: called from + //CHECK: mysource3: note: called from + %3 = constant 3 : i32 loc(fused[callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0)), "bar"]) + return %3 : i32 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_library(MLIRTestTransforms TestCallGraph.cpp TestConstantFold.cpp + TestDiagnosticHandler.cpp TestLoopFusion.cpp TestInlining.cpp TestLinalgTransforms.cpp diff --git a/mlir/test/lib/Transforms/TestDiagnosticHandler.cpp b/mlir/test/lib/Transforms/TestDiagnosticHandler.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDiagnosticHandler.cpp @@ -0,0 +1,35 @@ +//===- TestDiagnosticHandler.cpp - Pass to test diagnostic handler --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Pass that emit diagnostic for each op. +struct TestDiagnosticHandler : public FunctionPass { + + void runOnFunction() override { + auto &os = llvm::outs(); + llvm::SourceMgr fileSourceMgr; + SourceMgrDiagnosticHandler diagHandler(fileSourceMgr, &getContext(), os); + + getFunction().walk([&](Operation *op) { + if (isa(op) || op->isKnownTerminator()) + return; + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note); + diagHandler.emitDiagnostic(diag); + }); + } +}; + +} // end anonymous namespace + +static PassRegistration pass("test-diagnostic-handler", + "emit diagnostic for ops");