diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -533,6 +533,26 @@ } } +multiclass VPatMultiplyAddSDNode_VV_VX { + foreach vti = AllIntegerVectors in { + defvar suffix = vti.LMul.MX; + // NOTE: We choose VMADD because it has the most commuting freedom. So it + // works best with how TwoAddressInstructionPass tries commuting. + def : Pat<(vti.Vector (op vti.RegClass:$rs2, + (mul_oneuse vti.RegClass:$rs1, vti.RegClass:$rd))), + (!cast(instruction_name#"_VV_"# suffix) + vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, + vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>; + // The choice of VMADD here is arbitrary, vmadd.vx and vmacc.vx are equally + // commutable. + def : Pat<(vti.Vector (op vti.RegClass:$rs2, + (mul_oneuse (SplatPat XLenVT:$rs1), vti.RegClass:$rd))), + (!cast(instruction_name#"_VX_" # suffix) + vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, + vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>; + } +} + //===----------------------------------------------------------------------===// // Patterns. //===----------------------------------------------------------------------===// @@ -678,36 +698,8 @@ "PseudoVWMULSU">; // 12.13 Vector Single-Width Integer Multiply-Add Instructions. -foreach vti = AllIntegerVectors in { - // NOTE: We choose VMADD because it has the most commuting freedom. So it - // works best with how TwoAddressInstructionPass tries commuting. - defvar suffix = vti.LMul.MX; - def : Pat<(vti.Vector (add vti.RegClass:$rs2, - (mul_oneuse vti.RegClass:$rs1, vti.RegClass:$rd))), - (!cast("PseudoVMADD_VV_"# suffix) - vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, - vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(vti.Vector (sub vti.RegClass:$rs2, - (mul_oneuse vti.RegClass:$rs1, vti.RegClass:$rd))), - (!cast("PseudoVNMSUB_VV_"# suffix) - vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, - vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>; - - // The choice of VMADD here is arbitrary, vmadd.vx and vmacc.vx are equally - // commutable. - def : Pat<(vti.Vector (add vti.RegClass:$rs2, - (mul_oneuse (SplatPat XLenVT:$rs1), - vti.RegClass:$rd))), - (!cast("PseudoVMADD_VX_" # suffix) - vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, - vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(vti.Vector (sub vti.RegClass:$rs2, - (mul_oneuse (SplatPat XLenVT:$rs1), - vti.RegClass:$rd))), - (!cast("PseudoVNMSUB_VX_" # suffix) - vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, - vti.AVL, vti.Log2SEW, TAIL_AGNOSTIC)>; -} +defm : VPatMultiplyAddSDNode_VV_VX; +defm : VPatMultiplyAddSDNode_VV_VX; // 12.14 Vector Widening Integer Multiply-Add Instructions defm : VPatWidenMulAddSDNode_VV; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -852,6 +852,34 @@ } } +multiclass VPatMultiplyAddVL_VV_VX { + foreach vti = AllIntegerVectors in { + defvar suffix = vti.LMul.MX; + // NOTE: We choose VMADD because it has the most commuting freedom. So it + // works best with how TwoAddressInstructionPass tries commuting. + def : Pat<(vti.Vector + (op vti.RegClass:$rs2, + (riscv_mul_vl_oneuse vti.RegClass:$rs1, + vti.RegClass:$rd, + (vti.Mask true_mask), VLOpFrag), + (vti.Mask true_mask), VLOpFrag)), + (!cast(instruction_name#"_VV_"# suffix) + vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, + GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + // The choice of VMADD here is arbitrary, vmadd.vx and vmacc.vx are equally + // commutable. + def : Pat<(vti.Vector + (op vti.RegClass:$rs2, + (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), + vti.RegClass:$rd, + (vti.Mask true_mask), VLOpFrag), + (vti.Mask true_mask), VLOpFrag)), + (!cast(instruction_name#"_VX_" # suffix) + vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, + GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; + } +} + //===----------------------------------------------------------------------===// // Patterns. //===----------------------------------------------------------------------===// @@ -1031,51 +1059,8 @@ defm : VPatBinaryWVL_VV_VX; // 12.13 Vector Single-Width Integer Multiply-Add Instructions -foreach vti = AllIntegerVectors in { - // NOTE: We choose VMADD because it has the most commuting freedom. So it - // works best with how TwoAddressInstructionPass tries commuting. - defvar suffix = vti.LMul.MX; - def : Pat<(vti.Vector - (riscv_add_vl vti.RegClass:$rs2, - (riscv_mul_vl_oneuse vti.RegClass:$rs1, - vti.RegClass:$rd, - (vti.Mask true_mask), VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), - (!cast("PseudoVMADD_VV_"# suffix) - vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(vti.Vector - (riscv_sub_vl vti.RegClass:$rs2, - (riscv_mul_vl_oneuse vti.RegClass:$rs1, - vti.RegClass:$rd, - (vti.Mask true_mask), VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), - (!cast("PseudoVNMSUB_VV_"# suffix) - vti.RegClass:$rd, vti.RegClass:$rs1, vti.RegClass:$rs2, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; - - // The choice of VMADD here is arbitrary, vmadd.vx and vmacc.vx are equally - // commutable. - def : Pat<(vti.Vector - (riscv_add_vl vti.RegClass:$rs2, - (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), - vti.RegClass:$rd, - (vti.Mask true_mask), VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), - (!cast("PseudoVMADD_VX_" # suffix) - vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; - def : Pat<(vti.Vector - (riscv_sub_vl vti.RegClass:$rs2, - (riscv_mul_vl_oneuse (SplatPat XLenVT:$rs1), - vti.RegClass:$rd, - (vti.Mask true_mask), - VLOpFrag), - (vti.Mask true_mask), VLOpFrag)), - (!cast("PseudoVNMSUB_VX_" # suffix) - vti.RegClass:$rd, vti.ScalarRegClass:$rs1, vti.RegClass:$rs2, - GPR:$vl, vti.Log2SEW, TAIL_AGNOSTIC)>; -} +defm : VPatMultiplyAddVL_VV_VX; +defm : VPatMultiplyAddVL_VV_VX; // 12.14. Vector Widening Integer Multiply-Add Instructions foreach vtiTowti = AllWidenableIntVectors in {