Index: llvm/trunk/docs/ProgrammersManual.rst =================================================================== --- llvm/trunk/docs/ProgrammersManual.rst +++ llvm/trunk/docs/ProgrammersManual.rst @@ -3283,13 +3283,13 @@ * ``const Type * getElementType() const``: Returns the type of each of the elements in the sequential type. + * ``uint64_t getNumElements() const``: Returns the number of elements + in the sequential type. + ``ArrayType`` This is a subclass of SequentialType and defines the interface for array types. - * ``unsigned getNumElements() const``: Returns the number of elements - in the array. - ``PointerType`` Subclass of Type for pointer types. Index: llvm/trunk/include/llvm/IR/DerivedTypes.h =================================================================== --- llvm/trunk/include/llvm/IR/DerivedTypes.h +++ llvm/trunk/include/llvm/IR/DerivedTypes.h @@ -313,18 +313,21 @@ /// identically. class SequentialType : public CompositeType { Type *ContainedType; ///< Storage for the single contained type. + uint64_t NumElements; SequentialType(const SequentialType &) = delete; const SequentialType &operator=(const SequentialType &) = delete; protected: - SequentialType(TypeID TID, Type *ElType) - : CompositeType(ElType->getContext(), TID), ContainedType(ElType) { + SequentialType(TypeID TID, Type *ElType, uint64_t NumElements) + : CompositeType(ElType->getContext(), TID), ContainedType(ElType), + NumElements(NumElements) { ContainedTys = &ContainedType; NumContainedTys = 1; } public: - Type *getElementType() const { return getSequentialElementType(); } + uint64_t getNumElements() const { return NumElements; } + Type *getElementType() const { return ContainedType; } /// Methods for support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const Type *T) { @@ -334,8 +337,6 @@ /// Class to represent array types. class ArrayType : public SequentialType { - uint64_t NumElements; - ArrayType(const ArrayType &) = delete; const ArrayType &operator=(const ArrayType &) = delete; ArrayType(Type *ElType, uint64_t NumEl); @@ -347,8 +348,6 @@ /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); - uint64_t getNumElements() const { return NumElements; } - /// Methods for support type inquiry through isa, cast, and dyn_cast. static inline bool classof(const Type *T) { return T->getTypeID() == ArrayTyID; @@ -361,8 +360,6 @@ /// Class to represent vector types. class VectorType : public SequentialType { - unsigned NumElements; - VectorType(const VectorType &) = delete; const VectorType &operator=(const VectorType &) = delete; VectorType(Type *ElType, unsigned NumEl); @@ -418,13 +415,10 @@ /// Return true if the specified type is valid as a element type. static bool isValidElementType(Type *ElemTy); - /// Return the number of elements in the Vector type. - unsigned getNumElements() const { return NumElements; } - /// Return the number of bits in the Vector type. /// Returns zero when the vector is a vector of pointers. unsigned getBitWidth() const { - return NumElements * getElementType()->getPrimitiveSizeInBits(); + return getNumElements() * getElementType()->getPrimitiveSizeInBits(); } /// Methods for support type inquiry through isa, cast, and dyn_cast. Index: llvm/trunk/include/llvm/IR/GetElementPtrTypeIterator.h =================================================================== --- llvm/trunk/include/llvm/IR/GetElementPtrTypeIterator.h +++ llvm/trunk/include/llvm/IR/GetElementPtrTypeIterator.h @@ -74,12 +74,9 @@ generic_gep_type_iterator& operator++() { // Preincrement Type *Ty = getIndexedType(); - if (auto *ATy = dyn_cast(Ty)) { - CurTy = ATy->getElementType(); - NumElements = ATy->getNumElements(); - } else if (auto *VTy = dyn_cast(Ty)) { - CurTy = VTy->getElementType(); - NumElements = VTy->getNumElements(); + if (auto *STy = dyn_cast(Ty)) { + CurTy = STy->getElementType(); + NumElements = STy->getNumElements(); } else CurTy = dyn_cast(Ty); ++OpIt; Index: llvm/trunk/lib/IR/ConstantFold.cpp =================================================================== --- llvm/trunk/lib/IR/ConstantFold.cpp +++ llvm/trunk/lib/IR/ConstantFold.cpp @@ -891,10 +891,8 @@ unsigned NumElts; if (StructType *ST = dyn_cast(Agg->getType())) NumElts = ST->getNumElements(); - else if (ArrayType *AT = dyn_cast(Agg->getType())) - NumElts = AT->getNumElements(); else - NumElts = Agg->getType()->getVectorNumElements(); + NumElts = cast(Agg->getType())->getNumElements(); SmallVector Result; for (unsigned i = 0; i != NumElts; ++i) { @@ -2210,10 +2208,7 @@ Unknown = true; continue; } - if (isIndexInRangeOfArrayType(isa(STy) - ? cast(STy)->getNumElements() - : cast(STy)->getNumElements(), - CI)) + if (isIndexInRangeOfArrayType(STy->getNumElements(), CI)) // It's in range, skip to the next index. continue; if (isa(Prev)) { Index: llvm/trunk/lib/IR/Constants.cpp =================================================================== --- llvm/trunk/lib/IR/Constants.cpp +++ llvm/trunk/lib/IR/Constants.cpp @@ -794,10 +794,8 @@ unsigned UndefValue::getNumElements() const { Type *Ty = getType(); - if (auto *AT = dyn_cast(Ty)) - return AT->getNumElements(); - if (auto *VT = dyn_cast(Ty)) - return VT->getNumElements(); + if (auto *ST = dyn_cast(Ty)) + return ST->getNumElements(); return Ty->getStructNumElements(); } Index: llvm/trunk/lib/IR/Type.cpp =================================================================== --- llvm/trunk/lib/IR/Type.cpp +++ llvm/trunk/lib/IR/Type.cpp @@ -601,9 +601,7 @@ //===----------------------------------------------------------------------===// ArrayType::ArrayType(Type *ElType, uint64_t NumEl) - : SequentialType(ArrayTyID, ElType) { - NumElements = NumEl; -} + : SequentialType(ArrayTyID, ElType, NumEl) {} ArrayType *ArrayType::get(Type *ElementType, uint64_t NumElements) { assert(isValidElementType(ElementType) && "Invalid type for array element!"); @@ -628,9 +626,7 @@ //===----------------------------------------------------------------------===// VectorType::VectorType(Type *ElType, unsigned NumEl) - : SequentialType(VectorTyID, ElType) { - NumElements = NumEl; -} + : SequentialType(VectorTyID, ElType, NumEl) {} VectorType *VectorType::get(Type *ElementType, unsigned NumElements) { assert(NumElements > 0 && "#Elements of a VectorType must be greater than 0"); Index: llvm/trunk/lib/Linker/IRMover.cpp =================================================================== --- llvm/trunk/lib/Linker/IRMover.cpp +++ llvm/trunk/lib/Linker/IRMover.cpp @@ -169,11 +169,9 @@ if (DSTy->isLiteral() != SSTy->isLiteral() || DSTy->isPacked() != SSTy->isPacked()) return false; - } else if (ArrayType *DATy = dyn_cast(DstTy)) { - if (DATy->getNumElements() != cast(SrcTy)->getNumElements()) - return false; - } else if (VectorType *DVTy = dyn_cast(DstTy)) { - if (DVTy->getNumElements() != cast(SrcTy)->getNumElements()) + } else if (auto *DSeqTy = dyn_cast(DstTy)) { + if (DSeqTy->getNumElements() != + cast(SrcTy)->getNumElements()) return false; } Index: llvm/trunk/lib/Transforms/IPO/GlobalOpt.cpp =================================================================== --- llvm/trunk/lib/Transforms/IPO/GlobalOpt.cpp +++ llvm/trunk/lib/Transforms/IPO/GlobalOpt.cpp @@ -467,12 +467,7 @@ NGV->setAlignment(NewAlign); } } else if (SequentialType *STy = dyn_cast(Ty)) { - unsigned NumElements = 0; - if (ArrayType *ATy = dyn_cast(STy)) - NumElements = ATy->getNumElements(); - else - NumElements = cast(STy)->getNumElements(); - + unsigned NumElements = STy->getNumElements(); if (NumElements > 16 && GV->hasNUsesOrMore(16)) return nullptr; // It's not worth it. NewGlobals.reserve(NumElements); @@ -2119,12 +2114,7 @@ ConstantInt *CI = cast(Addr->getOperand(OpNo)); SequentialType *InitTy = cast(Init->getType()); - - uint64_t NumElts; - if (ArrayType *ATy = dyn_cast(InitTy)) - NumElts = ATy->getNumElements(); - else - NumElts = InitTy->getVectorNumElements(); + uint64_t NumElts = InitTy->getNumElements(); // Break up the array into elements. for (uint64_t i = 0, e = NumElts; i != e; ++i) Index: llvm/trunk/lib/Transforms/Scalar/SROA.cpp =================================================================== --- llvm/trunk/lib/Transforms/Scalar/SROA.cpp +++ llvm/trunk/lib/Transforms/Scalar/SROA.cpp @@ -3222,13 +3222,8 @@ Type *ElementTy = SeqTy->getElementType(); uint64_t ElementSize = DL.getTypeAllocSize(ElementTy); uint64_t NumSkippedElements = Offset / ElementSize; - if (ArrayType *ArrTy = dyn_cast(SeqTy)) { - if (NumSkippedElements >= ArrTy->getNumElements()) - return nullptr; - } else if (VectorType *VecTy = dyn_cast(SeqTy)) { - if (NumSkippedElements >= VecTy->getNumElements()) - return nullptr; - } + if (NumSkippedElements >= SeqTy->getNumElements()) + return nullptr; Offset -= NumSkippedElements * ElementSize; // First check if we need to recurse. Index: llvm/trunk/lib/Transforms/Utils/FunctionComparator.cpp =================================================================== --- llvm/trunk/lib/Transforms/Utils/FunctionComparator.cpp +++ llvm/trunk/lib/Transforms/Utils/FunctionComparator.cpp @@ -387,12 +387,6 @@ case Type::IntegerTyID: return cmpNumbers(cast(TyL)->getBitWidth(), cast(TyR)->getBitWidth()); - case Type::VectorTyID: { - VectorType *VTyL = cast(TyL), *VTyR = cast(TyR); - if (int Res = cmpNumbers(VTyL->getNumElements(), VTyR->getNumElements())) - return Res; - return cmpTypes(VTyL->getElementType(), VTyR->getElementType()); - } // TyL == TyR would have returned true earlier, because types are uniqued. case Type::VoidTyID: case Type::FloatTyID: @@ -445,12 +439,13 @@ return 0; } - case Type::ArrayTyID: { - ArrayType *ATyL = cast(TyL); - ArrayType *ATyR = cast(TyR); - if (ATyL->getNumElements() != ATyR->getNumElements()) - return cmpNumbers(ATyL->getNumElements(), ATyR->getNumElements()); - return cmpTypes(ATyL->getElementType(), ATyR->getElementType()); + case Type::ArrayTyID: + case Type::VectorTyID: { + auto *STyL = cast(TyL); + auto *STyR = cast(TyR); + if (STyL->getNumElements() != STyR->getNumElements()) + return cmpNumbers(STyL->getNumElements(), STyR->getNumElements()); + return cmpTypes(STyL->getElementType(), STyR->getElementType()); } } }