diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp --- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp +++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp @@ -1028,21 +1028,29 @@ return SelectInst::Create(X, TVal, FVal); } -static Constant *constantFoldOperationIntoSelectOperand( - Instruction &I, SelectInst *SI, Value *SO) { - auto *ConstSO = dyn_cast(SO); - if (!ConstSO) - return nullptr; - +static Constant *constantFoldOperationIntoSelectOperand(Instruction &I, + SelectInst *SI, + bool IsTrueArm) { SmallVector ConstOps; for (Value *Op : I.operands()) { - if (Op == SI) - ConstOps.push_back(ConstSO); - else if (auto *C = dyn_cast(Op)) - ConstOps.push_back(C); - else + CmpInst::Predicate Pred; + Constant *C = nullptr; + if (Op == SI) { + C = dyn_cast(IsTrueArm ? SI->getTrueValue() + : SI->getFalseValue()); + } else if (match(SI->getCondition(), + m_ICmp(Pred, m_Specific(Op), m_Constant(C))) && + Pred == (IsTrueArm ? ICmpInst::ICMP_EQ : ICmpInst::ICMP_NE)) { + // Pass + } else { + C = dyn_cast(Op); + } + if (C == nullptr) return nullptr; + + ConstOps.push_back(C); } + return ConstantFoldInstOperands(&I, ConstOps, I.getModule()->getDataLayout()); } @@ -1085,8 +1093,8 @@ } // Make sure that one of the select arms constant folds successfully. - Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, TV); - Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, FV); + Value *NewTV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ true); + Value *NewFV = constantFoldOperationIntoSelectOperand(Op, SI, /*IsTrueArm*/ false); if (!NewTV && !NewFV) return nullptr; diff --git a/llvm/test/Transforms/InstCombine/binop-select.ll b/llvm/test/Transforms/InstCombine/binop-select.ll --- a/llvm/test/Transforms/InstCombine/binop-select.ll +++ b/llvm/test/Transforms/InstCombine/binop-select.ll @@ -73,8 +73,8 @@ define i32 @test_sub_deduce_true(i32 %x, i32 %y) { ; CHECK-LABEL: @test_sub_deduce_true( ; CHECK-NEXT: [[C:%.*]] = icmp eq i32 [[X:%.*]], 9 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C]], i32 6, i32 [[Y:%.*]] -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[X]], i32 [[COND]]) +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[X]], i32 [[Y:%.*]]) +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C]], i32 15, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[SUB]] ; %c = icmp eq i32 %x, 9 @@ -99,8 +99,8 @@ define i32 @test_sub_deduce_false(i32 %x, i32 %y) { ; CHECK-LABEL: @test_sub_deduce_false( ; CHECK-NEXT: [[C_NOT:%.*]] = icmp eq i32 [[X:%.*]], 9 -; CHECK-NEXT: [[COND:%.*]] = select i1 [[C_NOT]], i32 7, i32 [[Y:%.*]] -; CHECK-NEXT: [[SUB:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[X]], i32 [[COND]]) +; CHECK-NEXT: [[TMP1:%.*]] = call i32 @llvm.sadd.sat.i32(i32 [[X]], i32 [[Y:%.*]]) +; CHECK-NEXT: [[SUB:%.*]] = select i1 [[C_NOT]], i32 16, i32 [[TMP1]] ; CHECK-NEXT: ret i32 [[SUB]] ; %c = icmp ne i32 %x, 9