diff --git a/llvm/include/llvm/Analysis/VecFuncs.def b/llvm/include/llvm/Analysis/VecFuncs.def --- a/llvm/include/llvm/Analysis/VecFuncs.def +++ b/llvm/include/llvm/Analysis/VecFuncs.def @@ -644,6 +644,9 @@ TLI_DEFINE_VECFUNC("exp10", "_ZGVsMxv_exp10", SCALABLE(2), MASKED) TLI_DEFINE_VECFUNC("exp10f", "_ZGVsMxv_exp10f", SCALABLE(4), MASKED) +TLI_DEFINE_VECFUNC("fmod", "_ZGVsMxvv_fmod", SCALABLE(2), MASKED) +TLI_DEFINE_VECFUNC("fmodf", "_ZGVsMxvv_fmodf", SCALABLE(4), MASKED) + TLI_DEFINE_VECFUNC("lgamma", "_ZGVsMxv_lgamma", SCALABLE(2), MASKED) TLI_DEFINE_VECFUNC("lgammaf", "_ZGVsMxv_lgammaf", SCALABLE(4), MASKED) diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp --- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -38,23 +38,44 @@ STATISTIC(NumFuncUsedAdded, "Number of functions added to `llvm.compiler.used`"); -static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) { - Module *M = CI.getModule(); - - Function *OldFunc = CI.getCalledFunction(); +static bool replaceWithTLIFunction(Instruction &I, const StringRef TLIName, + ElementCount *NumElements = nullptr, + bool Masked = false) { + Module *M = I.getModule(); + IRBuilder<> IRBuilder(&I); + CallInst *CI = dyn_cast(&I); // Check if the vector library function is already declared in this module, // otherwise insert it. Function *TLIFunc = M->getFunction(TLIName); + std::string OldName; if (!TLIFunc) { - TLIFunc = Function::Create(OldFunc->getFunctionType(), - Function::ExternalLinkage, TLIName, *M); - TLIFunc->copyAttributesFrom(OldFunc); - + if (!CI) { + // FRem handling. + assert(I.getOpcode() == Instruction::FRem && + "Must be a FRem instruction."); + assert(NumElements != nullptr && "Vectorization factor is missing."); + OldName = I.getOpcodeName(); + Type *RetTy = I.getType(); + SmallVector Tys = {RetTy, RetTy}; + if (Masked) + Tys.push_back(ToVectorTy(IRBuilder.getInt1Ty(), *NumElements)); + TLIFunc = Function::Create(FunctionType::get(RetTy, Tys, false), + Function::ExternalLinkage, TLIName, *M); + } else { + // Intrinsics handling. + Function *OldFunc = CI->getCalledFunction(); + FunctionType *OldFuncTy = OldFunc->getFunctionType(); + OldName = OldFunc->getName(); + TLIFunc = + Function::Create(OldFuncTy, Function::ExternalLinkage, TLIName, *M); + TLIFunc->copyAttributesFrom(OldFunc); + assert(OldFuncTy == TLIFunc->getFunctionType() && + "Expecting function types to be identical"); + } LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `" << TLIName << "` of type `" << *(TLIFunc->getType()) << "` to module.\n"); - ++NumTLIFuncDeclAdded; // Add the freshly created function to llvm.compiler.used, @@ -65,37 +86,85 @@ << "` to `@llvm.compiler.used`.\n"); ++NumFuncUsedAdded; } - - // Replace the call to the vector intrinsic with a call + // Replace the call to the frem instruction/vector intrinsic with a call // to the corresponding function from the vector library. - IRBuilder<> IRBuilder(&CI); - SmallVector Args(CI.args()); - // Preserve the operand bundles. - SmallVector OpBundles; - CI.getOperandBundlesAsDefs(OpBundles); - CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); - assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() && - "Expecting function types to be identical"); - CI.replaceAllUsesWith(Replacement); - if (isa(Replacement)) { - // Preserve fast math flags for FP math. - Replacement->copyFastMathFlags(&CI); + CallInst *Replacement = nullptr; + if (!CI) { + // FRem handling. + SmallVector Args(I.operand_values()); + if (Masked) { + Value *AllActiveMask = IRBuilder.getAllOnesMask(*NumElements); + Args.push_back(AllActiveMask); + } + Replacement = IRBuilder.CreateCall(TLIFunc, Args); + } else { + // Intrinsics handling. + SmallVector Args(CI->args()); + // Preserve the operand bundles. + SmallVector OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); } - - LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" - << OldFunc->getName() << "` with call to `" << TLIName - << "`.\n"); + I.replaceAllUsesWith(Replacement); + // Preserve fast math flags for FP math. + if (isa(Replacement)) + Replacement->copyFastMathFlags(&I); + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << OldName + << "` with call to `" << TLIName << "`.\n"); ++NumCallsReplaced; return true; } -static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, - CallInst &CI) { - if (!CI.getCalledFunction()) { +static bool replaceFremWithCallToVeclib(const TargetLibraryInfo &TLI, + Instruction &I) { + auto *VectorArgTy = dyn_cast(I.getType()); + // We have TLI mappings for FRem on scalable vectors only. + if (!VectorArgTy) return false; + ElementCount NumElements = VectorArgTy->getElementCount(); + Type *ElementType = VectorArgTy->getElementType(); + StringRef ScalarName; + if (ElementType->isFloatTy()) + ScalarName = TLI.getName(LibFunc_fmodf); + else if (ElementType->isDoubleTy()) + ScalarName = TLI.getName(LibFunc_fmod); + else + return false; + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `" + << ScalarName << "` and vector width " << NumElements + << ".\n"); + std::string TLIName = + std::string(TLI.getVectorizedFunction(ScalarName, NumElements)); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found unmasked TLI function `" + << TLIName << "`.\n"); + return replaceWithTLIFunction(I, TLIName, &NumElements); } + TLIName = + std::string(TLI.getVectorizedFunction(ScalarName, NumElements, true)); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found masked TLI function `" + << TLIName << "`.\n"); + return replaceWithTLIFunction(I, TLIName, &NumElements, true); + } + LLVM_DEBUG(dbgs() << DEBUG_TYPE + << ": Did not find suitable vectorized version of `" + << ScalarName << "`.\n"); + return false; +} + +static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, + Instruction &I) { - auto IntrinsicID = CI.getCalledFunction()->getIntrinsicID(); + // Replacement can be performed for FRem instruction. + if (I.getOpcode() == Instruction::FRem) + return replaceFremWithCallToVeclib(TLI, I); + + CallInst *CI = dyn_cast(&I); + if (!CI || !CI->getCalledFunction()) + return false; + + auto IntrinsicID = CI->getCalledFunction()->getIntrinsicID(); if (IntrinsicID == Intrinsic::not_intrinsic) { // Replacement is only performed for intrinsic functions return false; @@ -105,7 +174,7 @@ // all vector operands have identical vector width. ElementCount VF = ElementCount::getFixed(0); SmallVector ScalarTypes; - for (auto Arg : enumerate(CI.args())) { + for (auto Arg : enumerate(CI->args())) { auto *ArgType = Arg.value()->getType(); // Vector calls to intrinsics can still have // scalar operands for specific arguments. @@ -141,7 +210,7 @@ // converted to scalar above. std::string ScalarName; if (Intrinsic::isOverloaded(IntrinsicID)) { - ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI.getModule()); + ScalarName = Intrinsic::getName(IntrinsicID, ScalarTypes, CI->getModule()); } else { ScalarName = Intrinsic::getName(IntrinsicID).str(); } @@ -167,7 +236,7 @@ // the vector library function. LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI function `" << TLIName << "`.\n"); - return replaceWithTLIFunction(CI, TLIName); + return replaceWithTLIFunction(*CI, TLIName); } return false; @@ -175,20 +244,18 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { bool Changed = false; - SmallVector ReplacedCalls; + SmallVector ReplacedCalls; for (auto &I : instructions(F)) { - if (auto *CI = dyn_cast(&I)) { - if (replaceWithCallToVeclib(TLI, *CI)) { - ReplacedCalls.push_back(CI); - Changed = true; - } + if (replaceWithCallToVeclib(TLI, I)) { + ReplacedCalls.push_back(&I); + Changed = true; } } - // Erase the calls to the intrinsics that have been replaced - // with calls to the vector library. - for (auto *CI : ReplacedCalls) { - CI->eraseFromParent(); - } + // Erase the calls to the intrinsics and the frem instructions that have been + // replaced with calls to the vector library. + for (auto *I : ReplacedCalls) + I->eraseFromParent(); + return Changed; } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll @@ -377,4 +377,26 @@ ret %1 } +; NOTE: TLI mappings for FREM instruction. + +define @frem_vscale_f64( %in1, %in2) #0 { +; CHECK-LABEL: define @frem_vscale_f64 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[TMP1:%.*]] = call fast @armpl_svfmod_f64_x( [[IN1]], [[IN2]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @frem_vscale_f32( %in1, %in2) #0 { +; CHECK-LABEL: define @frem_vscale_f32 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[TMP1:%.*]] = call fast @armpl_svfmod_f32_x( [[IN1]], [[IN2]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + attributes #0 = { "target-features"="+sve" } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll @@ -365,6 +365,26 @@ ret %1 } +; NOTE: TLI mapping for FREM instruction. + +define @frem_vscale_f64( %in1, %in2) { +; CHECK-LABEL: @frem_vscale_f64( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmod( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @frem_vscale_f32( %in1, %in2) { +; CHECK-LABEL: @frem_vscale_f32( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmodf( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + declare @llvm.ceil.nxv2f64() declare @llvm.ceil.nxv4f32() declare @llvm.copysign.nxv2f64(, ) diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/sleef-calls-aarch64.ll b/llvm/test/Transforms/LoopVectorize/AArch64/sleef-calls-aarch64.ll --- a/llvm/test/Transforms/LoopVectorize/AArch64/sleef-calls-aarch64.ll +++ b/llvm/test/Transforms/LoopVectorize/AArch64/sleef-calls-aarch64.ll @@ -1,6 +1,6 @@ ; Do NOT use -O3. It will lower exp2 to ldexp, and the test will fail. -; RUN: opt -vector-library=sleefgnuabi -replace-with-veclib < %s | opt -vector-library=sleefgnuabi -passes=inject-tli-mappings,loop-unroll,loop-vectorize -S | FileCheck %s --check-prefixes=CHECK,NEON -; RUN: opt -mattr=+sve -vector-library=sleefgnuabi -replace-with-veclib < %s | opt -vector-library=sleefgnuabi -passes=inject-tli-mappings,loop-unroll,loop-vectorize -S | FileCheck %s --check-prefixes=CHECK,SVE +; RUN: opt -vector-library=sleefgnuabi -passes=inject-tli-mappings,loop-unroll,loop-vectorize -S < %s | FileCheck %s --check-prefixes=CHECK,NEON +; RUN: opt -mattr=+sve -vector-library=sleefgnuabi -passes=inject-tli-mappings,loop-unroll,loop-vectorize -S < %s | FileCheck %s --check-prefixes=CHECK,SVE target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" target triple = "aarch64-unknown-linux-gnu" @@ -535,6 +535,55 @@ ret void } +declare double @fmod(double, double) #0 +declare float @fmodf(float, float) #0 + +define void @fmod_f64(double* nocapture %varray) { + ; CHECK-LABEL: @fmod_f64( + ; SVE: [[TMP5:%.*]] = call @_ZGVsMxvv_fmod( [[TMP4:%.*]], [[TMP4:%.*]], {{.*}}) + ; CHECK: ret void + ; + entry: + br label %for.body + + for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %tmp = trunc i64 %iv to i32 + %conv = sitofp i32 %tmp to double + %call = tail call double @fmod(double %conv, double %conv) + %arrayidx = getelementptr inbounds double, double* %varray, i64 %iv + store double %call, double* %arrayidx, align 8 + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond = icmp eq i64 %iv.next, 1000 + br i1 %exitcond, label %for.end, label %for.body + + for.end: + ret void +} + +define void @fmod_f32(float* nocapture %varray) { + ; CHECK-LABEL: @fmod_f32( + ; SVE: [[TMP5:%.*]] = call @_ZGVsMxvv_fmodf( [[TMP4:%.*]], [[TMP4:%.*]], {{.*}}) + ; CHECK: ret void + ; + entry: + br label %for.body + + for.body: + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %tmp = trunc i64 %iv to i32 + %conv = sitofp i32 %tmp to float + %call = tail call float @fmodf(float %conv, float %conv) + %arrayidx = getelementptr inbounds float, float* %varray, i64 %iv + store float %call, float* %arrayidx, align 4 + %iv.next = add nuw nsw i64 %iv, 1 + %exitcond = icmp eq i64 %iv.next, 1000 + br i1 %exitcond, label %for.end, label %for.body + + for.end: + ret void +} + declare double @lgamma(double) #0 declare float @lgammaf(float) #0 declare double @llvm.lgamma.f64(double) #0