Index: llvm/include/llvm/Analysis/TargetTransformInfo.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfo.h +++ llvm/include/llvm/Analysis/TargetTransformInfo.h @@ -1510,6 +1510,14 @@ bool areTypesABICompatible(const Function *Caller, const Function *Callee, const ArrayRef &Types) const; + /// \returns True if \p ParamTy at a call base can be passed to \p ArgTy at a + /// callee. + bool isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy) const; + + /// \returns True if \p CB may legally call \p Callee via an indirect call. + bool isValidCallBaseForCallee(const CallBase *CB, + const Function *Callee) const; + /// The type of load/store indexing. enum MemIndexedMode { MIM_Unindexed, ///< No indexing. @@ -1998,6 +2006,9 @@ virtual bool areTypesABICompatible(const Function *Caller, const Function *Callee, const ArrayRef &Types) const = 0; + virtual bool isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy) const = 0; + virtual bool isValidCallBaseForCallee(const CallBase *CB, + const Function *Callee) const = 0; virtual bool isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const = 0; virtual bool isIndexedStoreLegal(MemIndexedMode Mode, Type *Ty) const = 0; virtual unsigned getLoadStoreVecRegBitWidth(unsigned AddrSpace) const = 0; @@ -2654,6 +2665,13 @@ const ArrayRef &Types) const override { return Impl.areTypesABICompatible(Caller, Callee, Types); } + bool isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy) const override { + return Impl.isValidTypePairForCallEdge(ParamTy, ArgTy); + } + bool isValidCallBaseForCallee(const CallBase *CB, + const Function *Callee) const override { + return Impl.isValidCallBaseForCallee(CB, Callee); + } bool isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const override { return Impl.isIndexedLoadLegal(Mode, Ty, getDataLayout()); } Index: llvm/include/llvm/Analysis/TargetTransformInfoImpl.h =================================================================== --- llvm/include/llvm/Analysis/TargetTransformInfoImpl.h +++ llvm/include/llvm/Analysis/TargetTransformInfoImpl.h @@ -17,6 +17,7 @@ #include "llvm/Analysis/ScalarEvolutionExpressions.h" #include "llvm/Analysis/TargetTransformInfo.h" #include "llvm/Analysis/VectorUtils.h" +#include "llvm/IR/Attributes.h" #include "llvm/IR/DataLayout.h" #include "llvm/IR/GetElementPtrTypeIterator.h" #include "llvm/IR/IntrinsicInst.h" @@ -806,6 +807,64 @@ Callee->getFnAttribute("target-features")); } + bool isValidTypePairForCallEdge(Type *ParamTy, Type *ArgTy) const { + if (ParamTy == ArgTy) + return true; + // Scalable vs non-scalable, not OK. + bool ScalableAT = ArgTy->isScalableTy(); + if (ScalableAT != ParamTy->isScalableTy()) + return false; + // Sized vs non-sized, not OK. + bool SizedAT = ArgTy->isSized(); + if (SizedAT != ParamTy->isSized()) + return false; + // Non-sized and non-scalable, check the sizes: different sizes, not OK. + if (SizedAT && !ScalableAT) + if (DL.getTypeSizeInBits(ParamTy) != DL.getTypeSizeInBits(ArgTy)) + return false; + // Floating point vs non-floating point, not OK. + if (ParamTy->isFloatingPointTy() != ArgTy->isFloatingPointTy()) + return false; + + if (ParamTy->isPointerTy()) { + // Pointer vs non-pointer, not OK. + if (ArgTy->isPointerTy()) + return false; + // Implicit non-trivial AS cast, not OK. + if (!isNoopAddrSpaceCast(ParamTy->getPointerAddressSpace(), + ArgTy->getPointerAddressSpace())) + return false; + } + return true; + } + + bool isValidCallBaseForCallee(const CallBase *CB, + const Function *Callee) const { + unsigned FnNumArgs = Callee->arg_size(); + unsigned CBNumArgs = CB->arg_size(); + + // Argument mistmatch, sometimes OK for variadic functions. + if (FnNumArgs > CBNumArgs) + return false; + if (FnNumArgs < CBNumArgs && !Callee->isVarArg()) + return false; + + if (!isValidTypePairForCallEdge(CB->getType(), Callee->getReturnType())) + return false; + + for (unsigned ArgNo = 0; ArgNo < FnNumArgs; ++ArgNo) + if (!isValidTypePairForCallEdge(CB->getArgOperand(ArgNo)->getType(), + Callee->getArg(ArgNo)->getType())) + return false; + + // TODO: Check FnNumArgs till CBNumArgs for OK variadic argument + // types/attributes. + + AttributeList CBAttrs = CB->getAttributes(); + AttributeList FnAttrs = Callee->getAttributes(); + return AttributeFuncs::areCallCompatible(CBAttrs, FnAttrs); + } + bool isIndexedLoadLegal(TTI::MemIndexedMode Mode, Type *Ty, const DataLayout &DL) const { return false; Index: llvm/include/llvm/IR/Attributes.h =================================================================== --- llvm/include/llvm/IR/Attributes.h +++ llvm/include/llvm/IR/Attributes.h @@ -1224,6 +1224,9 @@ /// attributes for inlining purposes. bool areInlineCompatible(const Function &Caller, const Function &Callee); +/// \returns Return true if the attributes at the call base \p CallBaseAttrs +/// allow a call to the callee with the attributes in \p CalleeAttrs. +bool areCallCompatible(AttributeList CallBaseAttrs, AttributeList CalleeAttrs); /// Checks if there are any incompatible function attributes between /// \p A and \p B. Index: llvm/lib/Analysis/TargetTransformInfo.cpp =================================================================== --- llvm/lib/Analysis/TargetTransformInfo.cpp +++ llvm/lib/Analysis/TargetTransformInfo.cpp @@ -1128,6 +1128,16 @@ return TTIImpl->areTypesABICompatible(Caller, Callee, Types); } +bool TargetTransformInfo::isValidTypePairForCallEdge(Type *ParamTy, + Type *ArgTy) const { + return TTIImpl->isValidTypePairForCallEdge(ParamTy, ArgTy); +} + +bool TargetTransformInfo::isValidCallBaseForCallee( + const CallBase *CB, const Function *Callee) const { + return TTIImpl->isValidCallBaseForCallee(CB, Callee); +} + bool TargetTransformInfo::isIndexedLoadLegal(MemIndexedMode Mode, Type *Ty) const { return TTIImpl->isIndexedLoadLegal(Mode, Ty); Index: llvm/lib/IR/Attributes.cpp =================================================================== --- llvm/lib/IR/Attributes.cpp +++ llvm/lib/IR/Attributes.cpp @@ -2214,6 +2214,45 @@ return hasCompatibleFnAttrs(Caller, Callee); } +bool AttributeFuncs::areCallCompatible(AttributeList CallBaseAttrs, + AttributeList CalleeAttrs) { + // Match sext/zext for return values. + SmallVector RetABIAttrs = {Attribute::SExt, + Attribute::ZExt}; + // Match all ABI attributes for parameters. + SmallVector ParamABIAttrs = { + Attribute::SExt, Attribute::ZExt, Attribute::ByRef, + Attribute::ByRef, Attribute::InAlloca, Attribute::InReg, + Attribute::Preallocated, Attribute::StructRet}; + + if (llvm::all_of(ParamABIAttrs, [&](Attribute::AttrKind AK) { + return !CallBaseAttrs.hasAttrSomewhere(AK) && + !CalleeAttrs.hasAttrSomewhere(AK); + })) + return true; + + AttributeSet CBRetAS = CallBaseAttrs.getRetAttrs(); + AttributeSet FnRetAS = CalleeAttrs.getRetAttrs(); + + for (Attribute::AttrKind AK : RetABIAttrs) + if (CBRetAS.hasAttribute(AK) != FnRetAS.hasAttribute(AK)) + return false; + + unsigned NumArgs = + std::max(CallBaseAttrs.getNumAttrSets(), CalleeAttrs.getNumAttrSets()) - + AttributeList::FirstArgIndex; + + for (unsigned ArgNo = 0; ArgNo < NumArgs; ++ArgNo) { + AttributeSet CBParamAS = CallBaseAttrs.getParamAttrs(ArgNo); + AttributeSet FnParamAS = CalleeAttrs.getParamAttrs(ArgNo); + for (Attribute::AttrKind AK : ParamABIAttrs) + if (CBParamAS.hasAttribute(AK) != FnParamAS.hasAttribute(AK)) + return false; + } + + return true; +} + bool AttributeFuncs::areOutlineCompatible(const Function &A, const Function &B) { return hasCompatibleFnAttrs(A, B);