diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -314,47 +314,68 @@ TI->getType()); } - // Cond ? -X : -Y --> -(Cond ? X : Y) - Value *X, *Y; - if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y))) && - (TI->hasOneUse() || FI->hasOneUse())) { - // Intersect FMF from the fneg instructions and union those with the select. - FastMathFlags FMF = TI->getFastMathFlags(); - FMF &= FI->getFastMathFlags(); - FMF |= SI.getFastMathFlags(); - Value *NewSel = Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); - if (auto *NewSelI = dyn_cast(NewSel)) - NewSelI->setFastMathFlags(FMF); - Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); - NewFNeg->setFastMathFlags(FMF); - return NewFNeg; - } - - // Min/max intrinsic with a common operand can have the common operand pulled - // after the select. This is the same transform as below for binops, but - // specialized for intrinsic matching and without the restrictive uses clause. - auto *TII = dyn_cast(TI); - auto *FII = dyn_cast(FI); - if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID() && - (TII->hasOneUse() || FII->hasOneUse())) { - Value *T0, *T1, *F0, *F1; - if (match(TII, m_MaxOrMin(m_Value(T0), m_Value(T1))) && - match(FII, m_MaxOrMin(m_Value(F0), m_Value(F1)))) { - if (T0 == F0) { - Value *NewSel = Builder.CreateSelect(Cond, T1, F1, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T0}); - } - if (T0 == F1) { - Value *NewSel = Builder.CreateSelect(Cond, T1, F0, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T0}); - } - if (T1 == F0) { - Value *NewSel = Builder.CreateSelect(Cond, T0, F1, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T1}); - } - if (T1 == F1) { - Value *NewSel = Builder.CreateSelect(Cond, T0, F0, "minmaxop", &SI); - return CallInst::Create(TII->getCalledFunction(), {NewSel, T1}); + Value *OtherOpT, *OtherOpF; + bool MatchIsOpZero; + auto getCommonOp = [&](Instruction *TI, Instruction *FI, + bool Commute) -> Value * { + Value *CommonOp = nullptr; + if (TI->getOperand(0) == FI->getOperand(0)) { + CommonOp = TI->getOperand(0); + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + } else if (TI->getOperand(1) == FI->getOperand(1)) { + CommonOp = TI->getOperand(1); + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = false; + } else if (!Commute) { + return nullptr; + } else if (TI->getOperand(0) == FI->getOperand(1)) { + CommonOp = TI->getOperand(0); + OtherOpT = TI->getOperand(1); + OtherOpF = FI->getOperand(0); + MatchIsOpZero = true; + } else if (TI->getOperand(1) == FI->getOperand(0)) { + CommonOp = TI->getOperand(1); + OtherOpT = TI->getOperand(0); + OtherOpF = FI->getOperand(1); + MatchIsOpZero = true; + } + return CommonOp; + }; + + if (TI->hasOneUse() || FI->hasOneUse()) { + // Cond ? -X : -Y --> -(Cond ? X : Y) + Value *X, *Y; + if (match(TI, m_FNeg(m_Value(X))) && match(FI, m_FNeg(m_Value(Y)))) { + // Intersect FMF from the fneg instructions and union those with the + // select. + FastMathFlags FMF = TI->getFastMathFlags(); + FMF &= FI->getFastMathFlags(); + FMF |= SI.getFastMathFlags(); + Value *NewSel = + Builder.CreateSelect(Cond, X, Y, SI.getName() + ".v", &SI); + if (auto *NewSelI = dyn_cast(NewSel)) + NewSelI->setFastMathFlags(FMF); + Instruction *NewFNeg = UnaryOperator::CreateFNeg(NewSel); + NewFNeg->setFastMathFlags(FMF); + return NewFNeg; + } + + // Min/max intrinsic with a common operand can have the common operand + // pulled after the select. This is the same transform as below for binops, + // but specialized for intrinsic matching and without the restrictive uses + // clause. + auto *TII = dyn_cast(TI); + auto *FII = dyn_cast(FI); + if (TII && FII && TII->getIntrinsicID() == FII->getIntrinsicID()) { + if (match(TII, m_MaxOrMin(m_Value(), m_Value()))) { + if (Value *MatchOp = getCommonOp(TI, FI, true)) { + Value *NewSel = + Builder.CreateSelect(Cond, OtherOpT, OtherOpF, "minmaxop", &SI); + return CallInst::Create(TII->getCalledFunction(), {NewSel, MatchOp}); + } } } } @@ -370,33 +391,9 @@ return nullptr; // Figure out if the operations have any operands in common. - Value *MatchOp, *OtherOpT, *OtherOpF; - bool MatchIsOpZero; - if (TI->getOperand(0) == FI->getOperand(0)) { - MatchOp = TI->getOperand(0); - OtherOpT = TI->getOperand(1); - OtherOpF = FI->getOperand(1); - MatchIsOpZero = true; - } else if (TI->getOperand(1) == FI->getOperand(1)) { - MatchOp = TI->getOperand(1); - OtherOpT = TI->getOperand(0); - OtherOpF = FI->getOperand(0); - MatchIsOpZero = false; - } else if (!TI->isCommutative()) { - return nullptr; - } else if (TI->getOperand(0) == FI->getOperand(1)) { - MatchOp = TI->getOperand(0); - OtherOpT = TI->getOperand(1); - OtherOpF = FI->getOperand(0); - MatchIsOpZero = true; - } else if (TI->getOperand(1) == FI->getOperand(0)) { - MatchOp = TI->getOperand(1); - OtherOpT = TI->getOperand(0); - OtherOpF = FI->getOperand(1); - MatchIsOpZero = true; - } else { + Value *MatchOp = getCommonOp(TI, FI, TI->isCommutative()); + if (!MatchOp) return nullptr; - } // If the select condition is a vector, the operands of the original select's // operands also must be vectors. This may not be the case for getelementptr