Index: lib/Transforms/Utils/FunctionComparator.cpp =================================================================== --- lib/Transforms/Utils/FunctionComparator.cpp +++ lib/Transforms/Utils/FunctionComparator.cpp @@ -758,13 +758,35 @@ if (needToCmpOperands) { assert(InstL->getNumOperands() == InstR->getNumOperands()); - for (unsigned i = 0, e = InstL->getNumOperands(); i != e; ++i) { - Value *OpL = InstL->getOperand(i); - Value *OpR = InstR->getOperand(i); - if (int Res = cmpValues(OpL, OpR)) - return Res; - // cmpValues should ensure this is true. - assert(cmpTypes(OpL->getType(), OpR->getType()) == 0); + const BinaryOperator *BOL = dyn_cast(InstL); + const BinaryOperator *BOR = dyn_cast(InstR); + if (BOL && BOR && BOL->isCommutative()) { + assert(BOR->isCommutative()); + + Value *Op0L = BOL->getOperand(0); + Value *Op0R = BOR->getOperand(0); + Value *Op1L = BOL->getOperand(1); + Value *Op1R = BOR->getOperand(1); + int Res00 = cmpValues(Op0L, Op0R); + int Res01 = cmpValues(Op0L, Op1R); + int Res10 = cmpValues(Op1L, Op0R); + int Res11 = cmpValues(Op1L, Op1R); + int FinalRes; + //get the final result to return + if (Res00==0 && Res11==0) FinalRes = 0; //validate operands in the original order + else if (Res01==0 && Res10==0) FinalRes = 0; //validate reodered operands + else if (Res00) FinalRes = Res00; //assume original order + else FinalRes = Res11; + if (FinalRes) return FinalRes; + } else { + for (unsigned i = 0, e = InstL->getNumOperands(); i != e; ++i) { + Value *OpL = InstL->getOperand(i); + Value *OpR = InstR->getOperand(i); + if (int Res = cmpValues(OpL, OpR)) + return Res; + // cmpValues should ensure this is true. + assert(cmpTypes(OpL->getType(), OpR->getType()) == 0); + } } } Index: unittests/Transforms/Utils/FunctionComparatorTest.cpp =================================================================== --- unittests/Transforms/Utils/FunctionComparatorTest.cpp +++ unittests/Transforms/Utils/FunctionComparatorTest.cpp @@ -11,10 +11,21 @@ #include "llvm/IR/Instructions.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/Support/SourceMgr.h" #include "gtest/gtest.h" using namespace llvm; + +static std::unique_ptr parseIR(LLVMContext &C, const char *IR) { + SMDiagnostic Err; + std::unique_ptr Mod = parseAssemblyString(IR, Err, C); + if (!Mod) + Err.print("FunctionComparatorTest", errs()); + return Mod; +} + /// Generates a simple test function. struct TestFunction { Function *F; @@ -127,3 +138,35 @@ EXPECT_EQ(Cmp.testCmpTypes(F1.T, F2.T), 0); EXPECT_EQ(Cmp.testCmpPrimitives(), -4); } + + +TEST(FunctionComparatorTest, TestOperandReordering) { + LLVMContext C; + + std::unique_ptr M = parseIR( + C, + R"( +define dso_local i32 @foo(i32 %a, i32 %b, i32 %c) local_unnamed_addr { +entry: + %mul1 = mul nsw i32 %b, %c + %mul = mul nsw i32 3, %a + %add = add nsw i32 %mul1, %mul + ret i32 %add +} + +define dso_local i32 @bar(i32 %a, i32 %b, i32 %c) local_unnamed_addr { +entry: + %mul = mul nsw i32 %c, %b + %mul1 = mul nsw i32 %a, 3 + %add = add nsw i32 %mul1, %mul + ret i32 %add +} + )"); + + auto *F1 = M->getFunction("foo"); + auto *F2 = M->getFunction("bar"); + GlobalNumberState GlobalNumbers; + FunctionComparator FCmp(F1,F2,&GlobalNumbers); + bool AreEqual = FCmp.compare()==0; + EXPECT_TRUE(AreEqual); +}