diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp --- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -38,23 +38,44 @@ STATISTIC(NumFuncUsedAdded, "Number of functions added to `llvm.compiler.used`"); -static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) { - Module *M = CI.getModule(); - - Function *OldFunc = CI.getCalledFunction(); +static bool replaceWithTLIFunction(Instruction &I, const StringRef TLIName, + bool Masked = false) { + Module *M = I.getModule(); + CallInst *CI = dyn_cast(&I); // Check if the vector library function is already declared in this module, // otherwise insert it. Function *TLIFunc = M->getFunction(TLIName); + std::string OldName; + ElementCount NumElements; if (!TLIFunc) { - TLIFunc = Function::Create(OldFunc->getFunctionType(), - Function::ExternalLinkage, TLIName, *M); - TLIFunc->copyAttributesFrom(OldFunc); - + if (CI) { + // Intrinsics handling. + Function *OldFunc = CI->getCalledFunction(); + FunctionType *OldFuncTy = OldFunc->getFunctionType(); + OldName = OldFunc->getName(); + TLIFunc = + Function::Create(OldFuncTy, Function::ExternalLinkage, TLIName, *M); + TLIFunc->copyAttributesFrom(OldFunc); + assert(OldFuncTy == TLIFunc->getFunctionType() && + "Expecting function types to be identical"); + } else { + // FRem handling. + assert(I.getOpcode() == Instruction::FRem && + "Must be a FRem instruction."); + OldName = I.getOpcodeName(); + Type *RetTy = I.getType(); + NumElements = (dyn_cast(RetTy))->getElementCount(); + SmallVector Tys = {RetTy, RetTy}; + if (Masked) + Tys.push_back( + VectorType::get(Type::getInt1Ty(M->getContext()), NumElements)); + TLIFunc = Function::Create(FunctionType::get(RetTy, Tys, false), + Function::ExternalLinkage, TLIName, *M); + } LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `" << TLIName << "` of type `" << *(TLIFunc->getType()) << "` to module.\n"); - ++NumTLIFuncDeclAdded; // Add the freshly created function to llvm.compiler.used, @@ -65,30 +86,71 @@ << "` to `@llvm.compiler.used`.\n"); ++NumFuncUsedAdded; } - - // Replace the call to the vector intrinsic with a call + // Replace the call to the FRem instruction/vector intrinsic with a call // to the corresponding function from the vector library. - IRBuilder<> IRBuilder(&CI); - SmallVector Args(CI.args()); - // Preserve the operand bundles. - SmallVector OpBundles; - CI.getOperandBundlesAsDefs(OpBundles); - CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); - assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() && - "Expecting function types to be identical"); - CI.replaceAllUsesWith(Replacement); + IRBuilder<> IRBuilder(&I); + CallInst *Replacement = nullptr; + if (CI) { + // Intrinsics handling. + SmallVector Args(CI->args()); + // Preserve the operand bundles. + SmallVector OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); + } else { + // FRem handling. + SmallVector Args(I.operand_values()); + if (Masked) + Args.push_back(IRBuilder.getAllOnesMask(NumElements)); + Replacement = IRBuilder.CreateCall(TLIFunc, Args); + } + I.replaceAllUsesWith(Replacement); if (isa(Replacement)) { // Preserve fast math flags for FP math. - Replacement->copyFastMathFlags(&CI); + Replacement->copyFastMathFlags(&I); } - - LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" - << OldFunc->getName() << "` with call to `" << TLIName - << "`.\n"); + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << OldName + << "` with call to `" << TLIName << "`.\n"); ++NumCallsReplaced; return true; } +static bool replaceFremWithCallToVeclib(const TargetLibraryInfo &TLI, + Instruction &I) { + auto *VectorArgTy = dyn_cast(I.getType()); + // We have TLI mappings for FRem on scalable vectors only. + if (!VectorArgTy) + return false; + ElementCount NumElements = VectorArgTy->getElementCount(); + auto *ElementType = VectorArgTy->getElementType(); + StringRef ScalarName; + if (ElementType->isFloatTy()) + ScalarName = TLI.getName(LibFunc_fmodf); + else if (ElementType->isDoubleTy()) + ScalarName = TLI.getName(LibFunc_fmod); + else + return false; + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `" + << ScalarName << "` and vector width " << NumElements + << ".\n"); + StringRef TLIName = TLI.getVectorizedFunction(ScalarName, NumElements); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found unmasked TLI function `" + << TLIName << "`.\n"); + return replaceWithTLIFunction(I, TLIName); + } + TLIName = TLI.getVectorizedFunction(ScalarName, NumElements, true); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found masked TLI function `" + << TLIName << "`.\n"); + return replaceWithTLIFunction(I, TLIName, /*Masked*/ true); + } + LLVM_DEBUG(dbgs() << DEBUG_TYPE + << ": Did not find suitable vectorized version of `" + << ScalarName << "`.\n"); + return false; +} + static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, CallInst &CI) { if (!CI.getCalledFunction()) { @@ -100,7 +162,6 @@ // Replacement is only performed for intrinsic functions return false; } - // Convert vector arguments to scalar type and check that // all vector operands have identical vector width. ElementCount VF = ElementCount::getFixed(0); @@ -175,19 +236,26 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { bool Changed = false; - SmallVector ReplacedCalls; + SmallVector ReplacedCalls; for (auto &I : instructions(F)) { if (auto *CI = dyn_cast(&I)) { if (replaceWithCallToVeclib(TLI, *CI)) { - ReplacedCalls.push_back(CI); + ReplacedCalls.push_back(&I); + Changed = true; + } + } else if (I.getOpcode() == Instruction::FRem) { + // If there is a suitable TLI mapping for FRem instruction, + // replace the instruction. + if (replaceFremWithCallToVeclib(TLI, I)) { + ReplacedCalls.push_back(&I); Changed = true; } } } - // Erase the calls to the intrinsics that have been replaced - // with calls to the vector library. - for (auto *CI : ReplacedCalls) { - CI->eraseFromParent(); + // Erase the calls to the intrinsics and the FRem instructions that have been + // replaced with calls to the vector library. + for (auto *I : ReplacedCalls) { + I->eraseFromParent(); } return Changed; } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll @@ -15,7 +15,7 @@ declare @llvm.cos.nxv4f32() ;. -; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [14 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32], section "llvm.metadata" +; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [16 x ptr] [ptr @armpl_vcosq_f64, ptr @armpl_vcosq_f32, ptr @armpl_vsinq_f64, ptr @armpl_vsinq_f32, ptr @armpl_vexpq_f64, ptr @armpl_vexpq_f32, ptr @armpl_vexp2q_f64, ptr @armpl_vexp2q_f32, ptr @armpl_vlogq_f64, ptr @armpl_vlogq_f32, ptr @armpl_vlog2q_f64, ptr @armpl_vlog2q_f32, ptr @armpl_vlog10q_f64, ptr @armpl_vlog10q_f32, ptr @armpl_svfmod_f64_x, ptr @armpl_svfmod_f32_x], section "llvm.metadata" ;. define <2 x double> @llvm_cos_f64(<2 x double> %in) { ; CHECK-LABEL: define <2 x double> @llvm_cos_f64 @@ -380,6 +380,28 @@ ret %1 } +; NOTE: TLI mappings for FREM instruction. + +define @frem_vscale_f64( %in1, %in2) #0 { +; CHECK-LABEL: define @frem_vscale_f64 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[TMP1:%.*]] = call fast @armpl_svfmod_f64_x( [[IN1]], [[IN2]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @frem_vscale_f32( %in1, %in2) #0 { +; CHECK-LABEL: define @frem_vscale_f32 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[TMP1:%.*]] = call fast @armpl_svfmod_f32_x( [[IN1]], [[IN2]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + attributes #0 = { "target-features"="+sve" } ;. ; CHECK: attributes #[[ATTR0:[0-9]+]] = { nocallback nofree nosync nounwind speculatable willreturn memory(none) } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll @@ -5,6 +5,9 @@ ; NOTE: The existing TLI mappings are not used since the -replace-with-veclib pass is broken for scalable vectors. +;. +; CHECK: @[[LLVM_COMPILER_USED:[a-zA-Z0-9_$"\\.-]+]] = appending global [2 x ptr] [ptr @_ZGVsMxvv_fmod, ptr @_ZGVsMxvv_fmodf], section "llvm.metadata" +;. define @llvm_ceil_vscale_f64( %in) { ; CHECK-LABEL: @llvm_ceil_vscale_f64( ; CHECK-NEXT: [[TMP1:%.*]] = call fast @llvm.ceil.nxv2f64( [[IN:%.*]]) @@ -365,6 +368,26 @@ ret %1 } +; NOTE: TLI mapping for FREM instruction. + +define @frem_vscale_f64( %in1, %in2) { +; CHECK-LABEL: @frem_vscale_f64( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmod( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @frem_vscale_f32( %in1, %in2) { +; CHECK-LABEL: @frem_vscale_f32( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmodf( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + declare @llvm.ceil.nxv2f64() declare @llvm.ceil.nxv4f32() declare @llvm.copysign.nxv2f64(, )