diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1026,22 +1026,38 @@ return SelectInst::Create(X, TVal, FVal); } -static Constant *constantFoldOperationIntoSelectOperand( - Instruction &I, SelectInst *SI, Value *SO) { - auto *ConstSO = dyn_cast(SO); - if (!ConstSO) - return nullptr; - - SmallVector ConstOps; +static SmallVector +getInstSelectArmOperands(Instruction &I, SelectInst *SI, bool IsTrueArm) { + SmallVector Ops; for (Value *Op : I.operands()) { + CmpInst::Predicate Pred; + Constant *CondC; if (Op == SI) - ConstOps.push_back(ConstSO); - else if (auto *C = dyn_cast(Op)) + Ops.push_back(IsTrueArm ? SI->getTrueValue() : SI->getFalseValue()); + else if (match(SI->getCondition(), + m_ICmp(Pred, m_Specific(Op), m_Constant(CondC))) && + Pred == (IsTrueArm ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE)) + Ops.push_back(CondC); + else + Ops.push_back(Op); + } + + return Ops; +} + +static Constant *constantFoldOperationIntoSelectOperand(Instruction &I, + SelectInst *SI, + bool IsTrueArm) { + SmallVector ConstOps; + SmallVector Ops = getInstSelectArmOperands(I, SI, IsTrueArm); + for (Value *Op : Ops) { + if (auto *C = dyn_cast(Op)) ConstOps.push_back(C); else return nullptr; } - return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); + return ConstantFoldInstOperands(&I, ConstOps, + I.getModule()->getDataLayout()); } static Value *foldOperationIntoSelectOperand(Instruction &I, SelectInst *SI, @@ -1083,8 +1099,8 @@ } // Make sure that one of the select arms constant folds successfully. - Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, TV); - Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, FV); + Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true); + Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false); if (!NewTV && !NewFV) return nullptr;