diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -14819,40 +14819,40 @@ // Use ShiftedConstValue instead of ConstValue to support both shift+add/sub // and shift+add+shift. APInt ShiftedConstValue = ConstValue.ashr(TrailingZeroes); - unsigned ShiftAmt; - // Is the shifted value the LHS operand of the add/sub? - bool ShiftValUseIsN0 = true; - // Do we need to negate the result? - bool NegateResult = false; auto Shl = [&](SDValue N0, unsigned N1) { SDValue RHS = DAG.getConstant(N1, DL, MVT::i64); return DAG.getNode(ISD::SHL, DL, VT, N0, RHS); }; + auto Add = [&](SDValue N0, SDValue N1) { + return DAG.getNode(ISD::ADD, DL, VT, N0, N1); + }; auto Sub = [&](SDValue N0, SDValue N1) { return DAG.getNode(ISD::SUB, DL, VT, N0, N1); }; + auto Negate = [&](SDValue N) { + SDValue Zero = DAG.getConstant(0, DL, VT); + return DAG.getNode(ISD::SUB, DL, VT, Zero, N); + }; if (ConstValue.isNonNegative()) { - // (mul x, 2^N + 1) => (add (shl x, N), x) - // (mul x, 2^N - 1) => (sub (shl x, N), x) // (mul x, (2^N + 1) * 2^M) => (shl (add (shl x, N), x), M) + // (mul x, 2^N - 1) => (sub (shl x, N), x) // (mul x, (2^(N-M) - 1) * 2^M) => (sub (shl x, N), (shl x, M)) APInt SCVMinus1 = ShiftedConstValue - 1; APInt SCVPlus1 = ShiftedConstValue + 1; APInt CVPlus1 = ConstValue + 1; if (SCVMinus1.isPowerOf2()) { ShiftAmt = SCVMinus1.logBase2(); - AddSubOpc = ISD::ADD; + return Shl(Add(Shl(N0, ShiftAmt), N0), TrailingZeroes); } else if (CVPlus1.isPowerOf2()) { ShiftAmt = CVPlus1.logBase2(); - AddSubOpc = ISD::SUB; + return Sub(Shl(N0, ShiftAmt), N0); } else if (SCVPlus1.isPowerOf2()) { ShiftAmt = SCVPlus1.logBase2() + TrailingZeroes; return Sub(Shl(N0, ShiftAmt), Shl(N0, TrailingZeroes)); - } else - return SDValue(); + } } else { // (mul x, -(2^N - 1)) => (sub x, (shl x, N)) // (mul x, -(2^N + 1)) => - (add (shl x, N), x) @@ -14860,29 +14860,14 @@ APInt CVNegMinus1 = -ConstValue - 1; if (CVNegPlus1.isPowerOf2()) { ShiftAmt = CVNegPlus1.logBase2(); - AddSubOpc = ISD::SUB; - ShiftValUseIsN0 = false; + return Sub(N0, Shl(N0, ShiftAmt)); } else if (CVNegMinus1.isPowerOf2()) { ShiftAmt = CVNegMinus1.logBase2(); - AddSubOpc = ISD::ADD; - NegateResult = true; - } else - return SDValue(); + return Negate(Add(Shl(N0, ShiftAmt), N0)); + } } - SDValue ShiftedVal0 = Shl(N0, ShiftAmt); - SDValue AddSubN0 = ShiftValUseIsN0 ? ShiftedVal0 : N0; - SDValue AddSubN1 = ShiftValUseIsN0 ? N0 : ShiftedVal0; - SDValue Res = DAG.getNode(AddSubOpc, DL, VT, AddSubN0, AddSubN1); - assert(!(NegateResult && TrailingZeroes) && - "NegateResult and TrailingZeroes cannot both be true for now."); - // Negate the result. - if (NegateResult) - return Sub(DAG.getConstant(0, DL, VT), Res); - // Shift the result. - if (TrailingZeroes) - return Shl(Res, TrailingZeroes); - return Res; + return SDValue(); } static SDValue performVectorCompareAndMaskUnaryOpCombine(SDNode *N,