Skip to content

Commit

Permalink
[AArch64] Lower multiplication by a constant int to shl+add+shl
Browse files Browse the repository at this point in the history
Lower a = b * C where C = (2^n + 1) * 2^m to

add     w0, w0, w0, lsl n
lsl     w0, w0, m

Differential Revision: https://reviews.llvm.org/D229245

llvm-svn: 287019
  • Loading branch information
Haicheng Wu committed Nov 15, 2016
1 parent 3666629 commit faee2b7
Show file tree
Hide file tree
Showing 2 changed files with 281 additions and 10 deletions.
48 changes: 39 additions & 9 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7670,6 +7670,31 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
// future CPUs have a cheaper MADD instruction, this may need to be
// gated on a subtarget feature. For Cyclone, 32-bit MADD is 4 cycles and
// 64-bit is 5 cycles, so this is always a win.
// More aggressively, some multiplications N0 * C can be lowered to
// shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M,
// e.g. 6=3*2=(2+1)*2.
// TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45
// which equals to (1+2)*16-(1+2).
SDValue N0 = N->getOperand(0);
// TrailingZeroes is used to test if the mul can be lowered to
// shift+add+shift.
unsigned TrailingZeroes = ConstValue.countTrailingZeros();
if (TrailingZeroes) {
// Conservatively do not lower to shift+add+shift if the mul might be
// folded into smul or umul.
if (N0->hasOneUse() && (isSignExtended(N0.getNode(), DAG) ||
isZeroExtended(N0.getNode(), DAG)))
return SDValue();
// Conservatively do not lower to shift+add+shift if the mul might be
// folded into madd or msub.
if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ADD ||
N->use_begin()->getOpcode() == ISD::SUB))
return SDValue();
}
// Use ShiftedConstValue instead of ConstValue to support both shift+add/sub
// and shift+add+shift.
APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes);

