Index: include/llvm/Support/LowLevelTypeImpl.h =================================================================== --- include/llvm/Support/LowLevelTypeImpl.h +++ include/llvm/Support/LowLevelTypeImpl.h @@ -70,6 +70,14 @@ ScalarTy.isPointer() ? ScalarTy.getAddressSpace() : 0}; } + static LLT scalarOrVector(uint16_t NumElements, LLT ScalarTy) { + return NumElements == 1 ? ScalarTy : LLT::vector(NumElements, ScalarTy); + } + + static LLT scalarOrVector(uint16_t NumElements, unsigned ScalarSize) { + return scalarOrVector(NumElements, LLT::scalar(ScalarSize)); + } + explicit LLT(bool isPointer, bool isVector, uint16_t NumElements, unsigned SizeInBits, unsigned AddressSpace) { init(isPointer, isVector, NumElements, SizeInBits, AddressSpace); Index: unittests/CodeGen/LowLevelTypeTest.cpp =================================================================== --- unittests/CodeGen/LowLevelTypeTest.cpp +++ unittests/CodeGen/LowLevelTypeTest.cpp @@ -103,6 +103,21 @@ } } +TEST(LowLevelTypeTest, ScalarOrVector) { + // Test version with number of bits for scalar type. + EXPECT_EQ(LLT::scalar(32), LLT::scalarOrVector(1, 32)); + EXPECT_EQ(LLT::vector(2, 32), LLT::scalarOrVector(2, 32)); + + // Test version with LLT for scalar type. + EXPECT_EQ(LLT::scalar(32), LLT::scalarOrVector(1, LLT::scalar(32))); + EXPECT_EQ(LLT::vector(2, 32), LLT::scalarOrVector(2, LLT::scalar(32))); + + // Test with pointer elements. + EXPECT_EQ(LLT::pointer(1, 32), LLT::scalarOrVector(1, LLT::pointer(1, 32))); + EXPECT_EQ(LLT::vector(2, LLT::pointer(1, 32)), + LLT::scalarOrVector(2, LLT::pointer(1, 32))); +} + TEST(LowLevelTypeTest, Pointer) { LLVMContext C; DataLayout DL("");