diff --git a/llvm/include/llvm/IR/DerivedTypes.h b/llvm/include/llvm/IR/DerivedTypes.h --- a/llvm/include/llvm/IR/DerivedTypes.h +++ b/llvm/include/llvm/IR/DerivedTypes.h @@ -562,6 +562,17 @@ public: static FixedVectorType *get(Type *ElementType, unsigned NumElts); + static FixedVectorType *get(Type *ElementType, ElementCount EC) { + if (EC.Scalable) + return nullptr; + + return get(ElementType, EC.Min); + } + + static FixedVectorType *get(Type *ElementType, const VectorType *VTy) { + return get(ElementType, VTy->getElementCount()); + } + static FixedVectorType *get(Type *ElementType, const FixedVectorType *FVTy) { return get(ElementType, FVTy->getNumElements()); } @@ -607,6 +618,17 @@ public: static ScalableVectorType *get(Type *ElementType, unsigned MinNumElts); + static ScalableVectorType *get(Type *ElementType, ElementCount EC) { + if (!EC.Scalable) + return nullptr; + + return get(ElementType, EC.Min); + } + + static ScalableVectorType *get(Type *ElementType, const VectorType *VTy) { + return get(ElementType, VTy->getElementCount()); + } + static ScalableVectorType *get(Type *ElementType, const ScalableVectorType *SVTy) { return get(ElementType, SVTy->getMinNumElements()); diff --git a/llvm/unittests/IR/VectorTypesTest.cpp b/llvm/unittests/IR/VectorTypesTest.cpp --- a/llvm/unittests/IR/VectorTypesTest.cpp +++ b/llvm/unittests/IR/VectorTypesTest.cpp @@ -377,4 +377,45 @@ // non-scalable vector sizes. } +TEST(VectorTypesTest, DerivedGettersHidingBaseGetters) { + // The derived vector types implement get(Type *, ElementCount) and + // get(Type*, VectorType*) that hide the base class versions and return + // instances of the derived vector type or null. + + LLVMContext Ctx; + + Type *Ty = Type::getInt32Ty(Ctx); + + ElementCount FV4 = {4, false}; + ElementCount SV4 = {4, true}; + + auto *BFV4Ty = VectorType::get(Ty, FV4); + auto *BSV4Ty = VectorType::get(Ty, SV4); + + EXPECT_NE(nullptr, BFV4Ty); + EXPECT_NE(nullptr, BSV4Ty); + + // Test ElementCount getters + auto *FV4Ty = FixedVectorType::get(Ty, FV4); + auto *FV4TyNaught = FixedVectorType::get(Ty, SV4); + auto *SV4Ty = ScalableVectorType::get(Ty, SV4); + auto *SV4TyNaught = ScalableVectorType::get(Ty, FV4); + + EXPECT_EQ(BFV4Ty, FV4Ty); + EXPECT_EQ(nullptr, FV4TyNaught); + EXPECT_EQ(BSV4Ty, SV4Ty); + EXPECT_EQ(nullptr, SV4TyNaught); + + // Test VectorType getters + FV4Ty = FixedVectorType::get(Ty, BFV4Ty); + FV4TyNaught = FixedVectorType::get(Ty, BSV4Ty); + SV4Ty = ScalableVectorType::get(Ty, BSV4Ty); + SV4TyNaught = ScalableVectorType::get(Ty, BFV4Ty); + + EXPECT_EQ(BFV4Ty, FV4Ty); + EXPECT_EQ(nullptr, FV4TyNaught); + EXPECT_EQ(BSV4Ty, SV4Ty); + EXPECT_EQ(nullptr, SV4TyNaught); +} + } // end anonymous namespace