diff --git a/llvm/include/llvm/Support/InstructionCost.h b/llvm/include/llvm/Support/InstructionCost.h --- a/llvm/include/llvm/Support/InstructionCost.h +++ b/llvm/include/llvm/Support/InstructionCost.h @@ -146,31 +146,30 @@ return Copy; } + /// For the comparison operators we have chosen to use lexicographical + /// ordering where valid costs are always considered to be less than invalid + /// costs. This avoids having to add asserts the comparison operators that the + /// states are valid and users can test for validity of the cost explicitly. + bool operator<(const InstructionCost &RHS) const { + if (State < RHS.State) + return true; + + return Value < RHS.Value; + } + bool operator==(const InstructionCost &RHS) const { - return State == RHS.State && Value == RHS.Value; + return !(*this < RHS) && !(RHS < *this); } bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); } bool operator==(const CostType RHS) const { - return State == Valid && Value == RHS; + InstructionCost RHS2(RHS); + return *this == RHS2; } bool operator!=(const CostType RHS) const { return !(*this == RHS); } - /// For the comparison operators we have chosen to use total ordering with - /// the following rules: - /// 1. If either of the states != Valid then a lexicographical order is - /// applied based upon the state. - /// 2. If both states are valid then order based upon value. - /// This avoids having to add asserts the comparison operators that the states - /// are valid and users can test for validity of the cost explicitly. - bool operator<(const InstructionCost &RHS) const { - if (State != Valid || RHS.State != Valid) - return State < RHS.State; - return Value < RHS.Value; - } - bool operator>(const InstructionCost &RHS) const { return RHS < *this; } bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); } diff --git a/llvm/unittests/Support/InstructionCostTest.cpp b/llvm/unittests/Support/InstructionCostTest.cpp --- a/llvm/unittests/Support/InstructionCostTest.cpp +++ b/llvm/unittests/Support/InstructionCostTest.cpp @@ -25,6 +25,7 @@ InstructionCost VSix = 6; InstructionCost IThreeA = InstructionCost::getInvalid(3); InstructionCost IThreeB = InstructionCost::getInvalid(3); + InstructionCost ITwo = InstructionCost::getInvalid(2); InstructionCost TmpCost; EXPECT_NE(VThree, VNegTwo); @@ -38,6 +39,11 @@ EXPECT_EQ(VThree * VNegTwo, -6); EXPECT_EQ(VSix / VThree, 2); + EXPECT_TRUE((IThreeA < ITwo) || (IThreeA > ITwo) || (IThreeA == ITwo)); + EXPECT_NE(IThreeA, ITwo); + EXPECT_LT(ITwo, IThreeA); + EXPECT_GT(IThreeA, ITwo); + EXPECT_FALSE(IThreeA.isValid()); EXPECT_EQ(IThreeA.getState(), InstructionCost::Invalid);