diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1011,7 +1011,7 @@ setJumpIsExpensive(); setTargetDAGCombine({ISD::INTRINSIC_WO_CHAIN, ISD::ADD, ISD::SUB, ISD::AND, - ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT}); + ISD::OR, ISD::XOR, ISD::SETCC, ISD::SELECT, ISD::MUL}); if (Subtarget.is64Bit()) setTargetDAGCombine(ISD::SRA); @@ -8569,6 +8569,134 @@ return DAG.getNode(ISD::XOR, DL, VT, Logic, DAG.getConstant(1, DL, VT)); } +static SDValue performMULCombine(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + SDLoc DL(N); + const MVT XLenVT = Subtarget.getXLenVT(); + const EVT VT = N->getValueType(0); + + // An MUL is usually smaller than any alternative sequence for legal type. + const TargetLowering &TLI = DAG.getTargetLoweringInfo(); + if (DAG.getMachineFunction().getFunction().hasMinSize() && + TLI.isOperationLegal(ISD::MUL, VT)) + return SDValue(); + + SDValue N0 = N->getOperand(0); + SDValue N1 = N->getOperand(1); + ConstantSDNode *ConstOp = dyn_cast(N1); + // Any optimization requires a constant RHS. + if (!ConstOp) + return SDValue(); + + const APInt &C = ConstOp->getAPIntValue(); + // A multiply-by-pow2 will be reduced to a shift by the + // architecture-independent code. + if (C.isPowerOf2()) + return SDValue(); + + // The below optimizations only work for non-negative constants + if (!C.isNonNegative()) + return SDValue(); + + auto Shl = [&](SDValue Value, unsigned ShiftAmount) { + if (!ShiftAmount) + return Value; + + SDValue ShiftAmountConst = DAG.getConstant(ShiftAmount, DL, XLenVT); + return DAG.getNode(ISD::SHL, DL, Value.getValueType(), Value, + ShiftAmountConst); + }; + auto Add = [&](SDValue Addend1, SDValue Addend2) { + return DAG.getNode(ISD::ADD, DL, Addend1.getValueType(), Addend1, Addend2); + }; + + if (Subtarget.hasVendorXTHeadBa()) { + // We try to simplify using shift-and-add instructions into up to + // 3 instructions (e.g. 2x shift-and-add and 1x shift). + + auto isDivisibleByShiftedAddConst = [&](APInt C, APInt &N, + APInt &Quotient) { + unsigned BitWidth = C.getBitWidth(); + for (unsigned i = 3; i >= 1; --i) { + APInt X(BitWidth, (1 << i) + 1); + APInt Remainder; + APInt::sdivrem(C, X, Quotient, Remainder); + if (Remainder == 0) { + N = X; + return true; + } + } + return false; + }; + auto isShiftedAddConst = [&](APInt C, APInt &N) { + APInt Quotient; + return isDivisibleByShiftedAddConst(C, N, Quotient) && Quotient == 1; + }; + auto isSmallShiftAmount = [](APInt C) { + return (C == 2) || (C == 4) || (C == 8); + }; + + auto ShiftAndAdd = [&](SDValue Value, unsigned ShiftAmount, + SDValue Addend) { + return Add(Shl(Value, ShiftAmount), Addend); + }; + auto AnyExt = [&](SDValue Value) { + return DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Value); + }; + auto Trunc = [&](SDValue Value) { + return DAG.getNode(ISD::TRUNCATE, DL, VT, Value); + }; + + unsigned TrailingZeroes = C.countTrailingZeros(); + const APInt ShiftedC = C.ashr(TrailingZeroes); + const APInt ShiftedCMinusOne = ShiftedC - 1; + + // the below comments use the following notation: + // n, m .. a shift-amount for a shift-and-add instruction + // (i.e. in { 2, 4, 8 }) + // k .. a power-of-2 that is equivalent to shifting by + // TrailingZeroes bits + // i, j .. a power-of-2 + + APInt ShiftAmt1; + APInt ShiftAmt2; + APInt Quotient; + + // C = (m + 1) * k + if (isShiftedAddConst(ShiftedC, ShiftAmt1)) { + SDValue Op0 = AnyExt(N0); + SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0); + return Trunc(Shl(Result, TrailingZeroes)); + } + // C = (m + 1) * (n + 1) * k + if (isDivisibleByShiftedAddConst(ShiftedC, ShiftAmt1, Quotient) && + isShiftedAddConst(Quotient, ShiftAmt2)) { + SDValue Op0 = AnyExt(N0); + SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0); + Result = ShiftAndAdd(Result, ShiftAmt2.logBase2(), Result); + return Trunc(Shl(Result, TrailingZeroes)); + } + // C = ((m + 1) * n + 1) * k + if (isDivisibleByShiftedAddConst(ShiftedCMinusOne, ShiftAmt1, ShiftAmt2) && + isSmallShiftAmount(ShiftAmt2)) { + SDValue Op0 = AnyExt(N0); + SDValue Result = ShiftAndAdd(Op0, ShiftAmt1.logBase2(), Op0); + Result = ShiftAndAdd(Result, Quotient.logBase2(), Op0); + return Trunc(Shl(Result, TrailingZeroes)); + } + + // C has 2 bits set: synthesize using 2 shifts and 1 add (which may + // see one of the shifts merged into a shift-and-add, if feasible) + if (C.countPopulation() == 2) { + APInt HighBit(C.getBitWidth(), (1 << C.logBase2())); + APInt LowBit = C - HighBit; + return Add(Shl(N0, HighBit.logBase2()), Shl(N0, LowBit.logBase2())); + } + } + + return SDValue(); +} + static SDValue performTRUNCATECombine(SDNode *N, SelectionDAG &DAG, const RISCVSubtarget &Subtarget) { SDValue N0 = N->getOperand(0); @@ -10218,6 +10346,8 @@ return performADDCombine(N, DAG, Subtarget); case ISD::SUB: return performSUBCombine(N, DAG, Subtarget); + case ISD::MUL: + return performMULCombine(N, DAG, Subtarget); case ISD::AND: return performANDCombine(N, DCI, Subtarget); case ISD::OR: diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoXTHead.td @@ -161,67 +161,6 @@ (TH_ADDSL GPR:$rs2, sh2add_op:$rs1, 2)>; def : Pat<(add sh3add_op:$rs1, non_imm12:$rs2), (TH_ADDSL GPR:$rs2, sh3add_op:$rs1, 3)>; - -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 6)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 1)>; -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 10)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 1)>; -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 18)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 1)>; -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 12)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 2)>; -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 20)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 2)>; -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 36)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 2)>; -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 24)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 1), 3)>; -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 40)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 2), 3)>; -def : Pat<(add (mul_oneuse GPR:$rs1, (XLenVT 72)), GPR:$rs2), - (TH_ADDSL GPR:$rs2, (TH_ADDSL GPR:$rs1, GPR:$rs1, 3), 3)>; - -def : Pat<(add GPR:$r, CSImm12MulBy4:$i), - (TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy2XForm CSImm12MulBy4:$i)), 2)>; -def : Pat<(add GPR:$r, CSImm12MulBy8:$i), - (TH_ADDSL GPR:$r, (ADDI X0, (SimmShiftRightBy3XForm CSImm12MulBy8:$i)), 3)>; - -def : Pat<(mul GPR:$r, C3LeftShift:$i), - (SLLI (TH_ADDSL GPR:$r, GPR:$r, 1), - (TrailingZeros C3LeftShift:$i))>; -def : Pat<(mul GPR:$r, C5LeftShift:$i), - (SLLI (TH_ADDSL GPR:$r, GPR:$r, 2), - (TrailingZeros C5LeftShift:$i))>; -def : Pat<(mul GPR:$r, C9LeftShift:$i), - (SLLI (TH_ADDSL GPR:$r, GPR:$r, 3), - (TrailingZeros C9LeftShift:$i))>; - -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 11)), - (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 1)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 19)), - (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 13)), - (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 1), 2)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 21)), - (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 37)), - (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 25)), - (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2), (TH_ADDSL GPR:$r, GPR:$r, 2), 2)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 41)), - (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 2), 3)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 73)), - (TH_ADDSL GPR:$r, (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 27)), - (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 1)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 45)), - (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 2)>; -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 81)), - (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 3), (TH_ADDSL GPR:$r, GPR:$r, 3), 3)>; - -def : Pat<(mul_const_oneuse GPR:$r, (XLenVT 200)), - (SLLI (TH_ADDSL (TH_ADDSL GPR:$r, GPR:$r, 2), - (TH_ADDSL GPR:$r, GPR:$r, 2), 2), 3)>; } // Predicates = [HasVendorXTHeadBa] defm PseudoTHVdotVMAQA : VPseudoVMAQA_VV_VX;