diff --git a/llvm/include/llvm/Analysis/VecFuncs.def b/llvm/include/llvm/Analysis/VecFuncs.def --- a/llvm/include/llvm/Analysis/VecFuncs.def +++ b/llvm/include/llvm/Analysis/VecFuncs.def @@ -644,6 +644,9 @@ TLI_DEFINE_VECFUNC("exp10", "_ZGVsMxv_exp10", SCALABLE(2), MASKED) TLI_DEFINE_VECFUNC("exp10f", "_ZGVsMxv_exp10f", SCALABLE(4), MASKED) +TLI_DEFINE_VECFUNC("fmod", "_ZGVsMxvv_fmod", SCALABLE(2), MASKED) +TLI_DEFINE_VECFUNC("fmodf", "_ZGVsMxvv_fmodf", SCALABLE(4), MASKED) + TLI_DEFINE_VECFUNC("lgamma", "_ZGVsMxv_lgamma", SCALABLE(2), MASKED) TLI_DEFINE_VECFUNC("lgammaf", "_ZGVsMxv_lgammaf", SCALABLE(4), MASKED) diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp --- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -38,6 +38,110 @@ STATISTIC(NumFuncUsedAdded, "Number of functions added to `llvm.compiler.used`"); +static void replaceWithNewCallInst(Instruction &I, CallInst *Replacement, + const StringRef OldName, + const StringRef TLIName) { + I.replaceAllUsesWith(Replacement); + if (isa(Replacement)) { + // Preserve fast math flags for FP math. + Replacement->copyFastMathFlags(&I); + } + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << OldName + << "` with call to `" << TLIName << "`.\n"); + ++NumCallsReplaced; +} + +static void addFunctionToCompilerUsed(Module *M, Function *TLIFunc, + const StringRef TLIName) { + ++NumTLIFuncDeclAdded; + + // Add the freshly created function to llvm.compiler.used, + // similar to as it is done in InjectTLIMappings + appendToCompilerUsed(*M, {TLIFunc}); + + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName + << "` to `@llvm.compiler.used`.\n"); + ++NumFuncUsedAdded; +} + +static bool replaceFremWithTLIFunction(Instruction &I, const StringRef TLIName, + ElementCount NumElements, + Type *ElementType, bool Masked = false) { + Module *M = I.getModule(); + IRBuilder<> IRBuilder(&I); + + // Check if the vector library function is already declared in this module, + // otherwise insert it. + Function *TLIFunc = M->getFunction(TLIName); + if (!TLIFunc) { + FunctionType *FTy = nullptr; + Type *RetTy = I.getType(); + SmallVector Tys = {RetTy, RetTy}; + if (Masked) + Tys.push_back(ToVectorTy(IRBuilder.getInt1Ty(), NumElements)); + FTy = FunctionType::get(RetTy, Tys, false); + TLIFunc = Function::Create(FTy, Function::ExternalLinkage, TLIName, *M); + + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `" + << TLIName << "` of type `" << *(TLIFunc->getType()) + << "` to module.\n"); + addFunctionToCompilerUsed(M, TLIFunc, TLIName); + } + // Replace the call to the frem instruction with a call + // to the corresponding function from the vector library. + SmallVector Args(I.operand_values()); + if (Masked) { + Value *AllActiveMask = ConstantInt::getTrue(VectorType::get( + IntegerType::getInt1Ty(TLIFunc->getType()->getContext()), NumElements)); + Args.push_back(AllActiveMask); + } + CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args); + replaceWithNewCallInst(I, Replacement, I.getOpcodeName(), TLIName); + + return true; +} + +static bool replaceFremWithCallToVeclib(const TargetLibraryInfo &TLI, + Instruction &I) { + auto *VectorArgTy = dyn_cast(I.getType()); + if (!VectorArgTy) { + // We have TLI mappings for FREM on scalable vectors only. + return false; + } + ElementCount NumElements = VectorArgTy->getElementCount(); + Type *ElementType = VectorArgTy->getElementType(); + StringRef ScalarName = + (ElementType->isFloatTy()) + ? TLI.getName(LibFunc_fmodf) + : ((ElementType->isDoubleTy()) ? TLI.getName(LibFunc_fmod) : ""); + if (ScalarName.empty()) + return false; + if (!TLI.isFunctionVectorizable(ScalarName)) { + // The TargetLibraryInfo does not contain a vectorized version of + // the scalar function. + return false; + } + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `" + << ScalarName << "` and vector width " << NumElements + << ".\n"); + std::string TLIName = + std::string(TLI.getVectorizedFunction(ScalarName, NumElements)); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found unmasked TLI function `" + << TLIName << "`.\n"); + return replaceFremWithTLIFunction(I, TLIName, NumElements, ElementType); + } + TLIName = + std::string(TLI.getVectorizedFunction(ScalarName, NumElements, true)); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found masked TLI function `" + << TLIName << "`.\n"); + return replaceFremWithTLIFunction(I, TLIName, NumElements, ElementType, + true); + } + return false; +} + static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) { Module *M = CI.getModule(); @@ -55,15 +159,7 @@ << TLIName << "` of type `" << *(TLIFunc->getType()) << "` to module.\n"); - ++NumTLIFuncDeclAdded; - - // Add the freshly created function to llvm.compiler.used, - // similar to as it is done in InjectTLIMappings - appendToCompilerUsed(*M, {TLIFunc}); - - LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName - << "` to `@llvm.compiler.used`.\n"); - ++NumFuncUsedAdded; + addFunctionToCompilerUsed(M, TLIFunc, TLIName); } // Replace the call to the vector intrinsic with a call @@ -76,26 +172,25 @@ CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() && "Expecting function types to be identical"); - CI.replaceAllUsesWith(Replacement); - if (isa(Replacement)) { - // Preserve fast math flags for FP math. - Replacement->copyFastMathFlags(&CI); - } + replaceWithNewCallInst(CI, Replacement, OldFunc->getName(), TLIName); - LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" - << OldFunc->getName() << "` with call to `" << TLIName - << "`.\n"); - ++NumCallsReplaced; return true; } static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, - CallInst &CI) { - if (!CI.getCalledFunction()) { - return false; + Instruction &I) { + + if (I.getOpcode() == Instruction::FRem) { + // Replacement can be performed for FRem instruction. + return replaceFremWithCallToVeclib(TLI, I); } + CallInst *CI = dyn_cast(&I); + if (!CI) + return false; + if (!CI->getCalledFunction()) + return false; - auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID(); + auto IntrinsicID = CI->getCalledFunction()->getIntrinsicID(); if (IntrinsicID == Intrinsic::not_intrinsic) { // Replacement is only performed for intrinsic functions return false; @@ -105,7 +200,7 @@ // all vector operands have identical vector width. ElementCount VF = ElementCount::getFixed(0); SmallVector ScalarTypes; - for (auto Arg : enumerate(CI.args())) { + for (auto Arg : enumerate(CI->args())) { auto *ArgType = Arg.value()->getType(); // Vector calls to intrinsics can still have // scalar operands for specific arguments. @@ -141,7 +236,7 @@ // converted to scalar above. std::string ScalarName; if (Intrinsic::isOverloaded(IntrinsicID)) { - ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule()); + ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI->getModule()); } else { ScalarName = Intrinsic::getName(IntrinsicID).str(); } @@ -167,7 +262,7 @@ // the vector library function. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName << "`.\n"); - return replaceWithTLIFunction(CI, TLIName); + return replaceWithTLIFunction(*CI, TLIName); } return false; @@ -175,19 +270,17 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { bool Changed = false; - SmallVector ReplacedCalls; + SmallVector ReplacedCalls; for (auto &I : instructions(F)) { - if (auto *CI = dyn_cast(&I)) { - if (replaceWithCallToVeclib(TLI, *CI)) { - ReplacedCalls.push_back(CI); - Changed = true; - } + if (replaceWithCallToVeclib(TLI, I)) { + ReplacedCalls.push_back(&I); + Changed = true; } } - // Erase the calls to the intrinsics that have been replaced - // with calls to the vector library. - for (auto *CI : ReplacedCalls) { - CI->eraseFromParent(); + // Erase the calls to the intrinsics and the frem instructions that have been + // replaced with calls to the vector library. + for (auto *I : ReplacedCalls) { + I->eraseFromParent(); } return Changed; } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll @@ -377,4 +377,26 @@ ret %1 } +; NOTE: TLI mappings for FREM instruction. + +define @frem_vscale_f64( %in1, %in2) #0 { +; CHECK-LABEL: define @frem_vscale_f64 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[TMP1:%.*]] = call fast @armpl_svfmod_f64_x( [[IN1]], [[IN2]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @frem_vscale_f32( %in1, %in2) #0 { +; CHECK-LABEL: define @frem_vscale_f32 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[TMP1:%.*]] = call fast @armpl_svfmod_f32_x( [[IN1]], [[IN2]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + attributes #0 = { "target-features"="+sve" } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll @@ -365,6 +365,26 @@ ret %1 } +; NOTE: TLI mapping for FREM instruction. + +define @frem_vscale_f64( %in1, %in2) { +; CHECK-LABEL: @frem_vscale_f64( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmod( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @frem_vscale_f32( %in1, %in2) { +; CHECK-LABEL: @frem_vscale_f32( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmodf( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + declare @llvm.ceil.nxv2f64() declare @llvm.ceil.nxv4f32() declare @llvm.copysign.nxv2f64(, ) diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sleef-calls-aarch64.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sleef-calls-aarch64.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/sleef-calls-aarch64.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/sleef-calls-aarch64.ll @@ -1,6 +1,6 @@ ; Do NOT use -O3. It will lower exp2 to ldexp, and the test will fail. -; RUN: opt -vector-library=sleefgnuabi -replace-with-veclib < %s | opt -vector-library=sleefgnuabi -passes=inject-tli-mappings,loop-unroll,loop-vectorize -S | FileCheck %s --check-prefixes=CHECK,NEON -; RUN: opt -mattr=+sve -vector-library=sleefgnuabi -replace-with-veclib < %s | opt -vector-library=sleefgnuabi -passes=inject-tli-mappings,loop-unroll,loop-vectorize -S | FileCheck %s --check-prefixes=CHECK,SVE +; RUN: opt -vector-library=sleefgnuabi -passes=inject-tli-mappings,loop-unroll,loop-vectorize -S < %s | FileCheck %s --check-prefixes=CHECK,NEON +; RUN: opt -mattr=+sve -vector-library=sleefgnuabi -passes=inject-tli-mappings,loop-unroll,loop-vectorize -S < %s | FileCheck %s --check-prefixes=CHECK,SVE target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" target triple = "aarch64-unknown-linux-gnu" @@ -535,6 +535,55 @@ ret void } +declare double @fmod(double, double) #0 +declare float @fmodf(float, float) #0 + +define void @fmod_f64(double* nocapture %varray) { + ; CHECK-LABEL: @fmod_f64( + ; SVE: [[TMP5:%.*]] = call @_ZGVsMxvv_fmod( [[TMP4:%.*]], [[TMP4:%.*]], {{.*}}) + ; CHECK: ret void + ; + entry: + br label %for.body + + for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %tmp = trunc i64 %iv to i32 + %conv = sitofp i32 %tmp to double + %call = tail call double @fmod(double %conv, double %conv) + %arrayidx = getelementptr inbounds double, double* %varray, i64 %iv + store double %call, double* %arrayidx, align 8 + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond = icmp eq i64 %iv.next, 1000 + br i1 %exitcond, label %for.end, label %for.body + + for.end: + ret void +} + +define void @fmod_f32(float* nocapture %varray) { + ; CHECK-LABEL: @fmod_f32( + ; SVE: [[TMP5:%.*]] = call @_ZGVsMxvv_fmodf( [[TMP4:%.*]], [[TMP4:%.*]], {{.*}}) + ; CHECK: ret void + ; + entry: + br label %for.body + + for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %tmp = trunc i64 %iv to i32 + %conv = sitofp i32 %tmp to float + %call = tail call float @fmodf(float %conv, float %conv) + %arrayidx = getelementptr inbounds float, float* %varray, i64 %iv + store float %call, float* %arrayidx, align 4 + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond = icmp eq i64 %iv.next, 1000 + br i1 %exitcond, label %for.end, label %for.body + + for.end: + ret void +} + declare double @lgamma(double) #0 declare float @lgammaf(float) #0 declare double @llvm.lgamma.f64(double) #0