Index: mlir/include/mlir/IR/Types.h =================================================================== --- mlir/include/mlir/IR/Types.h +++ mlir/include/mlir/IR/Types.h @@ -94,11 +94,9 @@ bool operator!() const { return impl == nullptr; } - template - bool isa() const; - template + template bool isa() const; - template + template bool isa_and_nonnull() const; template U dyn_cast() const; @@ -185,6 +183,9 @@ /// Return the abstract type descriptor for this type. const AbstractTy &getAbstractType() { return impl->getAbstractType(); } + /// Return the Type implementation. + ImplType *getImpl() const { return impl; } + protected: ImplType *impl{nullptr}; }; @@ -250,34 +251,29 @@ return DenseMapInfo::getHashValue(arg.impl); } -template -bool Type::isa() const { - assert(impl && "isa<> used on a null type."); - return U::classof(*this); -} - -template +template bool Type::isa() const { - return isa() || isa(); + return llvm::isa(*this); } -template +template bool Type::isa_and_nonnull() const { - return impl && isa(); + return llvm::isa_and_present(*this); } template U Type::dyn_cast() const { - return isa() ? U(impl) : U(nullptr); + return llvm::dyn_cast(*this); } + template U Type::dyn_cast_or_null() const { - return (impl && isa()) ? U(impl) : U(nullptr); + return llvm::dyn_cast_or_null(*this); } + template U Type::cast() const { - assert(isa()); - return U(impl); + return llvm::cast(*this); } } // namespace mlir @@ -325,6 +321,32 @@ static constexpr int NumLowBitsAvailable = 3; }; +/// Add support for llvm style casts. +/// We provide a cast between To and From if From is mlir::Type or derives from +/// it +template +struct CastInfo> || + std::is_base_of_v>::type> + : NullableValueCastFailed, + DefaultDoCastIfPossible> { + /// Arguments are taken as mlir::Type here and not as From. + /// Because when casting from an intermediate type of the hierarchy to one of + /// its children, the val.getTypeID() inside T::classof will use the static + /// getTypeID of the parent instead of the non-static Type::getTypeID return + /// the dynamic ID. so T::classof would end up comparing the static TypeID of + /// The children to the static TypeID of its parent making it impossible to + /// downcast from the parent to the child + static inline bool isPossible(mlir::Type ty) { + /// Return a constant true instead of a dynamic true when casting to self or + /// up the hierarchy + return std::is_same_v> || + std::is_base_of_v || To::classof(ty); + } + static inline To doCast(mlir::Type ty) { return To(ty.getImpl()); } +}; + } // namespace llvm #endif // MLIR_IR_TYPES_H Index: mlir/unittests/IR/CMakeLists.txt =================================================================== --- mlir/unittests/IR/CMakeLists.txt +++ mlir/unittests/IR/CMakeLists.txt @@ -7,6 +7,7 @@ PatternMatchTest.cpp ShapedTypeTest.cpp SubElementInterfaceTest.cpp + TypeTest.cpp DEPENDS MLIRTestInterfaceIncGen Index: mlir/unittests/IR/TypeTest.cpp =================================================================== --- /dev/null +++ mlir/unittests/IR/TypeTest.cpp @@ -0,0 +1,67 @@ +//===- TypeTest.cpp - Type API unit tests ---------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Dialect.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Value.h" +#include "gtest/gtest.h" + +using namespace mlir; + +/// Mock implementations of a Type hierarchy +struct LeafType; + +struct MiddleType : Type::TypeBase { + using Base::Base; + static bool classof(Type ty) { + return ty.getTypeID() == TypeID::get() || Base::classof(ty); + } +}; + +struct LeafType : Type::TypeBase { + using Base::Base; +}; + +struct FakeDialect : Dialect { + FakeDialect(MLIRContext *context) + : Dialect(getDialectNamespace(), context, TypeID::get()) { + addTypes(); + } + static constexpr ::llvm::StringLiteral getDialectNamespace() { + return ::llvm::StringLiteral("fake"); + } +}; + +TEST(Type, Casting) { + MLIRContext ctx; + ctx.loadDialect(); + + Type intTy = IntegerType::get(&ctx, 8); + Type nullTy; + MiddleType middleTy = MiddleType::get(&ctx); + MiddleType leafTy = LeafType::get(&ctx); + Type leaf2Ty = LeafType::get(&ctx); + + EXPECT_TRUE(isa(intTy)); + EXPECT_FALSE(isa(intTy)); + EXPECT_FALSE(llvm::isa_and_present(nullTy)); + EXPECT_TRUE(isa(middleTy)); + EXPECT_FALSE(isa(middleTy)); + EXPECT_TRUE(isa(leafTy)); + EXPECT_TRUE(isa(leaf2Ty)); + EXPECT_TRUE(isa(leafTy)); + + EXPECT_TRUE(static_cast(dyn_cast(intTy))); + EXPECT_FALSE(static_cast(dyn_cast(intTy))); + EXPECT_FALSE(static_cast(llvm::cast_if_present(nullTy))); + EXPECT_FALSE( + static_cast(llvm::dyn_cast_if_present(nullTy))); + + EXPECT_EQ(8u, cast(intTy).getWidth()); +}