diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -374,6 +374,32 @@ case StandardAttributes::CallSiteLocation: // Process the callee of a callsite location. return getFileLineColLoc(loc.cast().getCallee()); + case StandardAttributes::FusedLocation: + for (auto subLoc : loc.cast().getLocations()) { + if (auto callLoc = getFileLineColLoc(subLoc)) { + return callLoc; + } + } + return llvm::None; + default: + return llvm::None; + } +} + +/// Return a processable CallSiteLoc from the given location. +static Optional getCallSiteLoc(Location loc) { + switch (loc->getKind()) { + case StandardAttributes::NameLocation: + return getCallSiteLoc(loc.cast().getChildLoc()); + case StandardAttributes::CallSiteLocation: + return loc.cast(); + case StandardAttributes::FusedLocation: + for (auto subLoc : loc.cast().getLocations()) { + if (auto callLoc = getCallSiteLoc(subLoc)) { + return callLoc; + } + } + return llvm::None; default: return llvm::None; } @@ -443,15 +469,15 @@ // If the diagnostic location was a call site location, then print the call // stack as well. - if (auto callLoc = loc.dyn_cast()) { + if (auto callLoc = getCallSiteLoc(loc)) { // Print the call stack while valid, or until the limit is reached. - Location callerLoc = callLoc.getCaller(); + Location callerLoc = callLoc->getCaller(); for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) { emitDiagnostic(callerLoc, "called from", DiagnosticSeverity::Note); - if ((callLoc = callerLoc.dyn_cast())) - callerLoc = callLoc.getCaller(); - else + callLoc = getCallSiteLoc(callerLoc); + if (!callLoc) break; + callerLoc = callLoc->getCaller(); } } diff --git a/mlir/test/IR/diagnostic-handler.mlir b/mlir/test/IR/diagnostic-handler.mlir new file mode 100644 --- /dev/null +++ b/mlir/test/IR/diagnostic-handler.mlir @@ -0,0 +1,13 @@ +// RUN: not mlir-opt %s -o - 2>&1 | FileCheck %s +// This test verifies that diagnostic handler can emit the call stack successfully. + +// ----- + +// Emit the first available call stack in the fused location. +func @constant_out_of_range() { + // CHECK: mysource1: error: 'std.constant' op requires attribute's type ('i64') to match op's return type ('i1') + // CHECK-NEXT: mysource2: note: called from + // CHECK-NEXT: mysource3: note: called from + %x = "std.constant"() {value = 100} : () -> i1 loc(fused["bar", callsite("foo"("mysource1":0:0) at callsite("mysource2":1:0 at "mysource3":2:0))]) + return +}