diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -2328,18 +2328,25 @@ multiclass VPseudoVSQR_V { foreach m = MxListF in { defvar mx = m.MX; - defvar WriteVFSqrtV_MX = !cast("WriteVFSqrtV_" # mx); - defvar ReadVFSqrtV_MX = !cast("ReadVFSqrtV_" # mx); + defvar sews = SchedSEWSet.val; - let VLMul = m.value in { - def "_V_" # mx : VPseudoUnaryNoMask, - Sched<[WriteVFSqrtV_MX, ReadVFSqrtV_MX, ReadVMask]>; - def "_V_" # mx # "_TU": VPseudoUnaryNoMaskTU, - Sched<[WriteVFSqrtV_MX, ReadVFSqrtV_MX, ReadVMask]>; - def "_V_" # mx # "_MASK" : VPseudoUnaryMaskTA, - RISCVMaskedPseudo, - Sched<[WriteVFSqrtV_MX, ReadVFSqrtV_MX, ReadVMask]>; - } + let VLMul = m.value in + foreach e = sews in { + defvar suffix = "_" # mx # "_E" # e; + defvar WriteVFSqrtV_MX_E = !cast("WriteVFSqrtV" # suffix); + defvar ReadVFSqrtV_MX_E = !cast("ReadVFSqrtV" # suffix); + + def "_V" # suffix : VPseudoUnaryNoMask, + Sched<[WriteVFSqrtV_MX_E, ReadVFSqrtV_MX_E, + ReadVMask]>; + def "_V" # suffix # "_TU": VPseudoUnaryNoMaskTU, + Sched<[WriteVFSqrtV_MX_E, ReadVFSqrtV_MX_E, + ReadVMask]>; + def "_V" # suffix # "_MASK" : VPseudoUnaryMaskTA, + RISCVMaskedPseudo, + Sched<[WriteVFSqrtV_MX_E, ReadVFSqrtV_MX_E, + ReadVMask]>; + } } } @@ -3835,6 +3842,23 @@ (op2_type op2_reg_class:$rs2), GPR:$vl, sew)>; +class VPatUnaryNoMask_E : + Pat<(result_type (!cast(intrinsic_name) + (result_type undef), + (op2_type op2_reg_class:$rs2), + VLOpFrag)), + (!cast(inst#"_"#kind#"_"#vlmul.MX#"_E"#sew) + (op2_type op2_reg_class:$rs2), + GPR:$vl, log2sew)>; + class VPatUnaryNoMaskTU; +class VPatUnaryNoMaskTU_E : + Pat<(result_type (!cast(intrinsic_name) + (result_type result_reg_class:$merge), + (op2_type op2_reg_class:$rs2), + VLOpFrag)), + (!cast(inst#"_"#kind#"_"#vlmul.MX#"_E"#sew#"_TU") + (result_type result_reg_class:$merge), + (op2_type op2_reg_class:$rs2), + GPR:$vl, log2sew)>; + class VPatUnaryMask; +class VPatUnaryMaskTA_E : + Pat<(result_type (!cast(intrinsic_name#"_mask") + (result_type result_reg_class:$merge), + (op2_type op2_reg_class:$rs2), + (mask_type V0), + VLOpFrag, (XLenVT timm:$policy))), + (!cast(inst#"_"#kind#"_"#vlmul.MX#"_E"#sew#"_MASK") + (result_type result_reg_class:$merge), + (op2_type op2_reg_class:$rs2), + (mask_type V0), GPR:$vl, log2sew, (XLenVT timm:$policy))>; + class VPatMaskUnaryNoMask : @@ -4336,6 +4400,23 @@ } } +multiclass VPatUnaryV_V_E vtilist> { + foreach vti = vtilist in { + def : VPatUnaryNoMask_E; + def : VPatUnaryNoMaskTU_E; + def : VPatUnaryMaskTA_E; + } +} + multiclass VPatNullaryV { foreach vti = AllIntegerVectors in { @@ -6292,7 +6373,7 @@ //===----------------------------------------------------------------------===// // 13.8. Vector Floating-Point Square-Root Instruction //===----------------------------------------------------------------------===// -defm : VPatUnaryV_V<"int_riscv_vfsqrt", "PseudoVFSQRT", AllFloatVectors>; +defm : VPatUnaryV_V_E<"int_riscv_vfsqrt", "PseudoVFSQRT", AllFloatVectors>; //===----------------------------------------------------------------------===// // 13.9. Vector Floating-Point Reciprocal Square-Root Estimate Instruction diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -1003,7 +1003,7 @@ foreach vti = AllFloatVectors in { // 13.8. Vector Floating-Point Square-Root Instruction def : Pat<(fsqrt (vti.Vector vti.RegClass:$rs2)), - (!cast("PseudoVFSQRT_V_"# vti.LMul.MX) + (!cast("PseudoVFSQRT_V_"# vti.LMul.MX#"_E"#vti.SEW) vti.RegClass:$rs2, vti.AVL, vti.Log2SEW)>; // 13.12. Vector Floating-Point Sign-Injection Instructions diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -1781,7 +1781,7 @@ // 13.8. Vector Floating-Point Square-Root Instruction def : Pat<(riscv_fsqrt_vl (vti.Vector vti.RegClass:$rs2), (vti.Mask V0), VLOpFrag), - (!cast("PseudoVFSQRT_V_"# vti.LMul.MX #"_MASK") + (!cast("PseudoVFSQRT_V_"# vti.LMul.MX # "_E" # vti.SEW # "_MASK") (vti.Vector (IMPLICIT_DEF)), vti.RegClass:$rs2, (vti.Mask V0), GPR:$vl, vti.Log2SEW, TA_MA)>; diff --git a/llvm/lib/Target/RISCV/RISCVScheduleV.td b/llvm/lib/Target/RISCV/RISCVScheduleV.td --- a/llvm/lib/Target/RISCV/RISCVScheduleV.td +++ b/llvm/lib/Target/RISCV/RISCVScheduleV.td @@ -309,7 +309,7 @@ defm "" : LMULSchedWritesFW<"WriteVFWMulAddV">; defm "" : LMULSchedWritesFW<"WriteVFWMulAddF">; // 13.8. Vector Floating-Point Square-Root Instruction -defm "" : LMULSchedWrites<"WriteVFSqrtV">; +defm "" : LMULSEWSchedWrites<"WriteVFSqrtV">; // 13.9. Vector Floating-Point Reciprocal Square-Root Estimate Instruction // 13.10. Vector Floating-Point Reciprocal Estimate Instruction defm "" : LMULSchedWrites<"WriteVFRecpV">; @@ -528,7 +528,7 @@ defm "" : LMULSchedReadsFW<"ReadVFWMulAddV">; defm "" : LMULSchedReadsFW<"ReadVFWMulAddF">; // 13.8. Vector Floating-Point Square-Root Instruction -defm "" : LMULSchedReads<"ReadVFSqrtV">; +defm "" : LMULSEWSchedReads<"ReadVFSqrtV">; // 13.9. Vector Floating-Point Reciprocal Square-Root Estimate Instruction // 13.10. Vector Floating-Point Reciprocal Estimate Instruction defm "" : LMULSchedReads<"ReadVFRecpV">; @@ -757,7 +757,7 @@ defm "" : LMULWriteRes<"WriteVFMulAddF", []>; defm "" : LMULWriteResFW<"WriteVFWMulAddV", []>; defm "" : LMULWriteResFW<"WriteVFWMulAddF", []>; -defm "" : LMULWriteRes<"WriteVFSqrtV", []>; +defm "" : LMULSEWWriteRes<"WriteVFSqrtV", []>; defm "" : LMULWriteRes<"WriteVFRecpV", []>; defm "" : LMULWriteRes<"WriteVFCmpV", []>; defm "" : LMULWriteRes<"WriteVFCmpF", []>; @@ -907,7 +907,7 @@ defm "" : LMULReadAdvance<"ReadVFMulAddF", 0>; defm "" : LMULReadAdvanceFW<"ReadVFWMulAddV", 0>; defm "" : LMULReadAdvanceFW<"ReadVFWMulAddF", 0>; -defm "" : LMULReadAdvance<"ReadVFSqrtV", 0>; +defm "" : LMULSEWReadAdvance<"ReadVFSqrtV", 0>; defm "" : LMULReadAdvance<"ReadVFRecpV", 0>; defm "" : LMULReadAdvance<"ReadVFCmpV", 0>; defm "" : LMULReadAdvance<"ReadVFCmpF", 0>;