diff --git a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp --- a/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp @@ -2516,24 +2516,29 @@ return 0; }; - auto foldMul = [&](SDValue X, SDValue Y, unsigned ShlAmt) { + auto foldMul = [&](SDValue Op, SDValue X, SDValue Y, unsigned ShlAmt) { EVT ShiftAmtTy = getShiftAmountTy(VT, TLO.DAG.getDataLayout()); SDValue ShlAmtC = TLO.DAG.getConstant(ShlAmt, dl, ShiftAmtTy); SDValue Shl = TLO.DAG.getNode(ISD::SHL, dl, VT, X, ShlAmtC); - SDValue Sub = TLO.DAG.getNode(ISD::SUB, dl, VT, Y, Shl); - return TLO.CombineTo(Op, Sub); + SDValue Res = TLO.DAG.getNode( + Op.getOpcode() == ISD::ADD ? ISD::SUB : ISD::ADD, dl, VT, Y, Shl); + return TLO.CombineTo(Op, Res); }; if (isOperationLegalOrCustom(ISD::SHL, VT)) { if (Op.getOpcode() == ISD::ADD) { // (X * MulC) + Op1 --> Op1 - (X << log2(-MulC)) if (unsigned ShAmt = getShiftLeftAmt(Op0)) - return foldMul(Op0.getOperand(0), Op1, ShAmt); + return foldMul(Op, Op0.getOperand(0), Op1, ShAmt); // Op0 + (X * MulC) --> Op0 - (X << log2(-MulC)) if (unsigned ShAmt = getShiftLeftAmt(Op1)) - return foldMul(Op1.getOperand(0), Op0, ShAmt); - // TODO: + return foldMul(Op, Op1.getOperand(0), Op0, ShAmt); + } + if (Op.getOpcode() == ISD::SUB) { // Op0 - (X * MulC) --> Op0 + (X << log2(-MulC)) + if (unsigned ShAmt = getShiftLeftAmt(Op1)) { + return foldMul(Op, Op1.getOperand(0), Op0, ShAmt); + } } } diff --git a/llvm/test/CodeGen/RISCV/mul.ll b/llvm/test/CodeGen/RISCV/mul.ll --- a/llvm/test/CodeGen/RISCV/mul.ll +++ b/llvm/test/CodeGen/RISCV/mul.ll @@ -1580,3 +1580,37 @@ %r = and i8 %a, 15 ret i8 %r } + +define i8 @mulsub(i8 %x, i8 %y) nounwind { +; RV32I-LABEL: mulsub: +; RV32I: # %bb.0: +; RV32I-NEXT: slli a0, a0, 1 +; RV32I-NEXT: add a0, a1, a0 +; RV32I-NEXT: andi a0, a0, 15 +; RV32I-NEXT: ret +; +; RV32IM-LABEL: mulsub: +; RV32IM: # %bb.0: +; RV32IM-NEXT: slli a0, a0, 1 +; RV32IM-NEXT: add a0, a1, a0 +; RV32IM-NEXT: andi a0, a0, 15 +; RV32IM-NEXT: ret +; +; RV64I-LABEL: mulsub: +; RV64I: # %bb.0: +; RV64I-NEXT: slliw a0, a0, 1 +; RV64I-NEXT: addw a0, a1, a0 +; RV64I-NEXT: andi a0, a0, 15 +; RV64I-NEXT: ret +; +; RV64IM-LABEL: mulsub: +; RV64IM: # %bb.0: +; RV64IM-NEXT: slliw a0, a0, 1 +; RV64IM-NEXT: addw a0, a1, a0 +; RV64IM-NEXT: andi a0, a0, 15 +; RV64IM-NEXT: ret + %m = mul i8 %x, 14 + %a = sub i8 %y, %m + %r = and i8 %a, 15 + ret i8 %r +}