unsigned ShiftAmt, AddSubOpc;
// Is the shifted value the LHS operand of the add/sub?
bool ShiftValUseIsN0 = true;
Expand All @@ -7679,10 +7704,11 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
if (ConstValue.isNonNegative()) {
// (mul x, 2^N + 1) => (add (shl x, N), x)
// (mul x, 2^N - 1) => (sub (shl x, N), x)
APInt CVMinus1 = ConstValue - 1;
// (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M)
APInt SCVMinus1 = ShiftedConstValue - 1;
APInt CVPlus1 = ConstValue + 1;
if (CVMinus1.isPowerOf2()) {
ShiftAmt = CVMinus1.logBase2();
if (SCVMinus1.isPowerOf2()) {
ShiftAmt = SCVMinus1.logBase2();
AddSubOpc = ISD::ADD;
} else if (CVPlus1.isPowerOf2()) {
ShiftAmt = CVPlus1.logBase2();
Expand All @@ -7708,18 +7734,22 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,

SDLoc DL(N);
EVT VT = N->getValueType(0);
SDValue N0 = N->getOperand(0);
SDValue ShiftedVal = DAG.getNode(ISD::SHL, DL, VT, N->getOperand(0),
SDValue ShiftedVal = DAG.getNode(ISD::SHL, DL, VT, N0,
DAG.getConstant(ShiftAmt, DL, MVT::i64));

SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal : N0;
SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal;
SDValue Res = DAG.getNode(AddSubOpc, DL, VT, AddSubN0, AddSubN1);
if (!NegateResult)
return Res;

assert(!(NegateResult && TrailingZeroes) &&
"NegateResult and TrailingZeroes cannot both be true for now.");
// Negate the result.
return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Res);
if (NegateResult)
return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Res);
// Shift the result.
if (TrailingZeroes)
return DAG.getNode(ISD::SHL, DL, VT, Res,
DAG.getConstant(TrailingZeroes, DL, MVT::i64));
return Res;
}

static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N,
Expand Down
243 changes: 242 additions & 1 deletion llvm/test/CodeGen/AArch64/mul_pow2.ll
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

; Convert mul x, pow2 to shift.
; Convert mul x, pow2 +/- 1 to shift + add/sub.
; Convert mul x, (pow2 + 1) * pow2 to shift + add + shift.
; Lowering other positive constants are not supported yet.

define i32 @test2(i32 %x) {
; CHECK-LABEL: test2
Expand Down Expand Up @@ -36,6 +38,122 @@ define i32 @test5(i32 %x) {
ret i32 %mul
}

define i32 @test6_32b(i32 %x) {
; CHECK-LABEL: test6
; CHECK: add {{w[0-9]+}}, w0, w0, lsl #1
; CHECK: lsl w0, {{w[0-9]+}}, #1

%mul = mul nsw i32 %x, 6
ret i32 %mul
}

define i64 @test6_64b(i64 %x) {
; CHECK-LABEL: test6_64b
; CHECK: add {{x[0-9]+}}, x0, x0, lsl #1
; CHECK: lsl x0, {{x[0-9]+}}, #1

%mul = mul nsw i64 %x, 6
ret i64 %mul
}

; mul that appears together with add, sub, s(z)ext is not supported to be
; converted to the combination of lsl, add/sub yet.
define i64 @test6_umull(i32 %x) {
; CHECK-LABEL: test6_umull
; CHECK: umull x0, w0, {{w[0-9]+}}

%ext = zext i32 %x to i64
%mul = mul nsw i64 %ext, 6
ret i64 %mul
}

define i64 @test6_smull(i32 %x) {
; CHECK-LABEL: test6_smull
; CHECK: smull x0, w0, {{w[0-9]+}}

%ext = sext i32 %x to i64
%mul = mul nsw i64 %ext, 6
ret i64 %mul
}

define i32 @test6_madd(i32 %x, i32 %y) {
; CHECK-LABEL: test6_madd
; CHECK: madd w0, w0, {{w[0-9]+}}, w1

%mul = mul nsw i32 %x, 6
%add = add i32 %mul, %y
ret i32 %add
}

define i32 @test6_msub(i32 %x, i32 %y) {
; CHECK-LABEL: test6_msub
; CHECK: msub w0, w0, {{w[0-9]+}}, w1

%mul = mul nsw i32 %x, 6
%sub = sub i32 %y, %mul
ret i32 %sub
}

define i64 @test6_umaddl(i32 %x, i64 %y) {
; CHECK-LABEL: test6_umaddl
; CHECK: umaddl x0, w0, {{w[0-9]+}}, x1

%ext = zext i32 %x to i64
%mul = mul nsw i64 %ext, 6
%add = add i64 %mul, %y
ret i64 %add
}

define i64 @test6_smaddl(i32 %x, i64 %y) {
; CHECK-LABEL: test6_smaddl
; CHECK: smaddl x0, w0, {{w[0-9]+}}, x1

%ext = sext i32 %x to i64
%mul = mul nsw i64 %ext, 6
%add = add i64 %mul, %y
ret i64 %add
}

define i64 @test6_umsubl(i32 %x, i64 %y) {
; CHECK-LABEL: test6_umsubl
; CHECK: umsubl x0, w0, {{w[0-9]+}}, x1

%ext = zext i32 %x to i64
%mul = mul nsw i64 %ext, 6
%sub = sub i64 %y, %mul
ret i64 %sub
}

define i64 @test6_smsubl(i32 %x, i64 %y) {
; CHECK-LABEL: test6_smsubl
; CHECK: smsubl x0, w0, {{w[0-9]+}}, x1

%ext = sext i32 %x to i64
%mul = mul nsw i64 %ext, 6
%sub = sub i64 %y, %mul
ret i64 %sub
}

define i64 @test6_umnegl(i32 %x) {
; CHECK-LABEL: test6_umnegl
; CHECK: umnegl x0, w0, {{w[0-9]+}}

%ext = zext i32 %x to i64
%mul = mul nsw i64 %ext, 6
%sub = sub i64 0, %mul
ret i64 %sub
}

define i64 @test6_smnegl(i32 %x) {
; CHECK-LABEL: test6_smnegl
; CHECK: smnegl x0, w0, {{w[0-9]+}}

%ext = sext i32 %x to i64
%mul = mul nsw i64 %ext, 6
%sub = sub i64 0, %mul
ret i64 %sub
}

define i32 @test7(i32 %x) {
; CHECK-LABEL: test7
; CHECK: lsl {{w[0-9]+}}, w0, #3
Expand All @@ -57,12 +175,72 @@ define i32 @test9(i32 %x) {
; CHECK-LABEL: test9
; CHECK: add w0, w0, w0, lsl #3

%mul = mul nsw i32 %x, 9
%mul = mul nsw i32 %x, 9
ret i32 %mul
}

define i32 @test10(i32 %x) {
; CHECK-LABEL: test10
; CHECK: add {{w[0-9]+}}, w0, w0, lsl #2
; CHECK: lsl w0, {{w[0-9]+}}, #1

%mul = mul nsw i32 %x, 10
ret i32 %mul
}

define i32 @test11(i32 %x) {
; CHECK-LABEL: test11
; CHECK: mul w0, w0, {{w[0-9]+}}

%mul = mul nsw i32 %x, 11
ret i32 %mul
}

define i32 @test12(i32 %x) {
; CHECK-LABEL: test12
; CHECK: add {{w[0-9]+}}, w0, w0, lsl #1
; CHECK: lsl w0, {{w[0-9]+}}, #2

%mul = mul nsw i32 %x, 12
ret i32 %mul
}

define i32 @test13(i32 %x) {
; CHECK-LABEL: test13
; CHECK: mul w0, w0, {{w[0-9]+}}

%mul = mul nsw i32 %x, 13
ret i32 %mul
}

define i32 @test14(i32 %x) {
; CHECK-LABEL: test14
; CHECK: mul w0, w0, {{w[0-9]+}}

%mul = mul nsw i32 %x, 14
ret i32 %mul
}

define i32 @test15(i32 %x) {
; CHECK-LABEL: test15
; CHECK: lsl {{w[0-9]+}}, w0, #4
; CHECK: sub w0, {{w[0-9]+}}, w0

%mul = mul nsw i32 %x, 15
ret i32 %mul
}

define i32 @test16(i32 %x) {
; CHECK-LABEL: test16
; CHECK: lsl w0, w0, #4

%mul = mul nsw i32 %x, 16
ret i32 %mul
}

; Convert mul x, -pow2 to shift.
; Convert mul x, -(pow2 +/- 1) to shift + add/sub.
; Lowering other negative constants are not supported yet.

define i32 @ntest2(i32 %x) {
; CHECK-LABEL: ntest2
Expand Down Expand Up @@ -96,6 +274,14 @@ define i32 @ntest5(i32 %x) {
ret i32 %mul
}

define i32 @ntest6(i32 %x) {
; CHECK-LABEL: ntest6
; CHECK: mul w0, w0, {{w[0-9]+}}

%mul = mul nsw i32 %x, -6
ret i32 %mul
}

define i32 @ntest7(i32 %x) {
; CHECK-LABEL: ntest7
; CHECK: sub w0, w0, w0, lsl #3
Expand All @@ -120,3 +306,58 @@ define i32 @ntest9(i32 %x) {
%mul = mul nsw i32 %x, -9
ret i32 %mul
}

define i32 @ntest10(i32 %x) {
; CHECK-LABEL: ntest10
; CHECK: mul w0, w0, {{w[0-9]+}}

%mul = mul nsw i32 %x, -10
ret i32 %mul
}

define i32 @ntest11(i32 %x) {
; CHECK-LABEL: ntest11
; CHECK: mul w0, w0, {{w[0-9]+}}

%mul = mul nsw i32 %x, -11
ret i32 %mul
}

define i32 @ntest12(i32 %x) {
; CHECK-LABEL: ntest12
; CHECK: mul w0, w0, {{w[0-9]+}}

%mul = mul nsw i32 %x, -12
ret i32 %mul
}

define i32 @ntest13(i32 %x) {
; CHECK-LABEL: ntest13
; CHECK: mul w0, w0, {{w[0-9]+}}
%mul = mul nsw i32 %x, -13
ret i32 %mul
}

define i32 @ntest14(i32 %x) {
; CHECK-LABEL: ntest14
; CHECK: mul w0, w0, {{w[0-9]+}}

%mul = mul nsw i32 %x, -14
ret i32 %mul
}

define i32 @ntest15(i32 %x) {
; CHECK-LABEL: ntest15
; CHECK: sub w0, w0, w0, lsl #4

%mul = mul nsw i32 %x, -15
ret i32 %mul
}

define i32 @ntest16(i32 %x) {
; CHECK-LABEL: ntest16
; CHECK: neg w0, w0, lsl #4

%mul = mul nsw i32 %x, -16
ret i32 %mul
}

0 comments on commit faee2b7

Please sign in to comment.