diff --git a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp --- a/llvm/lib/CodeGen/ReplaceWithVeclib.cpp +++ b/llvm/lib/CodeGen/ReplaceWithVeclib.cpp @@ -38,23 +38,44 @@ STATISTIC(NumFuncUsedAdded, "Number of functions added to `llvm.compiler.used`"); -static bool replaceWithTLIFunction(CallInst &CI, const StringRef TLIName) { - Module *M = CI.getModule(); - - Function *OldFunc = CI.getCalledFunction(); +static bool replaceWithTLIFunction(Instruction &I, const StringRef TLIName, + bool Masked = false) { + Module *M = I.getModule(); + IRBuilder<> IRBuilder(&I); + CallInst *CI = dyn_cast(&I); // Check if the vector library function is already declared in this module, // otherwise insert it. Function *TLIFunc = M->getFunction(TLIName); + std::string OldName; + ElementCount NumElements; if (!TLIFunc) { - TLIFunc = Function::Create(OldFunc->getFunctionType(), - Function::ExternalLinkage, TLIName, *M); - TLIFunc->copyAttributesFrom(OldFunc); - + if (!CI) { + // FRem handling. + assert(I.getOpcode() == Instruction::FRem && + "Must be a FRem instruction."); + OldName = I.getOpcodeName(); + Type *RetTy = I.getType(); + NumElements = (dyn_cast(RetTy))->getElementCount(); + SmallVector Tys = {RetTy, RetTy}; + if (Masked) + Tys.push_back(ToVectorTy(IRBuilder.getInt1Ty(), NumElements)); + TLIFunc = Function::Create(FunctionType::get(RetTy, Tys, false), + Function::ExternalLinkage, TLIName, *M); + } else { + // Intrinsics handling. + Function *OldFunc = CI->getCalledFunction(); + FunctionType *OldFuncTy = OldFunc->getFunctionType(); + OldName = OldFunc->getName(); + TLIFunc = + Function::Create(OldFuncTy, Function::ExternalLinkage, TLIName, *M); + TLIFunc->copyAttributesFrom(OldFunc); + assert(OldFuncTy == TLIFunc->getFunctionType() && + "Expecting function types to be identical"); + } LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `" << TLIName << "` of type `" << *(TLIFunc->getType()) << "` to module.\n"); - ++NumTLIFuncDeclAdded; // Add the freshly created function to llvm.compiler.used, @@ -65,30 +86,72 @@ << "` to `@llvm.compiler.used`.\n"); ++NumFuncUsedAdded; } - - // Replace the call to the vector intrinsic with a call + // Replace the call to the frem instruction/vector intrinsic with a call // to the corresponding function from the vector library. - IRBuilder<> IRBuilder(&CI); - SmallVector Args(CI.args()); - // Preserve the operand bundles. - SmallVector OpBundles; - CI.getOperandBundlesAsDefs(OpBundles); - CallInst *Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); - assert(OldFunc->getFunctionType() == TLIFunc->getFunctionType() && - "Expecting function types to be identical"); - CI.replaceAllUsesWith(Replacement); + CallInst *Replacement = nullptr; + if (!CI) { + // FRem handling. + SmallVector Args(I.operand_values()); + if (Masked) + Args.push_back(IRBuilder.getAllOnesMask(NumElements)); + Replacement = IRBuilder.CreateCall(TLIFunc, Args); + } else { + // Intrinsics handling. + SmallVector Args(CI->args()); + // Preserve the operand bundles. + SmallVector OpBundles; + CI->getOperandBundlesAsDefs(OpBundles); + Replacement = IRBuilder.CreateCall(TLIFunc, Args, OpBundles); + } + I.replaceAllUsesWith(Replacement); if (isa(Replacement)) { // Preserve fast math flags for FP math. - Replacement->copyFastMathFlags(&CI); + Replacement->copyFastMathFlags(&I); } - - LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" - << OldFunc->getName() << "` with call to `" << TLIName - << "`.\n"); + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << OldName + << "` with call to `" << TLIName << "`.\n"); ++NumCallsReplaced; return true; } +static bool replaceFremWithCallToVeclib(const TargetLibraryInfo &TLI, + Instruction &I) { + auto *VectorArgTy = dyn_cast(I.getType()); + // We have TLI mappings for FRem on scalable vectors only. + if (!VectorArgTy) + return false; + ElementCount NumElements = VectorArgTy->getElementCount(); + auto *ElementType = VectorArgTy->getElementType(); + StringRef ScalarName; + if (ElementType->isFloatTy()) + ScalarName = TLI.getName(LibFunc_fmodf); + else if (ElementType->isDoubleTy()) + ScalarName = TLI.getName(LibFunc_fmod); + else + return false; + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Looking up TLI mapping for `" + << ScalarName << "` and vector width " << NumElements + << ".\n"); + std::string TLIName = + std::string(TLI.getVectorizedFunction(ScalarName, NumElements)); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found unmasked TLI function `" + << TLIName << "`.\n"); + return replaceWithTLIFunction(I, TLIName); + } + TLIName = + std::string(TLI.getVectorizedFunction(ScalarName, NumElements, true)); + if (!TLIName.empty()) { + LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found masked TLI function `" + << TLIName << "`.\n"); + return replaceWithTLIFunction(I, TLIName, true); + } + LLVM_DEBUG(dbgs() << DEBUG_TYPE + << ": Did not find suitable vectorized version of `" + << ScalarName << "`.\n"); + return false; +} + static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI, CallInst &CI) { if (!CI.getCalledFunction()) { @@ -100,7 +163,6 @@ // Replacement is only performed for intrinsic functions return false; } - // Convert vector arguments to scalar type and check that // all vector operands have identical vector width. ElementCount VF = ElementCount::getFixed(0); @@ -175,19 +237,26 @@ static bool runImpl(const TargetLibraryInfo &TLI, Function &F) { bool Changed = false; - SmallVector ReplacedCalls; + SmallVector ReplacedCalls; for (auto &I : instructions(F)) { if (auto *CI = dyn_cast(&I)) { if (replaceWithCallToVeclib(TLI, *CI)) { - ReplacedCalls.push_back(CI); + ReplacedCalls.push_back(&I); + Changed = true; + } + } else if (I.getOpcode() == Instruction::FRem) { + // If there is a suitable TLI mapping for FRem instruction, + // replace the instruction. + if (replaceFremWithCallToVeclib(TLI, I)) { + ReplacedCalls.push_back(&I); Changed = true; } } } - // Erase the calls to the intrinsics that have been replaced - // with calls to the vector library. - for (auto *CI : ReplacedCalls) { - CI->eraseFromParent(); + // Erase the calls to the intrinsics and the frem instructions that have been + // replaced with calls to the vector library. + for (auto *I : ReplacedCalls) { + I->eraseFromParent(); } return Changed; } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-armpl.ll @@ -377,4 +377,26 @@ ret %1 } +; NOTE: TLI mappings for FREM instruction. + +define @frem_vscale_f64( %in1, %in2) #0 { +; CHECK-LABEL: define @frem_vscale_f64 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[TMP1:%.*]] = call fast @armpl_svfmod_f64_x( [[IN1]], [[IN2]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @frem_vscale_f32( %in1, %in2) #0 { +; CHECK-LABEL: define @frem_vscale_f32 +; CHECK-SAME: ( [[IN1:%.*]], [[IN2:%.*]]) #[[ATTR1]] { +; CHECK-NEXT: [[TMP1:%.*]] = call fast @armpl_svfmod_f32_x( [[IN1]], [[IN2]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + attributes #0 = { "target-features"="+sve" } diff --git a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll --- a/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll +++ b/llvm/test/CodeGen/AArch64/replace-intrinsics-with-veclib-sleef-scalable.ll @@ -365,6 +365,26 @@ ret %1 } +; NOTE: TLI mapping for FREM instruction. + +define @frem_vscale_f64( %in1, %in2) { +; CHECK-LABEL: @frem_vscale_f64( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmod( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + +define @frem_vscale_f32( %in1, %in2) { +; CHECK-LABEL: @frem_vscale_f32( +; CHECK-NEXT: [[TMP1:%.*]] = call fast @_ZGVsMxvv_fmodf( [[IN1:%.*]], [[IN2:%.*]], shufflevector ( insertelement ( poison, i1 true, i64 0), poison, zeroinitializer)) +; CHECK-NEXT: ret [[TMP1]] +; + %out = frem fast %in1, %in2 + ret %out +} + declare @llvm.ceil.nxv2f64() declare @llvm.ceil.nxv4f32() declare @llvm.copysign.nxv2f64(, )