diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -2772,12 +2772,77 @@ return true; break; } - case Instruction::Select: + case Instruction::Select: { // (C ? X : Y) != 0 if X != 0 and Y != 0. - if (isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q) && - isKnownNonZero(I->getOperand(2), DemandedElts, Depth, Q)) + bool Op1NonZero = isKnownNonZero(I->getOperand(1), DemandedElts, Depth, Q); + bool Op2NonZero = isKnownNonZero(I->getOperand(2), DemandedElts, Depth, Q); + if (Op1NonZero && Op2NonZero) + return true; + + auto IsConstZero = [](Value *MaybeC) { + if (auto *C = dyn_cast(MaybeC)) + return C->isNullValue(); + return false; + }; + + // The condition of the select dominates the true/false arm. Check if the + // condition implies that a given arm is non-zero. + auto SelectCondImpliesNonZero = [&](bool IsTrueArm) { + Value *X, *Op; + CmpInst::Predicate Pred; + + Op = IsTrueArm ? I->getOperand(1) : I->getOperand(2); + // TODO: We currently only check if the exact value used in the `icmp` + // matches one of the arms. We could go further and see if the value used + // in the condition being non-zero implies one of the arm's is non-zero + // (through a use). + if (match(I->getOperand(0), m_ICmp(Pred, m_Specific(Op), m_Value(X)))) { + // pass + } else if (match(I->getOperand(0), + m_ICmp(Pred, m_Value(X), m_Specific(Op)))) { + Pred = ICmpInst::getSwappedPredicate(Pred); + } else { + return false; + } + + if (!IsTrueArm) + Pred = ICmpInst::getInversePredicate(Pred); + + switch (Pred) { + case ICmpInst::ICMP_NE: + return IsConstZero(X); + case ICmpInst::ICMP_UGT: + return true; + case ICmpInst::ICMP_EQ: + case ICmpInst::ICMP_UGE: + return isKnownNonZero(X, DemandedElts, Depth, Q); + + case ICmpInst::ICMP_SGT: + return computeKnownBits(X, DemandedElts, Depth, Q).isNonNegative(); + case ICmpInst::ICMP_SGE: { + KnownBits KnownX = computeKnownBits(X, DemandedElts, Depth, Q); + return KnownX.isNonNegative() && + (KnownX.isNonZero() || + isKnownNonZero(X, DemandedElts, Depth, Q)); + } + + case ICmpInst::ICMP_SLT: + if (IsConstZero(X)) + return true; + [[fallthrough]]; + case ICmpInst::ICMP_SLE: + return computeKnownBits(X, DemandedElts, Depth, Q).isNegative(); + default: + return false; + } + }; + + if (Op2NonZero && SelectCondImpliesNonZero(/* IsTrueArm */ true)) + return true; + if (Op1NonZero && SelectCondImpliesNonZero(/* IsTrueArm */ false)) return true; break; + } case Instruction::PHI: { auto *PN = cast(I); if (Q.IIQ.UseInstrInfo && isNonZeroRecurrence(PN)) diff --git a/llvm/test/Analysis/ValueTracking/select-known-non-zero.ll b/llvm/test/Analysis/ValueTracking/select-known-non-zero.ll --- a/llvm/test/Analysis/ValueTracking/select-known-non-zero.ll +++ b/llvm/test/Analysis/ValueTracking/select-known-non-zero.ll @@ -6,10 +6,7 @@ ; CHECK-LABEL: @select_v_ne_z( ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp ne i8 [[V:%.*]], 0 -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -42,10 +39,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp ne i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp eq i8 [[C]], [[V:%.*]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -63,10 +57,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp ne i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i8 [[C]], [[V:%.*]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[V]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -103,10 +94,7 @@ ; CHECK-LABEL: @select_v_ult( ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp ult i8 [[C:%.*]], [[V:%.*]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -122,10 +110,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp ne i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp uge i8 [[V:%.*]], [[C]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -141,10 +126,7 @@ ; CHECK-LABEL: @inv_select_v_ule( ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp ule i8 [[V:%.*]], [[C:%.*]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[V]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -160,10 +142,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp sge i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[V:%.*]], [[C]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -202,10 +181,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp slt i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[V:%.*]], [[C]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[V]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -223,10 +199,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp sgt i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sgt i8 [[C]], [[V:%.*]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[V]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -244,10 +217,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp sge i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[C]], [[V:%.*]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -286,10 +256,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp slt i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp slt i8 [[V:%.*]], [[C]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -307,10 +274,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp sgt i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sge i8 [[V:%.*]], [[C]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -328,10 +292,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp slt i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sge i8 [[C]], [[V:%.*]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -347,10 +308,7 @@ ; CHECK-LABEL: @inv_select_v_sge_z( ; CHECK-NEXT: [[YNZ:%.*]] = icmp ne i8 [[Y:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sge i8 [[V:%.*]], 0 -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[V]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -383,10 +341,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp slt i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sle i8 [[V:%.*]], [[C]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -404,10 +359,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp sgt i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sle i8 [[C]], [[V:%.*]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[V]], i8 [[Y]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) @@ -446,10 +398,7 @@ ; CHECK-NEXT: call void @llvm.assume(i1 [[YNZ]]) ; CHECK-NEXT: [[PCOND0:%.*]] = icmp sge i8 [[C:%.*]], 0 ; CHECK-NEXT: call void @llvm.assume(i1 [[PCOND0]]) -; CHECK-NEXT: [[CMP:%.*]] = icmp sle i8 [[V:%.*]], [[C]] -; CHECK-NEXT: [[S:%.*]] = select i1 [[CMP]], i8 [[Y]], i8 [[V]] -; CHECK-NEXT: [[R:%.*]] = icmp eq i8 [[S]], 0 -; CHECK-NEXT: ret i1 [[R]] +; CHECK-NEXT: ret i1 false ; %ynz = icmp ne i8 %y, 0 call void @llvm.assume(i1 %ynz) diff --git a/llvm/test/Transforms/InstSimplify/compare.ll b/llvm/test/Transforms/InstSimplify/compare.ll --- a/llvm/test/Transforms/InstSimplify/compare.ll +++ b/llvm/test/Transforms/InstSimplify/compare.ll @@ -729,10 +729,7 @@ define i1 @select6(i32 %x) { ; CHECK-LABEL: @select6( -; CHECK-NEXT: [[C:%.*]] = icmp sgt i32 [[X:%.*]], 0 -; CHECK-NEXT: [[S:%.*]] = select i1 [[C]], i32 [[X]], i32 4 -; CHECK-NEXT: [[C2:%.*]] = icmp eq i32 [[S]], 0 -; CHECK-NEXT: ret i1 [[C2]] +; CHECK-NEXT: ret i1 false ; %c = icmp sgt i32 %x, 0 %s = select i1 %c, i32 %x, i32 4