diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -594,14 +594,19 @@ const Loop *L); /// Return true if I yields poison or raises UB if any of its operands is - /// poison. - /// Formally, given I = `r = op v1 v2 .. vN`, propagatesPoison returns true - /// if, for all i, r is evaluated to poison or op raises UB if vi = poison. + /// poison, if \p ValAssumedPoison is nullptr. Otherwise return true if I + /// yields poison or raises UB if \p ValAssumedPoison is poison. + // + /// Formally, given I = `r = op v1 v2 .. vN`, propagatesPoison(I, nullptr) + /// returns true if, for all i, r is evaluated to poison or op raises UB if vi + /// = poison. + // /// If vi is a vector or an aggregate and r is a single value, any poison /// element in vi should make r poison or raise UB. /// To filter out operands that raise UB on poison, you can use /// getGuaranteedNonPoisonOp. - bool propagatesPoison(const Operator *I); + bool propagatesPoison(const Operator *I, + const Value *ValAssumedPoison = nullptr); /// Insert operands of I into Ops such that I will trigger undefined behavior /// if I is executed and that operand has a poison value. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6722,8 +6722,9 @@ while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { const Instruction *Poison = PoisonStack.pop_back_val(); - for (auto *PoisonUser : Poison->users()) { - if (propagatesPoison(cast(PoisonUser))) { + for (const Use &U : Poison->uses()) { + const User *PoisonUser = U.getUser(); + if (propagatesPoison(cast(PoisonUser), U)) { if (Pushed.insert(cast(PoisonUser)).second) PoisonStack.push_back(cast(PoisonUser)); } else if (auto *BI = dyn_cast(PoisonUser)) { diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -5079,15 +5079,12 @@ return false; if (const auto *I = dyn_cast(V)) { - if (propagatesPoison(cast(I))) - return any_of(I->operands(), [=](const Value *Op) { - return directlyImpliesPoison(ValAssumedPoison, Op, Depth + 1); - }); + if (isa(I) && any_of(I->operands(), [=](const Value *Op) { + return propagatesPoison(cast(I), Op) && + directlyImpliesPoison(ValAssumedPoison, Op, Depth + 1); + })) + return true; - // 'select ValAssumedPoison, _, _' is poison. - if (const auto *SI = dyn_cast(I)) - return directlyImpliesPoison(ValAssumedPoison, SI->getCondition(), - Depth + 1); // V = extractvalue V0, idx // V2 = extractvalue V0, idx2 // V0's elements are all poison or not. (e.g., add_with_overflow) @@ -5243,7 +5240,7 @@ else if (PoisonOnly && isa(Cond)) { // For poison, we can analyze further auto *Opr = cast(Cond); - if (propagatesPoison(Opr) && is_contained(Opr->operand_values(), V)) + if (is_contained(Opr->operand_values(), V) && propagatesPoison(Opr, V)) return true; } } @@ -5366,13 +5363,15 @@ llvm_unreachable("Instruction not contained in its own parent basic block."); } -bool llvm::propagatesPoison(const Operator *I) { +bool llvm::propagatesPoison(const Operator *I, const Value *Op) { + assert(!Op || is_contained(I->operands(), Op)); switch (I->getOpcode()) { case Instruction::Freeze: - case Instruction::Select: case Instruction::PHI: case Instruction::Invoke: return false; + case Instruction::Select: + return Op && Op == I->getOperand(0); case Instruction::Call: if (auto *II = dyn_cast(I)) { switch (II->getIntrinsicID()) { @@ -5531,11 +5530,11 @@ SmallSet Visited; YieldsPoison.insert(V); - auto Propagate = [&](const User *User) { - if (propagatesPoison(cast(User))) - YieldsPoison.insert(User); + auto Propagate = [&](const Use &U) { + if (propagatesPoison(cast(U.getUser()), U)) + YieldsPoison.insert(U.getUser()); }; - for_each(V->users(), Propagate); + for_each(V->uses(), Propagate); Visited.insert(BB); while (true) { @@ -5551,7 +5550,7 @@ // Mark poison that propagates from I through uses of I. if (YieldsPoison.count(&I)) - for_each(I.users(), Propagate); + for_each(I.uses(), Propagate); } BB = BB->getSingleSuccessor(); diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -931,6 +931,42 @@ } } +TEST(ValueTracking, propagatesPoisonWithValAssumedPoison) { + std::string AssemblyStr = "define void @f(i1 %cond, i32 %x, i32 %y) {\n"; + // (propagates poison?, IR instruction, operand index) + SmallVector, 32> Data = { + {true, "select i1 %cond, i32 %x, i32 %y", 0}, + {false, "select i1 %cond, i32 %x, i32 %y", 1}, + {false, "select i1 %cond, i32 %x, i32 %y", 2}, + }; + std::string AsmTail = " ret void\n}"; + + for (auto &Itm : Data) + AssemblyStr += std::get<1>(Itm) + "\n"; + AssemblyStr += AsmTail; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(AssemblyStr, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto &BB = F->getEntryBlock(); + + int Index = 0; + for (auto &I : BB) { + if (isa(&I)) + break; + EXPECT_EQ(propagatesPoison(cast(&I), + I.getOperand(std::get<2>(Data[Index]))), + std::get<0>(Data[Index])) + << "Incorrect answer at instruction " << Index << " = " << I; + Index++; + } +} + TEST_F(ValueTrackingTest, programUndefinedIfPoison) { parseAssembly("declare i32 @any_num()" "define void @test(i32 %mask) {\n"