diff --git a/llvm/include/llvm/Support/InstructionCost.h b/llvm/include/llvm/Support/InstructionCost.h --- a/llvm/include/llvm/Support/InstructionCost.h +++ b/llvm/include/llvm/Support/InstructionCost.h @@ -10,7 +10,7 @@ /// the cost of an instruction, or a group of instructions. In addition to a /// numeric value representing the cost the class also contains a state that /// can be used to encode particular properties, i.e. a cost being invalid or -/// unknown. +/// unknown. Operations on InstructionCost implement saturation arithmetic. /// //===----------------------------------------------------------------------===// @@ -18,6 +18,8 @@ #define LLVM_SUPPORT_INSTRUCTIONCOST_H #include "llvm/ADT/Optional.h" +#include "llvm/Support/MathExtras.h" +#include namespace llvm { @@ -44,6 +46,9 @@ State = Invalid; } + static CostType getMaxValue() { return std::numeric_limits::max(); } + static CostType getMinValue() { return std::numeric_limits::min(); } + public: // A default constructed InstructionCost is a valid zero cost InstructionCost() = default; @@ -51,6 +56,8 @@ InstructionCost(CostState) = delete; InstructionCost(CostType Val) : Value(Val), State(Valid) {} + static InstructionCost getMax() { return getMaxValue(); } + static InstructionCost getMin() { return getMinValue(); } static InstructionCost getInvalid(CostType Val = 0) { InstructionCost Tmp(Val); Tmp.setInvalid(); @@ -73,13 +80,19 @@ /// For all of the arithmetic operators provided here any invalid state is /// perpetuated and cannot be removed. Once a cost becomes invalid it stays - /// invalid, and it also inherits any invalid state from the RHS. Regardless - /// of the state, arithmetic work on the actual values in the same way as they - /// would on a basic type, such as integer. + /// invalid, and it also inherits any invalid state from the RHS. + /// Arithmetic work on the actual values is implemented with saturation, + /// to avoid overflow when using more extreme cost values. InstructionCost &operator+=(const InstructionCost &RHS) { propagateState(RHS); - Value += RHS.Value; + + // Saturating addition. + InstructionCost::CostType Result; + if (AddOverflow(Value, RHS.Value, Result)) + Result = RHS.Value > 0 ? getMaxValue() : getMinValue(); + + Value = Result; return *this; } @@ -91,7 +104,12 @@ InstructionCost &operator-=(const InstructionCost &RHS) { propagateState(RHS); - Value -= RHS.Value; + + // Saturating subtract. + InstructionCost::CostType Result; + if (SubOverflow(Value, RHS.Value, Result)) + Result = RHS.Value > 0 ? getMinValue() : getMaxValue(); + Value = Result; return *this; } @@ -103,7 +121,17 @@ InstructionCost &operator*=(const InstructionCost &RHS) { propagateState(RHS); - Value *= RHS.Value; + + // Saturating multiply. + InstructionCost::CostType Result; + if (MulOverflow(Value, RHS.Value, Result)) { + if ((Value > 0 && RHS.Value > 0) || (Value < 0 && RHS.Value < 0)) + Result = getMaxValue(); + else + Result = getMinValue(); + } + + Value = Result; return *this; } diff --git a/llvm/unittests/Support/InstructionCostTest.cpp b/llvm/unittests/Support/InstructionCostTest.cpp --- a/llvm/unittests/Support/InstructionCostTest.cpp +++ b/llvm/unittests/Support/InstructionCostTest.cpp @@ -8,6 +8,7 @@ #include "llvm/Support/InstructionCost.h" #include "gtest/gtest.h" +#include using namespace llvm; @@ -75,4 +76,20 @@ EXPECT_EQ(std::min(VThree, VNegTwo), -2); EXPECT_EQ(std::max(VThree, VSix), 6); + + // Test saturation + auto Max = InstructionCost::getMax(); + auto Min = InstructionCost::getMin(); + auto MinusOne = InstructionCost(-1); + auto MinusTwo = InstructionCost(-2); + auto One = InstructionCost(1); + auto Two = InstructionCost(2); + EXPECT_EQ(Max + One, Max); + EXPECT_EQ(Min + MinusOne, Min); + EXPECT_EQ(Min - One, Min); + EXPECT_EQ(Max - MinusOne, Max); + EXPECT_EQ(Max * Two, Max); + EXPECT_EQ(Min * Two, Min); + EXPECT_EQ(Max * MinusTwo, Min); + EXPECT_EQ(Min * MinusTwo, Max); }