diff --git a/llvm/include/llvm/Analysis/TargetTransformInfo.h b/llvm/include/llvm/Analysis/TargetTransformInfo.h --- a/llvm/include/llvm/Analysis/TargetTransformInfo.h +++ b/llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -27,6 +27,7 @@ #include "llvm/Pass.h" #include "llvm/Support/AtomicOrdering.h" #include "llvm/Support/DataTypes.h" +#include "llvm/Support/InstructionCost.h" #include namespace llvm { @@ -231,19 +232,26 @@ /// /// Note, this method does not cache the cost calculation and it /// can be expensive in some cases. - int getInstructionCost(const Instruction *I, enum TargetCostKind kind) const { + InstructionCost getInstructionCost(const Instruction *I, + enum TargetCostKind kind) const { + InstructionCost Cost; switch (kind) { case TCK_RecipThroughput: - return getInstructionThroughput(I); - + Cost = getInstructionThroughput(I); + break; case TCK_Latency: - return getInstructionLatency(I); - + Cost = getInstructionLatency(I); + break; case TCK_CodeSize: case TCK_SizeAndLatency: - return getUserCost(I, kind); + Cost = getUserCost(I, kind); + break; + default: + llvm_unreachable("Unknown instruction cost kind"); } - llvm_unreachable("Unknown instruction cost kind"); + if (Cost == -1) + Cost.setInvalid(); + return Cost; } /// Underlying constants for 'cost' values in this interface. diff --git a/llvm/include/llvm/IR/DiagnosticInfo.h b/llvm/include/llvm/IR/DiagnosticInfo.h --- a/llvm/include/llvm/IR/DiagnosticInfo.h +++ b/llvm/include/llvm/IR/DiagnosticInfo.h @@ -35,6 +35,7 @@ class DiagnosticPrinter; class Function; class Instruction; +class InstructionCost; class LLVMContext; class Module; class SMDiagnostic; @@ -437,6 +438,7 @@ Argument(StringRef Key, ElementCount EC); Argument(StringRef Key, bool B) : Key(Key), Val(B ? "true" : "false") {} Argument(StringRef Key, DebugLoc dl); + Argument(StringRef Key, InstructionCost C); }; /// \p PassName is the name of the pass emitting this diagnostic. \p diff --git a/llvm/include/llvm/Support/InstructionCost.h b/llvm/include/llvm/Support/InstructionCost.h new file mode 100644 --- /dev/null +++ b/llvm/include/llvm/Support/InstructionCost.h @@ -0,0 +1,245 @@ +//===- InstructionCost.h ----------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file defines an InstructionCost class that is used when calculating +/// the cost of an instruction, or a group of instructions. In addition to a +/// numeric value representing the cost the class also contains a state that +/// can be used to encode particular properties, i.e. a cost being invalid or +/// unknown. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_SUPPORT_INSTRUCTIONCOST_H +#define LLVM_SUPPORT_INSTRUCTIONCOST_H + +#include "llvm/ADT/Optional.h" + +namespace llvm { + +class raw_ostream; + +class InstructionCost { +public: + using CostType = int; + + /// These states can currently be used to indicate whether a cost is valid or + /// invalid. Examples of an invalid cost might be where the cost is + /// prohibitively expensive and the user wants to prevent certain + /// optimizations being performed. Or perhaps the cost is simply unknown + /// because the operation makes no sense in certain circumstances. These + /// states can be expanded in future to support other cases if necessary. + enum CostState { Valid, Invalid }; + +private: + CostType Value; + CostState State; + + void propagateState(const InstructionCost &RHS) { + if (RHS.State == Invalid) + State = Invalid; + } + +public: + InstructionCost() = default; + + InstructionCost(CostType Val) : Value(Val), State(Valid) {} + + static InstructionCost getInvalid(CostType Val = 0) { + InstructionCost Tmp(Val); + Tmp.setInvalid(); + return Tmp; + } + + bool isValid() const { return State == Valid; } + void setValid() { State = Valid; } + void setInvalid() { State = Invalid; } + CostState getState() const { return State; } + + /// This function is intended to be used as sparingly as possible, since the + /// class provides the full range of operator support required for arithmetic + /// and comparisons. + Optional getValue() const { + if (isValid()) + return Value; + return None; + } + + /// For all of the arithmetic operators provided here any invalid state is + /// perpetuated and cannot be removed. Once a cost becomes invalid it stays + /// invalid, and it also inherits any invalid state from the RHS. Regardless + /// of the state, arithmetic and comparisons work on the actual values in the + /// same way as they would on a basic type, such as integer. + + InstructionCost &operator+=(const InstructionCost &RHS) { + propagateState(RHS); + Value += RHS.Value; + return *this; + } + + InstructionCost &operator+=(const CostType RHS) { + InstructionCost RHS2(RHS); + *this += RHS2; + return *this; + } + + InstructionCost &operator-=(const InstructionCost &RHS) { + propagateState(RHS); + Value -= RHS.Value; + return *this; + } + + InstructionCost &operator-=(const CostType RHS) { + InstructionCost RHS2(RHS); + *this -= RHS2; + return *this; + } + + InstructionCost &operator*=(const InstructionCost &RHS) { + propagateState(RHS); + Value *= RHS.Value; + return *this; + } + + InstructionCost &operator*=(const CostType RHS) { + InstructionCost RHS2(RHS); + *this *= RHS2; + return *this; + } + + InstructionCost &operator/=(const InstructionCost &RHS) { + propagateState(RHS); + Value /= RHS.Value; + return *this; + } + + InstructionCost &operator/=(const CostType RHS) { + InstructionCost RHS2(RHS); + *this /= RHS2; + return *this; + } + + InstructionCost &operator++() { + *this += 1; + return *this; + } + + InstructionCost operator++(int) { + InstructionCost Copy = *this; + ++*this; + return Copy; + } + + InstructionCost &operator--() { + *this -= 1; + return *this; + } + + InstructionCost operator--(int) { + InstructionCost Copy = *this; + --*this; + return Copy; + } + + bool operator==(const InstructionCost &RHS) const { + return State == RHS.State && Value == RHS.Value; + } + + bool operator!=(const InstructionCost &RHS) const { return !(*this == RHS); } + + bool operator==(const CostType RHS) const { + return State == Valid && Value == RHS; + } + + bool operator!=(const CostType RHS) const { return !(*this == RHS); } + + /// For the comparison operators we have chosen to use total ordering with + /// the following rules: + /// 1. If either of the states != Valid then a lexicographical order is + /// applied based upon the state. + /// 2. If both states are valid then order based upon value. + /// This avoids having to add asserts the comparison operators that the states + /// are valid and users can test for validity of the cost explicitly. + bool operator<(const InstructionCost &RHS) const { + if (State != Valid || RHS.State != Valid) + return State < RHS.State; + return Value < RHS.Value; + } + + bool operator>(const InstructionCost &RHS) const { return RHS < *this; } + + bool operator<=(const InstructionCost &RHS) const { return !(RHS < *this); } + + bool operator>=(const InstructionCost &RHS) const { return !(*this < RHS); } + + bool operator<(const CostType RHS) const { + InstructionCost RHS2(RHS); + return *this < RHS2; + } + + bool operator>(const CostType RHS) const { + InstructionCost RHS2(RHS); + return *this > RHS2; + } + + bool operator<=(const CostType RHS) const { + InstructionCost RHS2(RHS); + return *this <= RHS2; + } + + bool operator>=(const CostType RHS) const { + InstructionCost RHS2(RHS); + return *this >= RHS2; + } + + static InstructionCost min(InstructionCost LHS, InstructionCost RHS) { + return LHS < RHS ? LHS : RHS; + } + + static InstructionCost max(InstructionCost LHS, InstructionCost RHS) { + return LHS > RHS ? LHS : RHS; + } + + void print(raw_ostream &OS) const; +}; + +inline InstructionCost operator+(const InstructionCost &LHS, + const InstructionCost &RHS) { + InstructionCost LHS2(LHS); + LHS2 += RHS; + return LHS2; +} + +inline InstructionCost operator-(const InstructionCost &LHS, + const InstructionCost &RHS) { + InstructionCost LHS2(LHS); + LHS2 -= RHS; + return LHS2; +} + +inline InstructionCost operator*(const InstructionCost &LHS, + const InstructionCost &RHS) { + InstructionCost LHS2(LHS); + LHS2 *= RHS; + return LHS2; +} + +inline InstructionCost operator/(const InstructionCost &LHS, + const InstructionCost &RHS) { + InstructionCost LHS2(LHS); + LHS2 /= RHS; + return LHS2; +} + +inline raw_ostream &operator<<(raw_ostream &OS, const InstructionCost &V) { + V.print(OS); + return OS; +} + +} // namespace llvm + +#endif diff --git a/llvm/lib/Analysis/CostModel.cpp b/llvm/lib/Analysis/CostModel.cpp --- a/llvm/lib/Analysis/CostModel.cpp +++ b/llvm/lib/Analysis/CostModel.cpp @@ -57,7 +57,7 @@ /// Returns -1 if the cost is unknown. /// Note, this method does not cache the cost calculation and it /// can be expensive in some cases. - unsigned getInstructionCost(const Instruction *I) const { + InstructionCost getInstructionCost(const Instruction *I) const { return TTI->getInstructionCost(I, TargetTransformInfo::TCK_RecipThroughput); } @@ -103,9 +103,9 @@ for (BasicBlock &B : *F) { for (Instruction &Inst : B) { - unsigned Cost = TTI->getInstructionCost(&Inst, CostKind); - if (Cost != (unsigned)-1) - OS << "Cost Model: Found an estimated cost of " << Cost; + InstructionCost Cost = TTI->getInstructionCost(&Inst, CostKind); + if (auto CostVal = Cost.getValue()) + OS << "Cost Model: Found an estimated cost of " << *CostVal; else OS << "Cost Model: Unknown cost"; diff --git a/llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp b/llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp --- a/llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp +++ b/llvm/lib/CodeGen/InterleavedLoadCombinePass.cpp @@ -1130,8 +1130,8 @@ std::set Is; std::set SVIs; - unsigned InterleavedCost; - unsigned InstructionCost = 0; + InstructionCost InterleavedCost; + InstructionCost InstructionCost = 0; // Get the interleave factor unsigned Factor = InterleavedLoad.size(); @@ -1174,6 +1174,10 @@ } } + // We need to have a valid cost in order to proceed. + if (!InstructionCost.isValid()) + return false; + // We know that all LoadInst are within the same BB. This guarantees that // either everything or nothing is loaded. LoadInst *First = findFirstLoad(LIs); diff --git a/llvm/lib/IR/DiagnosticInfo.cpp b/llvm/lib/IR/DiagnosticInfo.cpp --- a/llvm/lib/IR/DiagnosticInfo.cpp +++ b/llvm/lib/IR/DiagnosticInfo.cpp @@ -32,6 +32,7 @@ #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/InstructionCost.h" #include "llvm/Support/Path.h" #include "llvm/Support/Regex.h" #include "llvm/Support/ScopedPrinter.h" @@ -220,6 +221,13 @@ EC.print(OS); } +DiagnosticInfoOptimizationBase::Argument::Argument(StringRef Key, + InstructionCost C) + : Key(std::string(Key)) { + raw_string_ostream OS(Val); + C.print(OS); +} + DiagnosticInfoOptimizationBase::Argument::Argument(StringRef Key, DebugLoc Loc) : Key(std::string(Key)), Loc(Loc) { if (Loc) { diff --git a/llvm/lib/Support/CMakeLists.txt b/llvm/lib/Support/CMakeLists.txt --- a/llvm/lib/Support/CMakeLists.txt +++ b/llvm/lib/Support/CMakeLists.txt @@ -128,6 +128,7 @@ GraphWriter.cpp Hashing.cpp InitLLVM.cpp + InstructionCost.cpp IntEqClasses.cpp IntervalMap.cpp ItaniumManglingCanonicalizer.cpp diff --git a/llvm/lib/Support/InstructionCost.cpp b/llvm/lib/Support/InstructionCost.cpp new file mode 100644 --- /dev/null +++ b/llvm/lib/Support/InstructionCost.cpp @@ -0,0 +1,24 @@ +//===- InstructionCost.cpp --------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// \file +/// This file includes the function definitions for the InstructionCost class +/// that is used when calculating the cost of an instruction, or a group of +/// instructions. +//===----------------------------------------------------------------------===// + +#include "llvm/Support/InstructionCost.h" +#include "llvm/Support/raw_ostream.h" + +using namespace llvm; + +void InstructionCost::print(raw_ostream &OS) const { + if (isValid()) + OS << Value; + else + OS << "Invalid"; +} diff --git a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp --- a/llvm/lib/Transforms/IPO/HotColdSplitting.cpp +++ b/llvm/lib/Transforms/IPO/HotColdSplitting.cpp @@ -233,11 +233,11 @@ } /// Get the benefit score of outlining \p Region. -static int getOutliningBenefit(ArrayRef Region, - TargetTransformInfo &TTI) { +static InstructionCost getOutliningBenefit(ArrayRef Region, + TargetTransformInfo &TTI) { // Sum up the code size costs of non-terminator instructions. Tight coupling // with \ref getOutliningPenalty is needed to model the costs of terminators. - int Benefit = 0; + InstructionCost Benefit = 0; for (BasicBlock *BB : Region) for (Instruction &I : BB->instructionsWithoutDebug()) if (&I != BB->getTerminator()) @@ -324,12 +324,12 @@ // splitting. SetVector Inputs, Outputs, Sinks; CE.findInputsOutputs(Inputs, Outputs, Sinks); - int OutliningBenefit = getOutliningBenefit(Region, TTI); + InstructionCost OutliningBenefit = getOutliningBenefit(Region, TTI); int OutliningPenalty = getOutliningPenalty(Region, Inputs.size(), Outputs.size()); LLVM_DEBUG(dbgs() << "Split profitability: benefit = " << OutliningBenefit << ", penalty = " << OutliningPenalty << "\n"); - if (OutliningBenefit <= OutliningPenalty) + if (!OutliningBenefit.isValid() || OutliningBenefit <= OutliningPenalty) return nullptr; Function *OrigF = Region[0]->getParent(); diff --git a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp --- a/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp +++ b/llvm/lib/Transforms/Scalar/CallSiteSplitting.cpp @@ -208,7 +208,7 @@ // instructions before the call is less then DuplicationThreshold. The // instructions before the call will be duplicated in the split blocks and // corresponding uses will be updated. - unsigned Cost = 0; + InstructionCost Cost = 0; for (auto &InstBeforeCall : llvm::make_range(CallSiteBB->begin(), CB.getIterator())) { Cost += TTI.getInstructionCost(&InstBeforeCall, diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7192,8 +7192,12 @@ return std::min(CallCost, getVectorIntrinsicCost(CI, VF)); return CallCost; } - case Instruction::ExtractValue: - return TTI.getInstructionCost(I, TTI::TCK_RecipThroughput); + case Instruction::ExtractValue: { + InstructionCost ExtractCost = + TTI.getInstructionCost(I, TTI::TCK_RecipThroughput); + assert(ExtractCost.isValid() && "Invalid cost for ExtractValue"); + return *(ExtractCost.getValue()); + } default: // The cost of executing VF copies of the scalar instruction. This opcode // is unknown. Assume that it is the same as 'mul'. diff --git a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp --- a/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp +++ b/llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp @@ -3793,17 +3793,23 @@ if (NeedToShuffleReuses) { for (unsigned Idx : E->ReuseShuffleIndices) { Instruction *I = cast(VL[Idx]); - ReuseShuffleCost -= TTI->getInstructionCost(I, CostKind); + InstructionCost Cost = TTI->getInstructionCost(I, CostKind); + assert(Cost.isValid() && "Invalid instruction cost"); + ReuseShuffleCost -= *(Cost.getValue()); } for (Value *V : VL) { Instruction *I = cast(V); - ReuseShuffleCost += TTI->getInstructionCost(I, CostKind); + InstructionCost Cost = TTI->getInstructionCost(I, CostKind); + assert(Cost.isValid() && "Invalid instruction cost"); + ReuseShuffleCost += *(Cost.getValue()); } } for (Value *V : VL) { Instruction *I = cast(V); assert(E->isOpcodeOrAlt(I) && "Unexpected main/alternate opcode"); - ScalarCost += TTI->getInstructionCost(I, CostKind); + InstructionCost Cost = TTI->getInstructionCost(I, CostKind); + assert(Cost.isValid() && "Invalid instruction cost"); + ScalarCost += *(Cost.getValue()); } // VecCost is equal to sum of the cost of creating 2 vectors // and the cost of creating shuffle. diff --git a/llvm/unittests/Support/CMakeLists.txt b/llvm/unittests/Support/CMakeLists.txt --- a/llvm/unittests/Support/CMakeLists.txt +++ b/llvm/unittests/Support/CMakeLists.txt @@ -40,6 +40,7 @@ GlobPatternTest.cpp Host.cpp IndexedAccessorTest.cpp + InstructionCostTest.cpp ItaniumManglingCanonicalizerTest.cpp JSONTest.cpp KnownBitsTest.cpp diff --git a/llvm/unittests/Support/InstructionCostTest.cpp b/llvm/unittests/Support/InstructionCostTest.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Support/InstructionCostTest.cpp @@ -0,0 +1,64 @@ +//===- InstructionCostTest.cpp - InstructionCost tests --------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Support/InstructionCost.h" +#include "gtest/gtest.h" + +using namespace llvm; + +namespace { + +struct CostTest : public testing::Test { + CostTest() {} +}; + +} // namespace + +TEST_F(CostTest, Operators) { + InstructionCost VThree = 3; + InstructionCost VNegTwo = -2; + InstructionCost VSix = 6; + InstructionCost IThreeA = InstructionCost::getInvalid(3); + InstructionCost IThreeB = InstructionCost::getInvalid(3); + InstructionCost TmpCost; + + EXPECT_NE(VThree, VNegTwo); + EXPECT_GT(VThree, VNegTwo); + EXPECT_NE(VThree, IThreeA); + EXPECT_EQ(IThreeA, IThreeB); + EXPECT_GE(IThreeA, VNegTwo); + EXPECT_LT(VSix, IThreeA); + EXPECT_EQ(VSix - IThreeA, IThreeB); + EXPECT_EQ(VThree - VNegTwo, 5); + EXPECT_EQ(VThree * VNegTwo, -6); + EXPECT_EQ(VSix / VThree, 2); + + EXPECT_FALSE(IThreeA.isValid()); + EXPECT_EQ(IThreeA.getState(), InstructionCost::Invalid); + + TmpCost = VThree + IThreeA; + EXPECT_FALSE(TmpCost.isValid()); + + // Test increments, decrements + EXPECT_EQ(++VThree, 4); + EXPECT_EQ(VThree++, 4); + EXPECT_EQ(VThree, 5); + EXPECT_EQ(--VThree, 4); + EXPECT_EQ(VThree--, 4); + EXPECT_EQ(VThree, 3); + + TmpCost = VThree * IThreeA; + EXPECT_FALSE(TmpCost.isValid()); + + // Test value extraction + EXPECT_EQ(*(VThree.getValue()), 3); + EXPECT_EQ(IThreeA.getValue(), None); + + EXPECT_EQ(InstructionCost::min(VThree, VNegTwo), -2); + EXPECT_EQ(InstructionCost::max(VThree, VSix), 6); +}