diff --git a/llvm/include/llvm/IR/PatternMatch.h b/llvm/include/llvm/IR/PatternMatch.h --- a/llvm/include/llvm/IR/PatternMatch.h +++ b/llvm/include/llvm/IR/PatternMatch.h @@ -963,20 +963,22 @@ // The LHS is always matched first. BinaryOp_match(const LHS_t &LHS, const RHS_t &RHS) : L(LHS), R(RHS) {} - template bool match(OpTy *V) { - if (V->getValueID() == Value::InstructionVal + Opcode) { + template inline bool match(unsigned Opc, OpTy *V) { + if (V->getValueID() == Value::InstructionVal + Opc) { auto *I = cast(V); return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || (Commutable && L.match(I->getOperand(1)) && R.match(I->getOperand(0))); } if (auto *CE = dyn_cast(V)) - return CE->getOpcode() == Opcode && + return CE->getOpcode() == Opc && ((L.match(CE->getOperand(0)) && R.match(CE->getOperand(1))) || (Commutable && L.match(CE->getOperand(1)) && R.match(CE->getOperand(0)))); return false; } + + template bool match(OpTy *V) { return match(Opcode, V); } }; template @@ -1221,6 +1223,26 @@ L, R); } +template +struct SpecificBinaryOp_match + : public BinaryOp_match { + unsigned Opcode; + + SpecificBinaryOp_match(unsigned Opcode, const LHS_t &LHS, const RHS_t &RHS) + : BinaryOp_match(LHS, RHS), Opcode(Opcode) {} + + template bool match(OpTy *V) { + return BinaryOp_match::match(Opcode, V); + } +}; + +/// Matches a specific opcode. +template +inline SpecificBinaryOp_match m_BinOp(unsigned Opcode, const LHS &L, + const RHS &R) { + return SpecificBinaryOp_match(Opcode, L, R); +} + //===----------------------------------------------------------------------===// // Class that matches a group of binary opcodes. // @@ -2198,6 +2220,13 @@ R); } +/// Matches a specific opcode with LHS and RHS in either order. +template +inline SpecificBinaryOp_match +m_c_BinOp(unsigned Opcode, const LHS &L, const RHS &R) { + return SpecificBinaryOp_match(Opcode, L, R); +} + /// Matches a Add with LHS and RHS in either order. template inline BinaryOp_match m_c_Add(const LHS &L, diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -1418,15 +1418,15 @@ std::swap(Op0, Op1); // Match inner binop and the predicate for combining 2 NAN checks into 1. - BinaryOperator *BO1; + Value *BO10, *BO11; FCmpInst::Predicate NanPred = Opcode == Instruction::And ? FCmpInst::FCMP_ORD : FCmpInst::FCMP_UNO; if (!match(Op0, m_FCmp(Pred, m_Value(X), m_AnyZeroFP())) || Pred != NanPred || - !match(Op1, m_BinOp(BO1)) || BO1->getOpcode() != Opcode) + !match(Op1, m_BinOp(Opcode, m_Value(BO10), m_Value(BO11)))) return nullptr; // The inner logic op must have a matching fcmp operand. - Value *BO10 = BO1->getOperand(0), *BO11 = BO1->getOperand(1), *Y; + Value *Y; if (!match(BO10, m_FCmp(Pred, m_Value(Y), m_AnyZeroFP())) || Pred != NanPred || X->getType() != Y->getType()) std::swap(BO10, BO11); @@ -1475,14 +1475,11 @@ // (~B & A) & ~C --> A & ~(B | C) // (A | ~B) | ~C --> A | ~(B & C) // (~B | A) | ~C --> A | ~(B & C) - BinaryOperator *BO; - if (match(Op0, m_OneUse(m_BinOp(BO))) && BO->getOpcode() == Opcode) { - Value *C; - if (match(BO, m_c_BinOp(m_Value(A), m_Not(m_Value(B)))) && - match(Op1, m_Not(m_Value(C)))) { - Value *FlippedBO = Builder.CreateBinOp(FlippedOpcode, B, C); - return BinaryOperator::Create(Opcode, A, Builder.CreateNot(FlippedBO)); - } + Value *C; + if (match(Op0, m_OneUse(m_c_BinOp(Opcode, m_Value(A), m_Not(m_Value(B))))) && + match(Op1, m_Not(m_Value(C)))) { + Value *FlippedBO = Builder.CreateBinOp(FlippedOpcode, B, C); + return BinaryOperator::Create(Opcode, A, Builder.CreateNot(FlippedBO)); } return nullptr;