Index: llvm/lib/Target/AArch64/AArch64ISelLowering.cpp =================================================================== --- llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14556,8 +14556,29 @@ // Conservatively do not lower to shift+add+shift if the mul might be // folded into madd or msub. if (N->hasOneUse() && (N->use_begin()->getOpcode() == ISD::ADD || - N->use_begin()->getOpcode() == ISD::SUB)) - return SDValue(); + N->use_begin()->getOpcode() == ISD::SUB)) { + SDNode *AddSub = *N->use_begin(); + // Shouldn't block the transform to add+shifts as the operand of madd/msub + // should not be const value. + SDValue BinaryOP; + if (N == AddSub->getOperand(0).getNode()) + BinaryOP = AddSub->getOperand(1); + else + BinaryOP = AddSub->getOperand(0); + const ConstantSDNode *OpC = dyn_cast(BinaryOP); + bool match; + // For sub, may transform to msub only when OpC is a reg. + if (AddSub->getOpcode() == ISD::SUB) + match = !OpC; + // For add, may transform to madd when OpC is a reg or not match the + // isLegalAddImmediate. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (AddSub->getOpcode() == ISD::ADD) + match = !OpC || + !TLI.isLegalAddImmediate(OpC->getAPIntValue().getSExtValue()); + if (match) + return SDValue(); + } } // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub // and shift+add+shift. @@ -14568,12 +14589,16 @@ bool ShiftValUseIsN0 = true; // Do we need to negate the result? bool NegateResult = false; + // Is the sub has 2 shifted value operands? + bool Sub2Shift = false; if (ConstValue.isNonNegative()) { // (mul x, 2^N + 1) => (add (shl x, N), x) // (mul x, 2^N - 1) => (sub (shl x, N), x) // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) + // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M)) APInt SCVMinus1 = ShiftedConstValue - 1; + APInt SCVPlus1 = ShiftedConstValue + 1; APInt CVPlus1 = ConstValue + 1; if (SCVMinus1.isPowerOf2()) { ShiftAmt = SCVMinus1.logBase2(); @@ -14581,6 +14606,10 @@ } else if (CVPlus1.isPowerOf2()) { ShiftAmt = CVPlus1.logBase2(); AddSubOpc = ISD::SUB; + } else if (SCVPlus1.isPowerOf2()) { + ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes; + AddSubOpc = ISD::SUB; + Sub2Shift = true; } else return SDValue(); } else { @@ -14600,11 +14629,15 @@ return SDValue(); } - SDValue ShiftedVal = DAG.getNode(ISD::SHL, DL, VT, N0, - DAG.getConstant(ShiftAmt, DL, MVT::i64)); - - SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal : N0; - SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal; + SDValue ShiftedVal0 = DAG.getNode(ISD::SHL, DL, VT, N0, + DAG.getConstant(ShiftAmt, DL, MVT::i64)); + if (Sub2Shift) { + SDValue ShiftedVal1 = DAG.getNode( + ISD::SHL, DL, VT, N0, DAG.getConstant(TrailingZeroes, DL, MVT::i64)); + return DAG.getNode(ISD::SUB, DL, VT, ShiftedVal0, ShiftedVal1); + } + SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal0 : N0; + SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal0; SDValue Res = DAG.getNode(AddSubOpc, DL, VT, AddSubN0, AddSubN1); assert(!(NegateResult && TrailingZeroes) && "NegateResult and TrailingZeroes cannot both be true for now."); Index: llvm/test/CodeGen/AArch64/mul_pow2.ll =================================================================== --- llvm/test/CodeGen/AArch64/mul_pow2.ll +++ llvm/test/CodeGen/AArch64/mul_pow2.ll @@ -290,6 +290,25 @@ ret i64 %sub } +define i32 @mull6_sub(i32 %x) { +; CHECK-LABEL: mull6_sub: +; CHECK: // %bb.0: +; CHECK-NEXT: add w8, w0, w0, lsl #1 +; CHECK-NEXT: lsl w8, w8, #1 +; CHECK-NEXT: sub w0, w8, #1 +; CHECK-NEXT: ret +; +; GISEL-LABEL: mull6_sub: +; GISEL: // %bb.0: +; GISEL-NEXT: mov w8, #6 +; GISEL-NEXT: mul w8, w0, w8 +; GISEL-NEXT: sub w0, w8, #1 +; GISEL-NEXT: ret + %mul = mul nsw i32 %x, 6 + %sub = add nsw i32 %mul, -1 + ret i32 %sub +} + define i32 @test7(i32 %x) { ; CHECK-LABEL: test7: ; CHECK: // %bb.0: @@ -408,8 +427,8 @@ define i32 @test14(i32 %x) { ; CHECK-LABEL: test14: ; CHECK: // %bb.0: -; CHECK-NEXT: mov w8, #14 -; CHECK-NEXT: mul w0, w0, w8 +; CHECK-NEXT: lsl w8, w0, #4 +; CHECK-NEXT: sub w0, w8, w0, lsl #1 ; CHECK-NEXT: ret ; ; GISEL-LABEL: test14: @@ -731,11 +750,11 @@ ; ; GISEL-LABEL: muladd_demand_commute: ; GISEL: // %bb.0: -; GISEL-NEXT: adrp x8, .LCPI42_1 -; GISEL-NEXT: ldr q2, [x8, :lo12:.LCPI42_1] -; GISEL-NEXT: adrp x8, .LCPI42_0 +; GISEL-NEXT: adrp x8, .LCPI43_1 +; GISEL-NEXT: ldr q2, [x8, :lo12:.LCPI43_1] +; GISEL-NEXT: adrp x8, .LCPI43_0 ; GISEL-NEXT: mla v1.4s, v0.4s, v2.4s -; GISEL-NEXT: ldr q0, [x8, :lo12:.LCPI42_0] +; GISEL-NEXT: ldr q0, [x8, :lo12:.LCPI43_0] ; GISEL-NEXT: and v0.16b, v1.16b, v0.16b ; GISEL-NEXT: ret %m = mul <4 x i32> %x,