Index: lib/Transforms/InstCombine/InstCombineVectorOps.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineVectorOps.cpp +++ lib/Transforms/InstCombine/InstCombineVectorOps.cpp @@ -615,6 +615,7 @@ Value *SplatVal = InsElt.getOperand(1); InsertElementInst *CurrIE = &InsElt; SmallVector ElementPresent(NumElements, false); + InsertElementInst *FirstIE = nullptr; // Walk the chain backwards, keeping track of which indices we inserted into, // until we hit something that isn't an insert of the splatted value. @@ -623,12 +624,18 @@ if (!Idx || CurrIE->getOperand(1) != SplatVal) return nullptr; - // Check none of the intermediate steps have any additional uses. - if ((CurrIE != &InsElt) && !CurrIE->hasOneUse()) + InsertElementInst *NextIE = + dyn_cast(CurrIE->getOperand(0)); + // Check none of the intermediate steps have any additional uses, except + // for the root insertelement instruction, which can be re-used, if it + // inserts at position 0. + if (CurrIE != &InsElt && + (!CurrIE->hasOneUse() && (NextIE != nullptr || !Idx->isZero()))) return nullptr; ElementPresent[Idx->getZExtValue()] = true; - CurrIE = dyn_cast(CurrIE->getOperand(0)); + FirstIE = CurrIE; + CurrIE = NextIE; } // Make sure we've seen an insert into every element. @@ -636,9 +643,14 @@ return nullptr; // All right, create the insert + shuffle. - Instruction *InsertFirst = InsertElementInst::Create( - UndefValue::get(VT), SplatVal, - ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), "", &InsElt); + Instruction *InsertFirst; + if (cast(FirstIE->getOperand(2))->isZero()) + InsertFirst = FirstIE; + else + InsertFirst = InsertElementInst::Create( + UndefValue::get(VT), SplatVal, + ConstantInt::get(Type::getInt32Ty(InsElt.getContext()), 0), + "", &InsElt); Constant *ZeroMask = ConstantAggregateZero::get( VectorType::get(Type::getInt32Ty(InsElt.getContext()), NumElements)); Index: test/Transforms/InstCombine/broadcast.ll =================================================================== --- test/Transforms/InstCombine/broadcast.ll +++ test/Transforms/InstCombine/broadcast.ll @@ -51,6 +51,22 @@ ret <4 x float> %tmp7 } +; CHECK-LABEL: @good5( +; CHECK-NEXT: %ins1 = insertelement <4 x float> undef, float %v, i32 0 +; CHECK-NEXT: %a1 = fadd <4 x float> %ins1, %ins1 +; CHECK-NEXT: %ins4 = shufflevector <4 x float> %ins1, <4 x float> undef, <4 x i32> zeroinitializer +; CHECK-NEXT: %res = fadd <4 x float> %a1, %ins4 +; CHECK-NEXT: ret <4 x float> %res +define <4 x float> @good5(float %v) { + %ins1 = insertelement <4 x float> undef, float %v, i32 0 + %a1 = fadd <4 x float> %ins1, %ins1 + %ins2 = insertelement<4 x float> %ins1, float %v, i32 1 + %ins3 = insertelement<4 x float> %ins2, float %v, i32 2 + %ins4 = insertelement<4 x float> %ins3, float %v, i32 3 + %res = fadd <4 x float> %a1, %ins4 + ret <4 x float> %res +} + ; CHECK-LABEL: bad1 ; CHECK-NOT: shufflevector define <4 x float> @bad1(float %arg) { @@ -107,3 +123,15 @@ %tmp6 = insertelement <4 x float> %tmp5, float %arg, i32 3 ret <4 x float> %tmp6 } + +; CHECK-LABEL: @bad7( +; CHECK-NOT: shufflevector +define <4 x float> @bad7(float %v) { + %ins1 = insertelement <4 x float> undef, float %v, i32 1 + %a1 = fadd <4 x float> %ins1, %ins1 + %ins2 = insertelement<4 x float> %ins1, float %v, i32 2 + %ins3 = insertelement<4 x float> %ins2, float %v, i32 3 + %ins4 = insertelement<4 x float> %ins3, float %v, i32 0 + %res = fadd <4 x float> %a1, %ins4 + ret <4 x float> %res +}