diff --git a/llvm/include/llvm/ADT/TypeSwitch.h b/llvm/include/llvm/ADT/TypeSwitch.h --- a/llvm/include/llvm/ADT/TypeSwitch.h +++ b/llvm/include/llvm/ADT/TypeSwitch.h @@ -124,6 +124,12 @@ return std::move(*result); return defaultFn(this->value); } + /// As a default, return the given value. + LLVM_NODISCARD ResultT Default(ResultT defaultResult) { + if (result) + return std::move(*result); + return defaultResult; + } LLVM_NODISCARD operator ResultT() { diff --git a/llvm/unittests/ADT/TypeSwitchTest.cpp b/llvm/unittests/ADT/TypeSwitchTest.cpp --- a/llvm/unittests/ADT/TypeSwitchTest.cpp +++ b/llvm/unittests/ADT/TypeSwitchTest.cpp @@ -47,7 +47,7 @@ return TypeSwitch(&value) .Case([](auto *) { return 0; }) .Case([](DerivedC *) { return 1; }) - .Default([](Base *) { return -1; }); + .Default(-1); }; EXPECT_EQ(0, translate(DerivedA())); EXPECT_EQ(0, translate(DerivedB())); diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -20,6 +20,8 @@ namespace mlir { class Identifier; +class Location; +class WalkResult; //===----------------------------------------------------------------------===// // LocationAttr @@ -31,6 +33,9 @@ public: using Attribute::Attribute; + /// Walk all of the locations nested under, and including, the current. + WalkResult walk(function_ref walkFn); + /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool classof(Attribute attr); }; diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -366,21 +366,15 @@ /// Return a processable FileLineColLoc from the given location. static Optional getFileLineColLoc(Location loc) { - if (auto nameLoc = loc.dyn_cast()) - return getFileLineColLoc(loc.cast().getChildLoc()); - if (auto fileLoc = loc.dyn_cast()) - return fileLoc; - if (auto callLoc = loc.dyn_cast()) - return getFileLineColLoc(loc.cast().getCallee()); - if (auto fusedLoc = loc.dyn_cast()) { - for (auto subLoc : loc.cast().getLocations()) { - if (auto callLoc = getFileLineColLoc(subLoc)) { - return callLoc; - } + Optional firstFileLoc; + loc->walk([&](Location loc) { + if (FileLineColLoc fileLoc = loc.dyn_cast()) { + firstFileLoc = fileLoc; + return WalkResult::interrupt(); } - return llvm::None; - } - return llvm::None; + return WalkResult::advance(); + }); + return firstFileLoc; } /// Return a processable CallSiteLoc from the given location. diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -9,7 +9,9 @@ #include "mlir/IR/Location.h" #include "mlir/IR/BuiltinDialect.h" #include "mlir/IR/Identifier.h" +#include "mlir/IR/Visitors.h" #include "llvm/ADT/SetVector.h" +#include "llvm/ADT/TypeSwitch.h" using namespace mlir; using namespace mlir::detail; @@ -36,6 +38,31 @@ // LocationAttr //===----------------------------------------------------------------------===// +WalkResult LocationAttr::walk(function_ref walkFn) { + if (walkFn(*this).wasInterrupted()) + return WalkResult::interrupt(); + + return TypeSwitch(*this) + .Case([&](CallSiteLoc callLoc) -> WalkResult { + if (callLoc.getCallee()->walk(walkFn).wasInterrupted()) + return WalkResult::interrupt(); + return callLoc.getCaller()->walk(walkFn); + }) + .Case([&](FusedLoc fusedLoc) -> WalkResult { + for (Location subLoc : fusedLoc.getLocations()) + if (subLoc->walk(walkFn).wasInterrupted()) + return WalkResult::interrupt(); + return WalkResult::advance(); + }) + .Case([&](NameLoc nameLoc) -> WalkResult { + return nameLoc.getChildLoc()->walk(walkFn); + }) + .Case([&](OpaqueLoc opaqueLoc) -> WalkResult { + return opaqueLoc.getFallbackLocation()->walk(walkFn); + }) + .Default(WalkResult::advance()); +} + /// Methods for support type inquiry through isa, cast, and dyn_cast. bool LocationAttr::classof(Attribute attr) { return attr.isa