diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -153,8 +153,10 @@ struct apint_match { const APInt *&Res; + bool AllowUndef; - apint_match(const APInt *&R) : Res(R) {} + apint_match(const APInt *&Res, bool AllowUndef) + : Res(Res), AllowUndef(AllowUndef) {} template bool match(ITy *V) { if (auto *CI = dyn_cast(V)) { @@ -163,7 +165,8 @@ } if (V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) - if (auto *CI = dyn_cast_or_null(C->getSplatValue())) { + if (auto *CI = dyn_cast_or_null( + C->getSplatValue(AllowUndef))) { Res = &CI->getValue(); return true; } @@ -175,7 +178,11 @@ // function for both apint/apfloat. struct apfloat_match { const APFloat *&Res; - apfloat_match(const APFloat *&R) : Res(R) {} + bool AllowUndef; + + apfloat_match(const APFloat *&Res, bool AllowUndef) + : Res(Res), AllowUndef(AllowUndef) {} + template bool match(ITy *V) { if (auto *CI = dyn_cast(V)) { Res = &CI->getValueAPF(); @@ -183,7 +190,8 @@ } if (V->getType()->isVectorTy()) if (const auto *C = dyn_cast(V)) - if (auto *CI = dyn_cast_or_null(C->getSplatValue())) { + if (auto *CI = dyn_cast_or_null( + C->getSplatValue(AllowUndef))) { Res = &CI->getValueAPF(); return true; } @@ -193,11 +201,37 @@ /// Match a ConstantInt or splatted ConstantVector, binding the /// specified pointer to the contained APInt. -inline apint_match m_APInt(const APInt *&Res) { return Res; } +inline apint_match m_APInt(const APInt *&Res) { + // Forbid undefs by default to maintain previous behavior. + return apint_match(Res, /* AllowUndef */ false); +} + +/// Match APInt while allowing undefs in splat vector constants. +inline apint_match m_APIntAllowUndef(const APInt *&Res) { + return apint_match(Res, /* AllowUndef */ true); +} + +/// Match APInt while forbidding undefs in splat vector constants. +inline apint_match m_APIntForbidUndef(const APInt *&Res) { + return apint_match(Res, /* AllowUndef */ false); +} /// Match a ConstantFP or splatted ConstantVector, binding the /// specified pointer to the contained APFloat. -inline apfloat_match m_APFloat(const APFloat *&Res) { return Res; } +inline apfloat_match m_APFloat(const APFloat *&Res) { + // Forbid undefs by default to maintain previous behavior. + return apfloat_match(Res, /* AllowUndef */ false); +} + +/// Match APFloat while allowing undefs in splat vector constants. +inline apfloat_match m_APFloatAllowUndef(const APFloat *&Res) { + return apfloat_match(Res, /* AllowUndef */ true); +} + +/// Match APFloat while forbidding undefs in splat vector constants. +inline apfloat_match m_APFloatForbidUndef(const APFloat *&Res) { + return apfloat_match(Res, /* AllowUndef */ false); +} template struct constantint_match { template bool match(ITy *V) { diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -1045,6 +1045,43 @@ EXPECT_TRUE(match(ScalarZero, m_Zero())); EXPECT_TRUE(match(VectorZero, m_Zero())); EXPECT_TRUE(match(VectorZeroUndef, m_Zero())); + + const APInt *C; + // Regardless of whether undefs are allowed, + // a fully undef constant does not match. + EXPECT_FALSE(match(ScalarUndef, m_APInt(C))); + EXPECT_FALSE(match(ScalarUndef, m_APIntForbidUndef(C))); + EXPECT_FALSE(match(ScalarUndef, m_APIntAllowUndef(C))); + EXPECT_FALSE(match(VectorUndef, m_APInt(C))); + EXPECT_FALSE(match(VectorUndef, m_APIntForbidUndef(C))); + EXPECT_FALSE(match(VectorUndef, m_APIntAllowUndef(C))); + + // We can always match simple constants and simple splats. + C = nullptr; + EXPECT_TRUE(match(ScalarZero, m_APInt(C))); + EXPECT_TRUE(C->isNullValue()); + C = nullptr; + EXPECT_TRUE(match(ScalarZero, m_APIntForbidUndef(C))); + EXPECT_TRUE(C->isNullValue()); + C = nullptr; + EXPECT_TRUE(match(ScalarZero, m_APIntAllowUndef(C))); + EXPECT_TRUE(C->isNullValue()); + C = nullptr; + EXPECT_TRUE(match(VectorZero, m_APInt(C))); + EXPECT_TRUE(C->isNullValue()); + C = nullptr; + EXPECT_TRUE(match(VectorZero, m_APIntForbidUndef(C))); + EXPECT_TRUE(C->isNullValue()); + C = nullptr; + EXPECT_TRUE(match(VectorZero, m_APIntAllowUndef(C))); + EXPECT_TRUE(C->isNullValue()); + + // Whether splats with undef can be matched depends on the matcher. + EXPECT_FALSE(match(VectorZeroUndef, m_APInt(C))); + EXPECT_FALSE(match(VectorZeroUndef, m_APIntForbidUndef(C))); + C = nullptr; + EXPECT_TRUE(match(VectorZeroUndef, m_APIntAllowUndef(C))); + EXPECT_TRUE(C->isNullValue()); } TEST_F(PatternMatchTest, VectorUndefFloat) { @@ -1073,6 +1110,43 @@ EXPECT_TRUE(match(ScalarZero, m_AnyZeroFP())); EXPECT_TRUE(match(VectorZero, m_AnyZeroFP())); EXPECT_TRUE(match(VectorZeroUndef, m_AnyZeroFP())); + + const APFloat *C; + // Regardless of whether undefs are allowed, + // a fully undef constant does not match. + EXPECT_FALSE(match(ScalarUndef, m_APFloat(C))); + EXPECT_FALSE(match(ScalarUndef, m_APFloatForbidUndef(C))); + EXPECT_FALSE(match(ScalarUndef, m_APFloatAllowUndef(C))); + EXPECT_FALSE(match(VectorUndef, m_APFloat(C))); + EXPECT_FALSE(match(VectorUndef, m_APFloatForbidUndef(C))); + EXPECT_FALSE(match(VectorUndef, m_APFloatAllowUndef(C))); + + // We can always match simple constants and simple splats. + C = nullptr; + EXPECT_TRUE(match(ScalarZero, m_APFloat(C))); + EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(ScalarZero, m_APFloatForbidUndef(C))); + EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(ScalarZero, m_APFloatAllowUndef(C))); + EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(VectorZero, m_APFloat(C))); + EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(VectorZero, m_APFloatForbidUndef(C))); + EXPECT_TRUE(C->isZero()); + C = nullptr; + EXPECT_TRUE(match(VectorZero, m_APFloatAllowUndef(C))); + EXPECT_TRUE(C->isZero()); + + // Whether splats with undef can be matched depends on the matcher. + EXPECT_FALSE(match(VectorZeroUndef, m_APFloat(C))); + EXPECT_FALSE(match(VectorZeroUndef, m_APFloatForbidUndef(C))); + C = nullptr; + EXPECT_TRUE(match(VectorZeroUndef, m_APFloatAllowUndef(C))); + EXPECT_TRUE(C->isZero()); } TEST_F(PatternMatchTest, FloatingPointFNeg) {