diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -13994,13 +13994,23 @@ return SDValue(); SelectionDAG::FlagInserter FlagsInserter(DAG, N); - // Try to convert x ** (1/3) into cube root. + // Try to convert + // a. x ** (1/3) into cube root. + // b. x ** (2/3) into cube root of (x*x). // TODO: Handle the various flavors of long double. // TODO: Since we're approximating, we don't need an exact 1/3 exponent. // Some range near 1/3 should be fine. EVT VT = N->getValueType(0); - if ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) || - (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))) { + + bool Exponent1by3 = + ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(1.0f/3.0f)) || + (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(1.0/3.0))); + + bool Exponent2by3 = + ((VT == MVT::f32 && ExponentC->getValueAPF().isExactlyValue(2.0f/3.0f)) || + (VT == MVT::f64 && ExponentC->getValueAPF().isExactlyValue(2.0/3.0))); + + if (Exponent1by3 || Exponent2by3) { // pow(-0.0, 1/3) = +0.0; cbrt(-0.0) = -0.0. // pow(-inf, 1/3) = +inf; cbrt(-inf) = -inf. // pow(-val, 1/3) = nan; cbrt(-val) = -num. @@ -14019,16 +14029,24 @@ DAG.getTargetLoweringInfo().isOperationExpand(ISD::FCBRT, VT))) return SDValue(); - return DAG.getNode(ISD::FCBRT, SDLoc(N), VT, N->getOperand(0)); + SDLoc DL(N); + SDValue cbrt = DAG.getNode(ISD::FCBRT, DL, VT, N->getOperand(0)); + if (Exponent1by3) + return cbrt; + return DAG.getNode(ISD::FMUL, DL, VT, cbrt, cbrt); } - // Try to convert x ** (1/4) and x ** (3/4) into square roots. + // Try to convert the following into square roots. + // a. x ** (1/4) + // b. x ** (3/4) + // c. x ** (3/2) // x ** (1/2) is canonicalized to sqrt, so we do not bother with that case. // TODO: This could be extended (using a target hook) to handle smaller // power-of-2 fractional exponents. bool ExponentIs025 = ExponentC->getValueAPF().isExactlyValue(0.25); bool ExponentIs075 = ExponentC->getValueAPF().isExactlyValue(0.75); - if (ExponentIs025 || ExponentIs075) { + bool ExponentIs150 = ExponentC->getValueAPF().isExactlyValue(1.50); + if (ExponentIs025 || ExponentIs075 || ExponentIs150) { // pow(-0.0, 0.25) = +0.0; sqrt(sqrt(-0.0)) = -0.0. // pow(-inf, 0.25) = +inf; sqrt(sqrt(-inf)) = NaN. // pow(-0.0, 0.75) = +0.0; sqrt(-0.0) * sqrt(sqrt(-0.0)) = +0.0. @@ -14055,6 +14073,11 @@ // pow(X, 0.25) --> sqrt(sqrt(X)) SDLoc DL(N); SDValue Sqrt = DAG.getNode(ISD::FSQRT, DL, VT, N->getOperand(0)); + + // pow(X, 1.50) --> X * sqrt(X) + if (ExponentIs150) + return DAG.getNode(ISD::FMUL, DL, VT, N->getOperand(0), Sqrt); + SDValue SqrtSqrt = DAG.getNode(ISD::FSQRT, DL, VT, Sqrt); if (ExponentIs025) return SqrtSqrt; diff --git a/llvm/test/CodeGen/X86/pow.150.ll b/llvm/test/CodeGen/X86/pow.150.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/X86/pow.150.ll @@ -0,0 +1,42 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc < %s -mtriple=x86_64-- -debug 2>&1 | FileCheck %s +; REQUIRES: asserts + +declare float @llvm.pow.f32(float, float) +declare <4 x float> @llvm.pow.v4f32(<4 x float>, <4 x float>) +declare double @llvm.pow.f64(double, double) +declare <2 x double> @llvm.pow.v2f64(<2 x double>, <2 x double>) + +define float @pow_f32_three_halves_fmf(float %x) nounwind { +; CHECK: Combining: {{.*}}: f32 = fpow ninf nsz afn [[X:t[0-9]+]], ConstantFP:f32<1.500000e+00> +; CHECK-NEXT: Creating new node: [[SQRT:t[0-9]+]]: f32 = fsqrt ninf nsz afn [[X]] +; CHECK-NEXT: Creating new node: [[R:t[0-9]+]]: f32 = fmul ninf nsz afn [[X]], [[SQRT]] + %r = call nsz ninf afn float @llvm.pow.f32(float %x, float 1.5e-00) + ret float %r +} + +define double @pow_f64_three_halves_fmf(double %x) nounwind { +; CHECK: Combining: {{.*}}: f64 = fpow ninf nsz afn [[X:t[0-9]+]], ConstantFP:f64<1.500000e+00> +; CHECK-NEXT: Creating new node: [[SQRT:t[0-9]+]]: f64 = fsqrt ninf nsz afn [[X]] +; CHECK-NEXT: Creating new node: [[R:t[0-9]+]]: f64 = fmul ninf nsz afn [[X]], [[SQRT]] + %r = call nsz ninf afn double @llvm.pow.f64(double %x, double 1.5e-00) + ret double %r +} + +define <4 x float> @pow_v4f32_three_halves_fmf(<4 x float> %x) nounwind { +; CHECK: Combining: {{.*}}: v4f32 = fpow nnan ninf nsz arcp contract afn reassoc [[X:t[0-9]+]], {{.*}} +; CHECK-NEXT: Creating new node: [[SQRT:t[0-9]+]]: v4f32 = fsqrt nnan ninf nsz arcp contract afn reassoc [[X]] +; CHECK-NEXT: Creating new node: [[R:t[0-9]+]]: v4f32 = fmul nnan ninf nsz arcp contract afn reassoc [[X]], [[SQRT]] +; CHECK-NEXT: ... into: [[R]]: v4f32 = fmul nnan ninf nsz arcp contract afn reassoc [[X]], [[SQRT]] + %r = call fast <4 x float> @llvm.pow.v4f32(<4 x float> %x, <4 x float> ) + ret <4 x float> %r +} + +define <2 x double> @pow_v2f64_three_halves_fmf(<2 x double> %x) nounwind { +; CHECK: Combining: {{.*}}: v2f64 = fpow nnan ninf nsz arcp contract afn reassoc [[X:t[0-9]+]], {{.*}} +; CHECK-NEXT: Creating new node: [[SQRT:t[0-9]+]]: v2f64 = fsqrt nnan ninf nsz arcp contract afn reassoc [[X]] +; CHECK-NEXT: Creating new node: [[R:t[0-9]+]]: v2f64 = fmul nnan ninf nsz arcp contract afn reassoc [[X]], [[SQRT]] +; CHECK-NEXT: ... into: [[R]]: v2f64 = fmul nnan ninf nsz arcp contract afn reassoc [[X]], [[SQRT]] + %r = call fast <2 x double> @llvm.pow.v2f64(<2 x double> %x, <2 x double> ) + ret <2 x double> %r +}