diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -65,15 +65,15 @@ /// Type casting utilities on the underlying location. template bool isa() const { - return impl.isa(); + return llvm::isa(*this); } template U dyn_cast() const { - return impl.dyn_cast(); + return llvm::dyn_cast(*this); } template U cast() const { - return impl.cast(); + return llvm::cast(*this); } /// Comparison operators. @@ -170,6 +170,39 @@ PointerLikeTypeTraits::NumLowBitsAvailable; }; +/// The constructors in mlir::Location ensure that the class is a non-nullable +/// wrapper around mlir::LocationAttr. Override default behavior and always +/// return true for isPresent(). +template <> +struct ValueIsPresent { + using UnwrappedType = mlir::Location; + static inline bool isPresent(const mlir::Location &location) { return true; } +}; + +/// Add support for llvm style casts. We provide a cast between To and From if +/// From is mlir::Location or derives from it. +template +struct CastInfo> || + std::is_base_of_v>> + : DefaultDoCastIfPossible> { + + static inline bool isPossible(mlir::Location location) { + /// Return a constant true instead of a dynamic true when casting to self or + /// up the hierarchy. Additionally, all casting info is deferred to the + /// wrapped mlir::LocationAttr instance stored in mlir::Location. + return std::is_same_v> || + isa(static_cast(location)); + } + + static inline To castFailed() { return To(); } + + static inline To doCast(mlir::Location location) { + return To(location->getImpl()); + } +}; + } // namespace llvm #endif diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -768,7 +768,7 @@ auto &attributeAliases = state.symbols.attributeAliasDefinitions; auto locID = TypeID::get(); auto resolveLocation = [&, this](auto &opOrArgument) -> LogicalResult { - auto fwdLoc = opOrArgument.getLoc().template dyn_cast(); + auto fwdLoc = dyn_cast(opOrArgument.getLoc()); if (!fwdLoc || fwdLoc.getUnderlyingTypeID() != locID) return success(); auto locInfo = deferredLocsReferences[fwdLoc.getUnderlyingLocation()]; @@ -776,7 +776,7 @@ if (!attr) return this->emitError(locInfo.loc) << "operation location alias was never defined"; - auto locAttr = attr.dyn_cast(); + auto locAttr = dyn_cast(attr); if (!locAttr) return this->emitError(locInfo.loc) << "expected location, but found '" << attr << "'"; @@ -1930,7 +1930,7 @@ // If this alias can be resolved, do it now. Attribute attr = state.symbols.attributeAliasDefinitions.lookup(identifier); if (attr) { - if (!(loc = attr.dyn_cast())) + if (!(loc = dyn_cast(attr))) return emitError(tok.getLoc()) << "expected location, but found '" << attr << "'"; } else { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -404,12 +404,12 @@ /// Return a processable CallSiteLoc from the given location. static Optional getCallSiteLoc(Location loc) { - if (auto nameLoc = loc.dyn_cast()) - return getCallSiteLoc(loc.cast().getChildLoc()); - if (auto callLoc = loc.dyn_cast()) + if (auto nameLoc = dyn_cast(loc)) + return getCallSiteLoc(cast(loc).getChildLoc()); + if (auto callLoc = dyn_cast(loc)) return callLoc; - if (auto fusedLoc = loc.dyn_cast()) { - for (auto subLoc : loc.cast().getLocations()) { + if (auto fusedLoc = dyn_cast(loc)) { + for (auto subLoc : cast(loc).getLocations()) { if (auto callLoc = getCallSiteLoc(subLoc)) { return callLoc; } diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -105,7 +105,7 @@ for (auto loc : locs) { // If the location is a fused location we decompose it if it has no // metadata or the metadata is the same as the top level metadata. - if (auto fusedLoc = loc.dyn_cast()) { + if (auto fusedLoc = llvm::dyn_cast(loc)) { if (fusedLoc.getMetadata() == metadata) { // UnknownLoc's have already been removed from FusedLocs so we can // simply add all of the internal locations.