diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -20,6 +20,8 @@ namespace mlir { class Identifier; +class Location; +class WalkResult; //===----------------------------------------------------------------------===// // LocationAttr @@ -31,6 +33,9 @@ public: using Attribute::Attribute; + /// Walk all of the locations nested under, and including, the current. + WalkResult walk(function_ref walkFn); + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Attribute attr); }; diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -366,21 +366,15 @@ /// Return a processable FileLineColLoc from the given location. static Optional getFileLineColLoc(Location loc) { - if (auto nameLoc = loc.dyn_cast()) - return getFileLineColLoc(loc.cast().getChildLoc()); - if (auto fileLoc = loc.dyn_cast()) - return fileLoc; - if (auto callLoc = loc.dyn_cast()) - return getFileLineColLoc(loc.cast().getCallee()); - if (auto fusedLoc = loc.dyn_cast()) { - for (auto subLoc : loc.cast().getLocations()) { - if (auto callLoc = getFileLineColLoc(subLoc)) { - return callLoc; - } + Optional firstFileLoc; + loc->walk([&](Location loc) { + if (FileLineColLoc fileLoc = loc.dyn_cast()) { + firstFileLoc = fileLoc; + return WalkResult::interrupt(); } - return llvm::None; - } - return llvm::None; + return WalkResult::advance(); + }); + return firstFileLoc; } /// Return a processable CallSiteLoc from the given location. diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -9,6 +9,7 @@ #include "mlir/IR/Location.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Identifier.h" +#include "mlir/IR/Visitors.h" #include "llvm/ADT/SetVector.h" using namespace mlir; @@ -36,6 +37,29 @@ // LocationAttr //===----------------------------------------------------------------------===// +WalkResult LocationAttr::walk(function_ref walkFn) { + if (walkFn(*this).wasInterrupted()) + return WalkResult::interrupt(); + + if (CallSiteLoc callLoc = dyn_cast()) { + if (callLoc.getCallee()->walk(walkFn).wasInterrupted()) + return WalkResult::interrupt(); + return callLoc.getCaller()->walk(walkFn); + } + if (FusedLoc fusedLoc = dyn_cast()) { + for (Location subLoc : fusedLoc.getLocations()) + if (subLoc->walk(walkFn).wasInterrupted()) + return WalkResult::interrupt(); + return WalkResult::advance(); + } + if (NameLoc nameLoc = dyn_cast()) + return nameLoc.getChildLoc()->walk(walkFn); + if (OpaqueLoc opaqueLoc = dyn_cast()) + return opaqueLoc.getFallbackLocation()->walk(walkFn); + + return WalkResult::advance(); +} + /// Methods for support type inquiry through isa, cast, and dyn_cast. bool LocationAttr::classof(Attribute attr) { return attr.isa