diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -82,13 +82,39 @@ struct VFShape { unsigned VF; // Vectorization factor. bool IsScalable; // True if the function is a scalable function. - VFISAKind ISA; // Instruction Set Architecture. SmallVector Parameters; // List of parameter informations. // Comparison operator. bool operator==(const VFShape &Other) const { - return std::tie(VF, IsScalable, ISA, Parameters) == - std::tie(Other.VF, Other.IsScalable, Other.ISA, Other.Parameters); + return std::tie(VF, IsScalable, Parameters) == + std::tie(Other.VF, Other.IsScalable, Other.Parameters); } + /// Update the parameter in position P.ParamPos to P. + void updateParam(VFParameter P) { + assert(P.ParamPos < Parameters.size() && "Invalid parameter position."); + assert((P.ParamKind != VFParamKind::GlobalPredicate || + P.ParamPos == (Parameters.size() - 1)) && + "Global predicate parameter must be the last parameter"); + Parameters[P.ParamPos] = P; + } + // Retrieve a vectorization shape of the function, where all + // parameters are mapped to VFParamKind::Vector with \p VF + // lanes. Specifies whether the function is a scalable vector + // function via \p IsScalable, or if it has a Global Predicate + // argument via \p HasGlobalPred. + static VFShape getAllVectorsParams(const CallInst &CI, const unsigned VF, + const bool IsScalable, + const bool HasGlobalPred) { + SmallVector Parameters; + for (unsigned I = 0; I < CI.arg_size(); ++I) + Parameters.push_back(VFParameter({I, VFParamKind::Vector})); + if (HasGlobalPred) + Parameters.push_back( + VFParameter({CI.arg_size(), VFParamKind::GlobalPredicate})); + + return {VF, IsScalable, Parameters}; + } + /// Sanity check on the VFShape. + bool isValid() const; }; /// Holds the VFShape for a specific scalar to vector function mapping. @@ -96,11 +122,12 @@ VFShape Shape; // Classification of the vector function. StringRef ScalarName; // Scalar Function Name. StringRef VectorName; // Vector Function Name associated to this VFInfo. + VFISAKind ISA; // Instruction Set Architecture. // Comparison operator. bool operator==(const VFInfo &Other) const { - return std::tie(Shape, ScalarName, VectorName) == - std::tie(Shape, Other.ScalarName, Other.VectorName); + return std::tie(Shape, ScalarName, VectorName, ISA) == + std::tie(Shape, Other.ScalarName, Other.VectorName, Other.ISA); } }; diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp --- a/llvm/lib/Analysis/VFABIDemangling.cpp +++ b/llvm/lib/Analysis/VFABIDemangling.cpp @@ -402,8 +402,8 @@ assert(Parameters.back().ParamKind == VFParamKind::GlobalPredicate && "The global predicate must be the last parameter"); - const VFShape Shape({VF, IsScalable, ISA, Parameters}); - return VFInfo({Shape, ScalarName, VectorName}); + const VFShape Shape({VF, IsScalable, Parameters}); + return VFInfo({Shape, ScalarName, VectorName, ISA}); } VFParamKind VFABI::getVFParamKindFromString(const StringRef Token) { diff --git a/llvm/lib/Analysis/VectorUtils.cpp b/llvm/lib/Analysis/VectorUtils.cpp --- a/llvm/lib/Analysis/VectorUtils.cpp +++ b/llvm/lib/Analysis/VectorUtils.cpp @@ -1182,3 +1182,38 @@ VariantMappings.push_back(S); } } + +bool VFShape::isValid() const { + for (const VFParameter &VFP : Parameters) + switch (VFP.ParamKind) { + default: // No checks + break; + case VFParamKind::OMP_LinearRef: + case VFParamKind::OMP_LinearVal: + case VFParamKind::OMP_LinearUVal: + // Compile time linear steps must be non-zero. + if (VFP.LinearStepOrPos == 0) + return false; + break; + case VFParamKind::OMP_LinearRefPos: + case VFParamKind::OMP_LinearValPos: + case VFParamKind::OMP_LinearUValPos: + // The runtime linear step must be referring to some other + // parameters in the signature. + if (VFP.LinearStepOrPos >= (int)Parameters.size()) + return false; + // The linear step parameter must be marked as uniform. + if (Parameters[VFP.LinearStepOrPos].ParamKind != VFParamKind::OMP_Uniform) + return false; + break; + case VFParamKind::GlobalPredicate: + // The global predicate must be the last one, as required by the + // Vector Function ABIs supported by LLVM (x86, AArch64, LLVM + // internal). + if (VFP.ParamPos != Parameters.size() - 1) + return false; + break; + } + + return true; +} diff --git a/llvm/unittests/Analysis/VectorFunctionABITest.cpp b/llvm/unittests/Analysis/VectorFunctionABITest.cpp --- a/llvm/unittests/Analysis/VectorFunctionABITest.cpp +++ b/llvm/unittests/Analysis/VectorFunctionABITest.cpp @@ -89,7 +89,7 @@ protected: // Referencies to the parser output field. unsigned &VF = Info.Shape.VF; - VFISAKind &ISA = Info.Shape.ISA; + VFISAKind &ISA = Info.ISA; SmallVector &Parameters = Info.Shape.Parameters; StringRef &ScalarName = Info.ScalarName; StringRef &VectorName = Info.VectorName; diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp --- a/llvm/unittests/Analysis/VectorUtilsTest.cpp +++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp @@ -279,3 +279,112 @@ "}\n"); EXPECT_EQ(getSplatValue(A), nullptr); } + +//////////////////////////////////////////////////////////////////////////////// +// VFShape API +//////////////////////////////////////////////////////////////////////////////// + +class VFShapeAPITest : public testing::Test { +protected: + void SetUp() override { + M = parseAssemblyString(IR, Err, Ctx); + // Get the only call instruction in the block, which is the first + // instruction. + CI = dyn_cast(&*(instructions(M->getFunction("f")).begin())); + } + const char *IR = "define i32 @f(i32 %a, i64 %b, double %c) {\n" + " %1 = call i32 @g(i32 %a, i64 %b, double %c)\n" + " ret i32 %1\n" + "}\n" + "declare i32 @g(i32, i64, double)\n"; + LLVMContext Ctx; + SMDiagnostic Err; + std::unique_ptr M; + CallInst *CI; + VFShape Shape; + VFShape Expected; + void buildShape(unsigned VF, bool IsScalable, bool HasGlobalPred) { + Shape = VFShape::getAllVectorsParams(*CI, VF, IsScalable, HasGlobalPred); + } + + void check() { EXPECT_EQ(Shape, Expected); } +}; + +TEST_F(VFShapeAPITest, API_buildVFShape) { + buildShape(/*VF*/ 2, /*IsScalable*/ false, /*HasGlobalPred*/ false); + Expected = {/*VF*/ 2, /*IsScalable*/ false, /*HasGlobalPred*/ { + {0, VFParamKind::Vector}, + {1, VFParamKind::Vector}, + {2, VFParamKind::Vector}, + }}; + check(); + + buildShape(/*VF*/ 4, /*IsScalable*/ false, /*HasGlobalPred*/ true); + Expected = {/*VF*/ 4, /*IsScalable*/ false, /*HasGlobalPred*/ { + {0, VFParamKind::Vector}, + {1, VFParamKind::Vector}, + {2, VFParamKind::Vector}, + {3, VFParamKind::GlobalPredicate}, + }}; + check(); + + buildShape(/*VF*/ 16, /*IsScalable*/ true, /*HasGlobalPred*/ false); + Expected = {/*VF*/ 16, /*IsScalable*/ true, /*HasGlobalPred*/ { + {0, VFParamKind::Vector}, + {1, VFParamKind::Vector}, + {2, VFParamKind::Vector}, + }}; + check(); +} + +TEST_F(VFShapeAPITest, API_updateVFShape) { + + buildShape(/*VF*/ 2, /*IsScalable*/ false, /*HasGlobalPred*/ false); + Expected = {/*VF*/ 2, /*IsScalable*/ false, /*HasGlobalPred*/ { + {0, VFParamKind::Vector}, + {1, VFParamKind::Vector}, + {2, VFParamKind::Vector}, + }}; + check(); + + VFParameter VFP = {0 /*Pos*/, VFParamKind::OMP_Linear, 0, Align(4)}; + Shape.updateParam(VFP); + Expected = {/*VF*/ 2, /*IsScalable*/ false, /*HasGlobalPred*/ { + {0, VFParamKind::OMP_Linear, 0, Align(4)}, + {1, VFParamKind::Vector}, + {2, VFParamKind::Vector}, + }}; + check(); + + VFP = {1 /*Pos*/, VFParamKind::OMP_Uniform}; + Shape.updateParam(VFP); + Expected = {/*VF*/ 2, /*IsScalable*/ false, /*HasGlobalPred*/ { + {0, VFParamKind::OMP_Linear, 0, Align(4)}, + {1, VFParamKind::OMP_Uniform}, + {2, VFParamKind::OMP_LinearRef, 1}, + }}; + + VFP = {2 /*Pos*/, VFParamKind::OMP_LinearRefPos, 1}; + Shape.updateParam(VFP); + Expected = {/*VF*/ 2, /*IsScalable*/ false, /*HasGlobalPred*/ { + {0, VFParamKind::OMP_Linear, 0, Align(4)}, + {1, VFParamKind::OMP_Uniform}, + {2, VFParamKind::OMP_LinearRefPos, 1}, + }}; + check(); +} + +TEST_F(VFShapeAPITest, API_updateVFShape_GlobalPredicate) { + + buildShape(/*VF*/ 2, /*IsScalable*/ true, /*HasGlobalPred*/ true); + VFParameter VFP = {1 /*Pos*/, VFParamKind::OMP_Uniform}; + Shape.updateParam(VFP); + Expected = {/*VF*/ 2, /*IsScalable*/ true, + /*HasGlobalPred*/ {{0, VFParamKind::Vector}, + {1, VFParamKind::OMP_Uniform}, + {2, VFParamKind::Vector}, + {3, VFParamKind::GlobalPredicate}}}; + check(); +} + +// TODO: add testing for VFShape.isValid();