diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14798,7 +14798,7 @@ // More aggressively, some multiplications N0 * C can be lowered to // shift+add+shift if the constant C = A * B where A = 2^N + 1 and B = 2^M, // e.g. 6=3*2=(2+1)*2. - // TODO: consider lowering more cases, e.g. C = 14, -6, -14 or even 45 + // TODO: consider lowering more cases, e.g. C = -6, -14 or even 45 // which equals to (1+2)*16-(1+2). // TrailingZeroes is used to test if the mul can be lowered to @@ -14826,11 +14826,21 @@ // Do we need to negate the result? bool NegateResult = false; + auto Shl = [&](SDValue N0, unsigned N1) { + SDValue RHS = DAG.getConstant(N1, DL, MVT::i64); + return DAG.getNode(ISD::SHL, DL, VT, N0, RHS); + }; + auto Sub = [&](SDValue N0, SDValue N1) { + return DAG.getNode(ISD::SUB, DL, VT, N0, N1); + }; + if (ConstValue.isNonNegative()) { // (mul x, 2^N + 1) => (add (shl x, N), x) // (mul x, 2^N - 1) => (sub (shl x, N), x) // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) + // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M)) APInt SCVMinus1 = ShiftedConstValue - 1; + APInt SCVPlus1 = ShiftedConstValue + 1; APInt CVPlus1 = ConstValue + 1; if (SCVMinus1.isPowerOf2()) { ShiftAmt = SCVMinus1.logBase2(); @@ -14838,6 +14848,9 @@ } else if (CVPlus1.isPowerOf2()) { ShiftAmt = CVPlus1.logBase2(); AddSubOpc = ISD::SUB; + } else if (SCVPlus1.isPowerOf2()) { + ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes; + return Sub(Shl(N0, ShiftAmt), Shl(N0, TrailingZeroes)); } else return SDValue(); } else { @@ -14857,21 +14870,18 @@ return SDValue(); } - SDValue ShiftedVal = DAG.getNode(ISD::SHL, DL, VT, N0, - DAG.getConstant(ShiftAmt, DL, MVT::i64)); - - SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal : N0; - SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal; + SDValue ShiftedVal0 = Shl(N0, ShiftAmt); + SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal0 : N0; + SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal0; SDValue Res = DAG.getNode(AddSubOpc, DL, VT, AddSubN0, AddSubN1); assert(!(NegateResult && TrailingZeroes) && "NegateResult and TrailingZeroes cannot both be true for now."); // Negate the result. if (NegateResult) - return DAG.getNode(ISD::SUB, DL, VT, DAG.getConstant(0, DL, VT), Res); + return Sub(DAG.getConstant(0, DL, VT), Res); // Shift the result. if (TrailingZeroes) - return DAG.getNode(ISD::SHL, DL, VT, Res, - DAG.getConstant(TrailingZeroes, DL, MVT::i64)); + return Shl(Res, TrailingZeroes); return Res; } diff --git a/llvm/test/CodeGen/AArch64/mul_pow2.ll b/llvm/test/CodeGen/AArch64/mul_pow2.ll --- a/llvm/test/CodeGen/AArch64/mul_pow2.ll +++ b/llvm/test/CodeGen/AArch64/mul_pow2.ll @@ -408,8 +408,8 @@ define i32 @test14(i32 %x) { ; CHECK-LABEL: test14: ; CHECK: // %bb.0: -; CHECK-NEXT: mov w8, #14 -; CHECK-NEXT: mul w0, w0, w8 +; CHECK-NEXT: lsl w8, w0, #4 +; CHECK-NEXT: sub w0, w8, w0, lsl #1 ; CHECK-NEXT: ret ; ; GISEL-LABEL: test14: