diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -2152,6 +2152,15 @@ return AnyBinaryOp_match(L, R); } +/// Matches a Cmp with a predicate over LHS and RHS in either order. +/// Swaps the predicate if operands are commuted. +template +inline CmpClass_match +m_c_Cmp(CmpInst::Predicate &Pred, const LHS &L, const RHS &R) { + return CmpClass_match(Pred, L, + R); +} + /// Matches an ICmp with a predicate over LHS and RHS in either order. /// Swaps the predicate if operands are commuted. template diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -4848,17 +4848,40 @@ PoisonI = dyn_cast(NextVal); } - if (PoisonValues.contains(V)) - return true; + SmallVector> Worklist; + Worklist.push_back({V, 0}); - // Let's look one level further, by seeing its arguments if I was an - // instruction. - // This happens when I is e.g. 'icmp X, const' where X is in PoisonValues. - const auto *I = dyn_cast(V); - if (I && propagatesPoison(cast(I))) { - for (const auto &Op : I->operands()) - if (PoisonValues.count(Op.get())) - return true; + while (!Worklist.empty()) { + auto Item = Worklist.back(); + Worklist.pop_back(); + const Value *V = Item.first; + unsigned Depth = Item.second; + + if (PoisonValues.contains(V)) + return true; + + // Common syntactic pattern + // cmp pred1 X, Y -> cmp pred2 X, Y + const Value *V1, *V2; + CmpInst::Predicate Pred, Pred2; + if (match(V, m_Cmp(Pred, m_Value(V1), m_Value(V2)))) { + for (auto *VP : PoisonValues) { + if (match(VP, m_c_Cmp(Pred2, m_Specific(V1), m_Specific(V2))) && + !canCreatePoison(cast(VP))) + return true; + } + } + + if (Depth == MaxDepth) + continue; + + if (const auto *SI = dyn_cast(V)) + Worklist.push_back({SI->getCondition(), Depth + 1}); + else if (const auto *Oper = dyn_cast(V)) { + if (propagatesPoison(Oper)) + for (const auto &Op : Oper->operands()) + Worklist.push_back({Op.get(), Depth + 1}); + } } return false; diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -748,6 +748,26 @@ EXPECT_FALSE(impliesPoison(A2, A)); } +TEST_F(ValueTrackingTest, impliesPoisonTest_Cmp) { + parseAssembly("define void @test(i32 %x, i32 %y, i1 %c) {\n" + " %A2 = icmp eq i32 %x, %y\n" + " %A0 = icmp ult i32 %x, %y\n" + " %A = or i1 %A0, %c\n" + " ret void\n" + "}"); + EXPECT_TRUE(impliesPoison(A2, A)); +} + +TEST_F(ValueTrackingTest, impliesPoisonTest_FCmpFMF) { + parseAssembly("define void @test(float %x, float %y, i1 %c) {\n" + " %A2 = fcmp nnan oeq float %x, %y\n" + " %A0 = fcmp olt float %x, %y\n" + " %A = or i1 %A0, %c\n" + " ret void\n" + "}"); + EXPECT_FALSE(impliesPoison(A2, A)); +} + TEST_F(ValueTrackingTest, ComputeNumSignBits_Shuffle_Pointers) { parseAssembly( "define <2 x i32*> @test(<2 x i32*> %x) {\n"