Index: lib/Transforms/InstCombine/InstCombineSelect.cpp =================================================================== --- lib/Transforms/InstCombine/InstCombineSelect.cpp +++ lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -731,6 +731,63 @@ return V; } +static bool CanFoldSelectOfShuffleVectors(Value *CondVal, Value *TrueVal, + Value *FalseVal) { + VectorType *VecTy = dyn_cast(CondVal->getType()); + ShuffleVectorInst *TrueSV = dyn_cast(TrueVal); + ShuffleVectorInst *FalseSV = dyn_cast(FalseVal); + if (!VecTy || !TrueSV || !FalseSV) + return false; + + ConstantVector *CondV = dyn_cast(CondVal); + if (!CondV) + return false; + + Value *TrueV2 = TrueSV->getOperand(1); + Value *FalseV2 = FalseSV->getOperand(1); + // We just check for *V2 being undef since instcombine will turn + // shufflevector(undef, v) into shufflevector(v, undef) + if (!isa(TrueV2) || !isa(FalseV2)) + return false; + + // The source vectors (not the mask) for the shuffle vectors have to + // have the same type. Otherwise we could end up trying to do a + // shufflevector <4 x i32> <2 x i32> + Type *TrueShuffleSrc = TrueV2->getType(); + Type *FalseShuffleSrc = FalseV2->getType(); + if (TrueShuffleSrc != FalseShuffleSrc) + return false; + + return true; +} + +// This is instruction is only safe to call if CanFoldSelectOfShuffleVectors is +// true. +static Instruction *FoldSelectOfShuffleVectors(SelectInst &SI) { + SmallVector ShuffleMask; + ShuffleVectorInst *Sources[] = { cast(SI.getFalseValue()), + cast(SI.getTrueValue()) }; + ConstantVector *CondV = cast(SI.getCondition()); + + unsigned NumElems = cast(SI.getType())->getNumElements(); + unsigned NumSourceElems = + cast(Sources[0]->getOperand(0)->getType())->getNumElements(); + for (unsigned i = 0; i < NumElems; ++i) { + ConstantInt *Element = dyn_cast(CondV->getAggregateElement(i)); + if (!Element) + return nullptr; + + int Selector = Element->isOne(); + int SourceIdx = Sources[Selector]->getMaskValue(i); + ShuffleMask.push_back(SourceIdx == -1 ? -1 : SourceIdx + + NumSourceElems * Selector); + } + + return new ShuffleVectorInst( + Sources[0]->getOperand(0), Sources[1]->getOperand(0), + ConstantDataVector::get(SI.getContext(), ShuffleMask)); +} + Instruction *InstCombiner::visitSelectInst(SelectInst &SI) { Value *CondVal = SI.getCondition(); Value *TrueVal = SI.getTrueValue(); @@ -1005,6 +1062,13 @@ if (isa(CondVal)) { return ReplaceInstUsesWith(SI, FalseVal); } + + // If all the elements with true in the mask correspond to undef on + // FalseV, and the reverse is true with false in the mask and TrueV, we + // can merge the shufflevectors and remove the select. + if (CanFoldSelectOfShuffleVectors(SI, CondVal, TrueVal, FalseVal)) + if (Instruction *I = FoldSelectOfShuffleVectors(SI)) + return I; } return nullptr; Index: test/Transforms/InstCombine/select.ll =================================================================== --- test/Transforms/InstCombine/select.ll +++ test/Transforms/InstCombine/select.ll @@ -1031,3 +1031,36 @@ ; CHECK: lshr exact i32 %2, 1 ; CHECK: xor i32 %3, 42 } + +define <4 x float> @add2f_0(<2 x float> %a0, <2 x float> %b0, <2 x float> %a1, <2 x float> %b1) { +; CHECK-LABEL: @add2f_0 +; CHECK-NOT: select +; CHECK: shufflevector <2 x float> %a0, <2 x float> %a1, <4 x i32> +; CHECK: shufflevector <2 x float> %b0, <2 x float> %b1, <4 x i32> +; CHECK: ret + %1 = shufflevector <2 x float> %a0, <2 x float> undef, <4 x i32> + %2 = shufflevector <2 x float> %a1, <2 x float> undef, <4 x i32> + %3 = select <4 x i1> , <4 x float> %2, <4 x float> %1 + %4 = shufflevector <2 x float> %b0, <2 x float> undef, <4 x i32> + %5 = shufflevector <2 x float> %b1, <2 x float> undef, <4 x i32> + %6 = select <4 x i1> , <4 x float> %5, <4 x float> %4 + %7 = fadd <4 x float> %3, %6 + ret <4 x float> %7 +} + +;; This test might need to change if we implement the transform +;; (shuffle (add x y) (add z w)) -> (add (shuffle x z) (shuffle y w)) +define <4 x float> @add2f_1(<2 x float> %a0, <2 x float> %b0, <2 x float> %a1, <2 x float> %b1) { +; CHECK-LABEL: @add2f_1 +; CHECK-NOT: select +; CHECK: [[ADD0:%[a-z0-9]+]] = fadd <2 x float> %a0, %b0 +; CHECK: [[ADD1:%[a-z0-9]+]] = fadd <2 x float> %a1, %b1 +; CHECK: [[RES:%[a-z0-9]+]] = shufflevector <2 x float> [[ADD0]], <2 x float> [[ADD1]], <4 x i32> +; CHECK: ret + %1 = fadd <2 x float> %a0, %b0 + %2 = shufflevector <2 x float> %1, <2 x float> undef, <4 x i32> + %3 = fadd <2 x float> %a1, %b1 + %4 = shufflevector <2 x float> %3, <2 x float> undef, <4 x i32> + %5 = select <4 x i1> , <4 x float> %4, <4 x float> %2 + ret <4 x float> %5 +}