diff --git a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp --- a/llvm/lib/Transforms/Vectorize/VectorCombine.cpp +++ b/llvm/lib/Transforms/Vectorize/VectorCombine.cpp @@ -92,18 +92,25 @@ } bool VectorCombine::vectorizeLoadInsert(Instruction &I) { - // Match insert into fixed vector of scalar load. + // Match insert into fixed vector of scalar value. auto *Ty = dyn_cast(I.getType()); Value *Scalar; if (!Ty || !match(&I, m_InsertElt(m_Undef(), m_Value(Scalar), m_ZeroInt())) || !Scalar->hasOneUse()) return false; + // Optionally match an extract from another vector. + Value *X; + bool HasExtract = match(Scalar, m_ExtractElt(m_Value(X), m_ZeroInt())); + if (!HasExtract) + X = Scalar; + + // Match source value as load of scalar or vector. // Do not vectorize scalar load (widening) if atomic/volatile or under // asan/hwasan/memtag/tsan. The widened load may load data from dirty regions // or create data races non-existent in the source. - auto *Load = dyn_cast(Scalar); - if (!Load || !Load->isSimple() || + auto *Load = dyn_cast(X); + if (!Load || !Load->isSimple() || !Load->hasOneUse() || Load->getFunction()->hasFnAttribute(Attribute::SanitizeMemTag) || mustSuppressSpeculation(*Load)) return false; @@ -134,10 +141,12 @@ return false; - // Original pattern: insertelt undef, load [free casts of] ScalarPtr, 0 - int OldCost = TTI.getMemoryOpCost(Instruction::Load, ScalarTy, Alignment, AS); + // Original pattern: insertelt undef, load [free casts of] PtrOp, 0 + Type *LoadTy = Load->getType(); + int OldCost = TTI.getMemoryOpCost(Instruction::Load, LoadTy, Alignment, AS); APInt DemandedElts = APInt::getOneBitSet(MinVecNumElts, 0); - OldCost += TTI.getScalarizationOverhead(MinVecTy, DemandedElts, true, false); + OldCost += TTI.getScalarizationOverhead(MinVecTy, DemandedElts, + /* Insert */ true, HasExtract); // New pattern: load VecPtr int NewCost = TTI.getMemoryOpCost(Instruction::Load, MinVecTy, Alignment, AS); diff --git a/llvm/test/Transforms/VectorCombine/X86/load.ll b/llvm/test/Transforms/VectorCombine/X86/load.ll --- a/llvm/test/Transforms/VectorCombine/X86/load.ll +++ b/llvm/test/Transforms/VectorCombine/X86/load.ll @@ -499,9 +499,8 @@ define <4 x float> @load_v2f32_extract_insert_v4f32(<2 x float>* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @load_v2f32_extract_insert_v4f32( -; CHECK-NEXT: [[L:%.*]] = load <2 x float>, <2 x float>* [[P:%.*]], align 4 -; CHECK-NEXT: [[S:%.*]] = extractelement <2 x float> [[L]], i32 0 -; CHECK-NEXT: [[R:%.*]] = insertelement <4 x float> undef, float [[S]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x float>* [[P:%.*]] to <4 x float>* +; CHECK-NEXT: [[R:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 ; CHECK-NEXT: ret <4 x float> [[R]] ; %l = load <2 x float>, <2 x float>* %p, align 4 @@ -512,9 +511,8 @@ define <4 x float> @load_v8f32_extract_insert_v4f32(<8 x float>* align 16 dereferenceable(16) %p) { ; CHECK-LABEL: @load_v8f32_extract_insert_v4f32( -; CHECK-NEXT: [[L:%.*]] = load <8 x float>, <8 x float>* [[P:%.*]], align 4 -; CHECK-NEXT: [[S:%.*]] = extractelement <8 x float> [[L]], i32 0 -; CHECK-NEXT: [[R:%.*]] = insertelement <4 x float> undef, float [[S]], i32 0 +; CHECK-NEXT: [[TMP1:%.*]] = bitcast <8 x float>* [[P:%.*]] to <4 x float>* +; CHECK-NEXT: [[R:%.*]] = load <4 x float>, <4 x float>* [[TMP1]], align 4 ; CHECK-NEXT: ret <4 x float> [[R]] ; %l = load <8 x float>, <8 x float>* %p, align 4