Index: llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h +++ llvm/include/llvm/CodeGen/GlobalISel/MIPatternMatch.h @@ -654,8 +654,11 @@ if (!P.match(MRI, TmpPred)) return false; - return L.match(MRI, TmpMI->getOperand(2).getReg()) && - R.match(MRI, TmpMI->getOperand(3).getReg()); + return (L.match(MRI, TmpMI->getOperand(2).getReg()) && + R.match(MRI, TmpMI->getOperand(3).getReg())) || + (CmpInst::isCommutative(TmpPred) && + (L.match(MRI, TmpMI->getOperand(3).getReg()) && + R.match(MRI, TmpMI->getOperand(2).getReg()))); } }; Index: llvm/include/llvm/IR/InstrTypes.h =================================================================== --- llvm/include/llvm/IR/InstrTypes.h +++ llvm/include/llvm/IR/InstrTypes.h @@ -924,9 +924,11 @@ /// the same comparison. void swapOperands(); - /// This is just a convenience that dispatches to the subclasses. - /// Determine if this CmpInst is commutative. - bool isCommutative() const; + /// \returns true if \p P is a commutative comparison predicate. + static bool isCommutative(Predicate P); + + /// \returns true if this CmpInst is commutative. + bool isCommutative() const { return isCommutative(getPredicate()); } /// Determine if this is an equals/not equals predicate. /// This is a static version that you can use without an instruction Index: llvm/include/llvm/IR/Instructions.h =================================================================== --- llvm/include/llvm/IR/Instructions.h +++ llvm/include/llvm/IR/Instructions.h @@ -1282,9 +1282,11 @@ return isEquality(getPredicate()); } + /// @returns true if \p P is a commutative integer comparison predicate. + static bool isCommutative(Predicate P) { return isEquality(P); } + /// @returns true if the predicate of this ICmpInst is commutative - /// Determine if this relation is commutative. - bool isCommutative() const { return isEquality(); } + bool isCommutative() const { return isCommutative(getPredicate()); } /// Return true if the predicate is relational (not EQ or NE). /// @@ -1424,14 +1426,24 @@ /// Determine if this is an equality predicate. bool isEquality() const { return isEquality(getPredicate()); } + /// @returns true if \p Pred is a commutative floating point comparison + /// predicate. + static bool isCommutative(Predicate Pred) { + switch (Pred) { + default: + return isEquality(Pred); + case FCMP_FALSE: + case FCMP_TRUE: + case FCMP_ORD: + case FCMP_UNO: + return true; + } + } + /// @returns true if the predicate of this instruction is commutative. /// Determine if this is a commutative predicate. bool isCommutative() const { - return isEquality() || - getPredicate() == FCMP_FALSE || - getPredicate() == FCMP_TRUE || - getPredicate() == FCMP_ORD || - getPredicate() == FCMP_UNO; + return isCommutative(getPredicate()); } /// @returns true if the predicate is relational (not EQ or NE). Index: llvm/lib/IR/Instructions.cpp =================================================================== --- llvm/lib/IR/Instructions.cpp +++ llvm/lib/IR/Instructions.cpp @@ -3955,10 +3955,12 @@ cast(this)->swapOperands(); } -bool CmpInst::isCommutative() const { - if (const ICmpInst *IC = dyn_cast(this)) - return IC->isCommutative(); - return cast(this)->isCommutative(); +bool CmpInst::isCommutative(Predicate P) { + if (ICmpInst::isIntPredicate(P)) + return ICmpInst::isCommutative(P); + if (FCmpInst::isFPPredicate(P)) + return FCmpInst::isCommutative(P); + llvm_unreachable("Unsupported predicate kind"); } bool CmpInst::isEquality(Predicate P) { Index: llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp =================================================================== --- llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp +++ llvm/unittests/CodeGen/GlobalISel/PatternMatchTest.cpp @@ -296,6 +296,36 @@ EXPECT_EQ(CmpInst::ICMP_EQ, Pred); EXPECT_EQ(Copies[0], Reg0); EXPECT_EQ(Copies[1], Reg1); + // Check that we can commute commutative predicates, and that we don't commute + // non-commutative predicates. + Register LHS = Copies[0]; + Register RHS = Copies[1]; + for (unsigned P = CmpInst::Predicate::FIRST_ICMP_PREDICATE; + P < CmpInst::Predicate::LAST_ICMP_PREDICATE; ++P) { + auto TestingPred = static_cast(P); + auto Cmp = B.buildICmp(TestingPred, s1, LHS, RHS); + auto Dst = Cmp.getReg(0); + // Check that we can match the basic case with and without a matched + // predicate. + bool BasicNoMatchPred = match = mi_match( + Dst, *MRI, m_GICmp(m_Pred(), m_SpecificReg(LHS), m_SpecificReg(RHS))); + bool BasicMatchPred = match = + mi_match(Dst, *MRI, + m_GICmp(m_Pred(Pred), m_SpecificReg(LHS), m_SpecificReg(RHS))); + EXPECT_TRUE(BasicNoMatchPred); + EXPECT_TRUE(BasicMatchPred); + // Check that the commuted case only matches if the predicate is + // commutative. + bool CommutedNoMatchPred = mi_match( + Dst, *MRI, m_GICmp(m_Pred(), m_SpecificReg(RHS), m_SpecificReg(LHS))); + bool CommutedMatchPred = + mi_match(Dst, *MRI, + m_GICmp(m_Pred(Pred), m_SpecificReg(RHS), m_SpecificReg(LHS))); + // Commuted case will only match if the predicate is commutative. + bool IsCommutative = CmpInst::isCommutative(TestingPred); + EXPECT_EQ(BasicNoMatchPred && CommutedNoMatchPred, IsCommutative); + EXPECT_EQ(BasicMatchPred && CommutedMatchPred, IsCommutative); + } } TEST_F(AArch64GISelMITest, MatchFCmp) { @@ -321,6 +351,35 @@ EXPECT_EQ(CmpInst::FCMP_OEQ, Pred); EXPECT_EQ(Copies[0], Reg0); EXPECT_EQ(Copies[1], Reg1); + // Check that we can commute commutative predicates, and that we don't commute + // non-commutative predicates. + Register LHS = Copies[0]; + Register RHS = Copies[1]; + for (unsigned P = CmpInst::Predicate::FIRST_FCMP_PREDICATE; + P < CmpInst::Predicate::LAST_FCMP_PREDICATE; ++P) { + auto TestingPred = static_cast(P); + auto Cmp = B.buildFCmp(TestingPred, s1, LHS, RHS); + auto Dst = Cmp.getReg(0); + // Check that we can match the basic case with and without a matched + // predicate. + bool BasicNoMatchPred = mi_match( + Dst, *MRI, m_GFCmp(m_Pred(), m_SpecificReg(LHS), m_SpecificReg(RHS))); + bool BasicMatchPred = + mi_match(Dst, *MRI, + m_GFCmp(m_Pred(Pred), m_SpecificReg(LHS), m_SpecificReg(RHS))); + EXPECT_TRUE(BasicNoMatchPred); + EXPECT_TRUE(BasicMatchPred); + // Check that the commuted case only matches if the predicate is + // commutative. + bool CommutedNoMatchPred = mi_match( + Dst, *MRI, m_GFCmp(m_Pred(), m_SpecificReg(RHS), m_SpecificReg(LHS))); + bool CommutedMatchPred = + mi_match(Dst, *MRI, + m_GFCmp(m_Pred(Pred), m_SpecificReg(RHS), m_SpecificReg(LHS))); + bool IsCommutative = CmpInst::isCommutative(TestingPred); + EXPECT_EQ(BasicNoMatchPred && CommutedNoMatchPred, IsCommutative); + EXPECT_EQ(BasicMatchPred && CommutedMatchPred, IsCommutative); + } } TEST_F(AArch64GISelMITest, MatchFPUnaryOp) {