diff --git a/llvm/include/llvm/IR/Constant.h b/llvm/include/llvm/IR/Constant.h --- a/llvm/include/llvm/IR/Constant.h +++ b/llvm/include/llvm/IR/Constant.h @@ -101,11 +101,15 @@ /// lane, the constants still match. bool isElementWiseEqual(Value *Y) const; - /// Return true if this is a vector constant that includes any undefined - /// elements. Since it is impossible to inspect a scalable vector element- - /// wise at compile time, this function returns true only if the entire - /// vector is undef - bool containsUndefElement() const; + /// Return true if this is a vector constant that includes any undef or + /// poison elements. Since it is impossible to inspect a scalable vector + /// element- wise at compile time, this function returns true only if the + /// entire vector is undef or poison. + bool containsUndefOrPoisonElement() const; + + /// Return true if this is a vector constant that includes any poison + /// elements. + bool containsPoisonElement() const; /// Return true if this is a fixed width vector constant that includes /// any constant expressions. diff --git a/llvm/lib/Analysis/ValueTracking.cpp b/llvm/lib/Analysis/ValueTracking.cpp --- a/llvm/lib/Analysis/ValueTracking.cpp +++ b/llvm/lib/Analysis/ValueTracking.cpp @@ -4895,7 +4895,8 @@ return true; if (C->getType()->isVectorTy() && !isa(C)) - return (PoisonOnly || !C->containsUndefElement()) && + return (PoisonOnly ? !C->containsPoisonElement() + : !C->containsUndefOrPoisonElement()) && !C->containsConstantExpression(); } @@ -5636,10 +5637,10 @@ // elements because those can not be back-propagated for analysis. Value *OutputZeroVal = nullptr; if (match(TrueVal, m_AnyZeroFP()) && !match(FalseVal, m_AnyZeroFP()) && - !cast(TrueVal)->containsUndefElement()) + !cast(TrueVal)->containsUndefOrPoisonElement()) OutputZeroVal = TrueVal; else if (match(FalseVal, m_AnyZeroFP()) && !match(TrueVal, m_AnyZeroFP()) && - !cast(FalseVal)->containsUndefElement()) + !cast(FalseVal)->containsUndefOrPoisonElement()) OutputZeroVal = FalseVal; if (OutputZeroVal) { diff --git a/llvm/lib/IR/ConstantFold.cpp b/llvm/lib/IR/ConstantFold.cpp --- a/llvm/lib/IR/ConstantFold.cpp +++ b/llvm/lib/IR/ConstantFold.cpp @@ -811,7 +811,7 @@ return true; if (C->getType()->isVectorTy()) - return !C->containsUndefElement() && !C->containsConstantExpression(); + return !C->containsPoisonElement() && !C->containsConstantExpression(); // TODO: Recursively analyze aggregates or other constants. return false; diff --git a/llvm/lib/IR/Constants.cpp b/llvm/lib/IR/Constants.cpp --- a/llvm/lib/IR/Constants.cpp +++ b/llvm/lib/IR/Constants.cpp @@ -304,31 +304,42 @@ return isa(CmpEq) || match(CmpEq, m_One()); } -bool Constant::containsUndefElement() const { - if (auto *VTy = dyn_cast(getType())) { - if (isa(this)) +static bool +containsUndefinedElement(const Constant *C, + function_ref HasFn) { + if (auto *VTy = dyn_cast(C->getType())) { + if (HasFn(C)) return true; - if (isa(this)) + if (isa(C)) return false; - if (isa(getType())) + if (isa(C->getType())) return false; for (unsigned i = 0, e = cast(VTy)->getNumElements(); i != e; ++i) - if (isa(getAggregateElement(i))) + if (HasFn(C->getAggregateElement(i))) return true; } return false; } +bool Constant::containsUndefOrPoisonElement() const { + return containsUndefinedElement( + this, [&](const auto *C) { return isa(C); }); +} + +bool Constant::containsPoisonElement() const { + return containsUndefinedElement( + this, [&](const auto *C) { return isa(C); }); +} + bool Constant::containsConstantExpression() const { if (auto *VTy = dyn_cast(getType())) { for (unsigned i = 0, e = VTy->getNumElements(); i != e; ++i) if (isa(getAggregateElement(i))) return true; } - return false; } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp @@ -3370,7 +3370,7 @@ Type *OpTy = M->getType(); auto *VecC = dyn_cast(M); auto *OpVTy = dyn_cast(OpTy); - if (OpVTy && VecC && VecC->containsUndefElement()) { + if (OpVTy && VecC && VecC->containsUndefOrPoisonElement()) { Constant *SafeReplacementConstant = nullptr; for (unsigned i = 0, e = OpVTy->getNumElements(); i != e; ++i) { if (!isa(VecC->getAggregateElement(i))) { @@ -5259,7 +5259,8 @@ // It may not be safe to change a compare predicate in the presence of // undefined elements, so replace those elements with the first safe constant // that we found. - if (C->containsUndefElement()) { + // TODO: in case of poison, it is safe; let's replace undefs only. + if (C->containsUndefOrPoisonElement()) { assert(SafeReplacementConstant && "Replacement constant not set"); C = Constant::replaceUndefsWith(C, SafeReplacementConstant); } diff --git a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineNegator.cpp @@ -239,8 +239,8 @@ // While this is normally not behind a use-check, // let's consider division to be special since it's costly. if (auto *Op1C = dyn_cast(I->getOperand(1))) { - if (!Op1C->containsUndefElement() && Op1C->isNotMinSignedValue() && - Op1C->isNotOneValue()) { + if (!Op1C->containsUndefOrPoisonElement() && + Op1C->isNotMinSignedValue() && Op1C->isNotOneValue()) { Value *BO = Builder.CreateSDiv(I->getOperand(0), ConstantExpr::getNeg(Op1C), I->getName() + ".neg"); diff --git a/llvm/unittests/Analysis/ValueTrackingTest.cpp b/llvm/unittests/Analysis/ValueTrackingTest.cpp --- a/llvm/unittests/Analysis/ValueTrackingTest.cpp +++ b/llvm/unittests/Analysis/ValueTrackingTest.cpp @@ -888,6 +888,30 @@ EXPECT_EQ(isGuaranteedNotToBeUndefOrPoison(PoisonValue::get(IntegerType::get(Context, 8))), false); EXPECT_EQ(isGuaranteedNotToBePoison(UndefValue::get(IntegerType::get(Context, 8))), true); EXPECT_EQ(isGuaranteedNotToBePoison(PoisonValue::get(IntegerType::get(Context, 8))), false); + + Type *Int32Ty = Type::getInt32Ty(Context); + Constant *CU = UndefValue::get(Int32Ty); + Constant *CP = PoisonValue::get(Int32Ty); + Constant *C1 = ConstantInt::get(Int32Ty, 1); + Constant *C2 = ConstantInt::get(Int32Ty, 2); + + { + Constant *V1 = ConstantVector::get({C1, C2}); + EXPECT_TRUE(isGuaranteedNotToBeUndefOrPoison(V1)); + EXPECT_TRUE(isGuaranteedNotToBePoison(V1)); + } + + { + Constant *V2 = ConstantVector::get({C1, CU}); + EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V2)); + EXPECT_TRUE(isGuaranteedNotToBePoison(V2)); + } + + { + Constant *V3 = ConstantVector::get({C1, CP}); + EXPECT_FALSE(isGuaranteedNotToBeUndefOrPoison(V3)); + EXPECT_FALSE(isGuaranteedNotToBePoison(V3)); + } } TEST_F(ValueTrackingTest, isGuaranteedNotToBeUndefOrPoison_assume) { diff --git a/llvm/unittests/IR/ConstantsTest.cpp b/llvm/unittests/IR/ConstantsTest.cpp --- a/llvm/unittests/IR/ConstantsTest.cpp +++ b/llvm/unittests/IR/ConstantsTest.cpp @@ -585,6 +585,43 @@ Instruction::And, TheConstantExpr, TheConstant)->isNullValue()); } +// Check that containsUndefOrPoisonElement and containsPoisonElement is working +// great + +TEST(ConstantsTest, containsUndefElemTest) { + LLVMContext Context; + + Type *Int32Ty = Type::getInt32Ty(Context); + Constant *CU = UndefValue::get(Int32Ty); + Constant *CP = PoisonValue::get(Int32Ty); + Constant *C1 = ConstantInt::get(Int32Ty, 1); + Constant *C2 = ConstantInt::get(Int32Ty, 2); + + { + Constant *V1 = ConstantVector::get({C1, C2}); + EXPECT_FALSE(V1->containsUndefOrPoisonElement()); + EXPECT_FALSE(V1->containsPoisonElement()); + } + + { + Constant *V2 = ConstantVector::get({C1, CU}); + EXPECT_TRUE(V2->containsUndefOrPoisonElement()); + EXPECT_FALSE(V2->containsPoisonElement()); + } + + { + Constant *V3 = ConstantVector::get({C1, CP}); + EXPECT_TRUE(V3->containsUndefOrPoisonElement()); + EXPECT_TRUE(V3->containsPoisonElement()); + } + + { + Constant *V4 = ConstantVector::get({CU, CP}); + EXPECT_TRUE(V4->containsUndefOrPoisonElement()); + EXPECT_TRUE(V4->containsPoisonElement()); + } +} + // Check that undefined elements in vector constants are matched // correctly for both integer and floating-point types. Just don't // crash on vectors of pointers (could be handled?).