Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1510,6 +1510,16 @@ bool areTypesABICompatible(const Function *Caller, const Function *Callee, const ArrayRef &Types) const; + /// \returns True if \p ParamTy with \p ParamAS at a call site can be passed + /// to \p ArgTy with \p ArgAS at a callee. + bool isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy, + AttributeSet ParamAS, + AttributeSet ArgAS) const; + + /// \returns True if \p CB may legally call \p Callee via an indirect call. + bool isValidCallBaseForCallee(const CallBase *CB, + const Function *Callee) const; + /// The type of load/store indexing. enum MemIndexedMode { MIM_Unindexed, ///< No indexing. @@ -1998,6 +2008,11 @@ virtual bool areTypesABICompatible(const Function *Caller, const Function *Callee, const ArrayRef &Types) const = 0; + virtual bool isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy, + AttributeSet ParamAS, + AttributeSet ArgAS) const = 0; + virtual bool isValidCallBaseForCallee(const CallBase *CB, + const Function *Callee) const = 0; virtual bool isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const = 0; virtual bool isIndexedStoreLegal(MemIndexedMode Mode, Type *Ty) const = 0; virtual unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const = 0; @@ -2654,6 +2669,15 @@ const ArrayRef &Types) const override { return Impl.areTypesABICompatible(Caller, Callee, Types); } + bool isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy, + AttributeSet ParamAS, + AttributeSet ArgAS) const override { + return Impl.isValidTypePairForCallEdge(ParamTy, ArgTy, ParamAS, ArgAS); + } + bool isValidCallBaseForCallee(const CallBase *CB, + const Function *Callee) const override { + return Impl.isValidCallBaseForCallee(CB, Callee); + } bool isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const override { return Impl.isIndexedLoadLegal(Mode, Ty, getDataLayout()); } Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -17,6 +17,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" @@ -806,6 +807,62 @@ Callee->getFnAttribute("target-features")); } + bool isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy, + AttributeSet ParamAS, + AttributeSet ArgAS) const { + if (ParamTy != ArgTy) { + bool ScalableAT = ArgTy->isScalableTy(); + if (ScalableAT != ParamTy->isScalableTy()) + return false; + bool SizedAT = ArgTy->isSized(); + if (SizedAT != ParamTy->isSized()) + return false; + if (SizedAT && !ScalableAT) + if (DL.getTypeSizeInBits(ParamTy) != DL.getTypeSizeInBits(ArgTy)) + return false; + if (ParamTy->isFloatingPointTy() != ArgTy->isFloatingPointTy()) + return false; + } + if (!ParamTy->isPointerTy() && !ArgTy->isPointerTy()) + return true; + if (!ParamAS.hasAttributes() && !ArgAS.hasAttributes()) + return true; + if (ParamAS.getByValType() != ArgAS.getByValType()) + return false; + if (ParamAS.getByRefType() != ArgAS.getByRefType()) + return false; + if (ParamAS.getInAllocaType() != ArgAS.getInAllocaType()) + return false; + if (ParamAS.getStructRetType() != ArgAS.getStructRetType()) + return false; + if (ParamAS.getPreallocatedType() != ArgAS.getPreallocatedType()) + return false; + return true; + } + + bool isValidCallBaseForCallee(const CallBase *CB, + const Function *Callee) const { + unsigned NumArgs = CB->arg_size(); + if (Callee->arg_size() != NumArgs) + return false; + + if (!isValidTypePairForCallEdge(CB->getType(), Callee->getReturnType(), + AttributeSet(), AttributeSet())) + return false; + + AttributeList CBAttrs = CB->getAttributes(); + AttributeList FnAttrs = Callee->getAttributes(); + for (unsigned ArgNo = 0; ArgNo < NumArgs; ++ArgNo) { + if (!isValidTypePairForCallEdge(CB->getArgOperand(ArgNo)->getType(), + Callee->getArg(ArgNo)->getType(), + CBAttrs.getParamAttrs(ArgNo), + FnAttrs.getParamAttrs(ArgNo))) + return false; + } + + return true; + } + bool isIndexedLoadLegal(TTI::MemIndexedMode Mode, Type *Ty, const DataLayout &DL) const { return false; Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1128,6 +1128,17 @@ return TTIImpl->areTypesABICompatible(Caller, Callee, Types); } +bool TargetTransformInfo::isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy, + AttributeSet ParamAS, + AttributeSet ArgAS) const { + return TTIImpl->isValidTypePairForCallEdge(ParamTy, ArgTy, ParamAS, ArgAS); +} + +bool TargetTransformInfo::isValidCallBaseForCallee( + const CallBase *CB, const Function *Callee) const { + return TTIImpl->isValidCallBaseForCallee(CB, Callee); +} + bool TargetTransformInfo::isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const { return TTIImpl->isIndexedLoadLegal(Mode, Ty);