diff --git a/llvm/include/llvm/Analysis/TargetLibraryInfo.h b/llvm/include/llvm/Analysis/TargetLibraryInfo.h --- a/llvm/include/llvm/Analysis/TargetLibraryInfo.h +++ b/llvm/include/llvm/Analysis/TargetLibraryInfo.h @@ -199,6 +199,7 @@ unsigned getWCharSize(const Module &M) const; }; +class SearchVFSystem; /// Provides information about what library functions are available for /// the current target. /// @@ -207,7 +208,7 @@ class TargetLibraryInfo { friend class TargetLibraryAnalysis; friend class TargetLibraryInfoWrapperPass; - + friend class SearchVFSystem; const TargetLibraryInfoImpl *Impl; public: @@ -248,6 +249,8 @@ bool has(LibFunc F) const { return Impl->getState(F) != TargetLibraryInfoImpl::Unavailable; } + +private: bool isFunctionVectorizable(StringRef F, unsigned VF) const { return Impl->isFunctionVectorizable(F, VF); } @@ -258,6 +261,7 @@ return Impl->getVectorizedFunction(F, VF); } +public: /// Tests if the function is both available and a candidate for optimized code /// generation. bool hasOptimizedCodeGen(LibFunc F) const { diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -15,6 +15,7 @@ #include "llvm/ADT/MapVector.h" #include "llvm/Analysis/LoopAccessAnalysis.h" +#include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/CheckedArithmetic.h" @@ -158,13 +159,82 @@ } }; +class SearchVFSystem { +private: + const TargetLibraryInfo *TLI; + SmallVector extractMangledNames(CallInst *CI) const { + const StringRef AttrString = + CI->getCalledFunction() + ->getFnAttribute("vector-function-abi-variant") + .getValueAsString(); + SmallVector ListOfStrings; + AttrString.split(ListOfStrings, ","); + return ListOfStrings; + } + +public: + explicit SearchVFSystem(const TargetLibraryInfo *TLI) : TLI(TLI) {} + SmallVector getVFMappings(CallInst *CI) const { + const StringRef ScalarName = CI->getCalledFunction()->getName(); + const SmallVector ListOfStrings = extractMangledNames(CI); + + SmallVector Ret; + for (auto MangledName : ListOfStrings) { + auto Shape = VFInfo::getFromVFABI(MangledName); + if (Shape.hasValue() && Shape.getValue().ScalarName == ScalarName) + Ret.push_back(Shape.getValue()); + } + + return Ret; + } + bool isFunctionVectorizable(CallInst *CI) const { + if (!getVFMappings(CI).empty()) + return true; + + if (TLI) + return TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()); + + return false; + } + bool isFunctionVectorizable(CallInst *CI, unsigned VF) const { + + const auto Shapes = getVFMappings(CI); + + if (std::any_of(Shapes.begin(), Shapes.end(), + [&VF](VFInfo S) { return S.Shape.VF == VF; })) + return true; + + if (TLI) + return TLI->isFunctionVectorizable(CI->getCalledFunction()->getName(), + VF); + + return false; + } + StringRef getVectorizedFunction(CallInst *CI, unsigned VF) const { + const auto Mappings = getVFMappings(CI); + for (const auto &Info : Mappings) + if (Info.Shape.VF == VF) + return Info.VectorName; + + if (TLI) + return TLI->getVectorizedFunction(CI->getCalledFunction()->getName(), VF); + + return ""; + } + bool isKnownVectorFunctionInLibrary(StringRef F) const { + if (TLI) + return TLI->isFunctionVectorizable(F); + + return false; + } +}; + template class ArrayRef; class DemandedBits; class GetElementPtrInst; template class InterleaveGroup; class Loop; class ScalarEvolution; -class TargetLibraryInfo; class TargetTransformInfo; class Type; class Value; diff --git a/llvm/lib/Analysis/LazyCallGraph.cpp b/llvm/lib/Analysis/LazyCallGraph.cpp --- a/llvm/lib/Analysis/LazyCallGraph.cpp +++ b/llvm/lib/Analysis/LazyCallGraph.cpp @@ -15,6 +15,7 @@ #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/iterator_range.h" #include "llvm/Analysis/TargetLibraryInfo.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/Config/llvm-config.h" #include "llvm/IR/CallSite.h" #include "llvm/IR/Function.h" @@ -146,8 +147,11 @@ static bool isKnownLibFunction(Function &F, TargetLibraryInfo &TLI) { LibFunc LF; - // Either this is a normal library function or a "vectorizable" function. - return TLI.getLibFunc(F, LF) || TLI.isFunctionVectorizable(F.getName()); + // Either this is a normal library function or a "vectorizable" + // function. Not using the SearchVFSystem here because this query + // is related only to libraries handled via the TLI. + return TLI.getLibFunc(F, LF) || + SearchVFSystem(&TLI).isKnownVectorFunctionInLibrary(F.getName()); } LazyCallGraph::LazyCallGraph( diff --git a/llvm/lib/Analysis/LoopAccessAnalysis.cpp b/llvm/lib/Analysis/LoopAccessAnalysis.cpp --- a/llvm/lib/Analysis/LoopAccessAnalysis.cpp +++ b/llvm/lib/Analysis/LoopAccessAnalysis.cpp @@ -1844,7 +1844,7 @@ // If the function has an explicit vectorized counterpart, we can safely // assume that it can be vectorized. if (Call && !Call->isNoBuiltin() && Call->getCalledFunction() && - TLI->isFunctionVectorizable(Call->getCalledFunction()->getName())) + SearchVFSystem(TLI).isFunctionVectorizable(Call)) continue; auto *Ld = dyn_cast(&I); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -669,7 +669,7 @@ if (CI && !getVectorIntrinsicIDForCall(CI, TLI) && !isa(CI) && !(CI->getCalledFunction() && TLI && - TLI->isFunctionVectorizable(CI->getCalledFunction()->getName()))) { + SearchVFSystem(TLI).isFunctionVectorizable(CI))) { // If the call is a recognized math libary call, it is likely that // we can vectorize it given loosened floating-point constraints. LibFunc Func; diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -3183,7 +3183,6 @@ unsigned VF, bool &NeedToScalarize) { Function *F = CI->getCalledFunction(); - StringRef FnName = CI->getCalledFunction()->getName(); Type *ScalarRetTy = CI->getType(); SmallVector Tys, ScalarTys; for (auto &ArgOp : CI->arg_operands()) @@ -3211,7 +3210,8 @@ // If we can't emit a vector call for this function, then the currently found // cost is the cost we need to return. NeedToScalarize = true; - if (!TLI || !TLI->isFunctionVectorizable(FnName, VF) || CI->isNoBuiltin()) + if (!TLI || !SearchVFSystem(TLI).isFunctionVectorizable(CI, VF) || + CI->isNoBuiltin()) return Cost; // If the corresponding vector cost is cheaper, return its cost. @@ -4225,7 +4225,6 @@ Module *M = I.getParent()->getParent()->getParent(); auto *CI = cast(&I); - StringRef FnName = CI->getCalledFunction()->getName(); Function *F = CI->getCalledFunction(); Type *RetTy = ToVectorTy(CI->getType(), VF); SmallVector Tys; @@ -4264,7 +4263,7 @@ VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); } else { // Use vector version of the library call. - StringRef VFnName = TLI->getVectorizedFunction(FnName, VF); + StringRef VFnName = SearchVFSystem(TLI).getVectorizedFunction(CI, VF); assert(!VFnName.empty() && "Vector function name is empty."); VectorF = M->getFunction(VFnName); if (!VectorF) {