diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -96,6 +96,12 @@ assert(hasValidParameterList() && "Invalid parameter list"); } + // Retrieve the VFShape that can be used to map a (scalar) function to itself, + // with VF = 1. + static VFShape getScalarShape(const CallInst &CI) { + return VFShape::get(CI, /*EC*/ {1, false}, /*HasGlobalPredicate*/ false); + } + // Retrieve the basic vectorization shape of the function, where all // parameters are mapped to VFParamKind::Vector with \p EC // lanes. Specifies whether the function has a Global Predicate @@ -186,6 +192,8 @@ class VFDatabase { /// The Module of the CallInst CI. const Module *M; + /// The CallInst instance being queried for scalar to vector mappings. + const CallInst &CI; /// List of vector functions descritors associated to the call /// instruction. const SmallVector ScalarToVectorMappings; @@ -233,13 +241,16 @@ /// Constructor, requires a CallInst instance. VFDatabase(CallInst &CI) - : M(CI.getModule()), ScalarToVectorMappings(VFDatabase::getMappings(CI)) { - } + : M(CI.getModule()), CI(CI), + ScalarToVectorMappings(VFDatabase::getMappings(CI)) {} /// \defgroup VFDatabase query interface. /// /// @{ /// Retrieve the Function with VFShape \p Shape. Function *getVectorizedFunction(const VFShape &Shape) const { + if (Shape == VFShape::getScalarShape(CI)) + return CI.getCalledFunction(); + for (const auto &Info : ScalarToVectorMappings) if (Info.Shape == Shape) return M->getFunction(Info.VectorName); diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -4382,22 +4382,17 @@ if (VF > 1) TysForDecl[0] = VectorType::get(CI->getType()->getScalarType(), VF); VectorF = Intrinsic::getDeclaration(M, ID, TysForDecl); + assert(VectorF && "Can't retrieve vector intrinsic."); } else { // Use vector version of the function call. const VFShape Shape = VFShape::get(*CI, {VF, false} /*EC*/, false /*HasGlobalPred*/); #ifndef NDEBUG - const SmallVector Infos = VFDatabase::getMappings(*CI); - assert(std::find_if(Infos.begin(), Infos.end(), - [&Shape](const VFInfo &Info) { - return Info.Shape == Shape; - }) != Infos.end() && - "Vector function shape is missing from the database."); + assert(VFDatabase(*CI).getVectorizedFunction(Shape) != nullptr && + "Can't create vector function."); #endif VectorF = VFDatabase(*CI).getVectorizedFunction(Shape); } - assert(VectorF && "Can't create vector function."); - SmallVector OpBundles; CI->getOperandBundlesAsDefs(OpBundles); CallInst *V = Builder.CreateCall(VectorF, Args, OpBundles); diff --git a/llvm/test/Transforms/LoopVectorize/vectorizeVFone.ll b/llvm/test/Transforms/LoopVectorize/vectorizeVFone.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/LoopVectorize/vectorizeVFone.ll @@ -0,0 +1,29 @@ +; RUN: opt < %s -passes=loop-vectorize -S 2>&1 | FileCheck %s + +target triple = "powerpc64le-unknown-linux-gnu" +%type = type { [3 x double] } + +define void @getScalarFunc(double* %A, double* %C, %type* %B) #0 { +; CHECK-LABEL: getScalarFunc +; CHECK-NOT: call fast <{{[0-9]+}} x double> @{{.*}}atan{{.*}}(<{{[0-9]+}} x double> %{{[0-9]+}}) +entry: + br label %for.body + +for.body: + %i = phi i64 [ %inc, %for.body ], [ 0, %entry ] + %dummyload2 = load double, double* %A, align 8 + %arrayidx.i24 = getelementptr inbounds %type, %type* %B, i64 %i, i32 0, i32 0 + %_15 = load double, double* %arrayidx.i24, align 8 + %call10 = tail call fast double @atan(double %_15) #1 + %inc = add i64 %i, 1 + %cmp = icmp ugt i64 1000, %inc + br i1 %cmp, label %for.body, label %for.end + +for.end: + ret void +} + +declare double @atan(double) local_unnamed_addr +declare <2 x double> @__atand2_massv(<2 x double>) #1 +attributes #0 = { "target-cpu"="pwr9" } +attributes #1 = { nounwind readnone "vector-function-abi-variant"="_ZGV_LLVM_N2v_atan(__atand2_massv)" } diff --git a/llvm/unittests/Analysis/VectorUtilsTest.cpp b/llvm/unittests/Analysis/VectorUtilsTest.cpp --- a/llvm/unittests/Analysis/VectorUtilsTest.cpp +++ b/llvm/unittests/Analysis/VectorUtilsTest.cpp @@ -536,6 +536,24 @@ EXPECT_EQ(Shape, Expected); } +TEST_F(VFShapeAPITest, API_getScalarShape) { + buildShape(/*VF*/ 1, /*IsScalable*/ false, /*HasGlobalPred*/ false); + EXPECT_EQ(VFShape::getScalarShape(*CI), Shape); +} + +TEST_F(VFShapeAPITest, API_getVectorizedFunction) { + VFShape ScalarShape = VFShape::getScalarShape(*CI); + EXPECT_EQ(VFDatabase(*CI).getVectorizedFunction(ScalarShape), + M->getFunction("g")); + + buildShape(/*VF*/ 1, /*IsScalable*/ true, /*HasGlobalPred*/ false); + EXPECT_EQ((new VFDatabase(*CI))->getVectorizedFunction(Shape), nullptr); + buildShape(/*VF*/ 1, /*IsScalable*/ false, /*HasGlobalPred*/ true); + EXPECT_EQ((new VFDatabase(*CI))->getVectorizedFunction(Shape), nullptr); + buildShape(/*VF*/ 1, /*IsScalable*/ true, /*HasGlobalPred*/ true); + EXPECT_EQ((new VFDatabase(*CI))->getVectorizedFunction(Shape), nullptr); +} + TEST_F(VFShapeAPITest, API_updateVFShape) { buildShape(/*VF*/ 2, /*IsScalable*/ false, /*HasGlobalPred*/ false);