diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -516,6 +516,63 @@ return nullptr; } +/// In the case of a comparison with two select instructions having the same +/// condition, try to simplify the comparison by seeing whether all +/// combinations of branches of the select result in the same value. Returns +/// the common value if so, otherwise returns null. +/// For example, if we have: +/// %tmp1 = select i1 %cmp, i32 1, i32 %param +/// %tmp2 = select i1 %cmp, i32 9, i32 %param +/// %cmp2 = icmp slt i32 %tmp2, %tmp1 +/// We can simplify %cmp2 to false, because for both branches of both the +/// selects will result false. +static Value *threadCmpOverMultSelect(CmpInst::Predicate Pred, Value *LHS, + Value *RHS, const SimplifyQuery &Q, + unsigned MaxRecurse) { + // Recursion is always used, so bail out at once if we already hit the limit. + if (!MaxRecurse--) + return nullptr; + + // Make sure both LHS and RHS are select. + assert(isa(LHS) && isa(RHS) && + "Not comparing with select instructions!"); + SelectInst *LHSSI = cast(LHS); + SelectInst *RHSSI = cast(RHS); + Value *LHSCond = LHSSI->getCondition(); + Value *RHSCond = RHSSI->getCondition(); + Value *LHSTV = LHSSI->getTrueValue(); + Value *LHSFV = LHSSI->getFalseValue(); + Value *RHSTV = RHSSI->getTrueValue(); + Value *RHSFV = RHSSI->getFalseValue(); + + // Make sure that both select have same condition for comparing equivalent + // branch. + if (LHSCond != RHSCond) + return nullptr; + + // Now that we have "cmp select(Cond, LHSTV, LHSFV), + // select(Cond, RHSTV, RHSFV)", + // analyse it. + // Does "cmp LHSTV, RHSTV" simplify? + Value *TCmp = + simplifyCmpSelTrueCase(Pred, LHSTV, RHSTV, LHSCond, Q, MaxRecurse); + if (!TCmp) + return nullptr; + + // Does "cmp LHSFV, RHSFV" simplify? + Value *FCmp = + simplifyCmpSelFalseCase(Pred, LHSFV, RHSFV, LHSCond, Q, MaxRecurse); + if (!FCmp) + return nullptr; + + // If both sides simplified to the same value, then use it as the result of + // the original comparison. + if (TCmp == FCmp) + return TCmp; + + return nullptr; +} + /// In the case of a binary operation with an operand that is a PHI instruction, /// try to simplify the binop by seeing whether evaluating it on the incoming /// phi values yields the same result for every value. If so returns the common @@ -3964,6 +4021,13 @@ if (Value *V = threadCmpOverSelect(Pred, LHS, RHS, Q, MaxRecurse)) return V; + // If the comparison is with the result of two select instructions with the + // same condition, check whether comparing with the equivalent branches of + // the selects always yields the same value. + if (isa(LHS) && isa(RHS)) + if (Value *V = threadCmpOverMultSelect(Pred, LHS, RHS, Q, MaxRecurse)) + return V; + // If the comparison is with the result of a phi instruction, check whether // doing the compare with each incoming phi value yields a common result. if (isa(LHS) || isa(RHS)) diff --git a/llvm/test/Transforms/InstSimplify/icmp-with-selects.ll b/llvm/test/Transforms/InstSimplify/icmp-with-selects.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/InstSimplify/icmp-with-selects.ll @@ -0,0 +1,31 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 2 +; RUN: opt < %s -passes=instsimplify -S | FileCheck %s + +define i32 @_Z3fooib(i32 noundef %param, i1 noundef zeroext %cond) { +; CHECK-LABEL: define i32 @_Z3fooib +; CHECK-SAME: (i32 noundef [[PARAM:%.*]], i1 noundef zeroext [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i32 0 +; +entry: + %cond1 = select i1 %cond, i32 1, i32 %param + %cond6 = select i1 %cond, i32 9, i32 %param + %cmp = icmp slt i32 %cond6, %cond1 + %cond7 = zext i1 %cmp to i32 + ret i32 %cond7 +} + +define i32 @_Z3barib(i32 noundef %param, i1 noundef zeroext %cond) { +; CHECK-LABEL: define i32 @_Z3barib +; CHECK-SAME: (i32 noundef [[PARAM:%.*]], i1 noundef zeroext [[COND:%.*]]) { +; CHECK-NEXT: entry: +; CHECK-NEXT: ret i32 0 +; +entry: + %cond1 = select i1 %cond, i32 9, i32 %param + %cond6 = select i1 %cond, i32 1, i32 %param + %cmp = icmp sgt i32 %cond6, %cond1 + %cond7 = zext i1 %cmp to i32 + ret i32 %cond7 +} +