diff --git a/mlir/include/mlir/IR/Diagnostics.h b/mlir/include/mlir/IR/Diagnostics.h --- a/mlir/include/mlir/IR/Diagnostics.h +++ b/mlir/include/mlir/IR/Diagnostics.h @@ -549,10 +549,10 @@ /// Emit the given diagnostic information with the held source manager. void emitDiagnostic(Location loc, Twine message, DiagnosticSeverity kind); -protected: /// Emit the given diagnostic with the held source manager. void emitDiagnostic(Diagnostic &diag); +protected: /// Get a memory buffer for the given file, or nullptr if no file is /// available. const llvm::MemoryBuffer *getBufferForFile(StringRef filename); diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -380,8 +380,8 @@ } } -/// Return a processable CallSiteLocation from the given location. -static Optional getCallSiteLoc(Location loc) { +/// Return a processable CallSiteLoc from the given location. +static Optional getCallSiteLoc(Location loc) { switch (loc->getKind()) { case StandardAttributes::NameLocation: return getCallSiteLoc(loc.cast().getChildLoc()); @@ -463,10 +463,10 @@ Location callerLoc = callLoc->getCaller(); for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) { emitDiagnostic(callerLoc, "called from", DiagnosticSeverity::Note); - if (auto subLoc = callerLoc.dyn_cast()) - callerLoc = subLoc.getCaller(); - else + callLoc = getCallSiteLoc(callerLoc); + if (!callLoc) break; + callerLoc = callLoc->getCaller(); } } diff --git a/mlir/test/IR/diagnostic-handler.mlir b/mlir/test/IR/diagnostic-handler.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/diagnostic-handler.mlir @@ -0,0 +1,10 @@ +// RUN: mlir-opt %s -test-diagnostic-handler | FileCheck %s +// This test verifies that diagnostic handler can emit the call stack successfully. + +func @call_site_loc_in_fused() -> i32 { + //CHECK: mysource1: note: + //CHECK: mysource2: note: called from + //CHECK: mysource3: note: called from + %3 = constant 3 : i32 loc(fused[callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0)), "bar"]) + return %3 : i32 +} diff --git a/mlir/test/lib/Transforms/CMakeLists.txt b/mlir/test/lib/Transforms/CMakeLists.txt --- a/mlir/test/lib/Transforms/CMakeLists.txt +++ b/mlir/test/lib/Transforms/CMakeLists.txt @@ -1,6 +1,7 @@ add_llvm_library(MLIRTestTransforms TestCallGraph.cpp TestConstantFold.cpp + TestDiagnosticHandler.cpp TestLoopFusion.cpp TestInlining.cpp TestLinalgTransforms.cpp diff --git a/mlir/test/lib/Transforms/TestDiagnosticHandler.cpp b/mlir/test/lib/Transforms/TestDiagnosticHandler.cpp new file mode 100644 --- /dev/null +++ b/mlir/test/lib/Transforms/TestDiagnosticHandler.cpp @@ -0,0 +1,36 @@ +//===- TestDiagnosticHandler.cpp - Pass to test diagnostic handler --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/StandardOps/Ops.h" +#include "mlir/IR/Builders.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +/// Pass that emit diagnostic for each op. +struct TestDiagnosticHandler : public FunctionPass { + + void runOnFunction() override { + auto &os = llvm::outs(); + llvm::SourceMgr fileSourceMgr; + SourceMgrDiagnosticHandler diagHandler(fileSourceMgr, &getContext(), os); + + getFunction().walk([&](Operation *op) { + if (isa(op) || op->isKnownTerminator()) + return; + Diagnostic diag(op->getLoc(), DiagnosticSeverity::Note); + diagHandler.emitDiagnostic(diag); + }); + } +}; + +} // end anonymous namespace + +static PassRegistration pass("test-diagnostic-handler", + "emit diagnostic for ops");