Index: llvm/include/llvm/Analysis/SearchVectorFunctionSystem.h =================================================================== --- /dev/null +++ llvm/include/llvm/Analysis/SearchVectorFunctionSystem.h @@ -0,0 +1,214 @@ +//===- SearchVectorFunctionSystem.h - Search Vector Function System -------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +/// \file +/// This is the interface used for user provided vector functions with the +/// vectorizer. +// +//===----------------------------------------------------------------------===// + +#include "llvm/ADT/StringRef.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/Instructions.h" +#include "llvm/ADT/StringSwitch.h" +#include "llvm/IR/Module.h" + +namespace llvm { + +enum class ParameterKind +{ + Vector, + OMP_Linear, + OMP_LinearRef, + OMP_LinearVal, + OMP_LinearUVal, + OMP_LinearPos, + OMP_LinearValPos, + OMP_LinearRefPos, + OMP_LinearUValPos, + OMP_Uniform +}; + +enum class ISAKind +{ + ISA_AdvancedSIMD, + ISA_SVE, + ISA_SSE, + ISA_AVX, + ISA_AVX2, + ISA_AVX512 +}; + +class ParamType { + private: + unsigned ParamPos; + ParameterKind ParamKind; + int LinearStepOrPos; + + public: + // Constructor - Used for testing + ParamType(unsigned ParamPos, StringRef ParamKind, int LinearStepOrPos) { + this->ParamPos = ParamPos; + setParameterKind(ParamKind); + this->LinearStepOrPos = LinearStepOrPos; + } + + // Constructor for ParameterKind with no step. Eg- ParameterKind::Vector + ParamType(unsigned ParamPos, StringRef ParamKind) { + this->ParamPos = ParamPos; + setParameterKind(ParamKind); + LinearStepOrPos = 0; + } + + bool operator==(const ParamType &other) const { + return (ParamPos == other.ParamPos + && ParamKind == other.ParamKind + && LinearStepOrPos == other.LinearStepOrPos); + } + + bool isVector() { + return ParamKind == ParameterKind::Vector; + } + + bool isLinear() { + return ParamKind == ParameterKind::OMP_Linear; + } + + bool isLinearVal() { + return ParamKind == ParameterKind::OMP_LinearVal; + } + + bool isLinearRef() { + return ParamKind == ParameterKind::OMP_LinearRef; + } + + bool isLinearUVal() { + return ParamKind == ParameterKind::OMP_LinearUVal; + } + + bool isLinearPos() { + return ParamKind == ParameterKind::OMP_LinearPos; + } + + bool isLinearValPos() { + return ParamKind == ParameterKind::OMP_LinearValPos; + } + + bool isLinearRefPos() { + return ParamKind == ParameterKind::OMP_LinearRefPos; + } + + bool isLinearUValPos() { + return ParamKind == ParameterKind::OMP_LinearUValPos; + } + + bool isUniform() { + return ParamKind == ParameterKind::OMP_Uniform; + } + + int getLinearStepOrPos(){ + return LinearStepOrPos; + } + + void setLinearStepOrPos(int s) { + LinearStepOrPos = s; + } + + ParameterKind getParameterKind() { + return ParamKind; + } + + void setParameterKind(StringRef kind) { + ParamKind = StringSwitch(kind) + .Case("v", ParameterKind::Vector) + .Case("l", ParameterKind::OMP_Linear) + .Case("R", ParameterKind::OMP_LinearRef) + .Case("L", ParameterKind::OMP_LinearVal) + .Case("U", ParameterKind::OMP_LinearUVal) + .Case("ls", ParameterKind::OMP_LinearPos) + .Case("Ls", ParameterKind::OMP_LinearValPos) + .Case("Rs", ParameterKind::OMP_LinearRefPos) + .Case("Us", ParameterKind::OMP_LinearUValPos) + .Case("u", ParameterKind::OMP_Uniform); + } + + unsigned getParamPos() { + return ParamPos; + } + + void setParamPos(unsigned Position) { + ParamPos = Position; + } +}; + +class SVFS{ + public: + Module *M; + + SVFS(Module *M) { + this->M = M; + } + + typedef struct VectorFunctionShape { + unsigned VF; // Vectorization factor + bool IsMasked; + bool IsScalable; + ISAKind ISA; + std::vector Parameters; + bool operator==(const VectorFunctionShape &other) const { + return (VF == other.VF + && IsMasked == other.IsMasked + && IsScalable == other.IsScalable + && ISA == other.ISA + && Parameters == other.Parameters); + } + } VectorFunctionShape; + + typedef struct VectorRecord { + VectorFunctionShape VFS; + StringRef ScalarName; + StringRef VectorName; + } VectorRecord; + + std::vector RecordTable; + + /// Creates a table for quick lookup of available vector functions. + /// + /// \param MangledName - Comma separated strings of the available vector + /// functions in their mangled format + /// \param ScalarFTy - Scalar function signature required to generate + /// the vector function signature + void createTableLookupRecord(CallInst *Call); + + /// Does name demangling + /// \param MangledName - Mangled name for demangling + /// \param Result filled with a populated vector record after name + /// demangling + void demangleName(StringRef MangledName, SVFS::VectorRecord &Record); + + std::vector isFunctionVectorizable( + CallInst *Call) const; + + Function* getVectorizedFunction(llvm::CallInst *Call, + VectorFunctionShape Info) const; + +}; + +namespace VFABI { + bool getABISignature(StringRef MangledName); + ISAKind getISA(StringRef MangledName); + bool getIsMasked(StringRef MangledName); + unsigned getVF(StringRef MangledName); + void getParameters(StringRef MangledName, SVFS::VectorRecord &NewRecord); + StringRef getScalarName(StringRef MangledName); + StringRef getVectorName(StringRef MangledName); + void parseParamList(StringRef ParameterList, + std::vector &Parameters); + bool getIsScalable(StringRef MangledName); +} +} \ No newline at end of file Index: llvm/lib/Analysis/CMakeLists.txt =================================================================== --- llvm/lib/Analysis/CMakeLists.txt +++ llvm/lib/Analysis/CMakeLists.txt @@ -82,6 +82,7 @@ ScalarEvolutionAliasAnalysis.cpp ScalarEvolutionExpander.cpp ScalarEvolutionNormalization.cpp + SearchVectorFunctionSystem.cpp StackSafetyAnalysis.cpp SyncDependenceAnalysis.cpp SyntheticCountsUtils.cpp Index: llvm/lib/Analysis/SearchVectorFunctionSystem.cpp =================================================================== --- /dev/null +++ llvm/lib/Analysis/SearchVectorFunctionSystem.cpp @@ -0,0 +1,263 @@ +//===- SearchVectorFunctionSystem.cpp - Search Vector Function System -----===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SearchVectorFunctionSystem.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Function.h" +#include "llvm/Support/Compiler.h" + +namespace llvm{ + +bool VFABI::getABISignature(StringRef MangledName) { + // Capture the ABI Signature value + StringRef ABISignature; + ABISignature = MangledName.substr(0,4); + return ABISignature.equals("_ZGV"); +} + +ISAKind VFABI::getISA(StringRef MangledName) { + // Capture the ISA value + StringRef ISAValue = MangledName.substr(4, 1); + assert(ISAValue.equals("n") || ISAValue.equals("s") + || ISAValue.equals("b") || ISAValue.equals("c") + || ISAValue.equals("d") || ISAValue.equals("e") + && "Unknown ISA Specified"); + ISAKind ISA = StringSwitch(ISAValue) + .Case("n", ISAKind::ISA_AdvancedSIMD) + .Case("s", ISAKind::ISA_SVE) + .Case("b", ISAKind::ISA_SSE) + .Case("c", ISAKind::ISA_AVX) + .Case("d", ISAKind::ISA_AVX2) + .Case("e", ISAKind::ISA_AVX512); + return ISA; +} + +bool VFABI::getIsMasked(StringRef MangledName) { + // Capture the MASK value + StringRef Mask; + Mask = MangledName.substr(5,1); + if (VFABI::getISA(MangledName) == ISAKind::ISA_SVE) + return true; + assert(Mask.equals("N") || Mask.equals("M") && "Invalid masking option"); + return Mask.equals("N") ? false : true; + +} + +unsigned VFABI::getVF(StringRef MangledName) { + // Capture the VF + // VectorLength can be any positive integer + unsigned VF; + std::size_t Pos; + + Pos = MangledName.find_first_of("vlRLUu", 6); + if (!MangledName.substr(6, Pos - 6).compare("x")){ + assert(VFABI::getISA(MangledName) == ISAKind::ISA_SVE + && "Incompatible ISA"); + return 0; + } + MangledName.substr(6, Pos - 6).consumeInteger(10, VF); + return VF; +} + +bool VFABI::getIsScalable(StringRef MangledName) { + // Capture if it is Scalable or Not + bool IsSVE = VFABI::getISA(MangledName) == ISAKind::ISA_SVE; + bool IsVFZero = VFABI::getVF(MangledName) == 0; + return IsSVE && IsVFZero; +} + +void VFABI::getParameters (StringRef MangledName, + SVFS::VectorRecord &NewRecord) { + // Capture the parameters + StringRef ParamList; + std::size_t Pos; + std::size_t NewPos; + + Pos = MangledName.find_first_of("vlRLUu", 6); + NewPos = MangledName.find("_", Pos); + ParamList = MangledName.substr(Pos, NewPos - Pos); + VFABI::parseParamList(ParamList, NewRecord.VFS.Parameters); +} + +StringRef VFABI::getScalarName(StringRef MangledName) { + // Get Scalar function name + StringRef ScalarFunctionName; + std::size_t Pos; + std::size_t NewPos; + + Pos = MangledName.find("_", 6) + 1; + NewPos = MangledName.find("(", Pos); + // Handles Case: custom vector function name not provided + if (NewPos == StringRef::npos) + NewPos = MangledName.str().length(); + ScalarFunctionName = MangledName.substr(Pos, NewPos - Pos); + + return ScalarFunctionName; +} + +StringRef VFABI::getVectorName(StringRef MangledName) { + // Get vector Function Name + StringRef VectorFunctionName; + std::size_t Pos; + std::size_t NewPos; + + Pos = MangledName.find("(", 6); + if (Pos == StringRef::npos) + return MangledName; + Pos ++; + NewPos = MangledName.find(")", Pos); + VectorFunctionName = MangledName.substr(Pos, NewPos - Pos); + + return VectorFunctionName; +} + +void VFABI::parseParamList(StringRef ParamList, + std::vector &Params){ + std::size_t Pos; // Pos to iterate ParamList + std::size_t NewPos; + + unsigned ParamPos = 0; // param Pos + int LinearStepOrPos; + + Pos = 0; + StringRef ParamKind; + + // Parse the param list and store it to newRecord.VFS.Parameters + // NewPos - Pos = length of param string + // length of param string = ParamList.str().length() + // TODO: Handle negative numbers for LinearStepOrPos. + // Eg=linear(val(ParamPos):-3)=`Ln3` + while ((NewPos = ParamList.find_first_of("vlRLUu", Pos+1)) < + (ParamList.str().length())) { + ParamKind = ParamList.substr(Pos, 1); + if (ParamList.substr(Pos+1, 1).equals("s")) { + ParamKind = ParamList.substr(Pos, 2); + Pos++; + } + // +1 to read the LinearStepOrPos value after the param kind + Pos++; + if (NewPos > Pos) { + ParamList.substr(Pos, NewPos - (Pos)).consumeInteger(10, + LinearStepOrPos); + Params.push_back(ParamType(ParamPos, ParamKind, LinearStepOrPos)); + } + else + { + ParamType Parameter(ParamPos, ParamKind); + assert(Parameter.isLinear() || Parameter.isLinearRef() + || Parameter.isLinearVal() || Parameter.isVector() + || Parameter.isUniform()|| Parameter.isLinearUVal() + && "Step Not Specified for given ParameterKind"); + // Default LinearStepOrPos value if not specified in ABI signature + if (Parameter.isLinear() || Parameter.isLinearRef() + || Parameter.isLinearVal() || Parameter.isLinearUVal()) { + Parameter.setLinearStepOrPos(1); + } + Params.push_back(Parameter); + } + ParamPos++; + Pos = NewPos; + } + + // Last param value || only param + // NewPos is the end of StringRef ParamList + ParamKind = ParamList.substr(Pos, 1); + + // +1 to read the LinearStepOrPos value after the param kind + if ((ParamList.str().length()) > Pos + 1) { + StringRef StrLinearStepOrPos = ParamList.substr(Pos+1, + (ParamList.str().length()) - (Pos + 1)); + StrLinearStepOrPos.consumeInteger(10, LinearStepOrPos); + Params.push_back(ParamType(ParamPos, ParamKind, LinearStepOrPos)); + } + else { + ParamType Parameter(ParamPos, ParamKind); + // Default LinearStepOrPos value if not specified in ABI signature + if (Parameter.isLinear() || Parameter.isLinearRef() + || Parameter.isLinearVal()) { + Parameter.setLinearStepOrPos(1); + } + Params.push_back(Parameter); + } +} + +void SVFS::demangleName(StringRef MangledName, SVFS::VectorRecord &NewRecord) { + // String Parsing - Populate newRecord + // Vector Function ABI for AArch64 (Sec 3.5) + + if (!VFABI::getABISignature(MangledName)) + return; + NewRecord.VFS.ISA = VFABI::getISA(MangledName); + NewRecord.VFS.IsMasked = VFABI::getIsMasked(MangledName); + NewRecord.VFS.VF = VFABI::getVF(MangledName); + NewRecord.VFS.IsScalable = VFABI::getIsScalable(MangledName); + VFABI::getParameters(MangledName, NewRecord); + NewRecord.ScalarName = VFABI::getScalarName(MangledName); + NewRecord.VectorName = VFABI::getVectorName(MangledName); + +} + +void SVFS::createTableLookupRecord(CallInst *Call) { + Function *ScalarFunc = Call->getCalledFunction(); + StringRef MangledName = ScalarFunc->getFnAttribute( + "vector-function-abi-variant").getValueAsString(); + + SmallVector VectorStrings; + + MangledName.split(VectorStrings, ","); + + for (auto VectorString : VectorStrings) { + // Populate a new entry to be recorded + SVFS::VectorRecord NewRecord; + SVFS::demangleName(VectorString.trim(), NewRecord); + // Populate the table + SVFS::RecordTable.push_back(NewRecord); + } +} + +std::vector SVFS::isFunctionVectorizable( + CallInst *Call) const { + std::vector AvailableVFS; + // Iterate through all records. + for (auto Record: SVFS::RecordTable) { + if (Record.ScalarName == Call->getCalledFunction()->getName()) { + AvailableVFS.push_back(Record.VFS); + } + } + return AvailableVFS; +} + +Function* SVFS::getVectorizedFunction(CallInst *Call, + SVFS::VectorFunctionShape Info) const { + std::vector Records; + // Iterate through all records + for (auto Record : SVFS::RecordTable) { + if (Record.ScalarName == Call->getCalledFunction()->getName() + && Record.VFS == Info) { + Records.push_back(Record); + } + } + StringRef FuncName = Records[0].VectorName; + // Atmost 2 functions that have the same VectorFunctionShape. + assert(Records.size() <= 2 && "Invalid Table Entries"); + if (Records.size() == 2) { + bool IsMangled = VFABI::getABISignature(Records[0].VectorName); + // Checks to ensure only one entry has a custom name + if (IsMangled) + assert(VFABI::getABISignature(Records[1].VectorName) == false + && "Invalid Table Entries"); + else + assert(VFABI::getABISignature(Records[1].VectorName) == true + && "Invalid Table Entries"); + FuncName = IsMangled?Records[1].VectorName:Records[0].VectorName; + } + + return SVFS::M->getFunction(FuncName); +} + +} \ No newline at end of file Index: llvm/unittests/Analysis/CMakeLists.txt =================================================================== --- llvm/unittests/Analysis/CMakeLists.txt +++ llvm/unittests/Analysis/CMakeLists.txt @@ -28,6 +28,7 @@ PhiValuesTest.cpp ProfileSummaryInfoTest.cpp ScalarEvolutionTest.cpp + SearchVectorFunctionSystemTest.cpp SparsePropagation.cpp TargetLibraryInfoTest.cpp TBAATest.cpp Index: llvm/unittests/Analysis/SearchVectorFunctionSystemTest.cpp =================================================================== --- /dev/null +++ llvm/unittests/Analysis/SearchVectorFunctionSystemTest.cpp @@ -0,0 +1,290 @@ +//===------- SearchVectorFunctionSystemTest.cpp - SVFS Unittests ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "llvm/Analysis/SearchVectorFunctionSystem.h" +#include "llvm/AsmParser/Parser.h" +#include "llvm/IR/Dominators.h" +#include "llvm/Support/SourceMgr.h" +#include "llvm/Support/raw_ostream.h" +#include "gtest/gtest.h" + +using namespace llvm; + +class SearchVectorFunctionSystemTest : public testing::Test { +protected: + SearchVectorFunctionSystemTest() : M(new Module("MyModule", Ctx)) { + FArgTypes.push_back(Type::getInt32Ty(Ctx)); + FArgTypes.push_back(Type::getInt32Ty(Ctx)); + FunctionType *FTy = + FunctionType::get(Type::getInt32Ty(Ctx), FArgTypes, false); + F = Function::Create(FTy, Function::ExternalLinkage, "sin", M.get()); + } + + LLVMContext Ctx; + std::unique_ptr M; + SmallVector FArgTypes; + Function *F; +}; + +TEST_F(SearchVectorFunctionSystemTest, TableCreation) { + SVFS S(M.get()); + + // Tests the complete functionality of table creation + Value *Args[] = {ConstantInt::get(Type::getInt32Ty(Ctx), 20), + ConstantInt::get(Type::getInt32Ty(Ctx), 20)}; + std::unique_ptr Caller(CallInst::Create(F, Args)); + CallInst *Call = Caller.get(); + F->addFnAttr("vector-function-abi-variant", "_ZGVnN2vv_sin(another_vector)," + "_ZGVnN8l8v_sin(sinVector),_ZGVnN4vl_sin"); + + S.createTableLookupRecord(Call); + + // Checks for "_ZGVnN2vv_sin(another_vector)" + EXPECT_EQ(S.RecordTable[0].VFS.VF, (unsigned) 2); + EXPECT_FALSE(S.RecordTable[0].VFS.IsMasked); + EXPECT_EQ(S.RecordTable[0].VFS.ISA, ISAKind::ISA_AdvancedSIMD); + + EXPECT_TRUE(S.RecordTable[0].VFS.Parameters[0].isVector()); + EXPECT_EQ(S.RecordTable[0].VFS.Parameters[0].getParamPos(), (unsigned) 0); + + EXPECT_TRUE(S.RecordTable[0].VFS.Parameters[1].isVector()); + EXPECT_EQ(S.RecordTable[0].VFS.Parameters[1].getParamPos(), (unsigned) 1); + + EXPECT_EQ(S.RecordTable[0].ScalarName, "sin"); + EXPECT_EQ(S.RecordTable[0].VectorName, "another_vector"); + + // Checks for "_ZGVnN8l8v_sin(sinVector)" + EXPECT_EQ(S.RecordTable[1].VFS.VF, (unsigned) 8); + EXPECT_FALSE(S.RecordTable[1].VFS.IsMasked); + EXPECT_EQ(S.RecordTable[1].VFS.ISA, ISAKind::ISA_AdvancedSIMD); + + EXPECT_TRUE(S.RecordTable[1].VFS.Parameters[0].isLinear()); + EXPECT_EQ(S.RecordTable[1].VFS.Parameters[0].getParamPos(), (unsigned) 0); + EXPECT_EQ(S.RecordTable[1].VFS.Parameters[0].getLinearStepOrPos(), 8); + + EXPECT_TRUE(S.RecordTable[1].VFS.Parameters[1].isVector()); + EXPECT_EQ(S.RecordTable[1].VFS.Parameters[1].getParamPos(), (unsigned) 1); + + EXPECT_EQ(S.RecordTable[1].ScalarName, "sin"); + EXPECT_EQ(S.RecordTable[1].VectorName, "sinVector"); + + // Checks for "_ZGVnN4vl_sin" + EXPECT_EQ(S.RecordTable[2].VFS.VF, (unsigned) 4); + EXPECT_FALSE(S.RecordTable[2].VFS.IsMasked); + EXPECT_EQ(S.RecordTable[2].VFS.ISA, ISAKind::ISA_AdvancedSIMD); + + EXPECT_TRUE(S.RecordTable[2].VFS.Parameters[0].isVector()); + EXPECT_EQ(S.RecordTable[2].VFS.Parameters[0].getParamPos(), (unsigned) 0); + + EXPECT_TRUE(S.RecordTable[2].VFS.Parameters[1].isLinear()); + EXPECT_EQ(S.RecordTable[2].VFS.Parameters[1].getParamPos(), (unsigned) 1); + EXPECT_EQ(S.RecordTable[2].VFS.Parameters[1].getLinearStepOrPos(), 1); + + EXPECT_EQ(S.RecordTable[2].ScalarName, "sin"); + EXPECT_EQ(S.RecordTable[2].VectorName, "_ZGVnN4vl_sin"); + +} + +TEST_F(SearchVectorFunctionSystemTest, NameDemangling) { + SVFS S(M.get()); + + SVFS::VectorRecord Output; + + S.demangleName("_ZGVnN2vl_sin", Output); + + EXPECT_EQ(Output.VFS.VF, (unsigned) 2); + EXPECT_FALSE(Output.VFS.IsMasked); + EXPECT_EQ(Output.VFS.ISA, ISAKind::ISA_AdvancedSIMD); + + // Does call parseParamList() + EXPECT_TRUE(Output.VFS.Parameters[0].isVector()); + EXPECT_EQ(Output.VFS.Parameters[0].getParamPos(), (unsigned) 0); + + EXPECT_EQ(Output.VFS.Parameters[1].getLinearStepOrPos(), 1); + EXPECT_TRUE(Output.VFS.Parameters[1].isLinear()); + EXPECT_EQ(Output.VFS.Parameters[1].getParamPos(), (unsigned) 1); + + EXPECT_EQ(Output.ScalarName, "sin") ; + EXPECT_EQ(Output.VectorName, "_ZGVnN2vl_sin"); + +} + +TEST_F(SearchVectorFunctionSystemTest, ParamListParsing) { + SVFS S(M.get()); + + std::vector Output; + + VFABI::parseParamList("vl16Ls32R3l", Output); + + EXPECT_EQ(Output[0].getParamPos(), (unsigned) 0); + EXPECT_TRUE(Output[0].isVector()); + + EXPECT_EQ(Output[1].getParamPos(), (unsigned) 1); + EXPECT_TRUE(Output[1].isLinear()); + EXPECT_EQ(Output[1].getLinearStepOrPos(), 16); + + EXPECT_EQ(Output[2].getParamPos(), (unsigned) 2); + EXPECT_TRUE(Output[2].isLinearValPos()); + EXPECT_EQ(Output[2].getLinearStepOrPos(), 32); + + EXPECT_EQ(Output[3].getParamPos(), (unsigned) 3); + EXPECT_TRUE(Output[3].isLinearRef()); + EXPECT_EQ(Output[3].getLinearStepOrPos(), 3); + + EXPECT_EQ(Output[4].getParamPos(), (unsigned) 4); + EXPECT_TRUE(Output[4].isLinear()); + EXPECT_EQ(Output[4].getLinearStepOrPos(), 1); + +} + +TEST_F(SearchVectorFunctionSystemTest, VFABITest) { + SVFS S(M.get()); + + // Getting ISA Kind + EXPECT_EQ(VFABI::getISA("_ZGVnN2_sin"), ISAKind::ISA_AdvancedSIMD); + EXPECT_EQ(VFABI::getISA("_ZGVsN2_sin"), ISAKind::ISA_SVE); + EXPECT_EQ(VFABI::getISA("_ZGVbN2_sin"), ISAKind::ISA_SSE); + EXPECT_EQ(VFABI::getISA("_ZGVcN2_sin"), ISAKind::ISA_AVX); + EXPECT_EQ(VFABI::getISA("_ZGVdN2_sin"), ISAKind::ISA_AVX2); + EXPECT_EQ(VFABI::getISA("_ZGVeN2_sin"), ISAKind::ISA_AVX512); + + // Getting Mask Value + EXPECT_TRUE(VFABI::getIsMasked("_ZGVnM2v_sin")); + EXPECT_FALSE(VFABI::getIsMasked("_ZGVnN2v_sin")); + EXPECT_TRUE(VFABI::getIsMasked("_ZGVsN2v_sin")); + EXPECT_TRUE(VFABI::getIsMasked("_ZGVsM2v_sin")); + + // Getting Scalable Value + EXPECT_TRUE(VFABI::getIsScalable("_ZGVsMxv_sin")); + EXPECT_FALSE(VFABI::getIsScalable("_ZGVsM2v_sin")); + + // Getting VF + EXPECT_EQ(VFABI::getVF("_ZGVnM2v_sin"), (unsigned) 2); + EXPECT_EQ(VFABI::getVF("_ZGVnM22v_sin"), (unsigned) 22); + EXPECT_EQ(VFABI::getVF("_ZGVsMxv_sin"), (unsigned) 0); + + // Getting Scalar Name + EXPECT_EQ(VFABI::getScalarName("_ZGVnM2v_sin"), "sin"); + EXPECT_EQ(VFABI::getScalarName("_ZGVnM2v_sin(UserFunc)"), "sin"); + EXPECT_EQ(VFABI::getScalarName("_ZGVnM2v___sin_sin_sin"), "__sin_sin_sin"); + + // Getting Vector Name + EXPECT_EQ(VFABI::getVectorName("_ZGVnM2v_sin"), "_ZGVnM2v_sin"); + EXPECT_EQ(VFABI::getVectorName("_ZGVnM2v_sin(UserFunc)"), "UserFunc"); + +} + +TEST_F(SearchVectorFunctionSystemTest, isFunctionVectorizable) { + SVFS S(M.get()); + + Value *Args[] = {ConstantInt::get(Type::getInt32Ty(Ctx), 20), + ConstantInt::get(Type::getInt32Ty(Ctx), 20)}; + std::unique_ptr Caller(CallInst::Create(F, Args)); + CallInst *Call = Caller.get(); + + F->addFnAttr("vector-function-abi-variant", + "_ZGVnN2vv_cos(another_vector), _ZGVnN2vv_sin," + " _ZGVnN4vl_sin(sinVector),_ZGVnN4vl_sin"); + + S.createTableLookupRecord(Call); + std::vector AvailableVFS; + AvailableVFS = S.isFunctionVectorizable(Call); + + EXPECT_EQ(AvailableVFS.size(), (unsigned) 3); + EXPECT_EQ(AvailableVFS[0], S.RecordTable[1].VFS); + EXPECT_EQ(AvailableVFS[1], S.RecordTable[2].VFS); + EXPECT_EQ(AvailableVFS[2], S.RecordTable[3].VFS); + + const ParamType Param1((unsigned) 0, "v"); + const ParamType Param2((unsigned) 1, "v"); + const ParamType Param3((unsigned) 1, "l", 1); + + // Corresponding VFS for _ZGVnN2vv + std::vector ParamList1{{Param1, Param2}}; + SVFS::VectorFunctionShape ExpectedVFS1{(unsigned) 2 /* VF */, + false /* IsMasked */, + false /* IsScalable */, + ISAKind::ISA_AdvancedSIMD, + ParamList1}; + + // Corresponding VFS for _ZGVnN4vl_sin + std::vector ParamList2{{Param1, Param3}}; + SVFS::VectorFunctionShape ExpectedVFS2{(unsigned) 4 /* VF */, + false /* IsMasked */, + false /* IsScalable */, + ISAKind::ISA_AdvancedSIMD, + ParamList2}; + + EXPECT_EQ(AvailableVFS[0], ExpectedVFS1); + EXPECT_EQ(AvailableVFS[1], ExpectedVFS2); + EXPECT_EQ(AvailableVFS[2], ExpectedVFS2); + +} + +TEST_F(SearchVectorFunctionSystemTest, getVectorizedFunction) { + SVFS S(M.get()); + + Value *Args[] = {ConstantInt::get(Type::getInt32Ty(Ctx), 20), + ConstantInt::get(Type::getInt32Ty(Ctx), 20)}; + std::unique_ptr Caller(CallInst::Create(F, Args)); + CallInst *Call = Caller.get(); + + F->addFnAttr("vector-function-abi-variant", + "_ZGVnN2vv_cos(another_vector),_ZGVnN2vv_sin," + "_ZGVnN4vl_sin(sinVector), _ZGVnN4vl_sin," + " _ZGVnN16ll_sin"); + + const ParamType Param1((unsigned) 0, "v"); + const ParamType Param2((unsigned) 0, "l", 1); + const ParamType Param3((unsigned) 1, "l", 1); + + // Corresponding VFS for _ZGVnN4vl_sin + std::vector ParamList1{{Param1, Param3}}; + SVFS::VectorFunctionShape VFS1{(unsigned) 4 /* VF */, + false /* IsMasked */, + false /* IsScalable */, + ISAKind::ISA_AdvancedSIMD, + ParamList1}; + + Type* Tys0[] = {VectorType::get(FArgTypes[0], 4), + FArgTypes[1]}; + FunctionType* VectorFTy0 = + FunctionType::get(VectorType::get(Type::getInt32Ty(Ctx), 4), + Tys0, false); + Function* VectorFunc0 = Function::Create(VectorFTy0, + Function::ExternalLinkage, + "sinVector", M.get()); + + S.createTableLookupRecord(Call); + // Checking for "_ZGVnN4vl_sin" + Function *ResFunc0 = S.getVectorizedFunction(Call, VFS1); + // As the given VFS has two possible options, + // we return the one with custom name + EXPECT_EQ(ResFunc0, VectorFunc0); + + // Corresponding VFS for _ZGVnN16ll_sin + std::vector ParamList2{{Param2, Param3}}; + SVFS::VectorFunctionShape VFS2{(unsigned) 16 /* VF */, + false /* IsMasked */, + false /* IsScalable */, + ISAKind::ISA_AdvancedSIMD, + ParamList2}; + + FunctionType *VectorFTy1 = + FunctionType::get(VectorType::get(Type::getInt32Ty(Ctx), 16), + FArgTypes, false); + Function *VectorFunc1 = Function::Create(VectorFTy1, + Function::ExternalLinkage, + "_ZGVnN16ll_sin", M.get()); + + // Checking for "_ZGVnN16ll_sin" + Function *ResFunc1 = S.getVectorizedFunction(Call, VFS2); + // Only one record exists for the given VFS + EXPECT_EQ(ResFunc1, VectorFunc1); + +} \ No newline at end of file