diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -55,8 +55,8 @@ // If this cast changes element count then we can't handle it here: // doing so requires endianness information. This should be handled by // Analysis/ConstantFolding.cpp - unsigned NumElts = DstTy->getNumElements(); - if (NumElts != cast(CV->getType())->getNumElements()) + unsigned NumElts = cast(DstTy)->getNumElements(); + if (NumElts != cast(CV->getType())->getNumElements()) return nullptr; Type *DstEltTy = DstTy->getElementType(); @@ -571,10 +571,11 @@ // If the cast operand is a constant vector, perform the cast by // operating on each element. In the cast of bitcasts, the element // count may be mismatched; don't attempt to handle that here. + // FIXME: handle DstTy being a scalable vector if ((isa(V) || isa(V)) && DestTy->isVectorTy() && - cast(DestTy)->getNumElements() == - cast(V->getType())->getNumElements()) { + cast(DestTy)->getNumElements() == + cast(V->getType())->getNumElements()) { VectorType *DestVecTy = cast(DestTy); Type *DstEltTy = DestVecTy->getElementType(); // Fast path for splatted constants. @@ -585,7 +586,8 @@ } SmallVector res; Type *Ty = IntegerType::get(V->getContext(), 32); - for (unsigned i = 0, e = cast(V->getType())->getNumElements(); + for (unsigned i = 0, + e = cast(V->getType())->getNumElements(); i != e; ++i) { Constant *C = ConstantExpr::getExtractElement(V, ConstantInt::get(Ty, i)); @@ -798,17 +800,17 @@ Constant *llvm::ConstantFoldExtractElementInstruction(Constant *Val, Constant *Idx) { - auto *ValVTy = cast(Val->getType()); - // extractelt undef, C -> undef // extractelt C, undef -> undef if (isa(Val) || isa(Idx)) - return UndefValue::get(ValVTy->getElementType()); + return UndefValue::get(cast(Val->getType())->getElementType()); auto *CIdx = dyn_cast(Idx); if (!CIdx) return nullptr; + auto *ValVTy = cast(Val->getType()); + // ee({w,x,y,z}, wrong_value) -> undef if (CIdx->uge(ValVTy->getNumElements())) return UndefValue::get(ValVTy->getElementType()); @@ -847,11 +849,12 @@ // Do not iterate on scalable vector. The num of elements is unknown at // compile-time. - VectorType *ValTy = cast(Val->getType()); - if (isa(ValTy)) + if (isa(Val->getType())) return nullptr; - unsigned NumElts = cast(Val->getType())->getNumElements(); + auto *ValTy = cast(Val->getType()); + + unsigned NumElts = ValTy->getNumElements(); if (CIdx->uge(NumElts)) return UndefValue::get(Val->getType()); @@ -898,7 +901,7 @@ if (isa(V1VTy)) return nullptr; - unsigned SrcNumElts = V1VTy->getNumElements(); + unsigned SrcNumElts = V1VTy->getElementCount().Min; // Loop over the shuffle mask, evaluating each element. SmallVector Result; @@ -998,11 +1001,12 @@ case Instruction::FNeg: return ConstantFP::get(C->getContext(), neg(CV)); } - } else if (VectorType *VTy = dyn_cast(C->getType())) { + } else if (IsScalableVector) { // Do not iterate on scalable vector. The number of elements is unknown at // compile-time. - if (IsScalableVector) - return nullptr; + return nullptr; + } else if (auto *VTy = dyn_cast(C->getType())) { + Type *Ty = IntegerType::get(VTy->getContext(), 32); // Fast path for splatted constants. if (Constant *Splat = C->getSplatValue()) { @@ -1011,7 +1015,7 @@ } // Fold each element and create a vector constant from those constants. - SmallVector Result; + SmallVector Result; for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) { Constant *ExtractIdx = ConstantInt::get(Ty, i); Constant *Elt = ConstantExpr::getExtractElement(C, ExtractIdx); @@ -1367,11 +1371,11 @@ return ConstantFP::get(C1->getContext(), C3V); } } - } else if (VectorType *VTy = dyn_cast(C1->getType())) { + } else if (IsScalableVector) { // Do not iterate on scalable vector. The number of elements is unknown at // compile-time. - if (IsScalableVector) - return nullptr; + return nullptr; + } else if (auto *VTy = dyn_cast(C1->getType())) { // Fast path for splatted constants. if (Constant *C2Splat = C2->getSplatValue()) { if (Instruction::isIntDivRem(Opcode) && C2Splat->isNullValue()) @@ -2014,7 +2018,7 @@ SmallVector ResElts; Type *Ty = IntegerType::get(C1->getContext(), 32); // Compare the elements, producing an i1 result or constant expr. - for (unsigned i = 0, e = C1VTy->getNumElements(); i != e; ++i) { + for (unsigned i = 0, e = C1VTy->getElementCount().Min; i != e; ++i) { Constant *C1E = ConstantExpr::getExtractElement(C1, ConstantInt::get(Ty, i)); Constant *C2E = @@ -2286,14 +2290,18 @@ assert(Ty && "Invalid indices for GEP!"); Type *OrigGEPTy = PointerType::get(Ty, PtrTy->getAddressSpace()); Type *GEPTy = PointerType::get(Ty, PtrTy->getAddressSpace()); - if (VectorType *VT = dyn_cast(C->getType())) - GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements()); - + if (VectorType *VT = dyn_cast(C->getType())) { + // FIXME: handle scalable vectors + GEPTy = FixedVectorType::get( + OrigGEPTy, cast(VT)->getNumElements()); + } // The GEP returns a vector of pointers when one of more of // its arguments is a vector. for (unsigned i = 0, e = Idxs.size(); i != e; ++i) { if (auto *VT = dyn_cast(Idxs[i]->getType())) { - GEPTy = FixedVectorType::get(OrigGEPTy, VT->getNumElements()); + // FIXME: handle scalable vectors + GEPTy = FixedVectorType::get( + OrigGEPTy, cast(VT)->getNumElements()); break; } } @@ -2500,19 +2508,19 @@ if (!IsCurrIdxVector && IsPrevIdxVector) CurrIdx = ConstantDataVector::getSplat( - cast(PrevIdx->getType())->getNumElements(), CurrIdx); + cast(PrevIdx->getType())->getNumElements(), CurrIdx); if (!IsPrevIdxVector && IsCurrIdxVector) PrevIdx = ConstantDataVector::getSplat( - cast(CurrIdx->getType())->getNumElements(), PrevIdx); + cast(CurrIdx->getType())->getNumElements(), PrevIdx); Constant *Factor = ConstantInt::get(CurrIdx->getType()->getScalarType(), NumElements); if (UseVector) Factor = ConstantDataVector::getSplat( IsPrevIdxVector - ? cast(PrevIdx->getType())->getNumElements() - : cast(CurrIdx->getType())->getNumElements(), + ? cast(PrevIdx->getType())->getNumElements() + : cast(CurrIdx->getType())->getNumElements(), Factor); NewIdxs[i] = ConstantExpr::getSRem(CurrIdx, Factor); @@ -2531,8 +2539,8 @@ ExtendedTy = FixedVectorType::get( ExtendedTy, IsPrevIdxVector - ? cast(PrevIdx->getType())->getNumElements() - : cast(CurrIdx->getType())->getNumElements()); + ? cast(PrevIdx->getType())->getNumElements() + : cast(CurrIdx->getType())->getNumElements()); if (!PrevIdx->getType()->isIntOrIntVectorTy(CommonExtendedWidth)) PrevIdx = ConstantExpr::getSExt(PrevIdx, ExtendedTy);