Index: lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- lib/Target/AArch64/AArch64ISelLowering.cpp +++ lib/Target/AArch64/AArch64ISelLowering.cpp @@ -7614,52 +7614,92 @@ // future CPUs have a cheaper MADD instruction, this may need to be // gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and // 64-bit is 5 cycles, so this is always a win. - if (ConstantSDNode *C = dyn_cast(N->getOperand(1))) { - const APInt &Value = C->getAPIntValue(); - EVT VT = N->getValueType(0); - SDLoc DL(N); - if (Value.isNonNegative()) { + // More aggressively, some multiplications Var * C can be lowered to + // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, + // e.g. 6=3*2=(2+1)*2. + // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 + // which equals to (1+2)*16-(1+2). + auto C = dyn_cast(N->getOperand(1)); + if (!C) + return SDValue(); + + const APInt &ValueOfC = C->getAPIntValue(); + EVT VT = N->getValueType(0); + SDLoc DL(N); + SDValue Var = N->getOperand(0); + // TrailingZeroes is used to test if the mul can be lowered to + // shift+add+shift. + unsigned TrailingZeroes = ValueOfC.countTrailingZeros(); + if (TrailingZeroes) { + // Conservatively do not lower to shift+add+shift if the mul might be + // folded into smul or umul. + if (Var->hasOneUse() && (isSignExtended(Var.getNode(), DAG) || + isZeroExtended(Var.getNode(), DAG))) + return SDValue(); + // Conservatively do not lower to shift+add+shift if the mul might be + // folded into madd or msub. + if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ADD || + N->use_begin()->getOpcode() == ISD::SUB)) + return SDValue(); + } + APInt ConstantA = ValueOfC.ashr(TrailingZeroes); + + APInt IntValue; + unsigned Operation; + // SwapValues decides (Var - ShiftedValue) or (ShiftedValue - Var). It does + // not matter if the operation is Add. + bool SwapValues; + // ExtraNeg decides if a Neg is needed at last if C is negative. + bool ExtraNeg; + if (ValueOfC.isNonNegative()) { + // add+shl+add is supported. Use ConstantA instead of ValueOfC. + APInt ConstantAMinus1 = ConstantA - 1; + APInt ValueOfCPlus1 = ValueOfC + 1; + SwapValues = false; + ExtraNeg = false; + if (ConstantAMinus1.isPowerOf2()) { // (mul x, 2^N + 1) => (add (shl x, N), x) - APInt VM1 = Value - 1; - if (VM1.isPowerOf2()) { - SDValue ShiftedVal = - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(VM1.logBase2(), DL, MVT::i64)); - return DAG.getNode(ISD::ADD, DL, VT, ShiftedVal, - N->getOperand(0)); - } + // Or (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) + IntValue = ConstantAMinus1; + Operation = ISD::ADD; + } else if (ValueOfCPlus1.isPowerOf2()) { // (mul x, 2^N - 1) => (sub (shl x, N), x) - APInt VP1 = Value + 1; - if (VP1.isPowerOf2()) { - SDValue ShiftedVal = - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(VP1.logBase2(), DL, MVT::i64)); - return DAG.getNode(ISD::SUB, DL, VT, ShiftedVal, - N->getOperand(0)); - } - } else { + IntValue = ValueOfCPlus1; + Operation = ISD::SUB; + } else + return SDValue(); + } else { + APInt NegativeValueOfCPlus1 = -ValueOfC + 1; + APInt NegativeValueOfCMinus1 = -ValueOfC - 1; + if (NegativeValueOfCPlus1.isPowerOf2()) { // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) - APInt VNP1 = -Value + 1; - if (VNP1.isPowerOf2()) { - SDValue ShiftedVal = - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(VNP1.logBase2(), DL, MVT::i64)); - return DAG.getNode(ISD::SUB, DL, VT, N->getOperand(0), - ShiftedVal); - } - // (mul x, -(2^N + 1)) => - (add (shl x, N), x) - APInt VNM1 = -Value - 1; - if (VNM1.isPowerOf2()) { - SDValue ShiftedVal = - DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0), - DAG.getConstant(VNM1.logBase2(), DL, MVT::i64)); - SDValue Add = - DAG.getNode(ISD::ADD, DL, VT, ShiftedVal, N->getOperand(0)); - return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Add); - } - } + IntValue = NegativeValueOfCPlus1; + Operation = ISD::SUB; + SwapValues = true; + ExtraNeg = false; + } else if (NegativeValueOfCMinus1.isPowerOf2()) { + // (mul x, -(2^N + 1)) => -(add (shl x, N), x) + IntValue = NegativeValueOfCMinus1; + Operation = ISD::ADD; + SwapValues = false; + ExtraNeg = true; + } else + return SDValue(); } - return SDValue(); + assert(IntValue.isPowerOf2() && "IntValue must be power of 2"); + SDValue ShiftedVal = + DAG.getNode(ISD::SHL, DL, VT, Var, + DAG.getConstant(IntValue.logBase2(), DL, MVT::i64)); + SDValue AddOrSubVal = + DAG.getNode(Operation, DL, VT, SwapValues ? Var : ShiftedVal, + SwapValues ? ShiftedVal : Var); + + if (TrailingZeroes == 0 && !ExtraNeg) + return AddOrSubVal; + if (TrailingZeroes) + return DAG.getNode(ISD::SHL, DL, VT, AddOrSubVal, + DAG.getConstant(TrailingZeroes, DL, MVT::i64)); + return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), AddOrSubVal); } static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N, Index: test/CodeGen/AArch64/mul_pow2.ll =================================================================== --- test/CodeGen/AArch64/mul_pow2.ll +++ test/CodeGen/AArch64/mul_pow2.ll @@ -2,6 +2,8 @@ ; Convert mul x, pow2 to shift. ; Convert mul x, pow2 +/- 1 to shift + add/sub. +; Convert mul x, (pow2 + 1) * pow2 to shift + add + shift. +; Lowering other positive constants are not supported yet. define i32 @test2(i32 %x) { ; CHECK-LABEL: test2 @@ -36,6 +38,122 @@ ret i32 %mul } +define i32 @test6_32b(i32 %x) { +; CHECK-LABEL: test6 +; CHECK: add {{w[0-9]+}}, w0, w0, lsl #1 +; CHECK: lsl w0, {{w[0-9]+}}, #1 + + %mul = mul nsw i32 %x, 6 + ret i32 %mul +} + +define i64 @test6_64b(i64 %x) { +; CHECK-LABEL: test6_64b +; CHECK: add {{x[0-9]+}}, x0, x0, lsl #1 +; CHECK: lsl x0, {{x[0-9]+}}, #1 + + %mul = mul nsw i64 %x, 6 + ret i64 %mul +} + +; mul that appears together with add, sub, s(z)ext is not supported to be +; converted to the combination of lsl, add/sub yet. +define i64 @test6_umull(i32 %x) { +; CHECK-LABEL: test6_umull +; CHECK: umull x0, w0, {{w[0-9]+}} + + %ext = zext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + ret i64 %mul +} + +define i64 @test6_smull(i32 %x) { +; CHECK-LABEL: test6_smull +; CHECK: smull x0, w0, {{w[0-9]+}} + + %ext = sext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + ret i64 %mul +} + +define i32 @test6_madd(i32 %x, i32 %y) { +; CHECK-LABEL: test6_madd +; CHECK: madd w0, w0, {{w[0-9]+}}, w1 + + %mul = mul nsw i32 %x, 6 + %add = add i32 %mul, %y + ret i32 %add +} + +define i32 @test6_msub(i32 %x, i32 %y) { +; CHECK-LABEL: test6_msub +; CHECK: msub w0, w0, {{w[0-9]+}}, w1 + + %mul = mul nsw i32 %x, 6 + %sub = sub i32 %y, %mul + ret i32 %sub +} + +define i64 @test6_umaddl(i32 %x, i64 %y) { +; CHECK-LABEL: test6_umaddl +; CHECK: umaddl x0, w0, {{w[0-9]+}}, x1 + + %ext = zext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %add = add i64 %mul, %y + ret i64 %add +} + +define i64 @test6_smaddl(i32 %x, i64 %y) { +; CHECK-LABEL: test6_smaddl +; CHECK: smaddl x0, w0, {{w[0-9]+}}, x1 + + %ext = sext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %add = add i64 %mul, %y + ret i64 %add +} + +define i64 @test6_umsubl(i32 %x, i64 %y) { +; CHECK-LABEL: test6_umsubl +; CHECK: umsubl x0, w0, {{w[0-9]+}}, x1 + + %ext = zext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %sub = sub i64 %y, %mul + ret i64 %sub +} + +define i64 @test6_smsubl(i32 %x, i64 %y) { +; CHECK-LABEL: test6_smsubl +; CHECK: smsubl x0, w0, {{w[0-9]+}}, x1 + + %ext = sext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %sub = sub i64 %y, %mul + ret i64 %sub +} + +define i64 @test6_umnegl(i32 %x) { +; CHECK-LABEL: test6_umnegl +; CHECK: umnegl x0, w0, {{w[0-9]+}} + + %ext = zext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %sub = sub i64 0, %mul + ret i64 %sub +} + +define i64 @test6_smnegl(i32 %x) { +; CHECK-LABEL: test6_smnegl +; CHECK: smnegl x0, w0, {{w[0-9]+}} + + %ext = sext i32 %x to i64 + %mul = mul nsw i64 %ext, 6 + %sub = sub i64 0, %mul + ret i64 %sub +} + define i32 @test7(i32 %x) { ; CHECK-LABEL: test7 ; CHECK: lsl {{w[0-9]+}}, w0, #3 @@ -57,12 +175,72 @@ ; CHECK-LABEL: test9 ; CHECK: add w0, w0, w0, lsl #3 - %mul = mul nsw i32 %x, 9 + %mul = mul nsw i32 %x, 9 + ret i32 %mul +} + +define i32 @test10(i32 %x) { +; CHECK-LABEL: test10 +; CHECK: add {{w[0-9]+}}, w0, w0, lsl #2 +; CHECK: lsl w0, {{w[0-9]+}}, #1 + + %mul = mul nsw i32 %x, 10 + ret i32 %mul +} + +define i32 @test11(i32 %x) { +; CHECK-LABEL: test11 +; CHECK: mul w0, w0, {{w[0-9]+}} + + %mul = mul nsw i32 %x, 11 + ret i32 %mul +} + +define i32 @test12(i32 %x) { +; CHECK-LABEL: test12 +; CHECK: add {{w[0-9]+}}, w0, w0, lsl #1 +; CHECK: lsl w0, {{w[0-9]+}}, #2 + + %mul = mul nsw i32 %x, 12 + ret i32 %mul +} + +define i32 @test13(i32 %x) { +; CHECK-LABEL: test13 +; CHECK: mul w0, w0, {{w[0-9]+}} + + %mul = mul nsw i32 %x, 13 + ret i32 %mul +} + +define i32 @test14(i32 %x) { +; CHECK-LABEL: test14 +; CHECK: mul w0, w0, {{w[0-9]+}} + + %mul = mul nsw i32 %x, 14 + ret i32 %mul +} + +define i32 @test15(i32 %x) { +; CHECK-LABEL: test15 +; CHECK: lsl {{w[0-9]+}}, w0, #4 +; CHECK: sub w0, {{w[0-9]+}}, w0 + + %mul = mul nsw i32 %x, 15 + ret i32 %mul +} + +define i32 @test16(i32 %x) { +; CHECK-LABEL: test16 +; CHECK: lsl w0, w0, #4 + + %mul = mul nsw i32 %x, 16 ret i32 %mul } ; Convert mul x, -pow2 to shift. ; Convert mul x, -(pow2 +/- 1) to shift + add/sub. +; Lowering other negative constants are not supported yet. define i32 @ntest2(i32 %x) { ; CHECK-LABEL: ntest2 @@ -96,6 +274,14 @@ ret i32 %mul } +define i32 @ntest6(i32 %x) { +; CHECK-LABEL: ntest6 +; CHECK: mul w0, w0, {{w[0-9]+}} + + %mul = mul nsw i32 %x, -6 + ret i32 %mul +} + define i32 @ntest7(i32 %x) { ; CHECK-LABEL: ntest7 ; CHECK: sub w0, w0, w0, lsl #3 @@ -120,3 +306,58 @@ %mul = mul nsw i32 %x, -9 ret i32 %mul } + +define i32 @ntest10(i32 %x) { +; CHECK-LABEL: ntest10 +; CHECK: mul w0, w0, {{w[0-9]+}} + + %mul = mul nsw i32 %x, -10 + ret i32 %mul +} + +define i32 @ntest11(i32 %x) { +; CHECK-LABEL: ntest11 +; CHECK: mul w0, w0, {{w[0-9]+}} + + %mul = mul nsw i32 %x, -11 + ret i32 %mul +} + +define i32 @ntest12(i32 %x) { +; CHECK-LABEL: ntest12 +; CHECK: mul w0, w0, {{w[0-9]+}} + + %mul = mul nsw i32 %x, -12 + ret i32 %mul +} + +define i32 @ntest13(i32 %x) { +; CHECK-LABEL: ntest13 +; CHECK: mul w0, w0, {{w[0-9]+}} + %mul = mul nsw i32 %x, -13 + ret i32 %mul +} + +define i32 @ntest14(i32 %x) { +; CHECK-LABEL: ntest14 +; CHECK: mul w0, w0, {{w[0-9]+}} + + %mul = mul nsw i32 %x, -14 + ret i32 %mul +} + +define i32 @ntest15(i32 %x) { +; CHECK-LABEL: ntest15 +; CHECK: sub w0, w0, w0, lsl #4 + + %mul = mul nsw i32 %x, -15 + ret i32 %mul +} + +define i32 @ntest16(i32 %x) { +; CHECK-LABEL: ntest16 +; CHECK: neg w0, w0, lsl #4 + + %mul = mul nsw i32 %x, -16 + ret i32 %mul +}