Index: llvm/include/llvm/Analysis/VectorUtils.h =================================================================== --- llvm/include/llvm/Analysis/VectorUtils.h +++ llvm/include/llvm/Analysis/VectorUtils.h @@ -14,6 +14,7 @@ #define LLVM_ANALYSIS_VECTORUTILS_H #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/CheckedArithmetic.h" @@ -121,14 +122,20 @@ /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt /// -/// -/// /// \param MangledName -> input string in the format /// _ZGV_[()]. Optional tryDemangleForVFABI(StringRef MangledName); /// Retrieve the `VFParamKind` from a string token. VFParamKind getVFParamKindFromString(const StringRef Token); + +// Name of the attribute where the variant mappings are stored. +static constexpr char const *MappingsAttrName = "vector-function-abi-variant"; + +/// Populates a set of strings representing the Vector Function ABI variants +/// associated to the CallInst CI. +void getVectorVariantNames(const CallInst &CI, + SmallVector &VariantMappings); } // end namespace VFABI template class ArrayRef; @@ -137,7 +144,6 @@ template class InterleaveGroup; class Loop; class ScalarEvolution; -class TargetLibraryInfo; class TargetTransformInfo; class Type; class Value; Index: llvm/include/llvm/Transforms/Utils/ModuleUtils.h =================================================================== --- llvm/include/llvm/Transforms/Utils/ModuleUtils.h +++ llvm/include/llvm/Transforms/Utils/ModuleUtils.h @@ -13,6 +13,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_MODULEUTILS_H #define LLVM_TRANSFORMS_UTILS_MODULEUTILS_H +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include // for std::pair @@ -108,6 +109,23 @@ /// unique identifier for this module, so we return the empty string. std::string getUniqueModuleId(Module *M); +class TargetLibraryInfo; +class CallInst; +namespace VFABI { + +/// \defgroup Vector Function ABI (VABI) Module functions. +/// +/// Utility functions for VFABI data that can modify the module. +/// +/// @{ +/// Overwrite the Vector Function ABI variants attribute with the names provide +/// in \p VariantMappings. +void setVectorVariantNames(CallInst *CI, + const SmallVector &VariantMappings); + +/// @} +} // End VFABI namespace + } // End llvm namespace #endif // LLVM_TRANSFORMS_UTILS_MODULEUTILS_H Index: llvm/lib/Analysis/VFABIDemangling.cpp =================================================================== --- llvm/lib/Analysis/VFABIDemangling.cpp +++ llvm/lib/Analysis/VFABIDemangling.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Analysis/VectorUtils.h" using namespace llvm; @@ -34,7 +36,6 @@ .Case("d", VFISAKind::AVX2) .Case("e", VFISAKind::AVX512) .Default(VFISAKind::Unknown); - MangledName = MangledName.drop_front(1); return ParseRet::OK; @@ -338,7 +339,7 @@ } } while (ParamFound == ParseRet::OK); - // A valid MangledName mus have at least one valid entry in the + // A valid MangledName must have at least one valid entry in the // . if (Parameters.empty()) return None; Index: llvm/lib/Analysis/VectorUtils.cpp =================================================================== --- llvm/lib/Analysis/VectorUtils.cpp +++ llvm/lib/Analysis/VectorUtils.cpp @@ -1159,3 +1159,25 @@ propagateMetadata(NewInst, VL); } } + +void VFABI::getVectorVariantNames( + const CallInst &CI, SmallVector &VariantMappings) { + const StringRef S = + CI.getAttribute(AttributeList::FunctionIndex, VFABI::MappingsAttrName) + .getValueAsString(); + SmallVector ListAttr; + S.split(ListAttr, ","); + +#ifndef NDEBUG + const Module *M = CI.getModule(); +#endif + for (auto &S : SetVector(ListAttr.begin(), ListAttr.end())) { +#ifndef NDEBUG + Optional Info = VFABI::tryDemangleForVFABI(S); + assert(Info.hasValue() && "Invalid name for a VFABI variant."); + assert(M->getFunction(Info.getValue().VectorName) && + "Vector function is missing."); +#endif + VariantMappings.push_back(S); + } +} Index: llvm/lib/Transforms/Utils/ModuleUtils.cpp =================================================================== --- llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ModuleUtils.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -280,3 +281,137 @@ MD5::stringifyResult(R, Str); return ("$" + Str).str(); } + +/// Helper function to map the TLI name to a strings that holds +/// scalar-to-vector mapping. +/// +/// _ZGV_() +/// +/// where: +/// +/// = "_LLVM_TLI_" +/// = "N". Note: TLI does not support masked interfaces. +/// = Number of concurrent lanes, stored in the `VectorizationFactor` +/// field of the `VecDesc` struct. +/// = "v", as many as are the number of parameters of CI. +/// = the name of the scalar function called by CI. +/// = the name of the vector function mapped by the TLI. +static std::string mangleTLIName(StringRef VectorName, CallInst *CI, + unsigned VF) { + SmallString<256> Buffer; + llvm::raw_svector_ostream Out(Buffer); + Out << "_ZGV" << _LLVM_TLI_ << "N" << VF; + for (unsigned I = 0; I < CI->getNumArgOperands(); ++I) { + Out << "v"; + } + Out << "_" << CI->getCalledFunction()->getName() << "(" << VectorName << ")"; + return Out.str(); +} + +/// A helper function for converting Scalar types to vector types. +/// If the incoming type is void, we return void. If the VF is 1, we return +/// the scalar type. +static Type *ToVectorTy(Type *Scalar, unsigned VF) { + if (Scalar->isVoidTy() || VF == 1) + return Scalar; + return VectorType::get(Scalar, VF); +} + +/// A helper function that adds the vector function declaration that +/// vectorizes the CallInst CI with a vectorization factor of VF +/// lanes. The TLI assumes that all parameters and the return type of +/// CI (other than void) need to be widened to a VectorType of VF +/// lanes. +static void addVariantDeclaration(CallInst *CI, const unsigned VF, + const StringRef VFName) { + assert(CI && "Invalid CallInst."); + Module *M = CI->getParent()->getParent()->getParent(); + llvm::GlobalValue *Global = M->getNamedValue(VFName); + // Nothing to do if the function already exists in the module. + if (Global) + return; + + Type *RetTy = ToVectorTy(CI->getType(), VF); + SmallVector Tys; + for (Value *ArgOperand : CI->arg_operands()) + Tys.push_back(ToVectorTy(ArgOperand->getType(), VF)); + FunctionType *FTy = FunctionType::get(RetTy, Tys, /*isVarArg=*/false); + Function *VectorF = + Function::Create(FTy, Function::ExternalLinkage, VFName, M); + VectorF->copyAttributesFrom(CI->getCalledFunction()); + + // Make function declaration (without a body) "sticky" in the IR by + // listing them in the @llvm.compiler.used attribute + if (VectorF->size() == 0) { + Global = M->getNamedValue(VFName); + assert(Global && "Missing function declaration."); + appendToCompilerUsed(*M, {Global}); + } +} + +void VFABI::addMappingsFromTLI(const TargetLibraryInfo *TLI, CallInst *CI) { + assert(TLI && "Invalid TLI."); + assert(CI && "Invalid CallInst."); + + // This is needed to make sure we don't query the TLI for calls to + // bitcast of function pointers, like `%call = call i32 (i32*, ...) + // bitcast (i32 (...)* @goo to i32 (i32*, ...)*)(i32* nonnull %i)`, + // as such calls make the `isFunctionVectorizable` raise an + // exception. + if (CI->isNoBuiltin() || !CI->getCalledFunction()) + return; + + const std::string ScalarName = CI->getCalledFunction()->getName(); + // Nothing to be done if the TLI things the function is not + // vectorizable. + if (!TLI->isFunctionVectorizable(ScalarName)) + return; + + SmallSet SetOfMangledNames; + VFABI::getVectorVariantNames(CI, SetOfMangledNames); + Module *M = CI->getParent()->getParent()->getParent(); + + for (unsigned VF = 2; VF <= 16; VF *= 2) { + const std::string TLIName = TLI->getVectorizedFunction(ScalarName, VF); + if (TLIName != "") { + std::string MangledName = mangleTLIName(TLIName, CI, VF); + // List.push_back(MangledName); + SetOfMangledNames.insert(MangledName); + Function *VariantF = M->getFunction(TLIName); + if (!VariantF) + addVariantDeclaration(CI, VF, TLIName); + } + } + + VFABI::setVectorVariantNames(CI, SetOfMangledNames); +} + +void VFABI::setVectorVariantNames( + CallInst *CI, const SmallVector &VariantMappings) { + assert(CI && "Invalid CallInst"); + + if (VariantMappings.empty()) + return; + + SmallString<256> Buffer; + llvm::raw_svector_ostream Out(Buffer); + for (const std::string &VariantMapping : VariantMappings) + Out << VariantMapping << ","; + // Get rid of the trailing ','. + assert(!Buffer.str().empty() && "Must have at least one char."); + Buffer.pop_back(); + + Module *M = CI->getModule(); +#ifndef NDEBUG + for (const std::string &VariantMapping : VariantMappings) { + Optional VI = VFABI::tryDemangleForVFABI(VariantMapping); + assert(VI.hasValue() && "Canno add an invalid VFABI name."); + assert(M->getNamedValue(VI.getValue().VectorName) && + "Cannot add variant to attribute: " + "vector function declaration is missing."); + } +#endif + CI->addAttribute( + AttributeList::FunctionIndex, + Attribute::get(M->getContext(), MappingsAttrName, Buffer.str())); +} Index: llvm/unittests/Analysis/VectorFunctionABITest.cpp =================================================================== --- llvm/unittests/Analysis/VectorFunctionABITest.cpp +++ llvm/unittests/Analysis/VectorFunctionABITest.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/VectorUtils.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/InstIterator.h" #include "gtest/gtest.h" using namespace llvm; @@ -437,3 +439,37 @@ EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate})); EXPECT_EQ(ScalarName, "sin"); } + +class VFABIAttrTest : public testing::Test { +protected: + void SetUp() override { + M = parseAssemblyString(IR, Err, Ctx); + // Get the only call instruction in the block, which is the first + // instruction. + CI = dyn_cast(&*(instructions(M->getFunction("f")).begin())); + } + const char *IR = "define i32 @f(i32 %a) {\n" + " %1 = call i32 @g(i32 %a) #0\n" + " ret i32 %1\n" + "}\n" + "declare i32 @g(i32)\n" + "declare <2 x i32> @custom_vg(<2 x i32>)" + "declare <4 x i32> @_ZGVnN4v_g(<4 x i32>)" + "declare <8 x i32> @_ZGVnN8v_g(<8 x i32>)" + "attributes #0 = { " + "\"vector-function-abi-variant\"=\"" + "_ZGVnN2v_g(custom_vg),_ZGVnN4v_g\" }"; + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M; + CallInst *CI; + SmallVector Mappings; +}; + +TEST_F(VFABIAttrTest, Read) { + VFABI::getVectorVariantNames(*CI, Mappings); + SmallVector Exp; + Exp.push_back("_ZGVnN2v_g(custom_vg)"); + Exp.push_back("_ZGVnN4v_g"); + EXPECT_EQ(Mappings, Exp); +} Index: llvm/unittests/Transforms/Utils/CMakeLists.txt =================================================================== --- llvm/unittests/Transforms/Utils/CMakeLists.txt +++ llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -18,4 +18,5 @@ SSAUpdaterBulkTest.cpp UnrollLoopTest.cpp ValueMapperTest.cpp + VFABIUtils.cpp ) Index: llvm/unittests/Transforms/Utils/VFABIUtils.cpp =================================================================== --- /dev/null +++ llvm/unittests/Transforms/Utils/VFABIUtils.cpp @@ -0,0 +1,55 @@ +//===------- VFABIUtils.cpp - VFABI Unittests ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include "gtest/gtest.h" + +using namespace llvm; + +class VFABIAttrTest : public testing::Test { +protected: + void SetUp() override { + M = parseAssemblyString(IR, Err, Ctx); + // Get the only call instruction in the block, which is the first + // instruction. + CI = dyn_cast(&*(instructions(M->getFunction("f")).begin())); + } + const char *IR = "define i32 @f(i32 %a) {\n" + " %1 = call i32 @g(i32 %a) #0\n" + " ret i32 %1\n" + "}\n" + "declare i32 @g(i32)\n" + "declare <2 x i32> @custom_vg(<2 x i32>)" + "declare <4 x i32> @_ZGVnN4v_g(<4 x i32>)" + "declare <8 x i32> @_ZGVnN8v_g(<8 x i32>)" + "attributes #0 = { " + "\"vector-function-abi-variant\"=\"" + "_ZGVnN2v_g(custom_vg),_ZGVnN4v_g\" }"; + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M; + CallInst *CI; + SmallVector Mappings; +}; + +TEST_F(VFABIAttrTest, Write) { + Mappings.push_back("_ZGVnN8v_g"); + Mappings.push_back("_ZGVnN2v_g(custom_vg)"); + VFABI::setVectorVariantNames(CI, Mappings); + const AttributeList Attrs = CI->getAttributes(); + const AttributeSet FnAttrs = Attrs.getFnAttributes(); + const StringRef S = CI->getAttribute(AttributeList::FunctionIndex, + "vector-function-abi-variant") + .getValueAsString(); + EXPECT_EQ(S, "_ZGVnN8v_g,_ZGVnN2v_g(custom_vg)"); +}