diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -294,6 +294,12 @@ VFWADD_W_VL, VFWSUB_W_VL, + // Widening ternary operations with a mask as the fourth operand and VL as the + // fifth operand. + VWMACC_VL, + VWMACCU_VL, + VWMACCSU_VL, + // Narrowing logical shift right. // Operands are (source, shift, passthru, mask, vl) VNSRL_VL, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -12136,6 +12136,63 @@ return convertFromScalableVector(VT, Res, DAG, Subtarget); } +static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG, + const RISCVSubtarget &Subtarget) { + assert(N->getOpcode() == RISCVISD::ADD_VL); + SDValue Addend = N->getOperand(0); + SDValue MulOp = N->getOperand(1); + SDValue AddMergeOp = N->getOperand(2); + + if (!AddMergeOp.isUndef()) + return SDValue(); + + auto IsVWMulOpc = [](unsigned Opc) { + switch (Opc) { + case RISCVISD::VWMUL_VL: + case RISCVISD::VWMULU_VL: + case RISCVISD::VWMULSU_VL: + return true; + default: + return false; + } + }; + + if (!IsVWMulOpc(MulOp.getOpcode())) + std::swap(Addend, MulOp); + + if (!IsVWMulOpc(MulOp.getOpcode())) + return SDValue(); + + SDValue MulMergeOp = MulOp.getOperand(2); + + if (!MulMergeOp.isUndef()) + return SDValue(); + + SDValue AddMask = N->getOperand(3); + SDValue AddVL = N->getOperand(4); + SDValue MulMask = MulOp.getOperand(3); + SDValue MulVL = MulOp.getOperand(4); + + if (AddMask != MulMask || AddVL != MulVL) + return SDValue(); + + unsigned Opc = RISCVISD::VWMACC_VL + MulOp.getOpcode() - RISCVISD::VWMUL_VL; + static_assert(RISCVISD::VWMACC_VL + 1 == RISCVISD::VWMACCU_VL, + "Unexpected opcode after VWMACC_VL"); + static_assert(RISCVISD::VWMACC_VL + 2 == RISCVISD::VWMACCSU_VL, + "Unexpected opcode after VWMACC_VL!"); + static_assert(RISCVISD::VWMUL_VL + 1 == RISCVISD::VWMULU_VL, + "Unexpected opcode after VWMUL_VL!"); + static_assert(RISCVISD::VWMUL_VL + 2 == RISCVISD::VWMULSU_VL, + "Unexpected opcode after VWMUL_VL!"); + + SDLoc DL(N); + EVT VT = N->getValueType(0); + SDValue Ops[] = {MulOp.getOperand(0), MulOp.getOperand(1), Addend, AddMask, + AddVL}; + return DAG.getNode(Opc, DL, VT, Ops); +} + SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const { SelectionDAG &DAG = DCI.DAG; @@ -12546,6 +12603,9 @@ break; } case RISCVISD::ADD_VL: + if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI)) + return V; + return combineToVWMACC(N, DAG, Subtarget); case RISCVISD::SUB_VL: case RISCVISD::VWADD_W_VL: case RISCVISD::VWADDU_W_VL: @@ -15683,6 +15743,9 @@ NODE_NAME_CASE(VFWSUB_VL) NODE_NAME_CASE(VFWADD_W_VL) NODE_NAME_CASE(VFWSUB_W_VL) + NODE_NAME_CASE(VWMACC_VL) + NODE_NAME_CASE(VWMACCU_VL) + NODE_NAME_CASE(VWMACCSU_VL) NODE_NAME_CASE(VNSRL_VL) NODE_NAME_CASE(SETCC_VL) NODE_NAME_CASE(VSELECT_VL) diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -395,6 +395,19 @@ def riscv_vwsub_vl : SDNode<"RISCVISD::VWSUB_VL", SDT_RISCVVWIntBinOp_VL, []>; def riscv_vwsubu_vl : SDNode<"RISCVISD::VWSUBU_VL", SDT_RISCVVWIntBinOp_VL, []>; +def SDT_RISCVVWIntTernOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisInt<0>, + SDTCisInt<1>, + SDTCisSameNumEltsAs<0, 1>, + SDTCisOpSmallerThanOp<1, 0>, + SDTCisSameAs<1, 2>, + SDTCisSameAs<0, 3>, + SDTCisSameNumEltsAs<1, 4>, + SDTCVecEltisVT<4, i1>, + SDTCisVT<5, XLenVT>]>; +def riscv_vwmacc_vl : SDNode<"RISCVISD::VWMACC_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>; +def riscv_vwmaccu_vl : SDNode<"RISCVISD::VWMACCU_VL", SDT_RISCVVWIntTernOp_VL, [SDNPCommutative]>; +def riscv_vwmaccsu_vl : SDNode<"RISCVISD::VWMACCSU_VL", SDT_RISCVVWIntTernOp_VL, []>; + def SDT_RISCVVWFPBinOp_VL : SDTypeProfile<1, 5, [SDTCisVec<0>, SDTCisFP<0>, SDTCisFP<1>, SDTCisSameNumEltsAs<0, 1>, @@ -1407,30 +1420,27 @@ } } -multiclass VPatWidenMultiplyAddVL_VV_VX { +multiclass VPatWidenMultiplyAddVL_VV_VX { foreach vtiTowti = AllWidenableIntVectors in { defvar vti = vtiTowti.Vti; defvar wti = vtiTowti.Wti; let Predicates = !listconcat(GetVTypePredicates.Predicates, GetVTypePredicates.Predicates) in { - def : Pat<(wti.Vector - (riscv_add_vl wti.RegClass:$rd, - (op1 vti.RegClass:$rs1, - (vti.Vector vti.RegClass:$rs2), - srcvalue, (vti.Mask true_mask), VLOpFrag), - srcvalue, (vti.Mask true_mask), VLOpFrag)), - (!cast(instruction_name#"_VV_" # vti.LMul.MX) - wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(wti.Vector - (riscv_add_vl wti.RegClass:$rd, - (op1 (SplatPat XLenVT:$rs1), - (vti.Vector vti.RegClass:$rs2), - srcvalue, (vti.Mask true_mask), VLOpFrag), - srcvalue, (vti.Mask true_mask), VLOpFrag)), - (!cast(instruction_name#"_VX_" # vti.LMul.MX) - wti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + def : Pat<(vwmacc_op (vti.Vector vti.RegClass:$rs1), + (vti.Vector vti.RegClass:$rs2), + (wti.Vector wti.RegClass:$rd), + (vti.Mask V0), VLOpFrag), + (!cast(instr_name#"_VV_"#vti.LMul.MX#"_MASK") + wti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, + (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + def : Pat<(vwmacc_op (SplatPat XLenVT:$rs1), + (vti.Vector vti.RegClass:$rs2), + (wti.Vector wti.RegClass:$rd), + (vti.Mask V0), VLOpFrag), + (!cast(instr_name#"_VX_"#vti.LMul.MX#"_MASK") + wti.RegClass:$rd, vti.ScalarRegClass:$rs1, + vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, + TAIL_AGNOSTIC)>; } } } @@ -1704,25 +1714,21 @@ defm : VPatMultiplyAccVL_VV_VX; // 11.14. Vector Widening Integer Multiply-Add Instructions -defm : VPatWidenMultiplyAddVL_VV_VX; -defm : VPatWidenMultiplyAddVL_VV_VX; -defm : VPatWidenMultiplyAddVL_VV_VX; +defm : VPatWidenMultiplyAddVL_VV_VX; +defm : VPatWidenMultiplyAddVL_VV_VX; +defm : VPatWidenMultiplyAddVL_VV_VX; foreach vtiTowti = AllWidenableIntVectors in { defvar vti = vtiTowti.Vti; defvar wti = vtiTowti.Wti; let Predicates = !listconcat(GetVTypePredicates.Predicates, GetVTypePredicates.Predicates) in - def : Pat<(wti.Vector - (riscv_add_vl wti.RegClass:$rd, - (riscv_vwmulsu_vl_oneuse (vti.Vector vti.RegClass:$rs1), - (SplatPat XLenVT:$rs2), - srcvalue, - (vti.Mask true_mask), - VLOpFrag), - srcvalue, (vti.Mask true_mask),VLOpFrag)), - (!cast("PseudoVWMACCUS_VX_" # vti.LMul.MX) - wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + def : Pat<(riscv_vwmaccsu_vl (vti.Vector vti.RegClass:$rs1), + (SplatPat XLenVT:$rs2), + (wti.Vector wti.RegClass:$rd), + (vti.Mask V0), VLOpFrag), + (!cast("PseudoVWMACCUS_VX_"#vti.LMul.MX#"_MASK") + wti.RegClass:$rd, vti.ScalarRegClass:$rs2, vti.RegClass:$rs1, + (vti.Mask V0), GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; } // 11.15. Vector Integer Merge Instructions diff --git a/llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll b/llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll rename from llvm/test/CodeGen/RISCV/rvv/vwmaccus-vp.ll rename to llvm/test/CodeGen/RISCV/rvv/vwmaccsu-vp.ll