diff --git a/llvm/include/llvm/Analysis/InstructionSimplify.h b/llvm/include/llvm/Analysis/InstructionSimplify.h --- a/llvm/include/llvm/Analysis/InstructionSimplify.h +++ b/llvm/include/llvm/Analysis/InstructionSimplify.h @@ -37,6 +37,7 @@ #include "llvm/IR/Instruction.h" #include "llvm/IR/Operator.h" +#include "llvm/IR/PatternMatch.h" namespace llvm { @@ -133,7 +134,9 @@ bool isUndefValue(Value *V) const { if (!CanUseUndef) return false; - return isa(V); + + using namespace PatternMatch; + return match(V, m_Undef()); } }; diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -88,8 +88,52 @@ /// Matches any compare instruction and ignore it. inline class_match m_Cmp() { return class_match(); } -/// Match an arbitrary undef constant. -inline class_match m_Undef() { return class_match(); } +struct undef_match { + static bool check(const Value *V) { + if (isa(V)) + return true; + + const auto *CA = dyn_cast(V); + if (!CA) + return false; + + SmallPtrSet Seen; + SmallVector Worklist; + + // Either UndefValue, PoisonValue, or an aggregate that only contains + // these is accepted by matcher. + // CheckValue returns false if CA cannot satisfy this constraint. + auto CheckValue = [&](const ConstantAggregate *CA) { + for (const Value *Op : CA->operand_values()) { + if (isa(Op)) + continue; + + const auto *CA = dyn_cast(Op); + if (!CA) + return false; + if (Seen.insert(CA).second) + Worklist.emplace_back(CA); + } + + return true; + }; + + if (!CheckValue(CA)) + return false; + + while (!Worklist.empty()) { + if (!CheckValue(Worklist.pop_back_val())) + return false; + } + return true; + } + template bool match(ITy *V) { return check(V); } +}; + +/// Match an arbitrary undef constant. This matches poison as well. +/// If this is an aggregate and contains a non-aggregate element that is +/// neither undef nor poison, the aggregate is not matched. +inline auto m_Undef() { return undef_match(); } /// Match an arbitrary poison constant. inline class_match m_Poison() { return class_match(); } diff --git a/llvm/test/Transforms/InstSimplify/icmp-constant.ll b/llvm/test/Transforms/InstSimplify/icmp-constant.ll --- a/llvm/test/Transforms/InstSimplify/icmp-constant.ll +++ b/llvm/test/Transforms/InstSimplify/icmp-constant.ll @@ -1069,8 +1069,7 @@ define <2 x i1> @heterogeneous_constvector(<2 x i8> %x) { ; CHECK-LABEL: @heterogeneous_constvector( -; CHECK-NEXT: [[C:%.*]] = icmp ult <2 x i8> [[X:%.*]], -; CHECK-NEXT: ret <2 x i1> [[C]] +; CHECK-NEXT: ret <2 x i1> zeroinitializer ; %c = icmp ult <2 x i8> %x, ret <2 x i1> %c diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -1022,6 +1022,37 @@ EXPECT_TRUE(A == Val); } +TEST_F(PatternMatchTest, UndefPoisonMix) { + Type *ScalarTy = IRB.getInt8Ty(); + ArrayType *ArrTy = ArrayType::get(ScalarTy, 2); + StructType *StTy = StructType::get(ScalarTy, ScalarTy); + StructType *StTy2 = StructType::get(ScalarTy, StTy); + StructType *StTy3 = StructType::get(StTy, ScalarTy); + Constant *Zero = ConstantInt::getNullValue(ScalarTy); + UndefValue *U = UndefValue::get(ScalarTy); + UndefValue *P = PoisonValue::get(ScalarTy); + + EXPECT_TRUE(match(ConstantVector::get({U, P}), m_Undef())); + EXPECT_TRUE(match(ConstantVector::get({P, U}), m_Undef())); + + EXPECT_TRUE(match(ConstantArray::get(ArrTy, {U, P}), m_Undef())); + EXPECT_TRUE(match(ConstantArray::get(ArrTy, {P, U}), m_Undef())); + + auto *UP = ConstantStruct::get(StTy, {U, P}); + EXPECT_TRUE(match(ConstantStruct::get(StTy2, {U, UP}), m_Undef())); + EXPECT_TRUE(match(ConstantStruct::get(StTy2, {P, UP}), m_Undef())); + EXPECT_TRUE(match(ConstantStruct::get(StTy3, {UP, U}), m_Undef())); + EXPECT_TRUE(match(ConstantStruct::get(StTy3, {UP, P}), m_Undef())); + + EXPECT_FALSE(match(ConstantStruct::get(StTy, {U, Zero}), m_Undef())); + EXPECT_FALSE(match(ConstantStruct::get(StTy, {Zero, U}), m_Undef())); + EXPECT_FALSE(match(ConstantStruct::get(StTy, {P, Zero}), m_Undef())); + EXPECT_FALSE(match(ConstantStruct::get(StTy, {Zero, P}), m_Undef())); + + EXPECT_FALSE(match(ConstantStruct::get(StTy2, {Zero, UP}), m_Undef())); + EXPECT_FALSE(match(ConstantStruct::get(StTy3, {UP, Zero}), m_Undef())); +} + TEST_F(PatternMatchTest, VectorUndefInt) { Type *ScalarTy = IRB.getInt8Ty(); Type *VectorTy = FixedVectorType::get(ScalarTy, 4);