Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7614,52 +7614,86 @@ // future CPUs have a cheaper MADD instruction, this may need to be // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and // 64-bit is 5 cycles, so this is always a win. - if (ConstantSDNode *C = dyn_cast(N->getOperand(1))) { - const APInt &Value = C->getAPIntValue(); - EVT VT = N->getValueType(0); - SDLoc DL(N); - if (Value.isNonNegative()) { + // More aggressively, some multiplications can be lowered to shift+add+shift + // if the constant is (2^N + 1) * 2^M. + // TODO: consider constants in the form of (2^N - 1) * 2^M, (1-2^N) * 2^M, + // -(2^N+1) * 2^M, +/-(2^N +/- 1 ) * 2^M +/- 1, or (2^N +/- 1) * (2^M +/- 1). + auto C = dyn_cast(N->getOperand(1)); + if (!C) + return SDValue(); + + const APInt &Value = C->getAPIntValue(); + EVT VT = N->getValueType(0); + SDLoc DL(N); + SDValue N0 = N->getOperand(0); + // Lg2 is used to test if the mul can be lowered to shift+add+shift. + unsigned Lg2 = Value.countTrailingZeros(); + if (Lg2) { + // Conservatively do no lower to shift+add+shift if the mul might be + // folded into smul or umul. + if (N0->hasOneUse() && (isSignExtended(N0.getNode(), DAG) || + isZeroExtended(N0.getNode(), DAG))) + return SDValue(); + // Conservatively do no lower to shift+add+shift if the mul might be + // folded into madd or msub. + if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ADD || + N->use_begin()->getOpcode() == ISD::SUB)) + return SDValue(); + } + APInt ShiftedInt = Value.ashr(Lg2); + + APInt IntValue; + unsigned AddOrSub; + bool IsN0FirstOperand; + bool ExtraNeg; + if (Value.isNonNegative()) { + APInt VM1 = ShiftedInt - 1; + APInt VP1 = Value + 1; + if (VM1.isPowerOf2()) { // (mul x, 2^N + 1) => (add (shl x, N), x) - APInt VM1 = Value - 1; - if (VM1.isPowerOf2()) { - SDValue ShiftedVal = - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(VM1.logBase2(), DL, MVT::i64)); - return DAG.getNode(ISD::ADD, DL, VT, ShiftedVal, - N->getOperand(0)); - } + // Or (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) + IntValue = VM1; + AddOrSub = ISD::ADD; + IsN0FirstOperand = false; + ExtraNeg = false; + } else if (VP1.isPowerOf2()) { // (mul x, 2^N - 1) => (sub (shl x, N), x) - APInt VP1 = Value + 1; - if (VP1.isPowerOf2()) { - SDValue ShiftedVal = - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(VP1.logBase2(), DL, MVT::i64)); - return DAG.getNode(ISD::SUB, DL, VT, ShiftedVal, - N->getOperand(0)); - } - } else { + IntValue = VP1; + AddOrSub = ISD::SUB; + IsN0FirstOperand = false; + ExtraNeg = false; + } else + return SDValue(); + } else { + APInt VNP1 = -Value + 1; + APInt VNM1 = -Value - 1; + if (VNP1.isPowerOf2()) { // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) - APInt VNP1 = -Value + 1; - if (VNP1.isPowerOf2()) { - SDValue ShiftedVal = - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(VNP1.logBase2(), DL, MVT::i64)); - return DAG.getNode(ISD::SUB, DL, VT, N->getOperand(0), - ShiftedVal); - } - // (mul x, -(2^N + 1)) => - (add (shl x, N), x) - APInt VNM1 = -Value - 1; - if (VNM1.isPowerOf2()) { - SDValue ShiftedVal = - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(VNM1.logBase2(), DL, MVT::i64)); - SDValue Add = - DAG.getNode(ISD::ADD, DL, VT, ShiftedVal, N->getOperand(0)); - return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Add); - } - } + IntValue = VNP1; + AddOrSub = ISD::SUB; + IsN0FirstOperand = true; + ExtraNeg = false; + } else if (VNM1.isPowerOf2()) { + // (mul x, -(2^N + 1)) => -(add (shl x, N), x) + IntValue = VNM1; + AddOrSub = ISD::ADD; + IsN0FirstOperand = false; + ExtraNeg = true; + } else + return SDValue(); } - return SDValue(); + SDValue ShiftedVal = DAG.getNode( + ISD::SHL, DL, VT, N0, DAG.getConstant(IntValue.logBase2(), DL, MVT::i64)); + SDValue AddOrSubVal = + DAG.getNode(AddOrSub, DL, VT, IsN0FirstOperand ? N0 : ShiftedVal, + IsN0FirstOperand ? ShiftedVal : N0); + if (Lg2) + return DAG.getNode(ISD::SHL, DL, VT, AddOrSubVal, + DAG.getConstant(Lg2, DL, MVT::i64)); + if (ExtraNeg) + return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), + AddOrSubVal); + return AddOrSubVal; } static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N, Index: test/CodeGen/AArch64/mul_pow2.ll =================================================================== --- test/CodeGen/AArch64/mul_pow2.ll +++ test/CodeGen/AArch64/mul_pow2.ll @@ -2,6 +2,8 @@ ; Convert mul x, pow2 to shift. ; Convert mul x, pow2 +/- 1 to shift + add/sub. +; Convert mul x, (pow2 + 1) * pow2 to shift + add + shift. +; Lowering other positive constants are not supported yet. define i32 @test2(i32 %x) { ; CHECK-LABEL: test2 @@ -36,6 +38,122 @@ ret i32 %mul } +define i32 @test6_32b(i32 %x) { +; CHECK-LABEL: test6 +; CHECK: add w8, w0, w0, lsl #1 +; CHECK: lsl w0, w8, #1 + + %mul = mul nsw i32 %x, 6 + ret i32 %mul +} + +define i64 @test6_64b(i64 %x) { +; CHECK-LABEL: test6_64b +; CHECK: add x8, x0, x0, lsl #1 +; CHECK: lsl x0, x8, #1 + + %mul = mul nsw i64 %x, 6 + ret i64 %mul +} + +; mul that appears together with add, sub, s(z)ext is not supported to be +; converted to the combination of lsl, add/sub yet. +define i64 @test6_umull(i32 %x) { +; CHECK-LABEL: test6_umull +; CHECK: umull x0, w0, w8 + + %ext = zext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + ret i64 %mul +} + +define i64 @test6_smull(i32 %x) { +; CHECK-LABEL: test6_smull +; CHECK: smull x0, w0, w8 + + %ext = sext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + ret i64 %mul +} + +define i32 @test6_madd(i32 %x, i32 %y) { +; CHECK-LABEL: test6_madd +; CHECK: madd w0, w0, w8, w1 + + %mul = mul nsw i32 %x, 6 + %add = add i32 %mul, %y + ret i32 %add +} + +define i32 @test6_msub(i32 %x, i32 %y) { +; CHECK-LABEL: test6_msub +; CHECK: msub w0, w0, w8, w1 + + %mul = mul nsw i32 %x, 6 + %sub = sub i32 %y, %mul + ret i32 %sub +} + +define i64 @test6_umaddl(i32 %x, i64 %y) { +; CHECK-LABEL: test6_umaddl +; CHECK: umaddl x0, w0, w8, x1 + + %ext = zext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %add = add i64 %mul, %y + ret i64 %add +} + +define i64 @test6_smaddl(i32 %x, i64 %y) { +; CHECK-LABEL: test6_smaddl +; CHECK: smaddl x0, w0, w8, x1 + + %ext = sext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %add = add i64 %mul, %y + ret i64 %add +} + +define i64 @test6_umsubl(i32 %x, i64 %y) { +; CHECK-LABEL: test6_umsubl +; CHECK: umsubl x0, w0, w8, x1 + + %ext = zext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %sub = sub i64 %y, %mul + ret i64 %sub +} + +define i64 @test6_smsubl(i32 %x, i64 %y) { +; CHECK-LABEL: test6_smsubl +; CHECK: smsubl x0, w0, w8, x1 + + %ext = sext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %sub = sub i64 %y, %mul + ret i64 %sub +} + +define i64 @test6_umnegl(i32 %x) { +; CHECK-LABEL: test6_umnegl +; CHECK: umnegl x0, w0, w8 + + %ext = zext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %sub = sub i64 0, %mul + ret i64 %sub +} + +define i64 @test6_smnegl(i32 %x) { +; CHECK-LABEL: test6_smnegl +; CHECK: smnegl x0, w0, w8 + + %ext = sext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %sub = sub i64 0, %mul + ret i64 %sub +} + define i32 @test7(i32 %x) { ; CHECK-LABEL: test7 ; CHECK: lsl {{w[0-9]+}}, w0, #3 @@ -57,12 +175,72 @@ ; CHECK-LABEL: test9 ; CHECK: add w0, w0, w0, lsl #3 - %mul = mul nsw i32 %x, 9 + %mul = mul nsw i32 %x, 9 + ret i32 %mul +} + +define i32 @test10(i32 %x) { +; CHECK-LABEL: test10 +; CHECK: add w8, w0, w0, lsl #2 +; CHECK: lsl w0, w8, #1 + + %mul = mul nsw i32 %x, 10 + ret i32 %mul +} + +define i32 @test11(i32 %x) { +; CHECK-LABEL: test11 +; CHECK: mul w0, w0, w8 + + %mul = mul nsw i32 %x, 11 + ret i32 %mul +} + +define i32 @test12(i32 %x) { +; CHECK-LABEL: test12 +; CHECK: add w8, w0, w0, lsl #1 +; CHECK: lsl w0, w8, #2 + + %mul = mul nsw i32 %x, 12 + ret i32 %mul +} + +define i32 @test13(i32 %x) { +; CHECK-LABEL: test13 +; CHECK: mul w0, w0, w8 + + %mul = mul nsw i32 %x, 13 + ret i32 %mul +} + +define i32 @test14(i32 %x) { +; CHECK-LABEL: test14 +; CHECK: mul w0, w0, w8 + + %mul = mul nsw i32 %x, 14 + ret i32 %mul +} + +define i32 @test15(i32 %x) { +; CHECK-LABEL: test15 +; CHECK: lsl {{w[0-9]+}}, w0, #4 +; CHECK: sub w0, {{w[0-9]+}}, w0 + + %mul = mul nsw i32 %x, 15 + ret i32 %mul +} + +define i32 @test16(i32 %x) { +; CHECK-LABEL: test16 +; CHECK: lsl w0, w0, #4 + + %mul = mul nsw i32 %x, 16 ret i32 %mul } ; Convert mul x, -pow2 to shift. ; Convert mul x, -(pow2 +/- 1) to shift + add/sub. +; Lowering other negative constants are not supported yet. define i32 @ntest2(i32 %x) { ; CHECK-LABEL: ntest2 @@ -96,6 +274,14 @@ ret i32 %mul } +define i32 @ntest6(i32 %x) { +; CHECK-LABEL: ntest6 +; CHECK: mul w0, w0, w8 + + %mul = mul nsw i32 %x, -6 + ret i32 %mul +} + define i32 @ntest7(i32 %x) { ; CHECK-LABEL: ntest7 ; CHECK: sub w0, w0, w0, lsl #3 @@ -120,3 +306,58 @@ %mul = mul nsw i32 %x, -9 ret i32 %mul } + +define i32 @ntest10(i32 %x) { +; CHECK-LABEL: ntest10 +; CHECK: mul w0, w0, w8 + + %mul = mul nsw i32 %x, -10 + ret i32 %mul +} + +define i32 @ntest11(i32 %x) { +; CHECK-LABEL: ntest11 +; CHECK: mul w0, w0, w8 + + %mul = mul nsw i32 %x, -11 + ret i32 %mul +} + +define i32 @ntest12(i32 %x) { +; CHECK-LABEL: ntest12 +; CHECK: mul w0, w0, w8 + + %mul = mul nsw i32 %x, -12 + ret i32 %mul +} + +define i32 @ntest13(i32 %x) { +; CHECK-LABEL: ntest13 +; CHECK: mul w0, w0, w8 + %mul = mul nsw i32 %x, -13 + ret i32 %mul +} + +define i32 @ntest14(i32 %x) { +; CHECK-LABEL: ntest14 +; CHECK: mul w0, w0, w8 + + %mul = mul nsw i32 %x, -14 + ret i32 %mul +} + +define i32 @ntest15(i32 %x) { +; CHECK-LABEL: ntest15 +; CHECK: sub w0, w0, w0, lsl #4 + + %mul = mul nsw i32 %x, -15 + ret i32 %mul +} + +define i32 @ntest16(i32 %x) { +; CHECK-LABEL: ntest16 +; CHECK: neg w0, w0, lsl #4 + + %mul = mul nsw i32 %x, -16 + ret i32 %mul +}