diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -886,6 +886,13 @@ /// impliesPoison returns true. bool impliesPoison(const Value *ValAssumedPoison, const Value *V); +/// Return true if V is poison given that ValAssumedPoison is already poison. +/// Poison generating flags or metadatas are ignored in the process of implying. +/// And the instructions ignored will be recorded in IgnoredInsts. +bool impliesPoisonIgnoreFlagsOrMetadata( + Value *ValAssumedPoison, const Value *V, + SmallVectorImpl &IgnoredInsts); + /// Return true if this function can prove that V does not have undef bits /// and is never poison. If V is an aggregate value or vector, check whether /// all elements (except padding) are not undef or poison. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -6684,8 +6684,9 @@ return false; } -static bool impliesPoison(const Value *ValAssumedPoison, const Value *V, - unsigned Depth) { +static bool +impliesPoison(Value *ValAssumedPoison, const Value *V, unsigned Depth, + SmallVectorImpl *IgnoredInsts = nullptr) { if (isGuaranteedNotToBePoison(ValAssumedPoison)) return true; @@ -6696,17 +6697,30 @@ if (Depth >= MaxDepth) return false; - const auto *I = dyn_cast(ValAssumedPoison); - if (I && !canCreatePoison(cast(I))) { - return all_of(I->operands(), [=](const Value *Op) { - return impliesPoison(Op, V, Depth + 1); - }); - } - return false; + auto *I = dyn_cast(ValAssumedPoison); + if (!I || canCreatePoison(cast(I), + /*ConsiderFlagsAndMetadata*/ !IgnoredInsts)) + return false; + + for (Value *Op : I->operands()) + if (!impliesPoison(Op, V, Depth + 1, IgnoredInsts)) + return false; + + if (IgnoredInsts && I->hasPoisonGeneratingFlagsOrMetadata()) + IgnoredInsts->push_back(I); + + return true; } bool llvm::impliesPoison(const Value *ValAssumedPoison, const Value *V) { - return ::impliesPoison(ValAssumedPoison, V, /* Depth */ 0); + return ::impliesPoison(const_cast(ValAssumedPoison), V, + /* Depth */ 0); +} + +bool llvm::impliesPoisonIgnoreFlagsOrMetadata( + Value *ValAssumedPoison, const Value *V, + SmallVectorImpl &IgnoredInsts) { + return ::impliesPoison(ValAssumedPoison, V, /* Depth */ 0, &IgnoredInsts); } static bool programUndefinedIfUndefOrPoison(const Value *V, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp @@ -2923,21 +2923,31 @@ auto *Zero = ConstantInt::getFalse(SelType); Value *A, *B, *C, *D; + auto dropPoisonGeneratingFlagsAndMetadata = [](SmallVectorImpl &Insts) { + for (auto *I : Insts) + I->dropPoisonGeneratingFlagsAndMetadata(); + }; // Folding select to and/or i1 isn't poison safe in general. impliesPoison // checks whether folding it does not convert a well-defined value into // poison. if (match(TrueVal, m_One())) { - if (impliesPoison(FalseVal, CondVal)) { - // Change: A = select B, true, C --> A = or B, C - return BinaryOperator::CreateOr(CondVal, FalseVal); - } - if (auto *LHS = dyn_cast(CondVal)) if (auto *RHS = dyn_cast(FalseVal)) if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ false, /*IsSelectLogical*/ true)) return replaceInstUsesWith(SI, V); + // Some patterns can be matched by both of the above and following + // combinations. Because we need to drop poison generating + // flags and metadatas for the following combination, it has less priority + // than the above combination. + SmallVector IgnoredInsts; + if (impliesPoisonIgnoreFlagsOrMetadata(FalseVal, CondVal, IgnoredInsts)) { + dropPoisonGeneratingFlagsAndMetadata(IgnoredInsts); + // Change: A = select B, true, C --> A = or B, C + return BinaryOperator::CreateOr(CondVal, FalseVal); + } + // (A && B) || (C && B) --> (A || C) && B if (match(CondVal, m_LogicalAnd(m_Value(A), m_Value(B))) && match(FalseVal, m_LogicalAnd(m_Value(C), m_Value(D))) && @@ -2968,17 +2978,23 @@ } if (match(FalseVal, m_Zero())) { - if (impliesPoison(TrueVal, CondVal)) { - // Change: A = select B, C, false --> A = and B, C - return BinaryOperator::CreateAnd(CondVal, TrueVal); - } - if (auto *LHS = dyn_cast(CondVal)) if (auto *RHS = dyn_cast(TrueVal)) if (Value *V = foldLogicOfFCmps(LHS, RHS, /*IsAnd*/ true, /*IsSelectLogical*/ true)) return replaceInstUsesWith(SI, V); + // Some patterns can be matched by both of the above and following + // combinations. Because we need to drop poison generating + // flags and metadatas for the following combination, it has less priority + // than the above combination. + SmallVector IgnoredInsts; + if (impliesPoisonIgnoreFlagsOrMetadata(TrueVal, CondVal, IgnoredInsts)) { + dropPoisonGeneratingFlagsAndMetadata(IgnoredInsts); + // Change: A = select B, C, false --> A = and B, C + return BinaryOperator::CreateAnd(CondVal, TrueVal); + } + // (A || B) && (C || B) --> (A && C) || B if (match(CondVal, m_LogicalOr(m_Value(A), m_Value(B))) && match(TrueVal, m_LogicalOr(m_Value(C), m_Value(D))) && diff --git a/llvm/test/Transforms/InstCombine/ispow2.ll b/llvm/test/Transforms/InstCombine/ispow2.ll --- a/llvm/test/Transforms/InstCombine/ispow2.ll +++ b/llvm/test/Transforms/InstCombine/ispow2.ll @@ -282,7 +282,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[T0]], 3 ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = select i1 [[NOTZERO]], i1 [[CMP]], i1 false +; CHECK-NEXT: [[R:%.*]] = and i1 [[NOTZERO]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -314,7 +314,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[T0]], 2 ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 1 -; CHECK-NEXT: [[R:%.*]] = select i1 [[NOTZERO]], i1 [[CMP]], i1 false +; CHECK-NEXT: [[R:%.*]] = and i1 [[NOTZERO]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -346,7 +346,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[T0]], 2 ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = select i1 [[NOTZERO]], i1 [[CMP]], i1 false +; CHECK-NEXT: [[R:%.*]] = and i1 [[NOTZERO]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -378,7 +378,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp ult i32 [[T0]], 2 ; CHECK-NEXT: [[CMP2:%.*]] = icmp sgt i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP2]], i1 [[CMP]], i1 false +; CHECK-NEXT: [[R:%.*]] = and i1 [[CMP2]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -493,7 +493,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[T0]], 2 ; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = select i1 [[ISZERO]], i1 true, i1 [[CMP]] +; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -525,7 +525,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[T0]], 1 ; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 1 -; CHECK-NEXT: [[R:%.*]] = select i1 [[ISZERO]], i1 true, i1 [[CMP]] +; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -557,7 +557,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp ugt i32 [[T0]], 1 ; CHECK-NEXT: [[CMP2:%.*]] = icmp slt i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = select i1 [[CMP2]], i1 true, i1 [[CMP]] +; CHECK-NEXT: [[R:%.*]] = or i1 [[CMP2]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -855,7 +855,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[T0]], 3 ; CHECK-NEXT: [[ISZERO:%.*]] = icmp eq i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = select i1 [[ISZERO]], i1 true, i1 [[CMP]] +; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -914,7 +914,11 @@ define i1 @is_pow2or0_ctpop_wrong_pred2_logical(i32 %x) { ; CHECK-LABEL: @is_pow2or0_ctpop_wrong_pred2_logical( -; CHECK-NEXT: ret i1 true +; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] +; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[T0]], 1 +; CHECK-NEXT: [[ISZERO:%.*]] = icmp ne i32 [[X]], 0 +; CHECK-NEXT: [[R:%.*]] = or i1 [[ISZERO]], [[CMP]] +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp ne i32 %t0, 1 @@ -1058,7 +1062,7 @@ ; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] ; CHECK-NEXT: [[CMP:%.*]] = icmp ne i32 [[T0]], 5 ; CHECK-NEXT: [[NOTZERO:%.*]] = icmp ne i32 [[X]], 0 -; CHECK-NEXT: [[R:%.*]] = select i1 [[NOTZERO]], i1 [[CMP]], i1 false +; CHECK-NEXT: [[R:%.*]] = and i1 [[NOTZERO]], [[CMP]] ; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) @@ -1117,7 +1121,11 @@ define i1 @isnot_pow2nor0_ctpop_wrong_pred2_logical(i32 %x) { ; CHECK-LABEL: @isnot_pow2nor0_ctpop_wrong_pred2_logical( -; CHECK-NEXT: ret i1 false +; CHECK-NEXT: [[T0:%.*]] = tail call i32 @llvm.ctpop.i32(i32 [[X:%.*]]), !range [[RNG0]] +; CHECK-NEXT: [[CMP:%.*]] = icmp eq i32 [[T0]], 1 +; CHECK-NEXT: [[NOTZERO:%.*]] = icmp eq i32 [[X]], 0 +; CHECK-NEXT: [[R:%.*]] = and i1 [[NOTZERO]], [[CMP]] +; CHECK-NEXT: ret i1 [[R]] ; %t0 = tail call i32 @llvm.ctpop.i32(i32 %x) %cmp = icmp eq i32 %t0, 1 diff --git a/llvm/test/Transforms/InstCombine/prevent-cmp-merge.ll b/llvm/test/Transforms/InstCombine/prevent-cmp-merge.ll --- a/llvm/test/Transforms/InstCombine/prevent-cmp-merge.ll +++ b/llvm/test/Transforms/InstCombine/prevent-cmp-merge.ll @@ -71,10 +71,10 @@ define zeroext i1 @test3_logical(i32 %lhs, i32 %rhs) { ; CHECK-LABEL: @test3_logical( -; CHECK-NEXT: [[SUB:%.*]] = sub nsw i32 [[LHS:%.*]], [[RHS:%.*]] +; CHECK-NEXT: [[SUB:%.*]] = sub i32 [[LHS:%.*]], [[RHS:%.*]] ; CHECK-NEXT: [[CMP1:%.*]] = icmp eq i32 [[LHS]], [[RHS]] ; CHECK-NEXT: [[CMP2:%.*]] = icmp eq i32 [[SUB]], 31 -; CHECK-NEXT: [[SEL:%.*]] = select i1 [[CMP1]], i1 true, i1 [[CMP2]] +; CHECK-NEXT: [[SEL:%.*]] = or i1 [[CMP1]], [[CMP2]] ; CHECK-NEXT: ret i1 [[SEL]] ;