diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -14,6 +14,7 @@ #define LLVM_ANALYSIS_VECTORUTILS_H #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/CheckedArithmetic.h" @@ -121,14 +122,20 @@ /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt /// -/// -/// /// \param MangledName -> input string in the format /// _ZGV_[()]. Optional tryDemangleForVFABI(StringRef MangledName); /// Retrieve the `VFParamKind` from a string token. VFParamKind getVFParamKindFromString(const StringRef Token); + +// Name of the attribute where the variant mappings are stored. +static constexpr char const *MappingsAttrName = "vector-function-abi-variant"; + +/// Populates a set of strings representing the Vector Function ABI variants +/// associated to the CallInst CI. +void getVectorVariantNames(const CallInst &CI, + SmallVectorImpl &VariantMappings); } // end namespace VFABI template class ArrayRef; @@ -137,7 +144,6 @@ template class InterleaveGroup; class Loop; class ScalarEvolution; -class TargetLibraryInfo; class TargetTransformInfo; class Type; class Value; diff --git a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h --- a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h +++ b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h @@ -13,6 +13,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_MODULEUTILS_H #define LLVM_TRANSFORMS_UTILS_MODULEUTILS_H +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include // for std::pair @@ -108,6 +109,23 @@ /// unique identifier for this module, so we return the empty string. std::string getUniqueModuleId(Module *M); +class TargetLibraryInfo; +class CallInst; +namespace VFABI { + +/// \defgroup Vector Function ABI (VABI) Module functions. +/// +/// Utility functions for VFABI data that can modify the module. +/// +/// @{ +/// Overwrite the Vector Function ABI variants attribute with the names provide +/// in \p VariantMappings. +void setVectorVariantNames(CallInst *CI, + const SmallVector &VariantMappings); + +/// @} +} // End VFABI namespace + } // End llvm namespace #endif // LLVM_TRANSFORMS_UTILS_MODULEUTILS_H diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp --- a/llvm/lib/Analysis/VFABIDemangling.cpp +++ b/llvm/lib/Analysis/VFABIDemangling.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Analysis/VectorUtils.h" using namespace llvm; @@ -34,7 +36,6 @@ .Case("d", VFISAKind::AVX2) .Case("e", VFISAKind::AVX512) .Default(VFISAKind::Unknown); - MangledName = MangledName.drop_front(1); return ParseRet::OK; @@ -338,7 +339,7 @@ } } while (ParamFound == ParseRet::OK); - // A valid MangledName mus have at least one valid entry in the + // A valid MangledName must have at least one valid entry in the // . if (Parameters.empty()) return None; diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -1159,3 +1159,22 @@ propagateMetadata(NewInst, VL); } } + +void VFABI::getVectorVariantNames( + const CallInst &CI, SmallVectorImpl &VariantMappings) { + const StringRef S = + CI.getAttribute(AttributeList::FunctionIndex, VFABI::MappingsAttrName) + .getValueAsString(); + SmallVector ListAttr; + S.split(ListAttr, ","); + + for (auto &S : SetVector(ListAttr.begin(), ListAttr.end())) { +#ifndef NDEBUG + Optional Info = VFABI::tryDemangleForVFABI(S); + assert(Info.hasValue() && "Invalid name for a VFABI variant."); + assert(CI.getModule()->getFunction(Info.getValue().VectorName) && + "Vector function is missing."); +#endif + VariantMappings.push_back(S); + } +} diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ModuleUtils.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -280,3 +281,31 @@ MD5::stringifyResult(R, Str); return ("$" + Str).str(); } + +void VFABI::setVectorVariantNames( + CallInst *CI, const SmallVector &VariantMappings) { + if (VariantMappings.empty()) + return; + + SmallString<256> Buffer; + llvm::raw_svector_ostream Out(Buffer); + for (const std::string &VariantMapping : VariantMappings) + Out << VariantMapping << ","; + // Get rid of the trailing ','. + assert(!Buffer.str().empty() && "Must have at least one char."); + Buffer.pop_back(); + + Module *M = CI->getModule(); +#ifndef NDEBUG + for (const std::string &VariantMapping : VariantMappings) { + Optional VI = VFABI::tryDemangleForVFABI(VariantMapping); + assert(VI.hasValue() && "Canno add an invalid VFABI name."); + assert(M->getNamedValue(VI.getValue().VectorName) && + "Cannot add variant to attribute: " + "vector function declaration is missing."); + } +#endif + CI->addAttribute( + AttributeList::FunctionIndex, + Attribute::get(M->getContext(), MappingsAttrName, Buffer.str())); +} diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp --- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp +++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/VectorUtils.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/InstIterator.h" #include "gtest/gtest.h" using namespace llvm; @@ -437,3 +439,37 @@ EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate})); EXPECT_EQ(ScalarName, "sin"); } + +class VFABIAttrTest : public testing::Test { +protected: + void SetUp() override { + M = parseAssemblyString(IR, Err, Ctx); + // Get the only call instruction in the block, which is the first + // instruction. + CI = dyn_cast(&*(instructions(M->getFunction("f")).begin())); + } + const char *IR = "define i32 @f(i32 %a) {\n" + " %1 = call i32 @g(i32 %a) #0\n" + " ret i32 %1\n" + "}\n" + "declare i32 @g(i32)\n" + "declare <2 x i32> @custom_vg(<2 x i32>)" + "declare <4 x i32> @_ZGVnN4v_g(<4 x i32>)" + "declare <8 x i32> @_ZGVnN8v_g(<8 x i32>)" + "attributes #0 = { " + "\"vector-function-abi-variant\"=\"" + "_ZGVnN2v_g(custom_vg),_ZGVnN4v_g\" }"; + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M; + CallInst *CI; + SmallVector Mappings; +}; + +TEST_F(VFABIAttrTest, Read) { + VFABI::getVectorVariantNames(*CI, Mappings); + SmallVector Exp; + Exp.push_back("_ZGVnN2v_g(custom_vg)"); + Exp.push_back("_ZGVnN4v_g"); + EXPECT_EQ(Mappings, Exp); +} diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -18,4 +18,5 @@ SSAUpdaterBulkTest.cpp UnrollLoopTest.cpp ValueMapperTest.cpp + VFABIUtils.cpp ) diff --git a/llvm/unittests/Transforms/Utils/VFABIUtils.cpp b/llvm/unittests/Transforms/Utils/VFABIUtils.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/VFABIUtils.cpp @@ -0,0 +1,55 @@ +//===------- VFABIUtils.cpp - VFABI Unittests ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include "gtest/gtest.h" + +using namespace llvm; + +class VFABIAttrTest : public testing::Test { +protected: + void SetUp() override { + M = parseAssemblyString(IR, Err, Ctx); + // Get the only call instruction in the block, which is the first + // instruction. + CI = dyn_cast(&*(instructions(M->getFunction("f")).begin())); + } + const char *IR = "define i32 @f(i32 %a) {\n" + " %1 = call i32 @g(i32 %a) #0\n" + " ret i32 %1\n" + "}\n" + "declare i32 @g(i32)\n" + "declare <2 x i32> @custom_vg(<2 x i32>)" + "declare <4 x i32> @_ZGVnN4v_g(<4 x i32>)" + "declare <8 x i32> @_ZGVnN8v_g(<8 x i32>)" + "attributes #0 = { " + "\"vector-function-abi-variant\"=\"" + "_ZGVnN2v_g(custom_vg),_ZGVnN4v_g\" }"; + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M; + CallInst *CI; + SmallVector Mappings; +}; + +TEST_F(VFABIAttrTest, Write) { + Mappings.push_back("_ZGVnN8v_g"); + Mappings.push_back("_ZGVnN2v_g(custom_vg)"); + VFABI::setVectorVariantNames(CI, Mappings); + const AttributeList Attrs = CI->getAttributes(); + const AttributeSet FnAttrs = Attrs.getFnAttributes(); + const StringRef S = CI->getAttribute(AttributeList::FunctionIndex, + "vector-function-abi-variant") + .getValueAsString(); + EXPECT_EQ(S, "_ZGVnN8v_g,_ZGVnN2v_g(custom_vg)"); +}