diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -2285,6 +2285,31 @@ return m_c_Xor(V, m_AllOnes()); } +template struct NotForbidUndef_match { + ValTy Val; + NotForbidUndef_match(const ValTy &V) : Val(V) {} + + template bool match(OpTy *V) { + // We do not use m_c_Xor because that could match an arbitrary APInt that is + // not -1 as C and then fail to match the other operand if it is -1. + // This code should still work even when both operands are constants. + Value *X; + const APInt *C; + if (m_Xor(m_Value(X), m_APIntForbidUndef(C)).match(V) && C->isAllOnes()) + return Val.match(X); + if (m_Xor(m_APIntForbidUndef(C), m_Value(X)).match(V) && C->isAllOnes()) + return Val.match(X); + return false; + } +}; + +/// Matches a bitwise 'not' as 'xor V, -1' or 'xor -1, V'. For vectors, the +/// constant value must be composed of only -1 scalar elements. +template +inline NotForbidUndef_match m_NotForbidUndef(const ValTy &V) { + return NotForbidUndef_match(V); +} + /// Matches an SMin with LHS and RHS in either order. template inline MaxMin_match diff --git a/llvm/lib/Analysis/InstructionSimplify.cpp b/llvm/lib/Analysis/InstructionSimplify.cpp --- a/llvm/lib/Analysis/InstructionSimplify.cpp +++ b/llvm/lib/Analysis/InstructionSimplify.cpp @@ -2408,11 +2408,10 @@ // The 'not' op must contain a complete -1 operand (no undef elements for // vector) for the transform to be safe. Value *NotA; - const APInt *C; - if (match(X, m_c_Or(m_CombineAnd(m_Xor(m_Value(A), m_APIntForbidUndef(C)), - m_Value(NotA)), - m_Value(B))) && - match(Y, m_c_And(m_Specific(A), m_Specific(B))) && C->isAllOnes()) + if (match(X, + m_c_Or(m_CombineAnd(m_NotForbidUndef(m_Value(A)), m_Value(NotA)), + m_Value(B))) && + match(Y, m_c_And(m_Specific(A), m_Specific(B)))) return NotA; return nullptr; diff --git a/llvm/unittests/IR/PatternMatch.cpp b/llvm/unittests/IR/PatternMatch.cpp --- a/llvm/unittests/IR/PatternMatch.cpp +++ b/llvm/unittests/IR/PatternMatch.cpp @@ -1727,6 +1727,41 @@ EXPECT_FALSE(match(PtrToInt2, m_VScale(DL))); } +TEST_F(PatternMatchTest, NotForbidUndef) { + Type *ScalarTy = IRB.getInt8Ty(); + Type *VectorTy = FixedVectorType::get(ScalarTy, 3); + Constant *ScalarUndef = UndefValue::get(ScalarTy); + Constant *ScalarOnes = Constant::getAllOnesValue(ScalarTy); + Constant *VectorZero = Constant::getNullValue(VectorTy); + Constant *VectorOnes = Constant::getAllOnesValue(VectorTy); + + SmallVector MixedElems; + MixedElems.push_back(ScalarOnes); + MixedElems.push_back(ScalarOnes); + MixedElems.push_back(ScalarUndef); + Constant *VectorMixed = ConstantVector::get(MixedElems); + + Value *Not = IRB.CreateXor(VectorZero, VectorOnes); + Value *X; + EXPECT_TRUE(match(Not, m_Not(m_Value()))); + EXPECT_TRUE(match(Not, m_NotForbidUndef(m_Value(X)))); + EXPECT_TRUE(match(X, m_Zero())); + + Value *NotCommute = IRB.CreateXor(VectorOnes, VectorZero); + Value *Y; + EXPECT_TRUE(match(NotCommute, m_Not(m_Value()))); + EXPECT_TRUE(match(NotCommute, m_NotForbidUndef(m_Value(Y)))); + EXPECT_TRUE(match(Y, m_Zero())); + + Value *NotWithUndefs = IRB.CreateXor(VectorZero, VectorMixed); + EXPECT_TRUE(match(NotWithUndefs, m_Not(m_Value()))); + EXPECT_FALSE(match(NotWithUndefs, m_NotForbidUndef(m_Value()))); + + Value *NotWithUndefsCommute = IRB.CreateXor(VectorMixed, VectorZero); + EXPECT_TRUE(match(NotWithUndefsCommute, m_Not(m_Value()))); + EXPECT_FALSE(match(NotWithUndefsCommute, m_NotForbidUndef(m_Value(X)))); +} + template struct MutableConstTest : PatternMatchTest { }; typedef ::testing::Types,