diff --git a/llvm/include/llvm/IR/Constants.h b/include/llvm/IR/Constants.h --- a/llvm/include/llvm/IR/Constants.h +++ b/include/llvm/IR/Constants.h @@ -575,10 +575,10 @@ //===----------------------------------------------------------------------===// /// ConstantDataSequential - A vector or array constant whose element type is a -/// simple 1/2/4/8-byte integer or float/double, and whose elements are just -/// simple data values (i.e. ConstantInt/ConstantFP). This Constant node has no -/// operands because it stores all of the elements of the constant as densely -/// packed data, instead of as Value*'s. +/// simple 1/2/4/8-byte integer or half/bloat/float/double, and whose elements +/// are just simple data values (i.e. ConstantInt/ConstantFP). This Constant +/// node has no operands because it stores all of the elements of the constant +/// as densely packed data, instead of as Value*'s. /// /// This is the common base class of ConstantDataArray and ConstantDataVector. /// @@ -717,11 +717,11 @@ return ConstantDataArray::get(Context, makeArrayRef(Elts)); } - /// get() constructor - Return a constant with array type with an element + /// getRaw() constructor - Return a constant with array type with an element /// count and element type matching the NumElements and ElementTy parameters /// passed in. Note that this can return a ConstantAggregateZero object. - /// ElementTy needs to be one of i8/i16/i32/i64/float/double. Data is the - /// buffer containing the elements. Be careful to make sure Data uses the + /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is + /// the buffer containing the elements. Be careful to make sure Data uses the /// right endianness, the buffer will be used as-is. static Constant *getRaw(StringRef Data, uint64_t NumElements, Type *ElementTy) { Type *Ty = ArrayType::get(ElementTy, NumElements); @@ -788,6 +788,17 @@ static Constant *get(LLVMContext &Context, ArrayRef Elts); static Constant *get(LLVMContext &Context, ArrayRef Elts); + /// getRaw() constructor - Return a constant with vector type with an element + /// count and element type matching the NumElements and ElementTy parameters + /// passed in. Note that this can return a ConstantAggregateZero object. + /// ElementTy must be one of i8/i16/i32/i64/half/bfloat/float/double. Data is + /// the buffer containing the elements. Be careful to make sure Data uses the + /// right endianness, the buffer will be used as-is. + static Constant *getRaw(StringRef Data, uint64_t NumElements, Type *ElementTy) { + Type *Ty = VectorType::get(ElementTy, ElementCount::getFixed(NumElements)); + return getImpl(Data, Ty); + } + /// getFP() constructors - Return a constant of vector type with a float /// element type taken from argument `ElementType', and count taken from /// argument `Elts'. The amount of bits of the contained type must match the @@ -800,7 +811,7 @@ /// Return a ConstantVector with the specified constant in each element. /// The specified constant has to be a of a compatible type (i8/i16/ - /// i32/i64/float/double) and must be a ConstantFP or ConstantInt. + /// i32/i64/half/bfloat/float/double) and must be a ConstantFP or ConstantInt. static Constant *getSplat(unsigned NumElts, Constant *Elt); /// Returns true if this is a splat constant, meaning that all elements have diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/unittests/IR/ConstantsTest.cpp --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/unittests/IR/ConstantsTest.cpp @@ -418,45 +418,55 @@ TEST(ConstantsTest, BuildConstantDataArrays) { LLVMContext Context; - std::unique_ptr M(new Module("MyModule", Context)); for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context), Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) { ArrayType *ArrayTy = ArrayType::get(T, 2); Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)}; - Constant *CDV = ConstantArray::get(ArrayTy, Vals); - ASSERT_TRUE(dyn_cast(CDV) != nullptr) - << " T = " << getNameOfType(T); + Constant *CA = ConstantArray::get(ArrayTy, Vals); + ASSERT_TRUE(isa(CA)) << " T = " << getNameOfType(T); + auto CDA = cast(CA); + Constant *CA2 = ConstantDataArray::getRaw( + CDA->getRawDataValues(), CDA->getNumElements(), CDA->getElementType()); + ASSERT_TRUE(CA == CA2) << " T = " << getNameOfType(T); } - for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context), - Type::getDoubleTy(Context)}) { + for (Type *T : {Type::getHalfTy(Context), Type::getBFloatTy(Context), + Type::getFloatTy(Context), Type::getDoubleTy(Context)}) { ArrayType *ArrayTy = ArrayType::get(T, 2); Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)}; - Constant *CDV = ConstantArray::get(ArrayTy, Vals); - ASSERT_TRUE(dyn_cast(CDV) != nullptr) - << " T = " << getNameOfType(T); + Constant *CA = ConstantArray::get(ArrayTy, Vals); + ASSERT_TRUE(isa(CA)) << " T = " << getNameOfType(T); + auto CDA = cast(CA); + Constant *CA2 = ConstantDataArray::getRaw( + CDA->getRawDataValues(), CDA->getNumElements(), CDA->getElementType()); + ASSERT_TRUE(CA == CA2) << " T = " << getNameOfType(T); } } TEST(ConstantsTest, BuildConstantDataVectors) { LLVMContext Context; - std::unique_ptr M(new Module("MyModule", Context)); for (Type *T : {Type::getInt8Ty(Context), Type::getInt16Ty(Context), Type::getInt32Ty(Context), Type::getInt64Ty(Context)}) { Constant *Vals[] = {ConstantInt::get(T, 0), ConstantInt::get(T, 1)}; - Constant *CDV = ConstantVector::get(Vals); - ASSERT_TRUE(dyn_cast(CDV) != nullptr) - << " T = " << getNameOfType(T); + Constant *CV = ConstantVector::get(Vals); + ASSERT_TRUE(isa(CV)) << " T = " << getNameOfType(T); + auto CDV = cast(CV); + Constant *CV2 = ConstantDataVector::getRaw( + CDV->getRawDataValues(), CDV->getNumElements(), CDV->getElementType()); + ASSERT_TRUE(CV == CV2) << " T = " << getNameOfType(T); } - for (Type *T : {Type::getHalfTy(Context), Type::getFloatTy(Context), - Type::getDoubleTy(Context)}) { + for (Type *T : {Type::getHalfTy(Context), Type::getBFloatTy(Context), + Type::getFloatTy(Context), Type::getDoubleTy(Context)}) { Constant *Vals[] = {ConstantFP::get(T, 0), ConstantFP::get(T, 1)}; - Constant *CDV = ConstantVector::get(Vals); - ASSERT_TRUE(dyn_cast(CDV) != nullptr) - << " T = " << getNameOfType(T); + Constant *CV = ConstantVector::get(Vals); + ASSERT_TRUE(isa(CV)) << " T = " << getNameOfType(T); + auto CDV = cast(CV); + Constant *CV2 = ConstantDataVector::getRaw( + CDV->getRawDataValues(), CDV->getNumElements(), CDV->getElementType()); + ASSERT_TRUE(CV == CV2) << " T = " << getNameOfType(T); } }