diff --git a/llvm/include/llvm/Analysis/ValueTracking.h b/llvm/include/llvm/Analysis/ValueTracking.h --- a/llvm/include/llvm/Analysis/ValueTracking.h +++ b/llvm/include/llvm/Analysis/ValueTracking.h @@ -593,15 +593,17 @@ bool isGuaranteedToExecuteForEveryIteration(const Instruction *I, const Loop *L); - /// Return true if I yields poison or raises UB if any of its operands is - /// poison. - /// Formally, given I = `r = op v1 v2 .. vN`, propagatesPoison returns true - /// if, for all i, r is evaluated to poison or op raises UB if vi = poison. - /// If vi is a vector or an aggregate and r is a single value, any poison - /// element in vi should make r poison or raise UB. + /// Return true if \p PoisonOp's user yields poison or raises UB if its + /// operand \p PoisonOp is poison. + /// + /// If \p PoisonOp is a vector or an aggregate and the operation's result is a + /// single value, any poison element in /p PoisonOp should make the result + /// poison or raise UB. + /// + /// /// To filter out operands that raise UB on poison, you can use /// getGuaranteedNonPoisonOp. - bool propagatesPoison(const Operator *I); + bool propagatesPoison(const Use &PoisonOp); /// Insert operands of I into Ops such that I will trigger undefined behavior /// if I is executed and that operand has a poison value. diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp --- a/llvm/lib/Analysis/ScalarEvolution.cpp +++ b/llvm/lib/Analysis/ScalarEvolution.cpp @@ -6722,8 +6722,9 @@ while (!PoisonStack.empty() && !LatchControlDependentOnPoison) { const Instruction *Poison = PoisonStack.pop_back_val(); - for (auto *PoisonUser : Poison->users()) { - if (propagatesPoison(cast(PoisonUser))) { + for (const Use &U : Poison->uses()) { + const User *PoisonUser = U.getUser(); + if (propagatesPoison(U)) { if (Pushed.insert(cast(PoisonUser)).second) PoisonStack.push_back(cast(PoisonUser)); } else if (auto *BI = dyn_cast(PoisonUser)) { diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -5088,15 +5088,12 @@ return false; if (const auto *I = dyn_cast(V)) { - if (propagatesPoison(cast(I))) - return any_of(I->operands(), [=](const Value *Op) { - return directlyImpliesPoison(ValAssumedPoison, Op, Depth + 1); - }); + if (isa(I) && any_of(I->operands(), [=](const Use &Op) { + return propagatesPoison(Op) && + directlyImpliesPoison(ValAssumedPoison, Op, Depth + 1); + })) + return true; - // 'select ValAssumedPoison, _, _' is poison. - if (const auto *SI = dyn_cast(I)) - return directlyImpliesPoison(ValAssumedPoison, SI->getCondition(), - Depth + 1); // V = extractvalue V0, idx // V2 = extractvalue V0, idx2 // V0's elements are all poison or not. (e.g., add_with_overflow) @@ -5252,7 +5249,9 @@ else if (PoisonOnly && isa(Cond)) { // For poison, we can analyze further auto *Opr = cast(Cond); - if (propagatesPoison(Opr) && is_contained(Opr->operand_values(), V)) + if (any_of(Opr->operands(), [V, Opr](const Use &U) { + return V == U && propagatesPoison(U); + })) return true; } } @@ -5375,13 +5374,15 @@ llvm_unreachable("Instruction not contained in its own parent basic block."); } -bool llvm::propagatesPoison(const Operator *I) { +bool llvm::propagatesPoison(const Use &PoisonOp) { + const Operator *I = cast(PoisonOp.getUser()); switch (I->getOpcode()) { case Instruction::Freeze: - case Instruction::Select: case Instruction::PHI: case Instruction::Invoke: return false; + case Instruction::Select: + return PoisonOp == I->getOperand(0); case Instruction::Call: if (auto *II = dyn_cast(I)) { switch (II->getIntrinsicID()) { @@ -5540,11 +5541,11 @@ SmallSet Visited; YieldsPoison.insert(V); - auto Propagate = [&](const User *User) { - if (propagatesPoison(cast(User))) - YieldsPoison.insert(User); + auto Propagate = [&](const Use &U) { + if (propagatesPoison(U)) + YieldsPoison.insert(U.getUser()); }; - for_each(V->users(), Propagate); + for_each(V->uses(), Propagate); Visited.insert(BB); while (true) { @@ -5560,7 +5561,7 @@ // Mark poison that propagates from I through uses of I. if (YieldsPoison.count(&I)) - for_each(I.users(), Propagate); + for_each(I.uses(), Propagate); } BB = BB->getSingleSuccessor(); diff --git a/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp b/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp --- a/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp +++ b/llvm/lib/Transforms/Instrumentation/PoisonChecking.cpp @@ -295,7 +295,10 @@ } SmallVector Checks; - if (propagatesPoison(cast(&I))) + if (any_of(I.operands(), [&ValToPoison](const Use &U) { + return ValToPoison.find(U) != ValToPoison.end() && + propagatesPoison(U); + })) for (Value *V : I.operands()) Checks.push_back(getPoisonFor(ValToPoison, V)); diff --git a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp --- a/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp +++ b/llvm/lib/Transforms/Scalar/IndVarSimplify.cpp @@ -786,7 +786,9 @@ // If we can't analyze propagation through this instruction, just skip it // and transitive users. Safe as false is a conservative result. - if (!propagatesPoison(cast(I)) && I != Root) + if (I != Root && !any_of(I->operands(), [&KnownPoison](const Use &U) { + return KnownPoison.contains(U) && propagatesPoison(U); + })) continue; if (KnownPoison.insert(I).second) diff --git a/llvm/test/Analysis/ScalarEvolution/nsw.ll b/llvm/test/Analysis/ScalarEvolution/nsw.ll --- a/llvm/test/Analysis/ScalarEvolution/nsw.ll +++ b/llvm/test/Analysis/ScalarEvolution/nsw.ll @@ -400,7 +400,7 @@ ; CHECK-NEXT: %iv = phi i32 [ %iv.next, %loop ], [ 0, %entry ] ; CHECK-NEXT: --> {0,+,1}<%loop> U: [0,-2147483648) S: [0,-2147483648) Exits: <> LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %iv.next = add nsw i32 %iv, 1 -; CHECK-NEXT: --> {1,+,1}<%loop> U: [1,0) S: [1,0) Exits: <> LoopDispositions: { %loop: Computable } +; CHECK-NEXT: --> {1,+,1}<%loop> U: [1,-2147483648) S: [1,-2147483648) Exits: <> LoopDispositions: { %loop: Computable } ; CHECK-NEXT: %sel = select i1 %cmp, i32 10, i32 20 ; CHECK-NEXT: --> %sel U: [0,31) S: [0,31) Exits: <> LoopDispositions: { %loop: Variant } ; CHECK-NEXT: %cond = call i1 @cond() diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -847,68 +847,89 @@ "i1 %cond, i8* %p) {\n"; std::string AsmTail = " ret void\n}"; // (propagates poison?, IR instruction) - SmallVector, 32> Data = { - {true, "add i32 %x, %y"}, - {true, "add nsw nuw i32 %x, %y"}, - {true, "ashr i32 %x, %y"}, - {true, "lshr exact i32 %x, 31"}, - {true, "fadd float %fx, %fy"}, - {true, "fsub float %fx, %fy"}, - {true, "fmul float %fx, %fy"}, - {true, "fdiv float %fx, %fy"}, - {true, "frem float %fx, %fy"}, - {true, "fneg float %fx"}, - {true, "fcmp oeq float %fx, %fy"}, - {true, "icmp eq i32 %x, %y"}, - {true, "getelementptr i8, i8* %p, i32 %x"}, - {true, "getelementptr inbounds i8, i8* %p, i32 %x"}, - {true, "bitcast float %fx to i32"}, - {false, "select i1 %cond, i32 %x, i32 %y"}, - {false, "freeze i32 %x"}, - {true, "udiv i32 %x, %y"}, - {true, "urem i32 %x, %y"}, - {true, "sdiv exact i32 %x, %y"}, - {true, "srem i32 %x, %y"}, - {false, "call i32 @g(i32 %x)"}, - {true, "call {i32, i1} @llvm.sadd.with.overflow.i32(i32 %x, i32 %y)"}, - {true, "call {i32, i1} @llvm.ssub.with.overflow.i32(i32 %x, i32 %y)"}, - {true, "call {i32, i1} @llvm.smul.with.overflow.i32(i32 %x, i32 %y)"}, - {true, "call {i32, i1} @llvm.uadd.with.overflow.i32(i32 %x, i32 %y)"}, - {true, "call {i32, i1} @llvm.usub.with.overflow.i32(i32 %x, i32 %y)"}, - {true, "call {i32, i1} @llvm.umul.with.overflow.i32(i32 %x, i32 %y)"}, - {false, "call float @llvm.sqrt.f32(float %fx)"}, - {false, "call float @llvm.powi.f32.i32(float %fx, i32 %x)"}, - {false, "call float @llvm.sin.f32(float %fx)"}, - {false, "call float @llvm.cos.f32(float %fx)"}, - {false, "call float @llvm.pow.f32(float %fx, float %fy)"}, - {false, "call float @llvm.exp.f32(float %fx)"}, - {false, "call float @llvm.exp2.f32(float %fx)"}, - {false, "call float @llvm.log.f32(float %fx)"}, - {false, "call float @llvm.log10.f32(float %fx)"}, - {false, "call float @llvm.log2.f32(float %fx)"}, - {false, "call float @llvm.fma.f32(float %fx, float %fx, float %fy)"}, - {false, "call float @llvm.fabs.f32(float %fx)"}, - {false, "call float @llvm.minnum.f32(float %fx, float %fy)"}, - {false, "call float @llvm.maxnum.f32(float %fx, float %fy)"}, - {false, "call float @llvm.minimum.f32(float %fx, float %fy)"}, - {false, "call float @llvm.maximum.f32(float %fx, float %fy)"}, - {false, "call float @llvm.copysign.f32(float %fx, float %fy)"}, - {false, "call float @llvm.floor.f32(float %fx)"}, - {false, "call float @llvm.ceil.f32(float %fx)"}, - {false, "call float @llvm.trunc.f32(float %fx)"}, - {false, "call float @llvm.rint.f32(float %fx)"}, - {false, "call float @llvm.nearbyint.f32(float %fx)"}, - {false, "call float @llvm.round.f32(float %fx)"}, - {false, "call float @llvm.roundeven.f32(float %fx)"}, - {false, "call i32 @llvm.lround.f32(float %fx)"}, - {false, "call i64 @llvm.llround.f32(float %fx)"}, - {false, "call i32 @llvm.lrint.f32(float %fx)"}, - {false, "call i64 @llvm.llrint.f32(float %fx)"}, - {false, "call float @llvm.fmuladd.f32(float %fx, float %fx, float %fy)"}}; + SmallVector, 32> Data = { + {true, "add i32 %x, %y", 0}, + {true, "add i32 %x, %y", 1}, + {true, "add nsw nuw i32 %x, %y", 0}, + {true, "add nsw nuw i32 %x, %y", 1}, + {true, "ashr i32 %x, %y", 0}, + {true, "ashr i32 %x, %y", 1}, + {true, "lshr exact i32 %x, 31", 0}, + {true, "lshr exact i32 %x, 31", 1}, + {true, "fadd float %fx, %fy", 0}, + {true, "fadd float %fx, %fy", 1}, + {true, "fsub float %fx, %fy", 0}, + {true, "fsub float %fx, %fy", 1}, + {true, "fmul float %fx, %fy", 0}, + {true, "fmul float %fx, %fy", 1}, + {true, "fdiv float %fx, %fy", 0}, + {true, "fdiv float %fx, %fy", 1}, + {true, "frem float %fx, %fy", 0}, + {true, "frem float %fx, %fy", 1}, + {true, "fneg float %fx", 0}, + {true, "fcmp oeq float %fx, %fy", 0}, + {true, "fcmp oeq float %fx, %fy", 1}, + {true, "icmp eq i32 %x, %y", 0}, + {true, "icmp eq i32 %x, %y", 1}, + {true, "getelementptr i8, i8* %p, i32 %x", 0}, + {true, "getelementptr i8, i8* %p, i32 %x", 1}, + {true, "getelementptr inbounds i8, i8* %p, i32 %x", 0}, + {true, "getelementptr inbounds i8, i8* %p, i32 %x", 1}, + {true, "bitcast float %fx to i32", 0}, + {true, "select i1 %cond, i32 %x, i32 %y", 0}, + {false, "select i1 %cond, i32 %x, i32 %y", 1}, + {false, "select i1 %cond, i32 %x, i32 %y", 2}, + {false, "freeze i32 %x", 0}, + {true, "udiv i32 %x, %y", 0}, + {true, "udiv i32 %x, %y", 1}, + {true, "urem i32 %x, %y", 0}, + {true, "urem i32 %x, %y", 1}, + {true, "sdiv exact i32 %x, %y", 0}, + {true, "sdiv exact i32 %x, %y", 1}, + {true, "srem i32 %x, %y", 0}, + {true, "srem i32 %x, %y", 1}, + {false, "call i32 @g(i32 %x)", 0}, + {false, "call i32 @g(i32 %x)", 1}, + {true, "call {i32, i1} @llvm.sadd.with.overflow.i32(i32 %x, i32 %y)", 0}, + {true, "call {i32, i1} @llvm.ssub.with.overflow.i32(i32 %x, i32 %y)", 0}, + {true, "call {i32, i1} @llvm.smul.with.overflow.i32(i32 %x, i32 %y)", 0}, + {true, "call {i32, i1} @llvm.uadd.with.overflow.i32(i32 %x, i32 %y)", 0}, + {true, "call {i32, i1} @llvm.usub.with.overflow.i32(i32 %x, i32 %y)", 0}, + {true, "call {i32, i1} @llvm.umul.with.overflow.i32(i32 %x, i32 %y)", 0}, + {false, "call float @llvm.sqrt.f32(float %fx)", 0}, + {false, "call float @llvm.powi.f32.i32(float %fx, i32 %x)", 0}, + {false, "call float @llvm.sin.f32(float %fx)", 0}, + {false, "call float @llvm.cos.f32(float %fx)", 0}, + {false, "call float @llvm.pow.f32(float %fx, float %fy)", 0}, + {false, "call float @llvm.exp.f32(float %fx)", 0}, + {false, "call float @llvm.exp2.f32(float %fx)", 0}, + {false, "call float @llvm.log.f32(float %fx)", 0}, + {false, "call float @llvm.log10.f32(float %fx)", 0}, + {false, "call float @llvm.log2.f32(float %fx)", 0}, + {false, "call float @llvm.fma.f32(float %fx, float %fx, float %fy)", 0}, + {false, "call float @llvm.fabs.f32(float %fx)", 0}, + {false, "call float @llvm.minnum.f32(float %fx, float %fy)", 0}, + {false, "call float @llvm.maxnum.f32(float %fx, float %fy)", 0}, + {false, "call float @llvm.minimum.f32(float %fx, float %fy)", 0}, + {false, "call float @llvm.maximum.f32(float %fx, float %fy)", 0}, + {false, "call float @llvm.copysign.f32(float %fx, float %fy)", 0}, + {false, "call float @llvm.floor.f32(float %fx)", 0}, + {false, "call float @llvm.ceil.f32(float %fx)", 0}, + {false, "call float @llvm.trunc.f32(float %fx)", 0}, + {false, "call float @llvm.rint.f32(float %fx)", 0}, + {false, "call float @llvm.nearbyint.f32(float %fx)", 0}, + {false, "call float @llvm.round.f32(float %fx)", 0}, + {false, "call float @llvm.roundeven.f32(float %fx)", 0}, + {false, "call i32 @llvm.lround.f32(float %fx)", 0}, + {false, "call i64 @llvm.llround.f32(float %fx)", 0}, + {false, "call i32 @llvm.lrint.f32(float %fx)", 0}, + {false, "call i64 @llvm.llrint.f32(float %fx)", 0}, + {false, "call float @llvm.fmuladd.f32(float %fx, float %fx, float %fy)", + 0}}; std::string AssemblyStr = AsmHead; for (auto &Itm : Data) - AssemblyStr += Itm.second + "\n"; + AssemblyStr += std::get<1>(Itm) + "\n"; AssemblyStr += AsmTail; LLVMContext Context; @@ -925,7 +946,9 @@ for (auto &I : BB) { if (isa(&I)) break; - EXPECT_EQ(propagatesPoison(cast(&I)), Data[Index].first) + bool ExpectedVal = std::get<0>(Data[Index]); + unsigned OpIdx = std::get<2>(Data[Index]); + EXPECT_EQ(propagatesPoison(I.getOperandUse(OpIdx)), ExpectedVal) << "Incorrect answer at instruction " << Index << " = " << I; Index++; }