Index: llvm/include/llvm/Analysis/InstructionCost.h =================================================================== --- /dev/null +++ llvm/include/llvm/Analysis/InstructionCost.h @@ -0,0 +1,247 @@ +//===- InstructionCost.h ----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines an InstructionCost class that is used when calculating +/// the cost of an instruction, or a group of instructions. In addition to a +/// numeric value representing the cost the class also contains a state that +/// can be used to encode particular properties, i.e. a cost being invalid or +/// unknown. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_ANALYSIS_INSTRUCTIONCOST_H +#define LLVM_ANALYSIS_INSTRUCTIONCOST_H + +#include "llvm/ADT/Optional.h" + +namespace llvm { + +class raw_ostream; + +class InstructionCost { +public: + using CostType = int; + +private: + // These states can currently be used to indicate whether a cost is valid or + // invalid. Examples of an invalid cost might be where the cost is + // prohibitively expensive and the user wants to prevent certain + // optimizations being performed. Or perhaps the cost is simply unknown + // because the operation makes no sense in certain circumstances. These states + // can be expanded in future to support other cases if necessary. + enum CostState { Valid, Invalid }; + + CostType Value; + CostState State; + + void propagateState(const InstructionCost &RHS) { + if (RHS.State == Invalid) + State = Invalid; + } + +public: + InstructionCost() = default; + + InstructionCost(CostType Val) : Value(Val), State(Valid) {} + + static InstructionCost getInvalid(CostType Val = 0) { + InstructionCost Tmp(Val); + Tmp.setInvalid(); + return Tmp; + } + + bool isValid() const { return State == Valid; } + bool isInvalid() const { return State == Invalid; } + void setValid() { State = Valid; } + void setInvalid() { State = Invalid; } + + // This function is intended to be used as sparingly as possible, since the + // class provides the full range of operator support required for arithmetic + // and comparisons. Users of this class can check for validity before trying + // to extract a value, for example using the isValid() interface provided + // above. + Optional getValue() const { + if (isValid()) + return Value; + return None; + } + + // For all of the arithmetic operators provided here any invalid state is + // perpetuated and cannot be removed. Once a cost becomes invalid it stays + // invalid, and it also inherits any invalid state from the RHS. Regardless + // of the state, arithmetic and comparisons work on the actual values in the + // same way as they would on a basic type, such as integer. + + InstructionCost &operator+=(const InstructionCost &RHS) { + propagateState(RHS); + Value += RHS.Value; + return *this; + } + + InstructionCost &operator+=(const CostType RHS) { + InstructionCost RHS2(RHS); + *this += RHS2; + return *this; + } + + InstructionCost &operator-=(const InstructionCost &RHS) { + propagateState(RHS); + Value -= RHS.Value; + return *this; + } + + InstructionCost &operator-=(const CostType RHS) { + InstructionCost RHS2(RHS); + *this -= RHS2; + return *this; + } + + InstructionCost &operator*=(const InstructionCost &RHS) { + propagateState(RHS); + Value *= RHS.Value; + return *this; + } + + InstructionCost &operator*=(const CostType RHS) { + InstructionCost RHS2(RHS); + *this *= RHS2; + return *this; + } + + InstructionCost &operator/=(const InstructionCost &RHS) { + propagateState(RHS); + Value /= RHS.Value; + return *this; + } + + InstructionCost &operator/=(const CostType RHS) { + InstructionCost RHS2(RHS); + *this /= RHS2; + return *this; + } + + InstructionCost &operator++() { + *this += 1; + return *this; + } + + InstructionCost operator++(int) { + InstructionCost Copy = *this; + ++*this; + return Copy; + } + + InstructionCost &operator--() { + *this -= 1; + return *this; + } + + InstructionCost operator--(int) { + InstructionCost Copy = *this; + --*this; + return Copy; + } + + bool operator==(const InstructionCost &RHS) const { + return State == RHS.State && Value == RHS.Value; + } + + bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); } + + bool operator==(const CostType RHS) const { + return State == Valid && Value == RHS; + } + + bool operator!=(const CostType RHS) const { return !(*this == RHS); } + + // For the comparison operators we have chosen to use total ordering with + // the following rules: + // 1. If either of the states != Valid then a lexicographical order is + // applied based upon the state. + // 2. If both states are valid then order based upon value. + // This avoids having to add asserts the comparison operators that the states + // are valid and users can test for validity of the cost explicitly. + bool operator<(const InstructionCost &RHS) const { + if (State != Valid || RHS.State != Valid) + return State < RHS.State; + return Value < RHS.Value; + } + + bool operator>(const InstructionCost &RHS) const { return RHS < *this; } + + bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); } + + bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); } + + bool operator<(const CostType RHS) const { + InstructionCost RHS2(RHS); + return *this < RHS2; + } + + bool operator>(const CostType RHS) const { + InstructionCost RHS2(RHS); + return *this > RHS2; + } + + bool operator<=(const CostType RHS) const { + InstructionCost RHS2(RHS); + return *this <= RHS2; + } + + bool operator>=(const CostType RHS) const { + InstructionCost RHS2(RHS); + return *this >= RHS2; + } + + static InstructionCost min(InstructionCost LHS, InstructionCost RHS) { + return LHS < RHS ? LHS : RHS; + } + + static InstructionCost max(InstructionCost LHS, InstructionCost RHS) { + return LHS > RHS ? LHS : RHS; + } + + void print(raw_ostream &OS) const; +}; + +inline InstructionCost operator+(const InstructionCost &LHS, + const InstructionCost &RHS) { + InstructionCost LHS2(LHS); + LHS2 += RHS; + return LHS2; +} + +inline InstructionCost operator-(const InstructionCost &LHS, + const InstructionCost &RHS) { + InstructionCost LHS2(LHS); + LHS2 -= RHS; + return LHS2; +} + +inline InstructionCost operator*(const InstructionCost &LHS, + const InstructionCost &RHS) { + InstructionCost LHS2(LHS); + LHS2 *= RHS; + return LHS2; +} + +inline InstructionCost operator/(const InstructionCost &LHS, + const InstructionCost &RHS) { + InstructionCost LHS2(LHS); + LHS2 /= RHS; + return LHS2; +} + +inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) { + V.print(OS); + return OS; +} + +} // namespace llvm + +#endif Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -21,6 +21,7 @@ #ifndef LLVM_ANALYSIS_TARGETTRANSFORMINFO_H #define LLVM_ANALYSIS_TARGETTRANSFORMINFO_H +#include "llvm/Analysis/InstructionCost.h" #include "llvm/IR/InstrTypes.h" #include "llvm/IR/Operator.h" #include "llvm/IR/PassManager.h" @@ -236,19 +237,26 @@ /// /// Note, this method does not cache the cost calculation and it /// can be expensive in some cases. - int getInstructionCost(const Instruction *I, enum TargetCostKind kind) const { + InstructionCost getInstructionCost(const Instruction *I, + enum TargetCostKind kind) const { + InstructionCost Cost; switch (kind) { case TCK_RecipThroughput: - return getInstructionThroughput(I); - + Cost = getInstructionThroughput(I); + break; case TCK_Latency: - return getInstructionLatency(I); - + Cost = getInstructionLatency(I); + break; case TCK_CodeSize: case TCK_SizeAndLatency: - return getUserCost(I, kind); + Cost = getUserCost(I, kind); + break; + default: + llvm_unreachable("Unknown instruction cost kind"); } - llvm_unreachable("Unknown instruction cost kind"); + if (Cost == -1) + Cost.setInvalid(); + return Cost; } /// Underlying constants for 'cost' values in this interface. Index: llvm/include/llvm/IR/DiagnosticInfo.h =================================================================== --- llvm/include/llvm/IR/DiagnosticInfo.h +++ llvm/include/llvm/IR/DiagnosticInfo.h @@ -38,6 +38,7 @@ class LLVMContext; class Module; class SMDiagnostic; +class InstructionCost; /// Defines the different supported severity of a diagnostic. enum DiagnosticSeverity : char { @@ -437,6 +438,7 @@ Argument(StringRef Key, ElementCount EC); Argument(StringRef Key, bool B) : Key(Key), Val(B ? "true" : "false") {} Argument(StringRef Key, DebugLoc dl); + Argument(StringRef Key, InstructionCost C); }; /// \p PassName is the name of the pass emitting this diagnostic. \p Index: llvm/lib/Analysis/CostModel.cpp =================================================================== --- llvm/lib/Analysis/CostModel.cpp +++ llvm/lib/Analysis/CostModel.cpp @@ -57,7 +57,7 @@ /// Returns -1 if the cost is unknown. /// Note, this method does not cache the cost calculation and it /// can be expensive in some cases. - unsigned getInstructionCost(const Instruction *I) const { + InstructionCost getInstructionCost(const Instruction *I) const { return TTI->getInstructionCost(I, TargetTransformInfo::TCK_RecipThroughput); } @@ -103,8 +103,8 @@ for (BasicBlock &B : *F) { for (Instruction &Inst : B) { - unsigned Cost = TTI->getInstructionCost(&Inst, CostKind); - if (Cost != (unsigned)-1) + InstructionCost Cost = TTI->getInstructionCost(&Inst, CostKind); + if (!Cost.isInvalid()) OS << "Cost Model: Found an estimated cost of " << Cost; else OS << "Cost Model: Unknown cost"; Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -54,6 +54,13 @@ return true; } +void InstructionCost::print(raw_ostream &OS) const { + if (isInvalid()) + OS << "Invalid"; + else + OS << Value; +} + IntrinsicCostAttributes::IntrinsicCostAttributes(const IntrinsicInst &I) : II(&I), RetTy(I.getType()), IID(I.getIntrinsicID()) { Index: llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp =================================================================== --- llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp +++ llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp @@ -1130,8 +1130,8 @@ std::set Is; std::set SVIs; - unsigned InterleavedCost; - unsigned InstructionCost = 0; + InstructionCost InterleavedCost; + InstructionCost InstructionCost = 0; // Get the interleave factor unsigned Factor = InterleavedLoad.size(); @@ -1174,6 +1174,10 @@ } } + // We need to have a valid cost in order to proceed. + if (InstructionCost.isInvalid()) + return false; + // We know that all LoadInst are within the same BB. This guarantees that // either everything or nothing is loaded. LoadInst *First = findFirstLoad(LIs); Index: llvm/lib/IR/DiagnosticInfo.cpp =================================================================== --- llvm/lib/IR/DiagnosticInfo.cpp +++ llvm/lib/IR/DiagnosticInfo.cpp @@ -16,6 +16,7 @@ #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/ADT/iterator_range.h" +#include "llvm/Analysis/InstructionCost.h" #include "llvm/IR/BasicBlock.h" #include "llvm/IR/Constants.h" #include "llvm/IR/DebugInfoMetadata.h" @@ -220,6 +221,13 @@ EC.print(OS); } +DiagnosticInfoOptimizationBase::Argument::Argument(StringRef Key, + InstructionCost C) + : Key(std::string(Key)) { + raw_string_ostream OS(Val); + C.print(OS); +} + DiagnosticInfoOptimizationBase::Argument::Argument(StringRef Key, DebugLoc Loc) : Key(std::string(Key)), Loc(Loc) { if (Loc) { Index: llvm/lib/Transforms/IPO/HotColdSplitting.cpp =================================================================== --- llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -233,11 +233,11 @@ } /// Get the benefit score of outlining \p Region. -static int getOutliningBenefit(ArrayRef Region, - TargetTransformInfo &TTI) { +static InstructionCost getOutliningBenefit(ArrayRef Region, + TargetTransformInfo &TTI) { // Sum up the code size costs of non-terminator instructions. Tight coupling // with \ref getOutliningPenalty is needed to model the costs of terminators. - int Benefit = 0; + InstructionCost Benefit = 0; for (BasicBlock *BB : Region) for (Instruction &I : BB->instructionsWithoutDebug()) if (&I != BB->getTerminator()) @@ -324,12 +324,12 @@ // splitting. SetVector Inputs, Outputs, Sinks; CE.findInputsOutputs(Inputs, Outputs, Sinks); - int OutliningBenefit = getOutliningBenefit(Region, TTI); + InstructionCost OutliningBenefit = getOutliningBenefit(Region, TTI); int OutliningPenalty = getOutliningPenalty(Region, Inputs.size(), Outputs.size()); LLVM_DEBUG(dbgs() << "Split profitability: benefit = " << OutliningBenefit << ", penalty = " << OutliningPenalty << "\n"); - if (OutliningBenefit <= OutliningPenalty) + if (OutliningBenefit.isInvalid() || OutliningBenefit <= OutliningPenalty) return nullptr; Function *OrigF = Region[0]->getParent(); Index: llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp =================================================================== --- llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -208,7 +208,7 @@ // instructions before the call is less then DuplicationThreshold. The // instructions before the call will be duplicated in the split blocks and // corresponding uses will be updated. - unsigned Cost = 0; + InstructionCost Cost = 0; for (auto &InstBeforeCall : llvm::make_range(CallSiteBB->begin(), CB.getIterator())) { Cost += TTI.getInstructionCost(&InstBeforeCall, Index: llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp =================================================================== --- llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3781,17 +3781,17 @@ if (NeedToShuffleReuses) { for (unsigned Idx : E->ReuseShuffleIndices) { Instruction *I = cast(VL[Idx]); - ReuseShuffleCost -= TTI->getInstructionCost(I, CostKind); + ReuseShuffleCost -= *(TTI->getInstructionCost(I, CostKind).getValue()); } for (Value *V : VL) { Instruction *I = cast(V); - ReuseShuffleCost += TTI->getInstructionCost(I, CostKind); + ReuseShuffleCost += *(TTI->getInstructionCost(I, CostKind).getValue()); } } for (Value *V : VL) { Instruction *I = cast(V); assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); - ScalarCost += TTI->getInstructionCost(I, CostKind); + ScalarCost += *(TTI->getInstructionCost(I, CostKind).getValue()); } // VecCost is equal to sum of the cost of creating 2 vectors // and the cost of creating shuffle. Index: llvm/unittests/Analysis/CMakeLists.txt =================================================================== --- llvm/unittests/Analysis/CMakeLists.txt +++ llvm/unittests/Analysis/CMakeLists.txt @@ -29,6 +29,7 @@ DomTreeUpdaterTest.cpp GlobalsModRefTest.cpp FunctionPropertiesAnalysisTest.cpp + InstructionCostTest.cpp IRSimilarityIdentifierTest.cpp IVDescriptorsTest.cpp LazyCallGraphTest.cpp Index: llvm/unittests/Analysis/InstructionCostTest.cpp =================================================================== --- /dev/null +++ llvm/unittests/Analysis/InstructionCostTest.cpp @@ -0,0 +1,51 @@ +//===- InstructionCostTest.cpp - InstructionCost tests --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/InstructionCost.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +struct CostTest : public testing::Test { + CostTest() {} +}; + +} // namespace + +TEST_F(CostTest, Operators) { + InstructionCost VThree = 3; + InstructionCost VNegTwo = -2; + InstructionCost VSix = 6; + InstructionCost IThreeA = InstructionCost::getInvalid(3); + InstructionCost IThreeB = InstructionCost::getInvalid(3); + InstructionCost TmpCost; + + EXPECT_NE(VThree, VNegTwo); + EXPECT_GT(VThree, VNegTwo); + EXPECT_NE(VThree, IThreeA); + EXPECT_EQ(IThreeA, IThreeB); + EXPECT_GE(IThreeA, VNegTwo); + EXPECT_LT(VSix, IThreeA); + EXPECT_EQ(VSix - IThreeA, IThreeB); + EXPECT_EQ(VThree - VNegTwo, 5); + EXPECT_EQ(VThree * VNegTwo, -6); + EXPECT_EQ(VSix / VThree, 2); + + EXPECT_TRUE(IThreeA.isInvalid()); + + TmpCost = VThree + IThreeA; + EXPECT_TRUE(TmpCost.isInvalid()); + + TmpCost = VThree * IThreeA; + EXPECT_TRUE(TmpCost.isInvalid()); + + EXPECT_EQ(InstructionCost::min(VThree, VNegTwo), -2); + EXPECT_EQ(InstructionCost::max(VThree, VSix), 6); +}