diff --git a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h --- a/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ b/llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -636,7 +636,8 @@ // General helper for generic MI compares, i.e. G_ICMP and G_FCMP // TODO: Allow checking a specific predicate. -template +template struct CompareOp_match { Pred_P P; LHS_P L; @@ -655,9 +656,14 @@ static_cast(TmpMI->getOperand(1).getPredicate()); if (!P.match(MRI, TmpPred)) return false; - - return L.match(MRI, TmpMI->getOperand(2).getReg()) && - R.match(MRI, TmpMI->getOperand(3).getReg()); + Register LHS = TmpMI->getOperand(2).getReg(); + Register RHS = TmpMI->getOperand(3).getReg(); + if (L.match(MRI, LHS) && R.match(MRI, RHS)) + return true; + if (Commutable && L.match(MRI, RHS) && R.match(MRI, LHS) && + P.match(MRI, CmpInst::getSwappedPredicate(TmpPred))) + return true; + return false; } }; @@ -673,6 +679,36 @@ return CompareOp_match(P, L, R); } +/// G_ICMP matcher that also matches commuted compares. +/// E.g. +/// +/// m_c_GICmp(m_Pred(...), m_GAdd(...), m_GSub(...)) +/// +/// Could match both of: +/// +/// icmp ugt (add x, y) (sub a, b) +/// icmp ult (sub a, b) (add x, y) +template +inline CompareOp_match +m_c_GICmp(const Pred &P, const LHS &L, const RHS &R) { + return CompareOp_match(P, L, R); +} + +/// G_FCMP matcher that also matches commuted compares. +/// E.g. +/// +/// m_c_GFCmp(m_Pred(...), m_FAdd(...), m_GFMul(...)) +/// +/// Could match both of: +/// +/// fcmp ogt (fadd x, y) (fmul a, b) +/// fcmp olt (fmul a, b) (fadd x, y) +template +inline CompareOp_match +m_c_GFCmp(const Pred &P, const LHS &L, const RHS &R) { + return CompareOp_match(P, L, R); +} + // Helper for checking if a Reg is of specific type. struct CheckType { LLT Ty; diff --git a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp --- a/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ b/llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -323,6 +323,64 @@ EXPECT_EQ(Copies[1], Reg1); } +TEST_F(AArch64GISelMITest, MatcCommutativeICmp) { + setUp(); + if (!TM) + return; + const LLT s1 = LLT::scalar(1); + Register LHS = Copies[0]; + Register RHS = Copies[1]; + CmpInst::Predicate MatchedPred; + bool Match = false; + for (unsigned P = CmpInst::Predicate::FIRST_ICMP_PREDICATE; + P < CmpInst::Predicate::LAST_ICMP_PREDICATE; ++P) { + auto CurrPred = static_cast(P); + auto Cmp = B.buildICmp(CurrPred, s1, LHS, RHS); + // Basic matching. + Match = mi_match( + Cmp.getReg(0), *MRI, + m_c_GICmp(m_Pred(MatchedPred), m_SpecificReg(LHS), m_SpecificReg(RHS))); + EXPECT_TRUE(Match); + EXPECT_EQ(MatchedPred, CurrPred); + // Commuting operands should still match, but the predicate should be + // swapped. + Match = mi_match( + Cmp.getReg(0), *MRI, + m_c_GICmp(m_Pred(MatchedPred), m_SpecificReg(RHS), m_SpecificReg(LHS))); + EXPECT_TRUE(Match); + EXPECT_EQ(MatchedPred, CmpInst::getSwappedPredicate(CurrPred)); + } +} + +TEST_F(AArch64GISelMITest, MatcCommutativeFCmp) { + setUp(); + if (!TM) + return; + const LLT s1 = LLT::scalar(1); + Register LHS = Copies[0]; + Register RHS = Copies[1]; + CmpInst::Predicate MatchedPred; + bool Match = false; + for (unsigned P = CmpInst::Predicate::FIRST_FCMP_PREDICATE; + P < CmpInst::Predicate::LAST_FCMP_PREDICATE; ++P) { + auto CurrPred = static_cast(P); + auto Cmp = B.buildFCmp(CurrPred, s1, LHS, RHS); + // Basic matching. + Match = mi_match( + Cmp.getReg(0), *MRI, + m_c_GFCmp(m_Pred(MatchedPred), m_SpecificReg(LHS), m_SpecificReg(RHS))); + EXPECT_TRUE(Match); + EXPECT_EQ(MatchedPred, CurrPred); + // Commuting operands should still match, but the predicate should be + // swapped. + Match = mi_match( + Cmp.getReg(0), *MRI, + m_c_GFCmp(m_Pred(MatchedPred), m_SpecificReg(RHS), m_SpecificReg(LHS))); + EXPECT_TRUE(Match); + EXPECT_EQ(MatchedPred, CmpInst::getSwappedPredicate(CurrPred)); + } +} + TEST_F(AArch64GISelMITest, MatchFPUnaryOp) { setUp(); if (!TM)