diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -14,6 +14,7 @@ #define LLVM_ANALYSIS_VECTORUTILS_H #include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallSet.h" #include "llvm/Analysis/LoopAccessAnalysis.h" #include "llvm/IR/IRBuilder.h" #include "llvm/Support/CheckedArithmetic.h" @@ -121,14 +122,16 @@ /// * x86 (libmvec): https://sourceware.org/glibc/wiki/libmvec and /// https://sourceware.org/glibc/wiki/libmvec?action=AttachFile&do=view&target=VectorABI.txt /// -/// -/// /// \param MangledName -> input string in the format /// _ZGV_[()]. Optional tryDemangleForVFABI(StringRef MangledName); /// Retrieve the `VFParamKind` from a string token. VFParamKind getVFParamKindFromString(const StringRef Token); +/// Populates a set of strings representing the Vector Function ABI variants +/// associated to the CallInst CI. +void getVectorVariantNames(CallInst *CI, + SmallSet &VariantMappings); } // end namespace VFABI template class ArrayRef; @@ -137,7 +140,6 @@ template class InterleaveGroup; class Loop; class ScalarEvolution; -class TargetLibraryInfo; class TargetTransformInfo; class Type; class Value; diff --git a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h --- a/llvm/include/llvm/Transforms/Utils/ModuleUtils.h +++ b/llvm/include/llvm/Transforms/Utils/ModuleUtils.h @@ -13,6 +13,7 @@ #ifndef LLVM_TRANSFORMS_UTILS_MODULEUTILS_H #define LLVM_TRANSFORMS_UTILS_MODULEUTILS_H +#include "llvm/ADT/SmallSet.h" #include "llvm/ADT/StringRef.h" #include // for std::pair @@ -108,6 +109,23 @@ /// unique identifier for this module, so we return the empty string. std::string getUniqueModuleId(Module *M); +class TargetLibraryInfo; +class CallInst; +namespace VFABI { + +/// \defgroup Vector Function ABI (VABI) Module functions. +/// +/// Utility functions for VFABI data that can modify the module. +/// +/// @{ +/// Overwrite the Vector Function ABI variants attribute with the names provide +/// in \p VariantMappings. +void setVectorVariantNames(CallInst *CI, + const SmallSet &VariantMappings); + +/// @} +} // End VFABI namespace + } // End llvm namespace #endif // LLVM_TRANSFORMS_UTILS_MODULEUTILS_H diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp --- a/llvm/lib/Analysis/VFABIDemangling.cpp +++ b/llvm/lib/Analysis/VFABIDemangling.cpp @@ -6,6 +6,8 @@ // //===----------------------------------------------------------------------===// +#include "llvm/ADT/SmallSet.h" +#include "llvm/ADT/SmallString.h" #include "llvm/Analysis/VectorUtils.h" using namespace llvm; @@ -34,7 +36,6 @@ .Case("d", VFISAKind::AVX2) .Case("e", VFISAKind::AVX512) .Default(VFISAKind::Unknown); - MangledName = MangledName.drop_front(1); return ParseRet::OK; @@ -338,7 +339,7 @@ } } while (ParamFound == ParseRet::OK); - // A valid MangledName mus have at least one valid entry in the + // A valid MangledName must have at least one valid entry in the // . if (Parameters.empty()) return None; diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -1159,3 +1159,26 @@ propagateMetadata(NewInst, VL); } } + +void VFABI::getVectorVariantNames(CallInst *CI, + SmallSet &VariantMappings) { + AttributeList Attrs = CI->getAttributes(); + AttributeSet FnAttrs = Attrs.getFnAttributes(); + const StringRef S = + FnAttrs.getAttribute("vector-function-abi-variant").getValueAsString(); + SmallVector ListAttr; + S.split(ListAttr, ","); + +#ifndef NDEBUG + Module *M = CI->getParent()->getParent()->getParent(); +#endif + for (auto &S : ListAttr) { +#ifndef NDEBUG + Optional Info = VFABI::tryDemangleForVFABI(S); + assert(Info.hasValue() && "Invalid name for a VFABI variant."); + assert(M->getFunction(Info.getValue().VectorName) && + "Vector function is missing."); +#endif + VariantMappings.insert(S); + } +} diff --git a/llvm/lib/Transforms/Utils/ModuleUtils.cpp b/llvm/lib/Transforms/Utils/ModuleUtils.cpp --- a/llvm/lib/Transforms/Utils/ModuleUtils.cpp +++ b/llvm/lib/Transforms/Utils/ModuleUtils.cpp @@ -11,6 +11,7 @@ //===----------------------------------------------------------------------===// #include "llvm/Transforms/Utils/ModuleUtils.h" +#include "llvm/Analysis/VectorUtils.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Function.h" #include "llvm/IR/IRBuilder.h" @@ -280,3 +281,37 @@ MD5::stringifyResult(R, Str); return ("$" + Str).str(); } + +void VFABI::setVectorVariantNames( + CallInst *CI, const SmallSet &VariantMappings) { + assert(CI && "Invalid CallInst"); + SmallString<256> Buffer; + llvm::raw_svector_ostream Out(Buffer); + if (VariantMappings.size() >= 1) { + auto I = VariantMappings.begin(); + auto E = VariantMappings.end(); + Out << *I; + ++I; + while (I != E) { + Out << "," << *I; + ++I; + } + } + Module *M = CI->getParent()->getParent()->getParent(); +#ifndef NDEBUG + for (auto I = VariantMappings.begin(), E = VariantMappings.end(); I != E; + ++I) { + Optional VI = VFABI::tryDemangleForVFABI(*I); + assert(VI.hasValue() && "Canno add an invalid VFABI name."); + assert(M->getNamedValue(VI.getValue().VectorName) && + "Cannot add variant to attribute: " + "vector function declaration is missing."); + } +#endif + std::string Value = Out.str(); + auto &C = M->getContext(); + AttributeList Attrs = CI->getAttributes(); + Attrs = Attrs.addAttribute(C, AttributeList::FunctionIndex, + "vector-function-abi-variant", Value); + CI->setAttributes(Attrs); +} diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp --- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp +++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp @@ -7,6 +7,8 @@ //===----------------------------------------------------------------------===// #include "llvm/Analysis/VectorUtils.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/InstIterator.h" #include "gtest/gtest.h" using namespace llvm; @@ -437,3 +439,44 @@ EXPECT_EQ(Parameters[1], VFParameter({1, VFParamKind::GlobalPredicate})); EXPECT_EQ(ScalarName, "sin"); } + +// Test the attribute reader. +std::unique_ptr parseIR(const char *IR) { + // We just use a static context here. This is never called from multiple + // threads so it is harmless no matter how it is implemented. We just need + // the context to outlive the module which it does. + static LLVMContext C; + SMDiagnostic Err; + return parseAssemblyString(IR, Err, C); +} + +class VFABIAttrTest : public testing::Test { +protected: + VFABIAttrTest() + : M(parseIR("define i32 @f(i32 %a) {\n" + " %1 = call i32 @g(i32 %a) #0\n" + " ret i32 %1\n" + "}\n" + "declare i32 @g(i32)\n" + "declare <2 x i32> @custom_vg(<2 x i32>)" + "declare <4 x i32> @_ZGVnN4v_g(<4 x i32>)" + "declare <8 x i32> @_ZGVnN8v_g(<8 x i32>)" + "attributes #0 = { " + "\"vector-function-abi-variant\"=\"" + "_ZGVnN2v_g(custom_vg),_ZGVnN4v_g\" }")) { + CI = dyn_cast(&*(instructions(M->getFunction("f")).begin())); + } + + LLVMContext Ctx; + std::unique_ptr M; + CallInst *CI; + SmallSet Mappings; +}; + +TEST_F(VFABIAttrTest, Read) { + VFABI::getVectorVariantNames(CI, Mappings); + SmallSet Exp; + Exp.insert("_ZGVnN2v_g(custom_vg)"); + Exp.insert("_ZGVnN4v_g"); + EXPECT_EQ(Mappings, Exp); +} diff --git a/llvm/unittests/Transforms/Utils/CMakeLists.txt b/llvm/unittests/Transforms/Utils/CMakeLists.txt --- a/llvm/unittests/Transforms/Utils/CMakeLists.txt +++ b/llvm/unittests/Transforms/Utils/CMakeLists.txt @@ -18,4 +18,5 @@ SSAUpdaterBulkTest.cpp UnrollLoopTest.cpp ValueMapperTest.cpp + VFABIUtils.cpp ) diff --git a/llvm/unittests/Transforms/Utils/VFABIUtils.cpp b/llvm/unittests/Transforms/Utils/VFABIUtils.cpp new file mode 100644 --- /dev/null +++ b/llvm/unittests/Transforms/Utils/VFABIUtils.cpp @@ -0,0 +1,61 @@ +//===------- VFABIUtils.cpp - VFABI Unittests ----------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/InstIterator.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Module.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Transforms/Utils/ModuleUtils.h" +#include "gtest/gtest.h" + +using namespace llvm; + +// Test the attribute reader. +std::unique_ptr parseIR(const char *IR) { + // We just use a static context here. This is never called from multiple + // threads so it is harmless no matter how it is implemented. We just need + // the context to outlive the module which it does. + static LLVMContext C; + SMDiagnostic Err; + return parseAssemblyString(IR, Err, C); +} + +class VFABIAttrTest : public testing::Test { +protected: + VFABIAttrTest() + : M(parseIR("define i32 @f(i32 %a) {\n" + " %1 = call i32 @g(i32 %a) #0\n" + " ret i32 %1\n" + "}\n" + "declare i32 @g(i32)\n" + "declare <2 x i32> @custom_vg(<2 x i32>)" + "declare <4 x i32> @_ZGVnN4v_g(<4 x i32>)" + "declare <8 x i32> @_ZGVnN8v_g(<8 x i32>)" + "attributes #0 = { " + "\"vector-function-abi-variant\"=\"" + "_ZGVnN2v_g(custom_vg),_ZGVnN4v_g\" }")) { + CI = dyn_cast(&*(instructions(M->getFunction("f")).begin())); + } + + LLVMContext Ctx; + std::unique_ptr M; + CallInst *CI; + SmallSet Mappings; +}; + +TEST_F(VFABIAttrTest, Write) { + Mappings.insert("_ZGVnN8v_g"); + Mappings.insert("_ZGVnN2v_g(custom_vg)"); + VFABI::setVectorVariantNames(CI, Mappings); + const AttributeList Attrs = CI->getAttributes(); + const AttributeSet FnAttrs = Attrs.getFnAttributes(); + const StringRef S = + FnAttrs.getAttribute("vector-function-abi-variant").getValueAsString(); + EXPECT_EQ(S, "_ZGVnN8v_g,_ZGVnN2v_g(custom_vg)"); +}