Index: include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- include/llvm/Analysis/TargetTransformInfo.h +++ include/llvm/Analysis/TargetTransformInfo.h @@ -216,9 +216,23 @@ /// other context they may not be folded. This routine can distinguish such /// cases. /// + /// \p Operands is a list of operands which can be a result of transformations + /// of the current operands. The number of the operands on the list must equal + /// to the number of the current operands the IR user has. Their order on the + /// list must be the same as the order of the current operands the IR user + /// has. + /// /// The returned cost is defined in terms of \c TargetCostConstants, see its /// comments for a detailed explanation of the cost values. - int getUserCost(const User *U) const; + int getUserCost(const User *U, ArrayRef Operands) const; + + /// \brief This is a helper function which calls the two-argument getUserCost + /// with \p Operands which are the current operands U has. + int getUserCost(const User *U) const { + SmallVector Operands(U->value_op_begin(), + U->value_op_end()); + return getUserCost(U, std::move(Operands)); + } /// \brief Return true if branch divergence exists. /// @@ -823,7 +837,8 @@ ArrayRef Arguments) = 0; virtual unsigned getEstimatedNumberOfCaseClusters(const SwitchInst &SI, unsigned &JTSize) = 0; - virtual int getUserCost(const User *U) = 0; + virtual int + getUserCost(const User *U, ArrayRef Operands) = 0; virtual bool hasBranchDivergence() = 0; virtual bool isSourceOfDivergence(const Value *V) = 0; virtual bool isAlwaysUniform(const Value *V) = 0; @@ -998,7 +1013,9 @@ ArrayRef Arguments) override { return Impl.getIntrinsicCost(IID, RetTy, Arguments); } - int getUserCost(const User *U) override { return Impl.getUserCost(U); } + int getUserCost(const User *U, ArrayRef Operands) override { + return Impl.getUserCost(U, Operands); + } bool hasBranchDivergence() override { return Impl.hasBranchDivergence(); } bool isSourceOfDivergence(const Value *V) override { return Impl.isSourceOfDivergence(V); Index: include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- include/llvm/Analysis/TargetTransformInfoImpl.h +++ include/llvm/Analysis/TargetTransformInfoImpl.h @@ -684,14 +684,14 @@ return static_cast(this)->getIntrinsicCost(IID, RetTy, ParamTys); } - unsigned getUserCost(const User *U) { + unsigned getUserCost(const User *U, ArrayRef Operands) { if (isa(U)) return TTI::TCC_Free; // Model all PHI nodes as free. if (const GEPOperator *GEP = dyn_cast(U)) { - SmallVector Indices(GEP->idx_begin(), GEP->idx_end()); - return static_cast(this)->getGEPCost( - GEP->getSourceElementType(), GEP->getPointerOperand(), Indices); + return static_cast(this)->getGEPCost(GEP->getSourceElementType(), + GEP->getPointerOperand(), + Operands.drop_front()); } if (auto CS = ImmutableCallSite(U)) { Index: lib/Analysis/TargetTransformInfo.cpp =================================================================== --- lib/Analysis/TargetTransformInfo.cpp +++ lib/Analysis/TargetTransformInfo.cpp @@ -89,8 +89,9 @@ return TTIImpl->getEstimatedNumberOfCaseClusters(SI, JTSize); } -int TargetTransformInfo::getUserCost(const User *U) const { - int Cost = TTIImpl->getUserCost(U); +int TargetTransformInfo::getUserCost(const User *U, + ArrayRef Operands) const { + int Cost = TTIImpl->getUserCost(U, Operands); assert(Cost >= 0 && "TTI should not produce negative costs!"); return Cost; } Index: lib/Target/Hexagon/HexagonTargetTransformInfo.h =================================================================== --- lib/Target/Hexagon/HexagonTargetTransformInfo.h +++ lib/Target/Hexagon/HexagonTargetTransformInfo.h @@ -61,7 +61,7 @@ /// @} - int getUserCost(const User *U); + int getUserCost(const User *U, ArrayRef Operands); }; } // end namespace llvm Index: lib/Target/Hexagon/HexagonTargetTransformInfo.cpp =================================================================== --- lib/Target/Hexagon/HexagonTargetTransformInfo.cpp +++ lib/Target/Hexagon/HexagonTargetTransformInfo.cpp @@ -46,8 +46,9 @@ return getST()->getL1CacheLineSize(); } -int HexagonTTIImpl::getUserCost(const User *U) { - auto isCastFoldedIntoLoad = [] (const CastInst *CI) -> bool { +int HexagonTTIImpl::getUserCost(const User *U, + ArrayRef Operands) { + auto isCastFoldedIntoLoad = [](const CastInst *CI) -> bool { if (!CI->isIntegerCast()) return false; const LoadInst *LI = dyn_cast(CI->getOperand(0)); @@ -67,5 +68,5 @@ if (const CastInst *CI = dyn_cast(U)) if (isCastFoldedIntoLoad(CI)) return TargetTransformInfo::TCC_Free; - return BaseT::getUserCost(U); + return BaseT::getUserCost(U, Operands); }