diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1079,12 +1079,11 @@ /// Reduce logic-of-compares with equality to a constant by substituting a /// common operand with the constant. Callers are expected to call this with /// Cmp0/Cmp1 switched to handle logic op commutativity. -static Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, - BinaryOperator &Logic, - InstCombiner::BuilderTy &Builder, - const SimplifyQuery &Q) { - bool IsAnd = Logic.getOpcode() == Instruction::And; - assert((IsAnd || Logic.getOpcode() == Instruction::Or) && "Wrong logic op"); +Value *InstCombinerImpl::foldAndOrOfICmpsWithConstEq( + ICmpInst *Cmp0, ICmpInst *Cmp1, const Instruction::BinaryOps Logic, + const SimplifyQuery &Q, bool SelectForm) { + bool IsAnd = Logic == Instruction::And; + assert((IsAnd || Logic == Instruction::Or) && "Wrong logic op"); // Match an equality compare with a non-poison constant as Cmp0. // Also, give up if the compare can be constant-folded to avoid looping. @@ -1119,7 +1118,15 @@ return nullptr; SubstituteCmp = Builder.CreateICmp(Pred1, Y, C); } - return Builder.CreateBinOp(Logic.getOpcode(), Cmp0, SubstituteCmp); + + if (SelectForm) { + auto *True = ConstantInt::getTrue(SubstituteCmp->getType()); + auto *False = ConstantInt::getFalse(SubstituteCmp->getType()); + return Builder.CreateSelect( + Cmp0, Logic == Instruction::And ? SubstituteCmp : True, + Logic == Instruction::And ? False : SubstituteCmp); + } + return Builder.CreateBinOp(Logic, Cmp0, SubstituteCmp); } /// Fold (icmp)&(icmp) if possible. @@ -1152,9 +1159,9 @@ if (Value *V = foldLogOpOfMaskedICmps(LHS, RHS, true, Builder)) return V; - if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, And, Builder, Q)) + if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Instruction::And, Q)) return V; - if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, And, Builder, Q)) + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Instruction::And, Q)) return V; // E.g. (icmp sge x, 0) & (icmp slt x, n) --> icmp ult x, n @@ -2396,9 +2403,9 @@ Builder.CreateAdd(B, Constant::getAllOnesValue(B->getType())), A); } - if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Or, Builder, Q)) + if (Value *V = foldAndOrOfICmpsWithConstEq(LHS, RHS, Instruction::Or, Q)) return V; - if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Or, Builder, Q)) + if (Value *V = foldAndOrOfICmpsWithConstEq(RHS, LHS, Instruction::Or, Q)) return V; // E.g. (icmp slt x, 0) | (icmp sgt x, n) --> icmp ugt x, n diff --git a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h --- a/llvm/lib/Transforms/InstCombine/InstCombineInternal.h +++ b/llvm/lib/Transforms/InstCombine/InstCombineInternal.h @@ -340,6 +340,15 @@ Value *foldOrOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Or); Value *foldXorOfICmps(ICmpInst *LHS, ICmpInst *RHS, BinaryOperator &Xor); + /// Reduce logic-of-compares with equality to a constant by substituting a + /// common operand with the constant. Callers are expected to call this with + /// Cmp0/Cmp1 switched to handle logic op commutativity. + /// If SelectForm is true, create select i1 instead of and/or. + Value *foldAndOrOfICmpsWithConstEq(ICmpInst *Cmp0, ICmpInst *Cmp1, + const Instruction::BinaryOps Logic, + const SimplifyQuery &SQ, + bool SelectForm = false); + /// Optimize (fcmp)&(fcmp) or (fcmp)|(fcmp). /// NOTE: Unlike most of instcombine, this returns a Value which should /// already be inserted into the function. diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2611,6 +2611,25 @@ if (match(FalseVal, m_Not(m_Specific(CondVal)))) return SelectInst::Create(FalseVal, ConstantInt::getTrue(SelType), TrueVal); + + // select (X == C), (X pred Y), false -> select (X == C), (C pred Y), false + // select (X != C), true, (X pred Y) -> select (X != C), true, (C pred Y) + if (auto *CondICmp = dyn_cast(CondVal)) { + auto *TrueICmp = dyn_cast(TrueVal), + *FalseICmp = dyn_cast(FalseVal); + if (match(TrueVal, m_One()) && FalseICmp) { + // or case + if (Value *V = foldAndOrOfICmpsWithConstEq(CondICmp, FalseICmp, + Instruction::Or, SQ, true)) + return replaceInstUsesWith(SI, V); + } + if (TrueICmp && match(FalseVal, m_Zero())) { + // and case + if (Value *V = foldAndOrOfICmpsWithConstEq(CondICmp, TrueICmp, + Instruction::And, SQ, true)) + return replaceInstUsesWith(SI, V); + } + } } // Selecting between two integer or vector splat integer constants? diff --git a/llvm/test/Transforms/InstCombine/select-safe-transforms.ll b/llvm/test/Transforms/InstCombine/select-safe-transforms.ll --- a/llvm/test/Transforms/InstCombine/select-safe-transforms.ll +++ b/llvm/test/Transforms/InstCombine/select-safe-transforms.ll @@ -17,9 +17,9 @@ define i1 @cond_eq_and_const(i8 %X, i8 %Y) { ; CHECK-LABEL: @cond_eq_and_const( ; CHECK-NEXT: [[COND:%.*]] = icmp eq i8 [[X:%.*]], 10 -; CHECK-NEXT: [[LHS:%.*]] = icmp ult i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND]], i1 [[LHS]], i1 false -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i8 [[Y:%.*]], 10 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], i1 [[TMP1]], i1 false +; CHECK-NEXT: ret i1 [[TMP2]] ; %cond = icmp eq i8 %X, 10 %lhs = icmp ult i8 %X, %Y @@ -43,9 +43,9 @@ define i1 @cond_eq_or_const(i8 %X, i8 %Y) { ; CHECK-LABEL: @cond_eq_or_const( ; CHECK-NEXT: [[COND:%.*]] = icmp ne i8 [[X:%.*]], 10 -; CHECK-NEXT: [[LHS:%.*]] = icmp ult i8 [[X]], [[Y:%.*]] -; CHECK-NEXT: [[RES:%.*]] = select i1 [[COND]], i1 true, i1 [[LHS]] -; CHECK-NEXT: ret i1 [[RES]] +; CHECK-NEXT: [[TMP1:%.*]] = icmp ugt i8 [[Y:%.*]], 10 +; CHECK-NEXT: [[TMP2:%.*]] = select i1 [[COND]], i1 true, i1 [[TMP1]] +; CHECK-NEXT: ret i1 [[TMP2]] ; %cond = icmp ne i8 %X, 10 %lhs = icmp ult i8 %X, %Y