diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h
--- a/llvm/include/llvm/ADT/TypeSwitch.h
+++ b/llvm/include/llvm/ADT/TypeSwitch.h
@@ -124,6 +124,12 @@
return std::move(*result);
return defaultFn(this->value);
}
+ /// As a default, return the given value.
+ LLVM_NODISCARD ResultT Default(ResultT defaultResult) {
+ if (result)
+ return std::move(*result);
+ return defaultResult;
+ }
LLVM_NODISCARD
operator ResultT() {
diff --git a/llvm/unittests/ADT/TypeSwitchTest.cpp b/llvm/unittests/ADT/TypeSwitchTest.cpp
--- a/llvm/unittests/ADT/TypeSwitchTest.cpp
+++ b/llvm/unittests/ADT/TypeSwitchTest.cpp
@@ -47,7 +47,7 @@
return TypeSwitch(&value)
.Case([](auto *) { return 0; })
.Case([](DerivedC *) { return 1; })
- .Default([](Base *) { return -1; });
+ .Default(-1);
};
EXPECT_EQ(0, translate(DerivedA()));
EXPECT_EQ(0, translate(DerivedB()));
diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h
--- a/mlir/include/mlir/IR/Location.h
+++ b/mlir/include/mlir/IR/Location.h
@@ -20,6 +20,8 @@
namespace mlir {
class Identifier;
+class Location;
+class WalkResult;
//===----------------------------------------------------------------------===//
// LocationAttr
@@ -31,6 +33,9 @@
public:
using Attribute::Attribute;
+ /// Walk all of the locations nested under, and including, the current.
+ WalkResult walk(function_ref walkFn);
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool classof(Attribute attr);
};
diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp
--- a/mlir/lib/IR/Diagnostics.cpp
+++ b/mlir/lib/IR/Diagnostics.cpp
@@ -366,21 +366,15 @@
/// Return a processable FileLineColLoc from the given location.
static Optional getFileLineColLoc(Location loc) {
- if (auto nameLoc = loc.dyn_cast())
- return getFileLineColLoc(loc.cast().getChildLoc());
- if (auto fileLoc = loc.dyn_cast())
- return fileLoc;
- if (auto callLoc = loc.dyn_cast())
- return getFileLineColLoc(loc.cast().getCallee());
- if (auto fusedLoc = loc.dyn_cast()) {
- for (auto subLoc : loc.cast().getLocations()) {
- if (auto callLoc = getFileLineColLoc(subLoc)) {
- return callLoc;
- }
+ Optional firstFileLoc;
+ loc->walk([&](Location loc) {
+ if (FileLineColLoc fileLoc = loc.dyn_cast()) {
+ firstFileLoc = fileLoc;
+ return WalkResult::interrupt();
}
- return llvm::None;
- }
- return llvm::None;
+ return WalkResult::advance();
+ });
+ return firstFileLoc;
}
/// Return a processable CallSiteLoc from the given location.
diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp
--- a/mlir/lib/IR/Location.cpp
+++ b/mlir/lib/IR/Location.cpp
@@ -9,7 +9,9 @@
#include "mlir/IR/Location.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/Identifier.h"
+#include "mlir/IR/Visitors.h"
#include "llvm/ADT/SetVector.h"
+#include "llvm/ADT/TypeSwitch.h"
using namespace mlir;
using namespace mlir::detail;
@@ -36,6 +38,31 @@
// LocationAttr
//===----------------------------------------------------------------------===//
+WalkResult LocationAttr::walk(function_ref walkFn) {
+ if (walkFn(*this).wasInterrupted())
+ return WalkResult::interrupt();
+
+ return TypeSwitch(*this)
+ .Case([&](CallSiteLoc callLoc) -> WalkResult {
+ if (callLoc.getCallee()->walk(walkFn).wasInterrupted())
+ return WalkResult::interrupt();
+ return callLoc.getCaller()->walk(walkFn);
+ })
+ .Case([&](FusedLoc fusedLoc) -> WalkResult {
+ for (Location subLoc : fusedLoc.getLocations())
+ if (subLoc->walk(walkFn).wasInterrupted())
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .Case([&](NameLoc nameLoc) -> WalkResult {
+ return nameLoc.getChildLoc()->walk(walkFn);
+ })
+ .Case([&](OpaqueLoc opaqueLoc) -> WalkResult {
+ return opaqueLoc.getFallbackLocation()->walk(walkFn);
+ })
+ .Default(WalkResult::advance());
+}
+
/// Methods for support type inquiry through isa, cast, and dyn_cast.
bool LocationAttr::classof(Attribute attr) {
return attr.isa