diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -439,79 +439,56 @@ Value *FalseVal) { // See the comment above GetSelectFoldableOperands for a description of the // transformation we are doing here. - if (auto *TVI = dyn_cast(TrueVal)) { - if (TVI->hasOneUse() && !isa(FalseVal)) { - if (unsigned SFO = getSelectFoldableOperands(TVI)) { - unsigned OpToFold = 0; - if ((SFO & 1) && FalseVal == TVI->getOperand(0)) { - OpToFold = 1; - } else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) { - OpToFold = 2; - } - - if (OpToFold) { - FastMathFlags FMF; - if (isa(&SI)) - FMF = SI.getFastMathFlags(); - Constant *C = ConstantExpr::getBinOpIdentity( - TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); - Value *OOp = TVI->getOperand(2-OpToFold); - // Avoid creating select between 2 constants unless it's selecting - // between 0, 1 and -1. - const APInt *OOpC; - bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa(OOp) || - (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { - Value *NewSel = Builder.CreateSelect(SI.getCondition(), OOp, C); + auto TryFoldSelectIntoOp = [&](SelectInst &SI, Value *TrueVal, + Value *FalseVal, + bool Swapped) -> Instruction * { + if (auto *TVI = dyn_cast(TrueVal)) { + if (TVI->hasOneUse() && !isa(FalseVal)) { + if (unsigned SFO = getSelectFoldableOperands(TVI)) { + unsigned OpToFold = 0; + if ((SFO & 1) && FalseVal == TVI->getOperand(0)) + OpToFold = 1; + else if ((SFO & 2) && FalseVal == TVI->getOperand(1)) + OpToFold = 2; + + if (OpToFold) { + FastMathFlags FMF; + // TODO: We probably ought to revisit cases where the select and FP + // instructions have different flags and add tests to ensure the + // behaviour is correct. if (isa(&SI)) - cast(NewSel)->setFastMathFlags(FMF); - NewSel->takeName(TVI); - BinaryOperator *BO = BinaryOperator::Create(TVI->getOpcode(), - FalseVal, NewSel); - BO->copyIRFlags(TVI); - return BO; + FMF = SI.getFastMathFlags(); + Constant *C = ConstantExpr::getBinOpIdentity( + TVI->getOpcode(), TVI->getType(), true, FMF.noSignedZeros()); + Value *OOp = TVI->getOperand(2 - OpToFold); + // Avoid creating select between 2 constants unless it's selecting + // between 0, 1 and -1. + const APInt *OOpC; + bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); + if (!isa(OOp) || + (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { + Value *NewSel = Builder.CreateSelect( + SI.getCondition(), Swapped ? C : OOp, Swapped ? OOp : C); + if (isa(&SI)) + cast(NewSel)->setFastMathFlags(FMF); + NewSel->takeName(TVI); + BinaryOperator *BO = + BinaryOperator::Create(TVI->getOpcode(), FalseVal, NewSel); + BO->copyIRFlags(TVI); + return BO; + } } } } } - } + return nullptr; + }; - if (auto *FVI = dyn_cast(FalseVal)) { - if (FVI->hasOneUse() && !isa(TrueVal)) { - if (unsigned SFO = getSelectFoldableOperands(FVI)) { - unsigned OpToFold = 0; - if ((SFO & 1) && TrueVal == FVI->getOperand(0)) { - OpToFold = 1; - } else if ((SFO & 2) && TrueVal == FVI->getOperand(1)) { - OpToFold = 2; - } + if (Instruction *R = TryFoldSelectIntoOp(SI, TrueVal, FalseVal, false)) + return R; - if (OpToFold) { - FastMathFlags FMF; - if (isa(&SI)) - FMF = SI.getFastMathFlags(); - Constant *C = ConstantExpr::getBinOpIdentity( - FVI->getOpcode(), FVI->getType(), true, FMF.noSignedZeros()); - Value *OOp = FVI->getOperand(2-OpToFold); - // Avoid creating select between 2 constants unless it's selecting - // between 0, 1 and -1. - const APInt *OOpC; - bool OOpIsAPInt = match(OOp, m_APInt(OOpC)); - if (!isa(OOp) || - (OOpIsAPInt && isSelect01(C->getUniqueInteger(), *OOpC))) { - Value *NewSel = Builder.CreateSelect(SI.getCondition(), C, OOp); - if (isa(&SI)) - cast(NewSel)->setFastMathFlags(FMF); - NewSel->takeName(FVI); - BinaryOperator *BO = BinaryOperator::Create(FVI->getOpcode(), - TrueVal, NewSel); - BO->copyIRFlags(FVI); - return BO; - } - } - } - } - } + if (Instruction *R = TryFoldSelectIntoOp(SI, FalseVal, TrueVal, true)) + return R; return nullptr; }