diff --git a/llvm/include/llvm/Support/InstructionCost.h b/llvm/include/llvm/Support/InstructionCost.h --- a/llvm/include/llvm/Support/InstructionCost.h +++ b/llvm/include/llvm/Support/InstructionCost.h @@ -36,8 +36,8 @@ enum CostState { Valid, Invalid }; private: - CostType Value; - CostState State; + CostType Value = 0; + CostState State = Valid; void propagateState(const InstructionCost &RHS) { if (RHS.State == Invalid) @@ -45,6 +45,7 @@ } public: + // A default constructed InstructionCost is a valid zero cost InstructionCost() = default; InstructionCost(CostState) = delete; @@ -146,31 +147,32 @@ return Copy; } + /// For the comparison operators we have chosen to use lexicographical + /// ordering where valid costs are always considered to be less than invalid + /// costs. This avoids having to add asserts to the comparison operators that + /// the states are valid and users can test for validity of the cost + /// explicitly. + bool operator<(const InstructionCost &RHS) const { + if (State != RHS.State) + return State < RHS.State; + return Value < RHS.Value; + } + + // Implement in terms of operator< to ensure that the two comparisons stay in + // sync bool operator==(const InstructionCost &RHS) const { - return State == RHS.State && Value == RHS.Value; + return !(*this < RHS) && !(RHS < *this); } bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); } bool operator==(const CostType RHS) const { - return State == Valid && Value == RHS; + InstructionCost RHS2(RHS); + return *this == RHS2; } bool operator!=(const CostType RHS) const { return !(*this == RHS); } - /// For the comparison operators we have chosen to use total ordering with - /// the following rules: - /// 1. If either of the states != Valid then a lexicographical order is - /// applied based upon the state. - /// 2. If both states are valid then order based upon value. - /// This avoids having to add asserts the comparison operators that the states - /// are valid and users can test for validity of the cost explicitly. - bool operator<(const InstructionCost &RHS) const { - if (State != Valid || RHS.State != Valid) - return State < RHS.State; - return Value < RHS.Value; - } - bool operator>(const InstructionCost &RHS) const { return RHS < *this; } bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); } diff --git a/llvm/unittests/Support/InstructionCostTest.cpp b/llvm/unittests/Support/InstructionCostTest.cpp --- a/llvm/unittests/Support/InstructionCostTest.cpp +++ b/llvm/unittests/Support/InstructionCostTest.cpp @@ -19,12 +19,21 @@ } // namespace +TEST_F(CostTest, DefaultCtor) { + InstructionCost DefaultCost; + + ASSERT_TRUE(DefaultCost.isValid()); + EXPECT_EQ(*(DefaultCost.getValue()), 0); +} + TEST_F(CostTest, Operators) { + InstructionCost VThree = 3; InstructionCost VNegTwo = -2; InstructionCost VSix = 6; InstructionCost IThreeA = InstructionCost::getInvalid(3); InstructionCost IThreeB = InstructionCost::getInvalid(3); + InstructionCost ITwo = InstructionCost::getInvalid(2); InstructionCost TmpCost; EXPECT_NE(VThree, VNegTwo); @@ -33,10 +42,15 @@ EXPECT_EQ(IThreeA, IThreeB); EXPECT_GE(IThreeA, VNegTwo); EXPECT_LT(VSix, IThreeA); + EXPECT_LT(VThree, ITwo); + EXPECT_GE(ITwo, VThree); EXPECT_EQ(VSix - IThreeA, IThreeB); EXPECT_EQ(VThree - VNegTwo, 5); EXPECT_EQ(VThree * VNegTwo, -6); EXPECT_EQ(VSix / VThree, 2); + EXPECT_NE(IThreeA, ITwo); + EXPECT_LT(ITwo, IThreeA); + EXPECT_GT(IThreeA, ITwo); EXPECT_FALSE(IThreeA.isValid()); EXPECT_EQ(IThreeA.getState(), InstructionCost::Invalid);