Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -216,9 +216,23 @@ /// other context they may not be folded. This routine can distinguish such /// cases. /// + /// To improve the estimation of the cost an optional list of operands can be + /// provided to be used instead of those the IR user currently uses. + /// The operands on the list can be a result of transformations of the current + /// operands. The number of the operands on the list must equal to the number + /// of the current operands the IR user has. Their order on the list must be + /// the same as the order of the current operands the IR user has. + /// /// The returned cost is defined in terms of \c TargetCostConstants, see its /// comments for a detailed explanation of the cost values. - int getUserCost(const User *U) const; + int getUserCost(const User *U, ArrayRef Operands) const; + + /// \brief This is a helper function which calls the two-argument getUserCost + /// with an empty Operands list. + int getUserCost(const User *U) const { + ArrayRef Operands{}; + return getUserCost(U, Operands); + } /// \brief Return true if branch divergence exists. /// @@ -823,7 +837,8 @@ ArrayRef Arguments) = 0; virtual unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, unsigned &JTSize) = 0; - virtual int getUserCost(const User *U) = 0; + virtual int getUserCost(const User *U, + ArrayRef Operands) = 0; virtual bool hasBranchDivergence() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; virtual bool isAlwaysUniform(const Value *V) = 0; @@ -998,7 +1013,10 @@ ArrayRef Arguments) override { return Impl.getIntrinsicCost(IID, RetTy, Arguments); } - int getUserCost(const User *U) override { return Impl.getUserCost(U); } + int getUserCost(const User *U, + ArrayRef Operands) override { + return Impl.getUserCost(U, Operands); + } bool hasBranchDivergence() override { return Impl.hasBranchDivergence(); } bool isSourceOfDivergence(const Value *V) override { return Impl.isSourceOfDivergence(V); Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -684,12 +684,17 @@ return static_cast(this)->getIntrinsicCost(IID, RetTy, ParamTys); } - unsigned getUserCost(const User *U) { + unsigned getUserCost(const User *U, + ArrayRef Operands) { if (isa(U)) return TTI::TCC_Free; // Model all PHI nodes as free. if (const GEPOperator *GEP = dyn_cast(U)) { - SmallVector Indices(GEP->idx_begin(), GEP->idx_end()); + SmallVector Indices; + if (!Operands.empty()) + Indices.append(std::next(Operands.begin()), Operands.end()); + else + Indices.append(GEP->idx_begin(), GEP->idx_end()); return static_cast(this)->getGEPCost( GEP->getSourceElementType(), GEP->getPointerOperand(), Indices); } Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -89,8 +89,9 @@ return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); } -int TargetTransformInfo::getUserCost(const User *U) const { - int Cost = TTIImpl->getUserCost(U); +int TargetTransformInfo::getUserCost( + const User *U, ArrayRef Operands) const { + int Cost = TTIImpl->getUserCost(U, Operands); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } Index: lib/Target/Hexagon/HexagonTargetTransformInfo.h =================================================================== --- lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -61,7 +61,7 @@ /// @} - int getUserCost(const User *U); + int getUserCost(const User *U, ArrayRef Operands); }; } // end namespace llvm Index: lib/Target/Hexagon/HexagonTargetTransformInfo.cpp =================================================================== --- lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -46,8 +46,9 @@ return getST()->getL1CacheLineSize(); } -int HexagonTTIImpl::getUserCost(const User *U) { - auto isCastFoldedIntoLoad = [] (const CastInst *CI) -> bool { +int HexagonTTIImpl::getUserCost(const User *U, + ArrayRef Operands) { + auto isCastFoldedIntoLoad = [](const CastInst *CI) -> bool { if (!CI->isIntegerCast()) return false; const LoadInst *LI = dyn_cast(CI->getOperand(0)); @@ -67,5 +68,5 @@ if (const CastInst *CI = dyn_cast(U)) if (isCastFoldedIntoLoad(CI)) return TargetTransformInfo::TCC_Free; - return BaseT::getUserCost(U); + return BaseT::getUserCost(U, Operands); }