Index: llvm/lib/Target/AArch64/AArch64.td =================================================================== --- llvm/lib/Target/AArch64/AArch64.td +++ llvm/lib/Target/AArch64/AArch64.td @@ -106,6 +106,10 @@ "disable-latency-sched-heuristic", "DisableLatencySchedHeuristic", "true", "Disable latency scheduling heuristic">; +def FeatureUseRSqrt : SubtargetFeature< + "use-reciprocal-square-root", "UseRSqrt", "true", + "Use the reciprocal square root approximation">; + //===----------------------------------------------------------------------===// // Architectures. // @@ -227,6 +231,7 @@ FeatureNEON, FeaturePerfMon, FeaturePostRAScheduler, + FeatureUseRSqrt, FeatureZCZeroing ]>; Index: llvm/lib/Target/AArch64/AArch64ISelLowering.h =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.h +++ llvm/lib/Target/AArch64/AArch64ISelLowering.h @@ -187,6 +187,10 @@ SMULL, UMULL, + // Reciprocal estimates. + FRECPE, + FRSQRTE, + // NEON Load/Store with post-increment base updates LD2post = ISD::FIRST_TARGET_MEMORY_OPCODE, LD3post, @@ -532,6 +536,11 @@ SDValue BuildSDIVPow2(SDNode *N, const APInt &Divisor, SelectionDAG &DAG, std::vector *Created) const override; + SDValue getRsqrtEstimate(SDValue Operand, DAGCombinerInfo &DCI, + unsigned &RefinementSteps, + bool &UseOneConstNR) const override; + SDValue getRecipEstimate(SDValue Operand, DAGCombinerInfo &DCI, + unsigned &RefinementSteps) const override; unsigned combineRepeatedFPDivisors() const override; ConstraintType getConstraintType(StringRef Constraint) const override; Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -33,6 +33,7 @@ #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetOptions.h" +#include "llvm/Target/TargetRecip.h" using namespace llvm; #define DEBUG_TYPE "aarch64-lower" @@ -634,6 +635,28 @@ } } + // For the reciprocal estimates, convergence is quadratic, so the number of + // digits is doubled after each iteration. In ARMv8, the accuracy of the + // initial estimate is 2^-8. Thus the number of extra steps to refine the + // result for float (23 mantissa bits) is 2 and for double (52 mantissa bits) + // is 3. + const unsigned ExtraStepsF = 2, + ExtraStepsD = 3; + const bool UseRsqrt = STI.useRSqrt(); + + ReciprocalEstimates.set("sqrtf", UseRsqrt, ExtraStepsF); + ReciprocalEstimates.set("sqrtd", UseRsqrt, ExtraStepsD); + ReciprocalEstimates.set("vec-sqrtf", UseRsqrt, ExtraStepsF); + ReciprocalEstimates.set("vec-sqrtd", UseRsqrt, ExtraStepsD); + + // Using the reciprocal estimates for division breaks too many programs in the + // wild, so it's unlikely that it should be a feature. It's better left to + // user to weigh this choice. + ReciprocalEstimates.set("divf", false, ExtraStepsF); + ReciprocalEstimates.set("divd", false, ExtraStepsD); + ReciprocalEstimates.set("vec-divf", false, ExtraStepsF); + ReciprocalEstimates.set("vec-divd", false, ExtraStepsD); + PredictableSelectIsExpensive = Subtarget->predictableSelectIsExpensive(); } @@ -959,6 +982,8 @@ case AArch64ISD::ST4LANEpost: return "AArch64ISD::ST4LANEpost"; case AArch64ISD::SMULL: return "AArch64ISD::SMULL"; case AArch64ISD::UMULL: return "AArch64ISD::UMULL"; + case AArch64ISD::FRSQRTE: return "AArch64ISD::FRSQRTE"; + case AArch64ISD::FRECPE: return "AArch64ISD::FRECPE"; } return nullptr; } @@ -4589,6 +4614,51 @@ // AArch64 Optimization Hooks //===----------------------------------------------------------------------===// +/// Return the appropriate estimate DAG for either the reciprocal +/// or the reciprocal square root. +static SDValue getEstimate(const AArch64Subtarget &ST, + const AArch64TargetLowering::DAGCombinerInfo &DCI, TargetRecip &Recip, + unsigned Opcode, const SDValue &Operand, unsigned &ExtraSteps) { + if (!ST.hasNEON()) + return SDValue(); + + EVT VT = Operand.getValueType(); + if (VT != MVT::f64 && VT != MVT::v1f64 && VT != MVT::v2f64 && + VT != MVT::f32 && VT != MVT::v1f32 && + VT != MVT::v2f32 && VT != MVT::v4f32 && + (!ST.hasFullFP16() || + (VT != MVT::f16 && VT != MVT::v4f16 && VT != MVT::v8f16))) + return SDValue(); + + std::string RecipOp; + RecipOp = Opcode == (AArch64ISD::FRECPE) ? "div": "sqrt"; + RecipOp = ((VT.isVector()) ? "vec-": "") + RecipOp; + RecipOp += (VT.getScalarType() == MVT::f64) ? "d": "f"; + + if (!Recip.isEnabled(RecipOp)) + return SDValue(); + + ExtraSteps = Recip.getRefinementSteps(RecipOp); + return DCI.DAG.getNode(Opcode, SDLoc(Operand), VT, Operand); +} + +SDValue AArch64TargetLowering::getRecipEstimate(SDValue Operand, + DAGCombinerInfo &DCI, unsigned &ExtraSteps) const { + TargetRecip Recip = getTargetRecipForFunc(DCI.DAG.getMachineFunction()); + + return getEstimate(*Subtarget, DCI, Recip, AArch64ISD::FRECPE, Operand, + ExtraSteps); +} + +SDValue AArch64TargetLowering::getRsqrtEstimate(SDValue Operand, + DAGCombinerInfo &DCI, unsigned &ExtraSteps, bool &UseOneConst) const { + TargetRecip Recip = getTargetRecipForFunc(DCI.DAG.getMachineFunction()); + + UseOneConst = true; + return getEstimate(*Subtarget, DCI, Recip, AArch64ISD::FRSQRTE, Operand, + ExtraSteps); +} + //===----------------------------------------------------------------------===// // AArch64 Inline Assembly Support //===----------------------------------------------------------------------===// Index: llvm/lib/Target/AArch64/AArch64InstrInfo.td =================================================================== --- llvm/lib/Target/AArch64/AArch64InstrInfo.td +++ llvm/lib/Target/AArch64/AArch64InstrInfo.td @@ -286,6 +286,9 @@ def AArch64smull : SDNode<"AArch64ISD::SMULL", SDT_AArch64mull>; def AArch64umull : SDNode<"AArch64ISD::UMULL", SDT_AArch64mull>; +def AArch64frecpe : SDNode<"AArch64ISD::FRECPE", SDTFPUnaryOp>; +def AArch64frsqrte : SDNode<"AArch64ISD::FRSQRTE", SDTFPUnaryOp>; + def AArch64saddv : SDNode<"AArch64ISD::SADDV", SDT_AArch64UnaryVec>; def AArch64uaddv : SDNode<"AArch64ISD::UADDV", SDT_AArch64UnaryVec>; def AArch64sminv : SDNode<"AArch64ISD::SMINV", SDT_AArch64UnaryVec>; @@ -3406,6 +3409,19 @@ def : Pat<(v1f64 (int_aarch64_neon_frecpe (v1f64 FPR64:$Rn))), (FRECPEv1i64 FPR64:$Rn)>; +def : Pat<(f32 (AArch64frecpe (f32 FPR32:$Rn))), + (FRECPEv1i32 FPR32:$Rn)>; +def : Pat<(v2f32 (AArch64frecpe (v2f32 V64:$Rn))), + (FRECPEv2f32 V64:$Rn)>; +def : Pat<(v4f32 (AArch64frecpe (v4f32 FPR128:$Rn))), + (FRECPEv4f32 FPR128:$Rn)>; +def : Pat<(f64 (AArch64frecpe (f64 FPR64:$Rn))), + (FRECPEv1i64 FPR64:$Rn)>; +def : Pat<(v1f64 (AArch64frecpe (v1f64 FPR64:$Rn))), + (FRECPEv1i64 FPR64:$Rn)>; +def : Pat<(v2f64 (AArch64frecpe (v2f64 FPR128:$Rn))), + (FRECPEv2f64 FPR128:$Rn)>; + def : Pat<(f32 (int_aarch64_neon_frecpx (f32 FPR32:$Rn))), (FRECPXv1i32 FPR32:$Rn)>; def : Pat<(f64 (int_aarch64_neon_frecpx (f64 FPR64:$Rn))), @@ -3418,6 +3434,19 @@ def : Pat<(v1f64 (int_aarch64_neon_frsqrte (v1f64 FPR64:$Rn))), (FRSQRTEv1i64 FPR64:$Rn)>; +def : Pat<(f32 (AArch64frsqrte (f32 FPR32:$Rn))), + (FRSQRTEv1i32 FPR32:$Rn)>; +def : Pat<(v2f32 (AArch64frsqrte (v2f32 V64:$Rn))), + (FRSQRTEv2f32 V64:$Rn)>; +def : Pat<(v4f32 (AArch64frsqrte (v4f32 FPR128:$Rn))), + (FRSQRTEv4f32 FPR128:$Rn)>; +def : Pat<(f64 (AArch64frsqrte (f64 FPR64:$Rn))), + (FRSQRTEv1i64 FPR64:$Rn)>; +def : Pat<(v1f64 (AArch64frsqrte (v1f64 FPR64:$Rn))), + (FRSQRTEv1i64 FPR64:$Rn)>; +def : Pat<(v2f64 (AArch64frsqrte (v2f64 FPR128:$Rn))), + (FRSQRTEv2f64 FPR128:$Rn)>; + // If an integer is about to be converted to a floating point value, // just load it on the floating point unit. // Here are the patterns for 8 and 16-bits to float. Index: llvm/lib/Target/AArch64/AArch64Subtarget.h =================================================================== --- llvm/lib/Target/AArch64/AArch64Subtarget.h +++ llvm/lib/Target/AArch64/AArch64Subtarget.h @@ -83,6 +83,7 @@ bool HasArithmeticBccFusion = false; bool HasArithmeticCbzFusion = false; bool DisableLatencySchedHeuristic = false; + bool UseRSqrt = false; uint8_t MaxInterleaveFactor = 2; uint8_t VectorInsertExtractBaseCost = 3; uint16_t CacheLineSize = 0; @@ -191,6 +192,7 @@ } bool hasArithmeticBccFusion() const { return HasArithmeticBccFusion; } bool hasArithmeticCbzFusion() const { return HasArithmeticCbzFusion; } + bool useRSqrt() const { return UseRSqrt; } unsigned getMaxInterleaveFactor() const { return MaxInterleaveFactor; } unsigned getVectorInsertExtractBaseCost() const { return VectorInsertExtractBaseCost; Index: llvm/test/CodeGen/AArch64/recp-fastmath.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/recp-fastmath.ll @@ -0,0 +1,148 @@ +; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+neon | FileCheck %s + +define float @frecp0(float %x) #0 { + %div = fdiv fast float 1.0, %x + ret float %div + +; CHECK-LABEL: frecp0: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: fdiv +} + +define float @frecp1(float %x) #1 { + %div = fdiv fast float 1.0, %x + ret float %div + +; CHECK-LABEL: frecp1: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: frecpe +; CHECK-NEXT: fmov +} + +define <2 x float> @f2recp0(<2 x float> %x) #0 { + %div = fdiv fast <2 x float> , %x + ret <2 x float> %div + +; CHECK-LABEL: f2recp0: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: fdiv +} + +define <2 x float> @f2recp1(<2 x float> %x) #1 { + %div = fdiv fast <2 x float> , %x + ret <2 x float> %div + +; CHECK-LABEL: f2recp1: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frecpe +} + +define <4 x float> @f4recp0(<4 x float> %x) #0 { + %div = fdiv fast <4 x float> , %x + ret <4 x float> %div + +; CHECK-LABEL: f4recp0: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: fdiv +} + +define <4 x float> @f4recp1(<4 x float> %x) #1 { + %div = fdiv fast <4 x float> , %x + ret <4 x float> %div + +; CHECK-LABEL: f4recp1: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frecpe +} + +define <8 x float> @f8recp0(<8 x float> %x) #0 { + %div = fdiv fast <8 x float> , %x + ret <8 x float> %div + +; CHECK-LABEL: f8recp0: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: fdiv +; CHECK-NEXT: fdiv +} + +define <8 x float> @f8recp1(<8 x float> %x) #1 { + %div = fdiv fast <8 x float> , %x + ret <8 x float> %div + +; CHECK-LABEL: f8recp1: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frecpe +; CHECK: frecpe +} + +define double @drecp0(double %x) #0 { + %div = fdiv fast double 1.0, %x + ret double %div + +; CHECK-LABEL: drecp0: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: fdiv +} + +define double @drecp1(double %x) #1 { + %div = fdiv fast double 1.0, %x + ret double %div + +; CHECK-LABEL: drecp1: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: frecpe +; CHECK-NEXT: fmov +} + +define <2 x double> @d2recp0(<2 x double> %x) #0 { + %div = fdiv fast <2 x double> , %x + ret <2 x double> %div + +; CHECK-LABEL: d2recp0: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: fdiv +} + +define <2 x double> @d2recp1(<2 x double> %x) #1 { + %div = fdiv fast <2 x double> , %x + ret <2 x double> %div + +; CHECK-LABEL: d2recp1: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frecpe +} + +define <4 x double> @d4recp0(<4 x double> %x) #0 { + %div = fdiv fast <4 x double> , %x + ret <4 x double> %div + +; CHECK-LABEL: d4recp0: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: fdiv +; CHECK-NEXT: fdiv +} + +define <4 x double> @d4recp1(<4 x double> %x) #1 { + %div = fdiv fast <4 x double> , %x + ret <4 x double> %div + +; CHECK-LABEL: d4recp1: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frecpe +; CHECK: frecpe +} + +attributes #0 = { nounwind "unsafe-fp-math"="true" } +attributes #1 = { nounwind "unsafe-fp-math"="true" "reciprocal-estimates"="div,vec-div" } Index: llvm/test/CodeGen/AArch64/sqrt-fastmath.ll =================================================================== --- /dev/null +++ llvm/test/CodeGen/AArch64/sqrt-fastmath.ll @@ -0,0 +1,228 @@ +; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+neon,-use-reverse-square-root | FileCheck %s --check-prefix=FAULT +; RUN: llc < %s -mtriple=aarch64-unknown-linux-gnu -mattr=+neon,+use-reverse-square-root | FileCheck %s + +declare float @llvm.sqrt.f32(float) #0 +declare <2 x float> @llvm.sqrt.v2f32(<2 x float>) #0 +declare <4 x float> @llvm.sqrt.v4f32(<4 x float>) #0 +declare <8 x float> @llvm.sqrt.v8f32(<8 x float>) #0 +declare double @llvm.sqrt.f64(double) #0 +declare <2 x double> @llvm.sqrt.v2f64(<2 x double>) #0 +declare <4 x double> @llvm.sqrt.v4f64(<4 x double>) #0 + +define float @fsqrt(float %a) #0 { + %1 = tail call fast float @llvm.sqrt.f32(float %a) + ret float %1 + +; FAULT-LABEL: fsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: fsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +} + +define <2 x float> @f2sqrt(<2 x float> %a) #0 { + %1 = tail call fast <2 x float> @llvm.sqrt.v2f32(<2 x float> %a) + ret <2 x float> %1 + +; FAULT-LABEL: f2sqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: f2sqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: mov +; CHECK-NEXT: frsqrte +} + +define <4 x float> @f4sqrt(<4 x float> %a) #0 { + %1 = tail call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> %a) + ret <4 x float> %1 + +; FAULT-LABEL: f4sqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: f4sqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: mov +; CHECK-NEXT: frsqrte +} + +define <8 x float> @f8sqrt(<8 x float> %a) #0 { + %1 = tail call fast <8 x float> @llvm.sqrt.v8f32(<8 x float> %a) + ret <8 x float> %1 + +; FAULT-LABEL: f8sqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: f8sqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: mov +; CHECK-NEXT: frsqrte +; CHECK: frsqrte +} + +define double @dsqrt(double %a) #0 { + %1 = tail call fast double @llvm.sqrt.f64(double %a) + ret double %1 + +; FAULT-LABEL: dsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: dsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +} + +define <2 x double> @d2sqrt(<2 x double> %a) #0 { + %1 = tail call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> %a) + ret <2 x double> %1 + +; FAULT-LABEL: d2sqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: d2sqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: mov +; CHECK-NEXT: frsqrte +} + +define <4 x double> @d4sqrt(<4 x double> %a) #0 { + %1 = tail call fast <4 x double> @llvm.sqrt.v4f64(<4 x double> %a) + ret <4 x double> %1 + +; FAULT-LABEL: d4sqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: d4sqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: mov +; CHECK-NEXT: frsqrte +; CHECK: frsqrte +} + +define float @frsqrt(float %a) #0 { + %1 = tail call fast float @llvm.sqrt.f32(float %a) + %2 = fdiv fast float 1.000000e+00, %1 + ret float %2 + +; FAULT-LABEL: frsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: frsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +} + +define <2 x float> @f2rsqrt(<2 x float> %a) #0 { + %1 = tail call fast <2 x float> @llvm.sqrt.v2f32(<2 x float> %a) + %2 = fdiv fast <2 x float> , %1 + ret <2 x float> %2 + +; FAULT-LABEL: f2rsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: f2rsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +} + +define <4 x float> @f4rsqrt(<4 x float> %a) #0 { + %1 = tail call fast <4 x float> @llvm.sqrt.v4f32(<4 x float> %a) + %2 = fdiv fast <4 x float> , %1 + ret <4 x float> %2 + +; FAULT-LABEL: f4rsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: f4rsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +} + +define <8 x float> @f8rsqrt(<8 x float> %a) #0 { + %1 = tail call fast <8 x float> @llvm.sqrt.v8f32(<8 x float> %a) + %2 = fdiv fast <8 x float> , %1 + ret <8 x float> %2 + +; FAULT-LABEL: f8rsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: f8rsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +; CHECK: frsqrte +} + +define double @drsqrt(double %a) #0 { + %1 = tail call fast double @llvm.sqrt.f64(double %a) + %2 = fdiv fast double 1.000000e+00, %1 + ret double %2 + +; FAULT-LABEL: drsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: drsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +} + +define <2 x double> @d2rsqrt(<2 x double> %a) #0 { + %1 = tail call fast <2 x double> @llvm.sqrt.v2f64(<2 x double> %a) + %2 = fdiv fast <2 x double> , %1 + ret <2 x double> %2 + +; FAULT-LABEL: d2rsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: d2rsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +} + +define <4 x double> @d4rsqrt(<4 x double> %a) #0 { + %1 = tail call fast <4 x double> @llvm.sqrt.v4f64(<4 x double> %a) + %2 = fdiv fast <4 x double> , %1 + ret <4 x double> %2 + +; FAULT-LABEL: d4rsqrt: +; FAULT-NEXT: BB#0 +; FAULT-NEXT: fsqrt +; FAULT-NEXT: fsqrt + +; CHECK-LABEL: d4rsqrt: +; CHECK-NEXT: BB#0 +; CHECK-NEXT: fmov +; CHECK-NEXT: frsqrte +; CHECK: frsqrte +} + +attributes #0 = { nounwind "unsafe-fp-math"="true" }