Index: llvm/include/llvm/Analysis/SearchVectorFunctionSystem.h =================================================================== --- llvm/include/llvm/Analysis/SearchVectorFunctionSystem.h +++ llvm/include/llvm/Analysis/SearchVectorFunctionSystem.h @@ -16,12 +16,16 @@ #ifndef LLVM_ANALYSIS_SVFS_H #define LLVM_ANALYSIS_SVFS_H +#include "llvm/ADT/APFloat.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/Statistic.h" #include "llvm/ADT/StringRef.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/IR/DerivedTypes.h" #include "llvm/IR/Instructions.h" #include "llvm/IR/Module.h" +#include "llvm/IR/PassManager.h" +#include "llvm/Pass.h" #include "llvm/Support/Compiler.h" #include "llvm/Support/raw_ostream.h" @@ -159,6 +163,11 @@ ScalarName = parseScalarName(MangledName); } + VectorFunctionShape(unsigned VF, bool IsMasked, bool IsScalable, ISAKind ISA, + SmallVector Parameters) + : VF(VF), IsMasked(IsMasked), IsScalable(IsScalable), ISA(ISA), + Parameters(Parameters) {} + /// parses the name of the Vector variant of the function /// \param MangledName - Mangled name to be parsed /// \param Result Vector Function name @@ -204,8 +213,80 @@ void parseParameters(StringRef &MangledName); }; +/// Search Vector Function System Functionality +class SVFS { +public: + /// Creates a table for quick lookup of available vector functions. + /// \param Call - Extract information to populate the table + /// the vector function signature + void createTableLookupRecord(CallInst *Call); + + /// Checks for available vectorizable function + /// \param Call - Extract information to query the table + /// \param Result Available vector function shapes for vectorization + SmallVector + isFunctionVectorizable(CallInst *Call) const; + + /// Returns the user-defined vector function + /// \param Call - Extract information to query the table + /// \param Info - Return user-defined vector function which matches the + /// VectorFunctionShape passed as argument + /// \param Result Available vector function shapes for vectorization + Function *getVectorizedFunction(llvm::CallInst *Call, + VectorFunctionShape Info) const; + + SmallVector getRecordTable() { return RecordTable; } + +private: + // RecordTable stores all the information for querying + // Uses a vector as this table will not have a large entries + SmallVector RecordTable; +}; } // end namespace VFABI +/// Analysis pass which captures all the user-defined vector functions +struct SearchVectorFunctionSystem : public ModulePass { + static char ID; // Pass identification, replacement for typeid + + std::unique_ptr QSVFS; + + SearchVectorFunctionSystem() : ModulePass(ID) { + initializeSearchVectorFunctionSystemPass(*PassRegistry::getPassRegistry()); + } + + VFABI::SVFS &getSVFS() { return *QSVFS; } + const VFABI::SVFS &getSVFS() const { return *QSVFS; } + + bool runOnModule(Module &M) override; + + void getAnalysisUsage(AnalysisUsage &AU) const override { + AU.setPreservesAll(); + } +}; + +ModulePass *createSearchVectorFunctionSystemPass(); + +/// An analysis pass based on the new PM to deliver SVFS. +class SVFSAnalysis : public AnalysisInfoMixin { +public: + typedef VFABI::SVFS Result; + + Result run(Module &M, ModuleAnalysisManager &); + +private: + friend AnalysisInfoMixin; + static AnalysisKey Key; +}; + +/// Printer pass that uses \c SVFSAnalysis. +class SVFSPrinterPass : public PassInfoMixin { + raw_ostream &OS; + +public: + explicit SVFSPrinterPass(raw_ostream &OS) : OS(OS) {} + PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM); +}; + } // end namespace llvm #endif // LLVM_ANALYSIS_SVFS_H \ No newline at end of file Index: llvm/include/llvm/InitializePasses.h =================================================================== --- llvm/include/llvm/InitializePasses.h +++ llvm/include/llvm/InitializePasses.h @@ -412,6 +412,7 @@ void initializeWriteBitcodePassPass(PassRegistry&); void initializeWriteThinLTOBitcodePass(PassRegistry&); void initializeXRayInstrumentationPass(PassRegistry&); +void initializeSearchVectorFunctionSystemPass(PassRegistry &); } // end namespace llvm Index: llvm/include/llvm/LinkAllPasses.h =================================================================== --- llvm/include/llvm/LinkAllPasses.h +++ llvm/include/llvm/LinkAllPasses.h @@ -32,6 +32,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScopedNoAliasAA.h" +#include "llvm/Analysis/SearchVectorFunctionSystem.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TypeBasedAliasAnalysis.h" #include "llvm/CodeGen/Passes.h" @@ -221,6 +222,7 @@ (void) llvm::createEliminateAvailableExternallyPass(); (void) llvm::createScalarizeMaskedMemIntrinPass(); (void) llvm::createWarnMissedTransformationsPass(); + (void)llvm::createSearchVectorFunctionSystemPass(); (void)new llvm::IntervalPartition(); (void)new llvm::ScalarEvolutionWrapperPass(); Index: llvm/lib/Analysis/Analysis.cpp =================================================================== --- llvm/lib/Analysis/Analysis.cpp +++ llvm/lib/Analysis/Analysis.cpp @@ -84,6 +84,7 @@ initializeLCSSAVerificationPassPass(Registry); initializeMemorySSAWrapperPassPass(Registry); initializeMemorySSAPrinterLegacyPassPass(Registry); + initializeSearchVectorFunctionSystemPass(Registry); } void LLVMInitializeAnalysis(LLVMPassRegistryRef R) { Index: llvm/lib/Analysis/SearchVectorFunctionSystem.cpp =================================================================== --- llvm/lib/Analysis/SearchVectorFunctionSystem.cpp +++ llvm/lib/Analysis/SearchVectorFunctionSystem.cpp @@ -222,3 +222,116 @@ VectorFunctionShape::parseParamList(ParameterList); MangledName = MangledName.drop_front(ParameterList.size()); } + +void SVFS::createTableLookupRecord(CallInst *Call) { + Function *ScalarFunc = Call->getCalledFunction(); + StringRef MangledName = + ScalarFunc->getFnAttribute("vector-function-abi-variant") + .getValueAsString(); + + SmallVector VectorStrings; + + MangledName.split(VectorStrings, ","); + + for (auto VectorString : VectorStrings) { + // Populate a new entry to be recorded + VectorFunctionShape NewRecord(VectorString.trim()); + // Populate the table + this->RecordTable.push_back(NewRecord); + } +} + +SmallVector +SVFS::isFunctionVectorizable(CallInst *Call) const { + SmallVector AvailableVFS; + // Iterate through all records. + for (auto Record : this->RecordTable) { + if (Record.ScalarName == Call->getCalledFunction()->getName()) { + AvailableVFS.push_back(Record); + } + } + return AvailableVFS; +} + +Function *SVFS::getVectorizedFunction(CallInst *Call, + VectorFunctionShape Info) const { + SmallVector Records; + // Iterate through all records + for (auto Record : this->RecordTable) { + if (Record.ScalarName == Call->getCalledFunction()->getName() && + Record == Info) { + Records.push_back(Record); + } + } + + if (Records.empty()) + return nullptr; + + assert(!Records.empty() && "No scalar to vector function mappings available"); + StringRef FnName = VectorFunctionShape::getVectorName(Records[0].VectorName); + + Module *M = Call->getModule(); + // Atmost 2 functions that have the same VectorFunctionShape. + assert(Records.size() <= 2 && "Invalid Table Entries"); + if (Records.size() == 2) { + StringRef VecName0 = + VectorFunctionShape::getVectorName(Records[0].VectorName); + StringRef VecName1 = + VectorFunctionShape::getVectorName(Records[1].VectorName); + bool IsMangled = VecName0.startswith("_ZGV"); + // Checks to ensure only one entry has a custom name redirection + if (IsMangled) + assert(VecName1.startswith("_ZGV") == false && "Invalid Table Entries"); + else + assert(VecName1.startswith("_ZGV") == true && "Invalid Table Entries"); + FnName = IsMangled ? VecName1 : VecName0; + } + + StringRef FnMangledName; + FnMangledName = + VectorFunctionShape::getPartMangledName(Records[0].VectorName); + // Extract the existing function signature added by front-end + FunctionType *FnTy = M->getFunction(FnMangledName)->getFunctionType(); + + return cast(M->getOrInsertFunction(FnName, FnTy).getCallee()); +} + +char SearchVectorFunctionSystem::ID = 0; + +INITIALIZE_PASS(SearchVectorFunctionSystem, "svfs", + "Query System for the Vectorizer", false, true) + +ModulePass *llvm::createSearchVectorFunctionSystemPass() { + return new SearchVectorFunctionSystem(); +} + +// SearchVectorFunctionSystem::run - This is the main Analysis entry point. +bool SearchVectorFunctionSystem::runOnModule(Module &M) { + QSVFS.reset(new VFABI::SVFS()); + for (auto &F : M) + for (auto &I : instructions(F)) + if (auto *Call = dyn_cast(&I)) + QSVFS->createTableLookupRecord(Call); + return false; +} + +AnalysisKey SVFSAnalysis::Key; + +VFABI::SVFS SVFSAnalysis::run(Module &M, ModuleAnalysisManager &) { + VFABI::SVFS QSVFS; + for (auto &F : M) + for (auto &I : instructions(F)) + if (auto *Call = dyn_cast(&I)) + QSVFS.createTableLookupRecord(Call); + return QSVFS; +} + +PreservedAnalyses SVFSPrinterPass::run(Module &M, ModuleAnalysisManager &AM) { + VFABI::SVFS &QSVFS = AM.getResult(M); + + OS << "Functions in Record Table " + << "\n"; + for (auto &Record : QSVFS.getRecordTable()) + OS << Record.VectorName; + return PreservedAnalyses::all(); +} \ No newline at end of file Index: llvm/lib/Passes/PassBuilder.cpp =================================================================== --- llvm/lib/Passes/PassBuilder.cpp +++ llvm/lib/Passes/PassBuilder.cpp @@ -47,6 +47,7 @@ #include "llvm/Analysis/ScalarEvolution.h" #include "llvm/Analysis/ScalarEvolutionAliasAnalysis.h" #include "llvm/Analysis/ScopedNoAliasAA.h" +#include "llvm/Analysis/SearchVectorFunctionSystem.h" #include "llvm/Analysis/StackSafetyAnalysis.h" #include "llvm/Analysis/TargetLibraryInfo.h" #include "llvm/Analysis/TargetTransformInfo.h" @@ -141,8 +142,8 @@ #include "llvm/Transforms/Scalar/LowerWidenableCondition.h" #include "llvm/Transforms/Scalar/MakeGuardsExplicit.h" #include "llvm/Transforms/Scalar/MemCpyOptimizer.h" -#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" #include "llvm/Transforms/Scalar/MergeICmps.h" +#include "llvm/Transforms/Scalar/MergedLoadStoreMotion.h" #include "llvm/Transforms/Scalar/NaryReassociate.h" #include "llvm/Transforms/Scalar/NewGVN.h" #include "llvm/Transforms/Scalar/PartiallyInlineLibCalls.h" Index: llvm/lib/Passes/PassRegistry.def =================================================================== --- llvm/lib/Passes/PassRegistry.def +++ llvm/lib/Passes/PassRegistry.def @@ -28,6 +28,7 @@ MODULE_ANALYSIS("verify", VerifierAnalysis()) MODULE_ANALYSIS("pass-instrumentation", PassInstrumentationAnalysis(PIC)) MODULE_ANALYSIS("asan-globals-md", ASanGlobalsMetadataAnalysis()) +MODULE_ANALYSIS("svfs", SVFSAnalysis()) #ifndef MODULE_ALIAS_ANALYSIS #define MODULE_ALIAS_ANALYSIS(NAME, CREATE_PASS) \ @@ -81,6 +82,7 @@ MODULE_PASS("sample-profile", SampleProfileLoaderPass()) MODULE_PASS("strip-dead-prototypes", StripDeadPrototypesPass()) MODULE_PASS("synthetic-counts-propagation", SyntheticCountsPropagation()) +MODULE_PASS("print-svfs", SVFSPrinterPass(dbgs())) MODULE_PASS("wholeprogramdevirt", WholeProgramDevirtPass(nullptr, nullptr)) MODULE_PASS("verify", VerifierPass()) MODULE_PASS("asan-module", ModuleAddressSanitizerPass(/*CompileKernel=*/false, false, true, false)) Index: llvm/unittests/Analysis/CMakeLists.txt =================================================================== --- llvm/unittests/Analysis/CMakeLists.txt +++ llvm/unittests/Analysis/CMakeLists.txt @@ -28,6 +28,7 @@ PhiValuesTest.cpp ProfileSummaryInfoTest.cpp ScalarEvolutionTest.cpp + SearchVectorFunctionSystemTest.cpp VectorFunctionABITest.cpp SparsePropagation.cpp TargetLibraryInfoTest.cpp Index: llvm/unittests/Analysis/SearchVectorFunctionSystemTest.cpp =================================================================== --- /dev/null +++ llvm/unittests/Analysis/SearchVectorFunctionSystemTest.cpp @@ -0,0 +1,221 @@ +//===------- SearchVectorFunctionSystemTest.cpp - SVFS Unittests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SearchVectorFunctionSystem.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +using namespace llvm; +using namespace VFABI; + +class SearchVectorFunctionSystemTest : public testing::Test { +protected: + SearchVectorFunctionSystemTest() : M(new Module("MyModule", Ctx)) { + FArgTypes.push_back(Type::getInt32Ty(Ctx)); + FArgTypes.push_back(Type::getInt32Ty(Ctx)); + FunctionType *FTy = + FunctionType::get(Type::getInt32Ty(Ctx), FArgTypes, false); + F = Function::Create(FTy, Function::ExternalLinkage, "sin", M.get()); + } + + LLVMContext Ctx; + std::unique_ptr M; + SmallVector FArgTypes; + Function *F; +}; + +TEST_F(SearchVectorFunctionSystemTest, TableCreation) { + SVFS S; + + // Tests the complete functionality of table creation + Value *Args[] = {ConstantInt::get(Type::getInt32Ty(Ctx), 20), + ConstantInt::get(Type::getInt32Ty(Ctx), 20)}; + std::unique_ptr Caller(CallInst::Create(F, Args)); + CallInst *Call = Caller.get(); + F->addFnAttr("vector-function-abi-variant", + "_ZGVnN2vv_sin(another_vector)," + "_ZGVnN8l8v_sin(sinVector),_ZGVnN4vl_sin"); + + S.createTableLookupRecord(Call); + + // Checks for "_ZGVnN2vv_sin(another_vector)" + EXPECT_EQ(S.getRecordTable()[0].VF, (unsigned)2); + EXPECT_FALSE(S.getRecordTable()[0].IsMasked); + EXPECT_EQ(S.getRecordTable()[0].ISA, ISAKind::ISA_AdvancedSIMD); + + EXPECT_TRUE(S.getRecordTable()[0].Parameters[0].isVector()); + EXPECT_EQ(S.getRecordTable()[0].Parameters[0].getParamPos(), (unsigned)0); + + EXPECT_TRUE(S.getRecordTable()[0].Parameters[1].isVector()); + EXPECT_EQ(S.getRecordTable()[0].Parameters[1].getParamPos(), (unsigned)1); + + EXPECT_EQ(S.getRecordTable()[0].ScalarName, "sin"); + EXPECT_EQ(S.getRecordTable()[0].VectorName, "_ZGVnN2vv_sin(another_vector)"); + + // Checks for "_ZGVnN8l8v_sin(sinVector)" + EXPECT_EQ(S.getRecordTable()[1].VF, (unsigned)8); + EXPECT_FALSE(S.getRecordTable()[1].IsMasked); + EXPECT_EQ(S.getRecordTable()[1].ISA, ISAKind::ISA_AdvancedSIMD); + + EXPECT_TRUE(S.getRecordTable()[1].Parameters[0].isLinear()); + EXPECT_EQ(S.getRecordTable()[1].Parameters[0].getParamPos(), (unsigned)0); + EXPECT_EQ(S.getRecordTable()[1].Parameters[0].getLinearStepOrPos(), 8); + + EXPECT_TRUE(S.getRecordTable()[1].Parameters[1].isVector()); + EXPECT_EQ(S.getRecordTable()[1].Parameters[1].getParamPos(), (unsigned)1); + + EXPECT_EQ(S.getRecordTable()[1].ScalarName, "sin"); + EXPECT_EQ(S.getRecordTable()[1].VectorName, "_ZGVnN8l8v_sin(sinVector)"); + + // Checks for "_ZGVnN4vl_sin" + EXPECT_EQ(S.getRecordTable()[2].VF, (unsigned)4); + EXPECT_FALSE(S.getRecordTable()[2].IsMasked); + EXPECT_EQ(S.getRecordTable()[2].ISA, ISAKind::ISA_AdvancedSIMD); + + EXPECT_TRUE(S.getRecordTable()[2].Parameters[0].isVector()); + EXPECT_EQ(S.getRecordTable()[2].Parameters[0].getParamPos(), (unsigned)0); + + EXPECT_TRUE(S.getRecordTable()[2].Parameters[1].isLinear()); + EXPECT_EQ(S.getRecordTable()[2].Parameters[1].getParamPos(), (unsigned)1); + EXPECT_EQ(S.getRecordTable()[2].Parameters[1].getLinearStepOrPos(), 1); + + EXPECT_EQ(S.getRecordTable()[2].ScalarName, "sin"); + EXPECT_EQ(S.getRecordTable()[2].VectorName, "_ZGVnN4vl_sin"); +} + +TEST_F(SearchVectorFunctionSystemTest, isFunctionVectorizable) { + SVFS S; + + Value *Args[] = {ConstantInt::get(Type::getInt32Ty(Ctx), 20), + ConstantInt::get(Type::getInt32Ty(Ctx), 20)}; + std::unique_ptr Caller(CallInst::Create(F, Args)); + CallInst *Call = Caller.get(); + + F->addFnAttr("vector-function-abi-variant", + "_ZGVnN2vv_cos(another_vector), _ZGVnN2vv_sin," + " _ZGVnN4vl_sin(sinVector),_ZGVnN4vl_sin"); + + S.createTableLookupRecord(Call); + SmallVector AvailableVFS; + AvailableVFS = S.isFunctionVectorizable(Call); + + // Only three vector variants of function "sin" + EXPECT_EQ(AvailableVFS.size(), (unsigned)3); + EXPECT_EQ(AvailableVFS[0], S.getRecordTable()[1]); + EXPECT_EQ(AvailableVFS[1], S.getRecordTable()[2]); + EXPECT_EQ(AvailableVFS[2], S.getRecordTable()[3]); + + const ParamType Param1((unsigned)0, ParameterKind::Vector); + const ParamType Param2((unsigned)1, ParameterKind::Vector); + const ParamType Param3((unsigned)1, ParameterKind::OMP_Linear, 1); + + // Corresponding VFS for _ZGVnN2vv + SmallVector ParamList1{{Param1, Param2}}; + VectorFunctionShape ExpectedVFS1{(unsigned)2 /* VF */, false /* IsMasked */, + false /* IsScalable */, + ISAKind::ISA_AdvancedSIMD, ParamList1}; + + // Corresponding VFS for _ZGVnN4vl_sin + SmallVector ParamList2{{Param1, Param3}}; + VectorFunctionShape ExpectedVFS2{(unsigned)4 /* VF */, false /* IsMasked */, + false /* IsScalable */, + ISAKind::ISA_AdvancedSIMD, ParamList2}; + + EXPECT_EQ(AvailableVFS[0], ExpectedVFS1); + EXPECT_EQ(AvailableVFS[1], ExpectedVFS2); + EXPECT_EQ(AvailableVFS[2], ExpectedVFS2); +} + +TEST_F(SearchVectorFunctionSystemTest, getVectorizedFunction) { + SVFS S; + + std::unique_ptr Entry(BasicBlock::Create(Ctx, "entry", F)); + + Value *Args[] = {ConstantInt::get(Type::getInt32Ty(Ctx), 20), + ConstantInt::get(Type::getInt32Ty(Ctx), 20)}; + std::unique_ptr Caller( + CallInst::Create(F, Args, "sin", Entry.get())); + CallInst *Call = Caller.get(); + + F->addFnAttr("vector-function-abi-variant", + "_ZGVnN2vv_cos(another_vector),_ZGVnN2vv_sin," + "_ZGVnN4vl_sin(sinVector), _ZGVnN4vl_sin," + " _ZGVnN16ll_sin"); + + const ParamType Param1((unsigned)0, ParameterKind::Vector); + const ParamType Param2((unsigned)0, ParameterKind::OMP_Linear, 1); + const ParamType Param3((unsigned)1, ParameterKind::OMP_Linear, 1); + + // Corresponding VFS for _ZGVnN4vl_sin + SmallVector ParamList1{{Param1, Param3}}; + VectorFunctionShape VFS0{(unsigned)4 /* VF */, false /* IsMasked */, + false /* IsScalable */, ISAKind::ISA_AdvancedSIMD, + ParamList1}; + + Type *Tys0[] = {VectorType::get(FArgTypes[0], 4), FArgTypes[1]}; + FunctionType *VectorFTy0 = + FunctionType::get(VectorType::get(Type::getInt32Ty(Ctx), 4), Tys0, false); + Function *ExpectedFn0 = Function::Create( + VectorFTy0, Function::ExternalLinkage, "sinVector", M.get()); + + // This function will be inserted by the front-end as specified in RFC + Function *InsVecFn0 = Function::Create(VectorFTy0, Function::ExternalLinkage, + "_ZGVnN4vl_sin", M.get()); + + S.createTableLookupRecord(Call); + + // Checking for "_ZGVnN4vl_sin" + Function *ResultFn0 = S.getVectorizedFunction(Call, VFS0); + // As the given VFS has two possible options, + // we return the one with custom name + EXPECT_EQ(ResultFn0, ExpectedFn0); + + // Corresponding VFS for _ZGVnN16ll_sin + SmallVector ParamList2{{Param2, Param3}}; + VectorFunctionShape VFS1{(unsigned)16 /* VF */, false /* IsMasked */, + false /* IsScalable */, ISAKind::ISA_AdvancedSIMD, + ParamList2}; + + FunctionType *VectorFTy1 = FunctionType::get( + VectorType::get(Type::getInt32Ty(Ctx), 16), FArgTypes, false); + Function *ExpectedFn1 = Function::Create( + VectorFTy1, Function::ExternalLinkage, "_ZGVnN16ll_sin", M.get()); + + // This function will be inserted by the front-end as specified in RFC + Function *InsVecFn1 = Function::Create(VectorFTy1, Function::ExternalLinkage, + "_ZGVnN16ll_sin", M.get()); + + // Checking for "_ZGVnN16ll_sin" + Function *ResultFn1 = S.getVectorizedFunction(Call, VFS1); + // Only one record exists for the given VFS + EXPECT_EQ(ResultFn1, ExpectedFn1); + + // No such mapping available for given VFS (ISA mismatch) + VectorFunctionShape VFS2{(unsigned)16 /* VF */, false /* IsMasked */, + false /* IsScalable */, ISAKind::ISA_SVE, + ParamList2}; + Function *ResultFn2 = S.getVectorizedFunction(Call, VFS2); + + // nullptr returned + EXPECT_FALSE(ResultFn2); + + // No such mapping available for given VFS (VF mismatch) + VectorFunctionShape VFS3{(unsigned)8 /* VF */, false /* IsMasked */, + false /* IsScalable */, ISAKind::ISA_AdvancedSIMD, + ParamList2}; + Function *ResultFn3 = S.getVectorizedFunction(Call, VFS3); + + // nullptr returned + EXPECT_FALSE(ResultFn3); + + Call->removeFromParent(); + Entry.get()->removeFromParent(); +}