diff --git a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp --- a/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp +++ b/llvm/lib/Transforms/AggressiveInstCombine/AggressiveInstCombine.cpp @@ -31,6 +31,7 @@ #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" +#include "llvm/Transforms/Utils/BuildLibCalls.h" #include "llvm/Transforms/Utils/Local.h" using namespace llvm; @@ -427,27 +428,73 @@ return true; } +/// Try to replace a mathlib call to sqrt with the LLVM intrinsic. This avoids +/// pessimistic codegen that has to account for setting errno and can enable +/// vectorization. +static bool +foldSqrt(Instruction &I, TargetTransformInfo &TTI, TargetLibraryInfo &TLI) { + // Match a call to sqrt mathlib function. + auto *Call = dyn_cast(&I); + if (!Call) + return false; + + Module *M = Call->getModule(); + LibFunc Func; + if (!TLI.getLibFunc(*Call, Func) || !isLibFuncEmittable(M, &TLI, Func)) + return false; + + if (Func != LibFunc_sqrt && Func != LibFunc_sqrtf && Func != LibFunc_sqrtl) + return false; + + // If (1) this is a sqrt libcall, (2) we can assume that NAN is not created, + // and (3) we would not end up lowering to a libcall anyway (which could + // change the value of errno), then: + // (1) the operand arg must not be less than -0.0. + // (2) errno won't be set. + // (3) it is safe to convert this to an intrinsic call. + // TODO: Check if the arg is known non-negative. + Type *Ty = Call->getType(); + if (TTI.haveFastSqrt(Ty) && Call->hasNoNaNs()) { + IRBuilder<> Builder(&I); + IRBuilderBase::FastMathFlagGuard Guard(Builder); + Builder.setFastMathFlags(Call->getFastMathFlags()); + + Function *Sqrt = Intrinsic::getDeclaration(M, Intrinsic::sqrt, Ty); + Value *NewSqrt = Builder.CreateCall(Sqrt, Call->getArgOperand(0), "sqrt"); + I.replaceAllUsesWith(NewSqrt); + + // Explicitly erase the old call because a call with side effects is not + // trivially dead. + I.eraseFromParent(); + return true; + } + + return false; +} + /// This is the entry point for folds that could be implemented in regular /// InstCombine, but they are separated because they are not expected to /// occur frequently and/or have more than a constant-length pattern match. static bool foldUnusualPatterns(Function &F, DominatorTree &DT, - TargetTransformInfo &TTI) { + TargetTransformInfo &TTI, + TargetLibraryInfo &TLI) { bool MadeChange = false; for (BasicBlock &BB : F) { // Ignore unreachable basic blocks. if (!DT.isReachableFromEntry(&BB)) continue; - // Do not delete instructions under here and invalidate the iterator. + // Walk the block backwards for efficiency. We're matching a chain of // use->defs, so we're more likely to succeed by starting from the bottom. // Also, we want to avoid matching partial patterns. // TODO: It would be more efficient if we removed dead instructions // iteratively in this loop rather than waiting until the end. - for (Instruction &I : llvm::reverse(BB)) { + for (Instruction &I : make_early_inc_range(llvm::reverse(BB))) { MadeChange |= foldAnyOrAllBitsSet(I); MadeChange |= foldGuardedFunnelShift(I, DT); MadeChange |= tryToRecognizePopCount(I); MadeChange |= tryToFPToSat(I, TTI); + MadeChange |= foldSqrt(I, TTI, TLI); } } @@ -467,7 +514,7 @@ const DataLayout &DL = F.getParent()->getDataLayout(); TruncInstCombine TIC(AC, TLI, DL, DT); MadeChange |= TIC.run(F); - MadeChange |= foldUnusualPatterns(F, DT, TTI); + MadeChange |= foldUnusualPatterns(F, DT, TTI, TLI); return MadeChange; } diff --git a/llvm/test/Transforms/AggressiveInstCombine/X86/sqrt.ll b/llvm/test/Transforms/AggressiveInstCombine/X86/sqrt.ll new file mode 100644 --- /dev/null +++ b/llvm/test/Transforms/AggressiveInstCombine/X86/sqrt.ll @@ -0,0 +1,53 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py +; RUN: opt < %s -passes=aggressive-instcombine -mtriple x86_64-- -S | FileCheck %s + +declare float @sqrtf(float) +declare double @sqrt(double) +declare fp128 @sqrtl(fp128) + +; "nnan" implies no setting of errno and the target can lower this to an +; instruction, so transform to an intrinsic. + +define float @sqrt_call_nnan_f32(float %x) { +; CHECK-LABEL: @sqrt_call_nnan_f32( +; CHECK-NEXT: [[SQRT1:%.*]] = call nnan float @llvm.sqrt.f32(float [[X:%.*]]) +; CHECK-NEXT: ret float [[SQRT1]] +; + %sqrt = call nnan float @sqrtf(float %x) + ret float %sqrt +} + +; Verify that other FMF are propagated to the intrinsic call. +; We don't care about propagating 'tail' because this is not going to be a lowered as a call. + +define double @sqrt_call_nnan_f64(double %x) { +; CHECK-LABEL: @sqrt_call_nnan_f64( +; CHECK-NEXT: [[SQRT1:%.*]] = call nnan ninf double @llvm.sqrt.f64(double [[X:%.*]]) +; CHECK-NEXT: ret double [[SQRT1]] +; + %sqrt = tail call nnan ninf double @sqrt(double %x) + ret double %sqrt +} + +; We don't change this because it will be lowered to a call that could +; theoretically still change errno and affect other accessors of errno. + +define fp128 @sqrt_call_nnan_f128(fp128 %x) { +; CHECK-LABEL: @sqrt_call_nnan_f128( +; CHECK-NEXT: [[SQRT:%.*]] = call nnan fp128 @sqrtl(fp128 [[X:%.*]]) +; CHECK-NEXT: ret fp128 [[SQRT]] +; + %sqrt = call nnan fp128 @sqrtl(fp128 %x) + ret fp128 %sqrt +} + +; Don't alter a no-builtin libcall. + +define float @sqrt_call_nnan_f32_nobuiltin(float %x) { +; CHECK-LABEL: @sqrt_call_nnan_f32_nobuiltin( +; CHECK-NEXT: [[SQRT:%.*]] = call nnan float @sqrtf(float [[X:%.*]]) #[[ATTR1:[0-9]+]] +; CHECK-NEXT: ret float [[SQRT]] +; + %sqrt = call nnan float @sqrtf(float %x) nobuiltin + ret float %sqrt +}