diff --git a/llvm/include/llvm/Analysis/VectorUtils.h b/llvm/include/llvm/Analysis/VectorUtils.h --- a/llvm/include/llvm/Analysis/VectorUtils.h +++ b/llvm/include/llvm/Analysis/VectorUtils.h @@ -179,8 +179,9 @@ /// name. At the moment, this parameter is needed only to retrieve the /// Vectorization Factor of scalable vector functions from their /// respective IR declarations. -std::optional tryDemangleForVFABI(StringRef MangledName, - const Module &M); +std::optional +tryDemangleForVFABI(StringRef MangledName, const Module &M, + std::optional EC = std::nullopt); /// This routine mangles the given VectorName according to the LangRef /// specification for vector-function-abi-variant attribute and is specific to diff --git a/llvm/lib/Analysis/VFABIDemangling.cpp b/llvm/lib/Analysis/VFABIDemangling.cpp --- a/llvm/lib/Analysis/VFABIDemangling.cpp +++ b/llvm/lib/Analysis/VFABIDemangling.cpp @@ -314,8 +314,9 @@ // Format of the ABI name: // _ZGV_[()] -std::optional VFABI::tryDemangleForVFABI(StringRef MangledName, - const Module &M) { +std::optional +VFABI::tryDemangleForVFABI(StringRef MangledName, const Module &M, + std::optional EC) { const StringRef OriginalName = MangledName; // Assume there is no custom name , and therefore the // vector name consists of @@ -434,21 +435,24 @@ // need to make sure that the VF field of the VFShape class is never // set to 0. if (IsScalable) { - const Function *F = M.getFunction(VectorName); - // The declaration of the function must be present in the module - // to be able to retrieve its signature. - if (!F) - return std::nullopt; - const ElementCount EC = getECFromSignature(F->getFunctionType()); - VF = EC.getKnownMinValue(); + if (EC) { + VF = EC->getKnownMinValue(); + } else { + const Function *F = M.getFunction(VectorName); + // The declaration of the function must be present in the module + // to be able to retrieve its signature. + if (!F) + return std::nullopt; + const ElementCount EC = getECFromSignature(F->getFunctionType()); + VF = EC.getKnownMinValue(); + } } - // 1. We don't accept a zero lanes vectorization factor. // 2. We don't accept the demangling if the vector function is not // present in the module. if (VF == 0) return std::nullopt; - if (!M.getFunction(VectorName)) + if (!EC && !M.getFunction(VectorName)) return std::nullopt; const VFShape Shape({ElementCount::get(VF, IsScalable), Parameters}); diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp --- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -38,23 +38,54 @@ STATISTIC(NumFuncUsedAdded, "Number of functions added to `llvm.compiler.used`"); -static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) { - Module *M = CI.getModule(); - - Function *OldFunc = CI.getCalledFunction(); +static bool replaceWithTLIFunction(Instruction &I, const StringRef TLIName, + bool Masked = false) { + Module *M = I.getModule(); + CallInst *CI = dyn_cast(&I); // Check if the vector library function is already declared in this module, // otherwise insert it. Function *TLIFunc = M->getFunction(TLIName); + StringRef OldName = + CI ? CI->getCalledFunction()->getName() : I.getOpcodeName(); + ElementCount NumElements = + CI ? ElementCount::getFixed(0) + : (dyn_cast(I.getType()))->getElementCount(); if (!TLIFunc) { - TLIFunc = Function::Create(OldFunc->getFunctionType(), - Function::ExternalLinkage, TLIName, *M); - TLIFunc->copyAttributesFrom(OldFunc); - + if (CI) { + // Intrinsics handling. + Function *OldFunc = CI->getCalledFunction(); + FunctionType *OldFuncTy = OldFunc->getFunctionType(); + TLIFunc = + Function::Create(OldFuncTy, Function::ExternalLinkage, TLIName, *M); + TLIFunc->copyAttributesFrom(OldFunc); + assert(OldFuncTy == TLIFunc->getFunctionType() && + "Expecting function types to be identical"); + } else { + // FRem handling. + assert(I.getOpcode() == Instruction::FRem && + "Must be a FRem instruction."); + Type *RetTy = I.getType(); + SmallVector Tys = {RetTy, RetTy}; + if (Masked) { + // Get the mask position. + std::optional Info = + VFABI::tryDemangleForVFABI(TLIName, *M, NumElements); + if (!Info) + return false; + unsigned MaskPos = Info->getParamIndexForOptionalMask().value(); + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Mask position `" << MaskPos + << "`.\n"); + Tys.insert( + Tys.begin() + MaskPos, + VectorType::get(Type::getInt1Ty(M->getContext()), NumElements)); + } + TLIFunc = Function::Create(FunctionType::get(RetTy, Tys, false), + Function::ExternalLinkage, TLIName, *M); + } LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `" << TLIName << "` of type `" << *(TLIFunc->getType()) << "` to module.\n"); - ++NumTLIFuncDeclAdded; // Add the freshly created function to llvm.compiler.used, @@ -65,30 +96,68 @@ << "` to `@llvm.compiler.used`.\n"); ++NumFuncUsedAdded; } - - // Replace the call to the vector intrinsic with a call + // Replace the call to the FRem instruction/vector intrinsic with a call // to the corresponding function from the vector library. - IRBuilder<> IRBuilder(&CI); - SmallVector Args(CI.args()); - // Preserve the operand bundles. - SmallVector OpBundles; - CI.getOperandBundlesAsDefs(OpBundles); - CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); - assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() && - "Expecting function types to be identical"); - CI.replaceAllUsesWith(Replacement); + IRBuilder<> IRBuilder(&I); + CallInst *Replacement = nullptr; + if (CI) { + // Intrinsics handling. + SmallVector Args(CI->args()); + // Preserve the operand bundles. + SmallVector OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); + } else { + // FRem handling. + SmallVector Args(I.operand_values()); + if (Masked) + Args.push_back(IRBuilder.getAllOnesMask(NumElements)); + Replacement = IRBuilder.CreateCall(TLIFunc, Args); + } + I.replaceAllUsesWith(Replacement); if (isa(Replacement)) { // Preserve fast math flags for FP math. - Replacement->copyFastMathFlags(&CI); + Replacement->copyFastMathFlags(&I); } - - LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" - << OldFunc->getName() << "` with call to `" << TLIName - << "`.\n"); + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << OldName + << "` with call to `" << TLIName << "`.\n"); ++NumCallsReplaced; return true; } +static bool replaceFremWithCallToVeclib(const TargetLibraryInfo &TLI, + Instruction &I) { + auto *VectorArgTy = dyn_cast(I.getType()); + // We have TLI mappings for FRem on scalable vectors only. + if (!VectorArgTy) + return false; + ElementCount NumElements = VectorArgTy->getElementCount(); + auto *ElementType = VectorArgTy->getElementType(); + StringRef ScalarName; + if (ElementType->isFloatTy()) + ScalarName = TLI.getName(LibFunc_fmodf); + else if (ElementType->isDoubleTy()) + ScalarName = TLI.getName(LibFunc_fmod); + else + return false; + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `" + << ScalarName << "` and vector width " << NumElements + << ".\n"); + StringRef TLIName = TLI.getVectorizedFunction(ScalarName, NumElements); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found unmasked TLI function `" + << TLIName << "`.\n"); + return replaceWithTLIFunction(I, TLIName); + } + TLIName = TLI.getVectorizedFunction(ScalarName, NumElements, true); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found masked TLI function `" + << TLIName << "`.\n"); + return replaceWithTLIFunction(I, TLIName, /*Masked*/ true); + } + return false; +} + static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, CallInst &CI) { if (!CI.getCalledFunction()) { @@ -175,19 +244,26 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { bool Changed = false; - SmallVector ReplacedCalls; + SmallVector ReplacedCalls; for (auto &I : instructions(F)) { if (auto *CI = dyn_cast(&I)) { if (replaceWithCallToVeclib(TLI, *CI)) { - ReplacedCalls.push_back(CI); + ReplacedCalls.push_back(&I); + Changed = true; + } + } else if (I.getOpcode() == Instruction::FRem) { + // If there is a suitable TLI mapping for FRem instruction, + // replace the instruction. + if (replaceFremWithCallToVeclib(TLI, I)) { + ReplacedCalls.push_back(&I); Changed = true; } } } - // Erase the calls to the intrinsics that have been replaced - // with calls to the vector library. - for (auto *CI : ReplacedCalls) { - CI->eraseFromParent(); + // Erase the calls to the intrinsics and the FRem instructions that have been + // replaced with calls to the vector library. + for (auto *I : ReplacedCalls) { + I->eraseFromParent(); } return Changed; } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll @@ -192,6 +192,28 @@ ret %1 } +; NOTE: TLI mappings for FREM instruction. + +define @llvm_frem_vscale_f64( %in1, %in2) #0 { +; CHECK-LABEL: define @llvm_frem_vscale_f64 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[OUT:%.*]] = frem fast [[IN1]], [[IN2]] +; CHECK-NEXT: ret [[OUT]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @llvm_frem_vscale_f32( %in1, %in2) #0 { +; CHECK-LABEL: define @llvm_frem_vscale_f32 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[OUT:%.*]] = frem fast [[IN1]], [[IN2]] +; CHECK-NEXT: ret [[OUT]] +; + %out = frem fast %in1, %in2 + ret %out +} + declare <2 x double> @llvm.log.v2f64(<2 x double>) declare <4 x float> @llvm.log.v4f32(<4 x float>) diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll @@ -5,6 +5,9 @@ ; NOTE: The existing TLI mappings are not used since the -replace-with-veclib pass is broken for scalable vectors. +;. +; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [2 x ptr] [ptr @_ZGVsMxvv_fmod, ptr @_ZGVsMxvv_fmodf], section "llvm.metadata" +;. define @llvm_ceil_vscale_f64( %in) { ; CHECK-LABEL: @llvm_ceil_vscale_f64( ; CHECK-NEXT: [[TMP1:%.*]] = call fast @llvm.ceil.nxv2f64( [[IN:%.*]]) @@ -149,6 +152,26 @@ ret %1 } +; NOTE: TLI mapping for FREM instruction. + +define @llvm_frem_vscale_f64( %in1, %in2) { +; CHECK-LABEL: @llvm_frem_vscale_f64( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmod( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @llvm_frem_vscale_f32( %in1, %in2) { +; CHECK-LABEL: @llvm_frem_vscale_f32( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmodf( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + define @llvm_log_vscale_f64( %in) { ; CHECK-LABEL: @llvm_log_vscale_f64( ; CHECK-NEXT: [[TMP1:%.*]] = call fast @llvm.log.nxv2f64( [[IN:%.*]])