diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -2568,11 +2568,13 @@ const Value *Vec = EEI->getVectorOperand(); const Value *Idx = EEI->getIndexOperand(); auto *CIdx = dyn_cast(Idx); - unsigned NumElts = cast(Vec->getType())->getNumElements(); - APInt DemandedVecElts = APInt::getAllOnesValue(NumElts); - if (CIdx && CIdx->getValue().ult(NumElts)) - DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue()); - return isKnownNonZero(Vec, DemandedVecElts, Depth, Q); + if (auto *VecTy = dyn_cast(Vec->getType())) { + unsigned NumElts = VecTy->getNumElements(); + APInt DemandedVecElts = APInt::getAllOnesValue(NumElts); + if (CIdx && CIdx->getValue().ult(NumElts)) + DemandedVecElts = APInt::getOneBitSet(NumElts, CIdx->getZExtValue()); + return isKnownNonZero(Vec, DemandedVecElts, Depth, Q); + } } KnownBits Known(BitWidth); diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCasts.cpp @@ -2222,8 +2222,7 @@ if (!VectorType::isValidElementType(DestType)) return nullptr; - unsigned NumElts = ExtElt->getVectorOperandType()->getNumElements(); - auto *NewVecType = FixedVectorType::get(DestType, NumElts); + auto *NewVecType = VectorType::get(DestType, ExtElt->getVectorOperandType()); auto *NewBC = IC.Builder.CreateBitCast(ExtElt->getVectorOperand(), NewVecType, "bc"); return ExtractElementInst::Create(NewBC, ExtElt->getIndexOperand()); diff --git a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll --- a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll +++ b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll @@ -146,3 +146,27 @@ %4 = insertelement %3, i32 %vec.e3, i32 3 ret %4 } + +define i32 @bitcast_of_extractelement( %d) { +; CHECK-LABEL: @bitcast_of_extractelement( +; CHECK-NEXT: [[BC:%.*]] = bitcast [[D:%.*]] to +; CHECK-NEXT: [[CAST:%.*]] = extractelement [[BC]], i32 0 +; CHECK-NEXT: ret i32 [[CAST]] +; + %ext = extractelement %d, i32 0 + %cast = bitcast float %ext to i32 + ret i32 %cast +} + +define i1 @extractelement_is_zero( %d, i1 %b, i32 %z) { +; CHECK-LABEL: @extractelement_is_zero( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[EXT:%.*]] = extractelement [[D:%.*]], i32 0 +; CHECK-NEXT: [[BB:%.*]] = icmp eq i32 [[EXT]], 0 +; CHECK-NEXT: ret i1 [[BB]] +; +entry: + %ext = extractelement %d, i32 0 + %bb = icmp eq i32 %ext, 0 + ret i1 %bb +}