diff --git a/llvm/include/llvm/Support/InstructionCost.h b/llvm/include/llvm/Support/InstructionCost.h --- a/llvm/include/llvm/Support/InstructionCost.h +++ b/llvm/include/llvm/Support/InstructionCost.h @@ -9,8 +9,9 @@ /// This file defines an InstructionCost class that is used when calculating /// the cost of an instruction, or a group of instructions. In addition to a /// numeric value representing the cost the class also contains a state that -/// can be used to encode particular properties, i.e. a cost being invalid or -/// unknown. +/// can be used to encode particular properties, such as a cost being invalid. +/// Operations on InstructionCost implement saturation arithmetic, so that +/// accumulating costs on large cost-values don't overflow. /// //===----------------------------------------------------------------------===// @@ -18,6 +19,8 @@ #define LLVM_SUPPORT_INSTRUCTIONCOST_H #include "llvm/ADT/Optional.h" +#include "llvm/Support/MathExtras.h" +#include namespace llvm { @@ -27,13 +30,24 @@ public: using CostType = int; - /// These states can currently be used to indicate whether a cost is valid or - /// invalid. Examples of an invalid cost might be where the cost is - /// prohibitively expensive and the user wants to prevent certain - /// optimizations being performed. Or perhaps the cost is simply unknown - /// because the operation makes no sense in certain circumstances. These - /// states can be expanded in future to support other cases if necessary. - enum CostState { Valid, Invalid }; + /// CostState describes the state of a cost. + enum CostState { + Valid, /// < The cost value represents a valid cost, even when the + /// cost-value is large. + Invalid /// < Invalid indicates there is no way to represent the cost as a + /// numeric value. This state exists to represent a possible issue, + /// e.g. if the cost-model knows the operation cannot be expanded + /// into a valid code-sequence by the code-generator. While some + /// passes may assert that the calculated cost must be valid, it is + /// up to individual passes how to interpret an Invalid cost. For + /// example, a transformation pass could choose not to perform a + /// transformation if the resulting cost would end up Invalid. + /// Because some passes may assert a cost is Valid, it is not + /// recommended to use Invalid costs to model 'Unknown'. + /// Note that Invalid is semantically different from a (very) high, + /// but valid cost, which intentionally indicates no issue, but + /// rather a strong preference not to select a certain operation. + }; private: CostType Value = 0; @@ -44,6 +58,9 @@ State = Invalid; } + static CostType getMaxValue() { return std::numeric_limits::max(); } + static CostType getMinValue() { return std::numeric_limits::min(); } + public: // A default constructed InstructionCost is a valid zero cost InstructionCost() = default; @@ -51,6 +68,8 @@ InstructionCost(CostState) = delete; InstructionCost(CostType Val) : Value(Val), State(Valid) {} + static InstructionCost getMax() { return getMaxValue(); } + static InstructionCost getMin() { return getMinValue(); } static InstructionCost getInvalid(CostType Val = 0) { InstructionCost Tmp(Val); Tmp.setInvalid(); @@ -73,13 +92,19 @@ /// For all of the arithmetic operators provided here any invalid state is /// perpetuated and cannot be removed. Once a cost becomes invalid it stays - /// invalid, and it also inherits any invalid state from the RHS. Regardless - /// of the state, arithmetic work on the actual values in the same way as they - /// would on a basic type, such as integer. + /// invalid, and it also inherits any invalid state from the RHS. + /// Arithmetic work on the actual values is implemented with saturation, + /// to avoid overflow when using more extreme cost values. InstructionCost &operator+=(const InstructionCost &RHS) { propagateState(RHS); - Value += RHS.Value; + + // Saturating addition. + InstructionCost::CostType Result; + if (AddOverflow(Value, RHS.Value, Result)) + Result = RHS.Value > 0 ? getMaxValue() : getMinValue(); + + Value = Result; return *this; } @@ -91,7 +116,12 @@ InstructionCost &operator-=(const InstructionCost &RHS) { propagateState(RHS); - Value -= RHS.Value; + + // Saturating subtract. + InstructionCost::CostType Result; + if (SubOverflow(Value, RHS.Value, Result)) + Result = RHS.Value > 0 ? getMinValue() : getMaxValue(); + Value = Result; return *this; } @@ -103,7 +133,17 @@ InstructionCost &operator*=(const InstructionCost &RHS) { propagateState(RHS); - Value *= RHS.Value; + + // Saturating multiply. + InstructionCost::CostType Result; + if (MulOverflow(Value, RHS.Value, Result)) { + if ((Value > 0 && RHS.Value > 0) || (Value < 0 && RHS.Value < 0)) + Result = getMaxValue(); + else + Result = getMinValue(); + } + + Value = Result; return *this; } diff --git a/llvm/unittests/Support/InstructionCostTest.cpp b/llvm/unittests/Support/InstructionCostTest.cpp --- a/llvm/unittests/Support/InstructionCostTest.cpp +++ b/llvm/unittests/Support/InstructionCostTest.cpp @@ -8,6 +8,7 @@ #include "llvm/Support/InstructionCost.h" #include "gtest/gtest.h" +#include using namespace llvm; @@ -75,4 +76,20 @@ EXPECT_EQ(std::min(VThree, VNegTwo), -2); EXPECT_EQ(std::max(VThree, VSix), 6); + + // Test saturation + auto Max = InstructionCost::getMax(); + auto Min = InstructionCost::getMin(); + auto MinusOne = InstructionCost(-1); + auto MinusTwo = InstructionCost(-2); + auto One = InstructionCost(1); + auto Two = InstructionCost(2); + EXPECT_EQ(Max + One, Max); + EXPECT_EQ(Min + MinusOne, Min); + EXPECT_EQ(Min - One, Min); + EXPECT_EQ(Max - MinusOne, Max); + EXPECT_EQ(Max * Two, Max); + EXPECT_EQ(Min * Two, Min); + EXPECT_EQ(Max * MinusTwo, Min); + EXPECT_EQ(Min * MinusTwo, Max); }