diff --git a/mlir/include/mlir/IR/Value.h b/mlir/include/mlir/IR/Value.h --- a/mlir/include/mlir/IR/Value.h +++ b/mlir/include/mlir/IR/Value.h @@ -88,26 +88,22 @@ template bool isa() const { - assert(*this && "isa<> used on a null type."); - return U::classof(*this); + return llvm::isa(*this); } - template - bool isa() const { - return isa() || isa(); - } template U dyn_cast() const { - return isa() ? U(impl) : U(nullptr); + return llvm::dyn_cast(*this); } + template U dyn_cast_or_null() const { - return (*this && isa()) ? U(impl) : U(nullptr); + return llvm::dyn_cast_if_present(*this); } + template U cast() const { - assert(isa()); - return U(impl); + return llvm::cast(*this); } explicit operator bool() const { return impl; } @@ -560,6 +556,31 @@ } }; +/// Add support for llvm style casts. We provide a cast between To and From if +/// From is mlir::Value or derives from it. +template +struct CastInfo< + To, From, + std::enable_if_t> || + std::is_base_of_v>> + : NullableValueCastFailed, + DefaultDoCastIfPossible> { + /// Arguments are taken as mlir::Value here and not as `From`, because + /// when casting from an intermediate type of the hierarchy to one of its + /// children, the val.getKind() inside T::classof will use the static + /// getKind() of the parent instead of the non-static ValueImpl::getKind() + /// that returns the dynamic type. This means that T::classof would end up + /// comparing the static Kind of the children to the static Kind of its + /// parent, making it impossible to downcast from the parent to the child. + static inline bool isPossible(mlir::Value ty) { + /// Return a constant true instead of a dynamic true when casting to self or + /// up the hierarchy. + return std::is_same_v> || + std::is_base_of_v || To::classof(ty); + } + static inline To doCast(mlir::Value value) { return To(value.getImpl()); } +}; + } // namespace llvm #endif diff --git a/mlir/lib/AsmParser/AsmParserState.cpp b/mlir/lib/AsmParser/AsmParserState.cpp --- a/mlir/lib/AsmParser/AsmParserState.cpp +++ b/mlir/lib/AsmParser/AsmParserState.cpp @@ -273,7 +273,7 @@ void AsmParserState::addUses(Value value, ArrayRef locations) { // Handle the case where the value is an operation result. - if (OpResult result = value.dyn_cast()) { + if (OpResult result = dyn_cast(value)) { // Check to see if a definition for the parent operation has been recorded. // If one hasn't, we treat the provided value as a placeholder value that // will be refined further later. diff --git a/mlir/lib/AsmParser/Parser.cpp b/mlir/lib/AsmParser/Parser.cpp --- a/mlir/lib/AsmParser/Parser.cpp +++ b/mlir/lib/AsmParser/Parser.cpp @@ -2255,7 +2255,7 @@ // If the value isn't a forward reference, we also add the name of the op // to the detail. - if (auto result = frontValue.dyn_cast()) { + if (auto result = dyn_cast(frontValue)) { if (!forwardRefPlaceholders.count(result)) detailOS << result.getOwner()->getName() << ": "; } else { diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -1009,7 +1009,7 @@ // If this is an operation result, collect the head lookup value of the result // group and the result number of 'result' within that group. - if (OpResult result = value.dyn_cast()) + if (OpResult result = dyn_cast(value)) getResultIDAndNumber(result, lookupValue, resultNo); auto it = valueIDs.find(lookupValue); diff --git a/mlir/lib/IR/Dominance.cpp b/mlir/lib/IR/Dominance.cpp --- a/mlir/lib/IR/Dominance.cpp +++ b/mlir/lib/IR/Dominance.cpp @@ -297,7 +297,7 @@ bool DominanceInfo::properlyDominates(Value a, Operation *b) const { // block arguments properly dominate all operations in their own block, so // we use a dominates check here, not a properlyDominates check. - if (auto blockArg = a.dyn_cast()) + if (auto blockArg = dyn_cast(a)) return dominates(blockArg.getOwner(), b->getBlock()); // `a` properlyDominates `b` if the operation defining `a` properlyDominates