diff --git a/mlir/include/mlir/IR/Location.h b/mlir/include/mlir/IR/Location.h --- a/mlir/include/mlir/IR/Location.h +++ b/mlir/include/mlir/IR/Location.h @@ -65,15 +65,15 @@ /// Type casting utilities on the underlying location. template <typename U> bool isa() const { - return impl.isa<U>(); + return llvm::isa<U>(*this); } template <typename U> U dyn_cast() const { - return impl.dyn_cast<U>(); + return llvm::dyn_cast<U>(*this); } template <typename U> U cast() const { - return impl.cast<U>(); + return llvm::cast<U>(*this); } /// Comparison operators. @@ -170,6 +170,39 @@ PointerLikeTypeTraits<mlir::Attribute>::NumLowBitsAvailable; }; +/// The constructors in mlir::Location ensure that the class is a non-nullable +/// wrapper around mlir::LocationAttr. Override default behavior and always +/// return true for isPresent(). +template <> +struct ValueIsPresent<mlir::Location> { + using UnwrappedType = mlir::Location; + static inline bool isPresent(const mlir::Location &location) { return true; } +}; + +/// Add support for llvm style casts. We provide a cast between To and From if +/// From is mlir::Location or derives from it. +template <typename To, typename From> +struct CastInfo<To, From, + std::enable_if_t< + std::is_same_v<mlir::Location, std::remove_const_t<From>> || + std::is_base_of_v<mlir::Location, From>>> + : DefaultDoCastIfPossible<To, From, CastInfo<To, From>> { + + static inline bool isPossible(mlir::Location location) { + /// Return a constant true instead of a dynamic true when casting to self or + /// up the hierarchy. Additionally, all casting info is deferred to the + /// wrapped mlir::LocationAttr instance stored in mlir::Location. + return std::is_same_v<To, std::remove_const_t<From>> || + isa<To>(static_cast<mlir::LocationAttr>(location)); + } + + static inline To castFailed() { return To(); } + + static inline To doCast(mlir::Location location) { + return To(location->getImpl()); + } +}; + } // namespace llvm #endif diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -768,7 +768,7 @@ auto &attributeAliases = state.symbols.attributeAliasDefinitions; auto locID = TypeID::get<DeferredLocInfo *>(); auto resolveLocation = [&, this](auto &opOrArgument) -> LogicalResult { - auto fwdLoc = opOrArgument.getLoc().template dyn_cast<OpaqueLoc>(); + auto fwdLoc = dyn_cast<OpaqueLoc>(opOrArgument.getLoc()); if (!fwdLoc || fwdLoc.getUnderlyingTypeID() != locID) return success(); auto locInfo = deferredLocsReferences[fwdLoc.getUnderlyingLocation()]; @@ -776,7 +776,7 @@ if (!attr) return this->emitError(locInfo.loc) << "operation location alias was never defined"; - auto locAttr = attr.dyn_cast<LocationAttr>(); + auto locAttr = dyn_cast<LocationAttr>(attr); if (!locAttr) return this->emitError(locInfo.loc) << "expected location, but found '" << attr << "'"; @@ -1930,7 +1930,7 @@ // If this alias can be resolved, do it now. Attribute attr = state.symbols.attributeAliasDefinitions.lookup(identifier); if (attr) { - if (!(loc = attr.dyn_cast<LocationAttr>())) + if (!(loc = dyn_cast<LocationAttr>(attr))) return emitError(tok.getLoc()) << "expected location, but found '" << attr << "'"; } else { diff --git a/mlir/lib/IR/Diagnostics.cpp b/mlir/lib/IR/Diagnostics.cpp --- a/mlir/lib/IR/Diagnostics.cpp +++ b/mlir/lib/IR/Diagnostics.cpp @@ -404,12 +404,12 @@ /// Return a processable CallSiteLoc from the given location. static Optional<CallSiteLoc> getCallSiteLoc(Location loc) { - if (auto nameLoc = loc.dyn_cast<NameLoc>()) - return getCallSiteLoc(loc.cast<NameLoc>().getChildLoc()); - if (auto callLoc = loc.dyn_cast<CallSiteLoc>()) + if (auto nameLoc = dyn_cast<NameLoc>(loc)) + return getCallSiteLoc(cast<NameLoc>(loc).getChildLoc()); + if (auto callLoc = dyn_cast<CallSiteLoc>(loc)) return callLoc; - if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) { - for (auto subLoc : loc.cast<FusedLoc>().getLocations()) { + if (auto fusedLoc = dyn_cast<FusedLoc>(loc)) { + for (auto subLoc : cast<FusedLoc>(loc).getLocations()) { if (auto callLoc = getCallSiteLoc(subLoc)) { return callLoc; } diff --git a/mlir/lib/IR/Location.cpp b/mlir/lib/IR/Location.cpp --- a/mlir/lib/IR/Location.cpp +++ b/mlir/lib/IR/Location.cpp @@ -105,7 +105,7 @@ for (auto loc : locs) { // If the location is a fused location we decompose it if it has no // metadata or the metadata is the same as the top level metadata. - if (auto fusedLoc = loc.dyn_cast<FusedLoc>()) { + if (auto fusedLoc = llvm::dyn_cast<FusedLoc>(loc)) { if (fusedLoc.getMetadata() == metadata) { // UnknownLoc's have already been removed from FusedLocs so we can // simply add all of the internal locations.