diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -373,6 +373,22 @@ case StandardAttributes::CallSiteLocation: // Process the callee of a callsite location. return getFileLineColLoc(loc.cast().getCallee()); + case StandardAttributes::FusedLocation: + return getFileLineColLoc(loc.cast().getLocations().front()); + default: + return llvm::None; + } +} + +/// Return a processable CallSiteLoc from the given location. +static Optional getCallSiteLoc(Location loc) { + switch (loc->getKind()) { + case StandardAttributes::NameLocation: + return getCallSiteLoc(loc.cast().getChildLoc()); + case StandardAttributes::CallSiteLocation: + return loc.cast(); + case StandardAttributes::FusedLocation: + return getCallSiteLoc(loc.cast().getLocations().front()); default: return llvm::None; } @@ -442,15 +458,15 @@ // If the diagnostic location was a call site location, then print the call // stack as well. - if (auto callLoc = loc.dyn_cast()) { + if (auto callLoc = getCallSiteLoc(loc)) { // Print the call stack while valid, or until the limit is reached. - Location callerLoc = callLoc.getCaller(); + Location callerLoc = callLoc->getCaller(); for (unsigned curDepth = 0; curDepth < callStackLimit; ++curDepth) { emitDiagnostic(callerLoc, "called from", DiagnosticSeverity::Note); - if ((callLoc = callerLoc.dyn_cast())) - callerLoc = callLoc.getCaller(); - else + callLoc = getCallSiteLoc(callerLoc); + if (!callLoc) break; + callerLoc = callLoc->getCaller(); } }