diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -331,6 +331,12 @@ if (Elt->isNullValue()) return findScalarElement(Val, EltNo); + // If the vector is a splat then we can trivially find the scalar element. + if (isa(VTy)) + if (Value *Splat = getSplatValue(V)) + if (EltNo < VTy->getElementCount().getKnownMinValue()) + return Splat; + // Otherwise, we don't know. return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll --- a/llvm/test/Transforms/InstCombine/vscale_extractelement.ll +++ b/llvm/test/Transforms/InstCombine/vscale_extractelement.ll @@ -271,6 +271,20 @@ ret i1 %res } +define i64* @ext_lane_from_bitcast_of_splat(i32* %v) { +; CHECK-LABEL: @ext_lane_from_bitcast_of_splat( +; CHECK-NEXT: entry: +; CHECK-NEXT: [[R:%.*]] = bitcast i32* [[V:%.*]] to i64* +; CHECK-NEXT: ret i64* [[R]] +; +entry: + %in = insertelement poison, i32* %v, i32 0 + %splat = shufflevector %in, poison, zeroinitializer + %bc = bitcast %splat to + %r = extractelement %bc, i32 3 + ret i64* %r +} + declare @llvm.experimental.stepvector.nxv2i64() declare @llvm.experimental.stepvector.nxv4i64() declare @llvm.experimental.stepvector.nxv4i32()