diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -594,14 +594,18 @@ const Loop *L); /// Return true if I yields poison or raises UB if any of its operands is - /// poison. - /// Formally, given I = `r = op v1 v2 .. vN`, propagatesPoison returns true - /// if, for all i, r is evaluated to poison or op raises UB if vi = poison. + /// poison, if \p PoisonOp is nullptr. Otherwise the concrete operand \p + /// PoisonOp is assumed to be poison. + // + /// Formally, given I = `r = op v1 v2 .. vN`, propagatesPoison(I, nullptr) + /// returns true if, for all i, r is evaluated to poison or op raises UB if vi + /// = poison. + // /// If vi is a vector or an aggregate and r is a single value, any poison /// element in vi should make r poison or raise UB. /// To filter out operands that raise UB on poison, you can use /// getGuaranteedNonPoisonOp. - bool propagatesPoison(const Operator *I); + bool propagatesPoison(const Operator *I, const Use *PoisonOp = nullptr); /// Insert operands of I into Ops such that I will trigger undefined behavior /// if I is executed and that operand has a poison value. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6722,8 +6722,9 @@ while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { const Instruction *Poison = PoisonStack.pop_back_val(); - for (auto *PoisonUser : Poison->users()) { - if (propagatesPoison(cast(PoisonUser))) { + for (const Use &U : Poison->uses()) { + const User *PoisonUser = U.getUser(); + if (propagatesPoison(cast(PoisonUser), &U)) { if (Pushed.insert(cast(PoisonUser)).second) PoisonStack.push_back(cast(PoisonUser)); } else if (auto *BI = dyn_cast(PoisonUser)) { diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -5088,15 +5088,12 @@ return false; if (const auto *I = dyn_cast(V)) { - if (propagatesPoison(cast(I))) - return any_of(I->operands(), [=](const Value *Op) { - return directlyImpliesPoison(ValAssumedPoison, Op, Depth + 1); - }); + if (isa(I) && any_of(I->operands(), [=](const Use &Op) { + return propagatesPoison(cast(I), &Op) && + directlyImpliesPoison(ValAssumedPoison, Op, Depth + 1); + })) + return true; - // 'select ValAssumedPoison, _, _' is poison. - if (const auto *SI = dyn_cast(I)) - return directlyImpliesPoison(ValAssumedPoison, SI->getCondition(), - Depth + 1); // V = extractvalue V0, idx // V2 = extractvalue V0, idx2 // V0's elements are all poison or not. (e.g., add_with_overflow) @@ -5252,7 +5249,9 @@ else if (PoisonOnly && isa(Cond)) { // For poison, we can analyze further auto *Opr = cast(Cond); - if (propagatesPoison(Opr) && is_contained(Opr->operand_values(), V)) + if (any_of(Opr->operands(), [V, Opr](const Use &U) { + return V == U && propagatesPoison(Opr, &U); + })) return true; } } @@ -5375,13 +5374,14 @@ llvm_unreachable("Instruction not contained in its own parent basic block."); } -bool llvm::propagatesPoison(const Operator *I) { +bool llvm::propagatesPoison(const Operator *I, const Use *PoisonOp) { switch (I->getOpcode()) { case Instruction::Freeze: - case Instruction::Select: case Instruction::PHI: case Instruction::Invoke: return false; + case Instruction::Select: + return PoisonOp && PoisonOp == &I->getOperandUse(0); case Instruction::Call: if (auto *II = dyn_cast(I)) { switch (II->getIntrinsicID()) { @@ -5540,11 +5540,11 @@ SmallSet Visited; YieldsPoison.insert(V); - auto Propagate = [&](const User *User) { - if (propagatesPoison(cast(User))) - YieldsPoison.insert(User); + auto Propagate = [&](const Use &U) { + if (propagatesPoison(cast(U.getUser()), &U)) + YieldsPoison.insert(U.getUser()); }; - for_each(V->users(), Propagate); + for_each(V->uses(), Propagate); Visited.insert(BB); while (true) { @@ -5560,7 +5560,7 @@ // Mark poison that propagates from I through uses of I. if (YieldsPoison.count(&I)) - for_each(I.users(), Propagate); + for_each(I.uses(), Propagate); } BB = BB->getSingleSuccessor(); diff --git a/llvm/test/Analysis/ScalarEvolution/nsw.ll b/llvm/test/Analysis/ScalarEvolution/nsw.ll --- a/llvm/test/Analysis/ScalarEvolution/nsw.ll +++ b/llvm/test/Analysis/ScalarEvolution/nsw.ll @@ -400,7 +400,7 @@ ; CHECK-NEXT: %iv = phi i32 [ %iv.next, %loop ], [ 0, %entry ] ; CHECK-NEXT: --> {0,+,1}<%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: <> LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.next = add nsw i32 %iv, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: [1,0) S: [1,0) Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: [1,-2147483648) S: [1,-2147483648) Exits: <> LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %sel = select i1 %cmp, i32 10, i32 20 ; CHECK-NEXT: --> %sel U: [0,31) S: [0,31) Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: %cond = call i1 @cond() diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -931,6 +931,42 @@ } } +TEST(ValueTracking, propagatesPoisonWithValAssumedPoison) { + std::string AssemblyStr = "define void @f(i1 %cond, i32 %x, i32 %y) {\n"; + // (propagates poison?, IR instruction, operand index) + SmallVector, 32> Data = { + {true, "select i1 %cond, i32 %x, i32 %y", 0}, + {false, "select i1 %cond, i32 %x, i32 %y", 1}, + {false, "select i1 %cond, i32 %x, i32 %y", 2}, + }; + std::string AsmTail = " ret void\n}"; + + for (auto &Itm : Data) + AssemblyStr += std::get<1>(Itm) + "\n"; + AssemblyStr += AsmTail; + + LLVMContext Context; + SMDiagnostic Error; + auto M = parseAssemblyString(AssemblyStr, Error, Context); + assert(M && "Bad assembly?"); + + auto *F = M->getFunction("f"); + assert(F && "Bad assembly?"); + + auto &BB = F->getEntryBlock(); + + int Index = 0; + for (auto &I : BB) { + if (isa(&I)) + break; + EXPECT_EQ(propagatesPoison(cast(&I), + &I.getOperandUse(std::get<2>(Data[Index]))), + std::get<0>(Data[Index])) + << "Incorrect answer at instruction " << Index << " = " << I; + Index++; + } +} + TEST_F(ValueTrackingTest, programUndefinedIfPoison) { parseAssembly("declare i32 @any_num()" "define void @test(i32 %mask) {\n"