Index: lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSelect.cpp +++ lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -731,6 +731,64 @@ return V; } +static bool CanFoldSelectOfShuffleVectors(SelectInst &SI, Value *CondVal, + Value *TrueVal, Value *FalseVal) { + VectorType *VecTy = dyn_cast(CondVal->getType()); + ShuffleVectorInst *TrueSV = dyn_cast(TrueVal); + ShuffleVectorInst *FalseSV = dyn_cast(FalseVal); + if (!(VecTy && TrueSV && FalseSV)) + return false; + + ConstantVector *CondV = dyn_cast(CondVal); + if (!(CondV && SI.hasOneUse())) + return false; + + Value *TrueV2 = TrueSV->getOperand(1); + Value *FalseV2 = FalseSV->getOperand(1); + // We just check for *V2 being undef since instcombine will turn + // shufflevector(undef, v) into shufflevector(v, undef) + if (!(isa(TrueV2) && isa(FalseV2))) + return false; + + // The source vectors (not the mask) for the shuffle vectors have to + // have the same type. Otherwise we could end up trying to do a + // shufflevector <4 x i32> <2 x i32> + VectorType *TrueShuffleSrc = cast(TrueV2->getType()); + VectorType *FalseShuffleSrc = cast(FalseV2->getType()); + if (TrueShuffleSrc != FalseShuffleSrc) + return false; + + return true; +} + +// This is instruction is only safe to call if CanFoldSelectOfShuffleVectors is +// true. +static Instruction *FoldSelectOfShuffleVectors(SelectInst &SI) { + SmallVector ShuffleMask; + ShuffleVectorInst *Sources[] = { cast(SI.getFalseValue()), + cast(SI.getTrueValue()) }; + ConstantVector *CondV = cast(SI.getCondition()); + + unsigned NumElems = cast(SI.getType())->getNumElements(); + unsigned NumSourceElems = + cast(Sources[0]->getOperand(0)->getType())->getNumElements(); + for (unsigned i = 0; i < NumElems; ++i) { + ConstantInt *Element = dyn_cast(CondV->getAggregateElement(i)); + if (!Element) { + ShuffleMask.push_back(-1); + continue; + } + int Selector = Element->isOne(); + int SourceIdx = Sources[Selector]->getMaskValue(i); + ShuffleMask.push_back(SourceIdx == -1 ? -1 : SourceIdx + + NumSourceElems * Selector); + } + + return new ShuffleVectorInst( + Sources[0]->getOperand(0), Sources[1]->getOperand(0), + ConstantDataVector::get(SI.getContext(), ShuffleMask)); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1005,6 +1063,12 @@ if (isa(CondVal)) { return ReplaceInstUsesWith(SI, FalseVal); } + + // If all the elements with true in the mask correspond to undef on + // FalseV, and the reverse is true with false in the mask and TrueV, we + // can merge the shufflevectors and remove the select. + if (CanFoldSelectOfShuffleVectors(SI, CondVal, TrueVal, FalseVal)) + return FoldSelectOfShuffleVectors(SI); } return nullptr; Index: test/Transforms/InstCombine/select.ll =================================================================== --- test/Transforms/InstCombine/select.ll +++ test/Transforms/InstCombine/select.ll @@ -1031,3 +1031,37 @@ ; CHECK: lshr exact i32 %2, 1 ; CHECK: xor i32 %3, 42 } + +define <4 x float> @add2f_0(double %a0.coerce, double %b0.coerce, double %a1.coerce, double %b1.coerce) { +; CHECK-LABEL: @add2f_0 +; CHECK-NOT: select +; CHECK: ret + %1 = bitcast double %a0.coerce to <2 x float> + %2 = bitcast double %b0.coerce to <2 x float> + %3 = bitcast double %a1.coerce to <2 x float> + %4 = bitcast double %b1.coerce to <2 x float> + %5 = shufflevector <2 x float> %1, <2 x float> undef, <4 x i32> + %6 = shufflevector <2 x float> %3, <2 x float> undef, <4 x i32> + %7 = select <4 x i1> , <4 x float> %6, <4 x float> %5 + %8 = shufflevector <2 x float> %2, <2 x float> undef, <4 x i32> + %9 = shufflevector <2 x float> %4, <2 x float> undef, <4 x i32> + %10 = select <4 x i1> , <4 x float> %9, <4 x float> %8 + %11 = fadd <4 x float> %7, %10 + ret <4 x float> %11 +} + +define <4 x float> @add2f_1(double %a0.coerce, double %b0.coerce, double %a1.coerce, double %b1.coerce) { +; CHECK-LABEL: @add2f_1 +; CHECK-NOT: select +; CHECK: ret + %1 = bitcast double %a0.coerce to <2 x float> + %2 = bitcast double %b0.coerce to <2 x float> + %3 = bitcast double %a1.coerce to <2 x float> + %4 = bitcast double %b1.coerce to <2 x float> + %5 = fadd <2 x float> %1, %2 + %6 = shufflevector <2 x float> %5, <2 x float> undef, <4 x i32> + %7 = fadd <2 x float> %3, %4 + %8 = shufflevector <2 x float> %7, <2 x float> undef, <4 x i32> + %9 = select <4 x i1> , <4 x float> %8, <4 x float> %6 + ret <4 x float> %9 +}