Index: include/llvm/IR/PatternMatch.h =================================================================== --- include/llvm/IR/PatternMatch.h +++ include/llvm/IR/PatternMatch.h @@ -419,7 +419,8 @@ //===----------------------------------------------------------------------===// // Matcher for any binary operator. // -template struct AnyBinaryOp_match { +template +struct AnyBinaryOp_match { LHS_t L; RHS_t R; @@ -427,7 +428,9 @@ template bool match(OpTy *V) { if (auto *I = dyn_cast(V)) - return L.match(I->getOperand(0)) && R.match(I->getOperand(1)); + return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || + (Commutable && L.match(I->getOperand(0)) && + R.match(I->getOperand(1))); return false; } }; @@ -441,7 +444,8 @@ // Matchers for specific binary operators. // -template +template struct BinaryOp_match { LHS_t L; RHS_t R; @@ -451,11 +455,15 @@ template bool match(OpTy *V) { if (V->getValueID() == Value::InstructionVal + Opcode) { auto *I = cast(V); - return L.match(I->getOperand(0)) && R.match(I->getOperand(1)); + return (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || + (Commutable && R.match(I->getOperand(0)) && + L.match(I->getOperand(1))); } if (auto *CE = dyn_cast(V)) - return CE->getOpcode() == Opcode && L.match(CE->getOperand(0)) && - R.match(CE->getOperand(1)); + return CE->getOpcode() == Opcode && + ((L.match(CE->getOperand(0)) && R.match(CE->getOperand(1))) || + (Commutable && L.match(CE->getOperand(0)) && + R.match(CE->getOperand(1)))); return false; } }; @@ -726,7 +734,8 @@ // Matchers for CmpInst classes // -template +template struct CmpClass_match { PredicateTy &Predicate; LHS_t L; @@ -737,7 +746,9 @@ template bool match(OpTy *V) { if (auto *I = dyn_cast(V)) - if (L.match(I->getOperand(0)) && R.match(I->getOperand(1))) { + if ((L.match(I->getOperand(0)) && R.match(I->getOperand(1))) || + (Commutable && R.match(I->getOperand(0)) && + L.match(I->getOperand(1)))) { Predicate = I->getPredicate(); return true; } @@ -1002,7 +1013,8 @@ // Matchers for max/min idioms, eg: "select (sgt x, y), x, y" -> smax(x,y). // -template +template struct MaxMin_match { LHS_t L; RHS_t R; @@ -1032,7 +1044,8 @@ if (!Pred_t::match(Pred)) return false; // It does! Bind the operands. - return L.match(LHS) && R.match(RHS); + return (L.match(LHS) && R.match(RHS)) || + (Commutable && R.match(LHS) && L.match(RHS)); } }; @@ -1376,89 +1389,78 @@ // /// \brief Matches a BinaryOperator with LHS and RHS in either order. -template -inline match_combine_or, - AnyBinaryOp_match> -m_c_BinOp(const LHS &L, const RHS &R) { - return m_CombineOr(m_BinOp(L, R), m_BinOp(R, L)); +template +inline AnyBinaryOp_match m_c_BinOp(const LHS &L, const RHS &R) { + return AnyBinaryOp_match(L, R); } /// \brief Matches an ICmp with a predicate over LHS and RHS in either order. /// Does not swap the predicate. -template -inline match_combine_or, - CmpClass_match> +template +inline CmpClass_match m_c_ICmp(ICmpInst::Predicate &Pred, const LHS &L, const RHS &R) { - return m_CombineOr(m_ICmp(Pred, L, R), m_ICmp(Pred, R, L)); + return CmpClass_match(Pred, L, + R); } /// \brief Matches a Add with LHS and RHS in either order. -template -inline match_combine_or, - BinaryOp_match> -m_c_Add(const LHS &L, const RHS &R) { - return m_CombineOr(m_Add(L, R), m_Add(R, L)); +template +inline BinaryOp_match m_c_Add(const LHS &L, + const RHS &R) { + return BinaryOp_match(L, R); } /// \brief Matches a Mul with LHS and RHS in either order. -template -inline match_combine_or, - BinaryOp_match> -m_c_Mul(const LHS &L, const RHS &R) { - return m_CombineOr(m_Mul(L, R), m_Mul(R, L)); +template +inline BinaryOp_match m_c_Mul(const LHS &L, + const RHS &R) { + return BinaryOp_match(L, R); } /// \brief Matches an And with LHS and RHS in either order. -template -inline match_combine_or, - BinaryOp_match> -m_c_And(const LHS &L, const RHS &R) { - return m_CombineOr(m_And(L, R), m_And(R, L)); +template +inline BinaryOp_match m_c_And(const LHS &L, + const RHS &R) { + return BinaryOp_match(L, R); } /// \brief Matches an Or with LHS and RHS in either order. -template -inline match_combine_or, - BinaryOp_match> -m_c_Or(const LHS &L, const RHS &R) { - return m_CombineOr(m_Or(L, R), m_Or(R, L)); +template +inline BinaryOp_match m_c_Or(const LHS &L, + const RHS &R) { + return BinaryOp_match(L, R); } /// \brief Matches an Xor with LHS and RHS in either order. -template -inline match_combine_or, - BinaryOp_match> -m_c_Xor(const LHS &L, const RHS &R) { - return m_CombineOr(m_Xor(L, R), m_Xor(R, L)); +template +inline BinaryOp_match m_c_Xor(const LHS &L, + const RHS &R) { + return BinaryOp_match(L, R); } /// Matches an SMin with LHS and RHS in either order. template -inline match_combine_or, - MaxMin_match> +inline MaxMin_match m_c_SMin(const LHS &L, const RHS &R) { - return m_CombineOr(m_SMin(L, R), m_SMin(R, L)); + return MaxMin_match(L, R); } /// Matches an SMax with LHS and RHS in either order. template -inline match_combine_or, - MaxMin_match> +inline MaxMin_match m_c_SMax(const LHS &L, const RHS &R) { - return m_CombineOr(m_SMax(L, R), m_SMax(R, L)); + return MaxMin_match(L, R); } /// Matches a UMin with LHS and RHS in either order. template -inline match_combine_or, - MaxMin_match> +inline MaxMin_match m_c_UMin(const LHS &L, const RHS &R) { - return m_CombineOr(m_UMin(L, R), m_UMin(R, L)); + return MaxMin_match(L, R); } /// Matches a UMax with LHS and RHS in either order. template -inline match_combine_or, - MaxMin_match> +inline MaxMin_match m_c_UMax(const LHS &L, const RHS &R) { - return m_CombineOr(m_UMax(L, R), m_UMax(R, L)); + return MaxMin_match(L, R); } } // end namespace PatternMatch