diff --git a/clang/include/clang/Basic/riscv_vector.td b/clang/include/clang/Basic/riscv_vector.td --- a/clang/include/clang/Basic/riscv_vector.td +++ b/clang/include/clang/Basic/riscv_vector.td @@ -226,6 +226,11 @@ [["vv", "v", "vvv"], ["vf", "v", "vve"]]>; +multiclass RVVFloatingBinBuiltinSetRoundingMode + : RVVOutOp1BuiltinSet; + multiclass RVVFloatingBinVFBuiltinSet : RVVOutOp1BuiltinSet; @@ -2206,10 +2211,71 @@ defm vnclipu : RVVUnsignedNShiftBuiltinSetRoundingMode; defm vnclip : RVVSignedNShiftBuiltinSetRoundingMode; } +} // 14. Vector Floating-Point Instructions +let HeaderCode = +[{ +enum __RISCV_FRM { + __RISCV_FRM_RNE = 0, + __RISCV_FRM_RTZ = 1, + __RISCV_FRM_RDN = 2, + __RISCV_FRM_RUP = 3, + __RISCV_FRM_RMM = 4, +}; +}] in def frm_enum : RVVHeader; + +let UnMaskedPolicyScheme = HasPassthruOperand in { // 14.2. Vector Single-Width Floating-Point Add/Subtract Instructions -defm vfadd : RVVFloatingBinBuiltinSet; +let ManualCodegen = [{ + { + // LLVM intrinsic + // Unmasked: (passthru, op0, op1, round_mode, vl) + // Masked: (passthru, vector_in, vector_in/scalar_in, mask, frm, vl, policy) + + SmallVector Operands; + bool HasMaskedOff = !( + (IsMasked && (PolicyAttrs & RVV_VTA) && (PolicyAttrs & RVV_VMA)) || + (!IsMasked && PolicyAttrs & RVV_VTA)); + bool HasRoundModeOp = IsMasked ? + (HasMaskedOff ? Ops.size() == 6 : Ops.size() == 5) : + (HasMaskedOff ? Ops.size() == 5 : Ops.size() == 4); + + unsigned Offset = IsMasked ? + (HasMaskedOff ? 2 : 1) : (HasMaskedOff ? 1 : 0); + + if (!HasMaskedOff) + Operands.push_back(llvm::PoisonValue::get(ResultType)); + else + Operands.push_back(Ops[IsMasked ? 1 : 0]); + + Operands.push_back(Ops[Offset]); // op0 + Operands.push_back(Ops[Offset + 1]); // op1 + + if (IsMasked) + Operands.push_back(Ops[0]); // mask + + if (HasRoundModeOp) { + Operands.push_back(Ops[Offset + 2]); // frm + Operands.push_back(Ops[Offset + 3]); // vl + } else { + Operands.push_back(ConstantInt::get(Ops[Offset + 2]->getType(), 99)); // frm + Operands.push_back(Ops[Offset + 2]); // vl + } + + if (IsMasked) + Operands.push_back(ConstantInt::get(Ops.back()->getType(), PolicyAttrs)); + + IntrinsicTypes = {ResultType, Ops[Offset + 1]->getType(), Ops.back()->getType()}; + llvm::Function *F = CGM.getIntrinsic(ID, IntrinsicTypes); + return Builder.CreateCall(F, Operands, ""); + } +}] in { + let HasFRMRoundModeOp = true in { + defm vfadd : RVVFloatingBinBuiltinSetRoundingMode; + } + defm vfadd : RVVFloatingBinBuiltinSet; +} defm vfsub : RVVFloatingBinBuiltinSet; defm vfrsub : RVVFloatingBinVFBuiltinSet; diff --git a/clang/include/clang/Basic/riscv_vector_common.td b/clang/include/clang/Basic/riscv_vector_common.td --- a/clang/include/clang/Basic/riscv_vector_common.td +++ b/clang/include/clang/Basic/riscv_vector_common.td @@ -234,6 +234,10 @@ // Set to true if the builtin is associated with tuple types. bit IsTuple = false; + + // Set to true if the builtin has a parameter that models floating-point + // rounding mode control + bit HasFRMRoundModeOp = false; } // This is the code emitted in the header. diff --git a/clang/include/clang/Support/RISCVVIntrinsicUtils.h b/clang/include/clang/Support/RISCVVIntrinsicUtils.h --- a/clang/include/clang/Support/RISCVVIntrinsicUtils.h +++ b/clang/include/clang/Support/RISCVVIntrinsicUtils.h @@ -381,6 +381,7 @@ std::vector IntrinsicTypes; unsigned NF = 1; Policy PolicyAttrs; + bool HasFRMRoundModeOp; public: RVVIntrinsic(llvm::StringRef Name, llvm::StringRef Suffix, @@ -391,7 +392,7 @@ const RVVTypes &Types, const std::vector &IntrinsicTypes, const std::vector &RequiredFeatures, - unsigned NF, Policy PolicyAttrs); + unsigned NF, Policy PolicyAttrs, bool HasFRMRoundModeOp); ~RVVIntrinsic() = default; RVVTypePtr getOutputType() const { return OutputType; } @@ -461,7 +462,7 @@ static void updateNamesAndPolicy(bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName, std::string &OverloadedName, - Policy &PolicyAttrs); + Policy &PolicyAttrs, bool HasFRMRoundModeOp); }; // RVVRequire should be sync'ed with target features, but only @@ -520,6 +521,7 @@ bool HasMaskedOffOperand : 1; bool HasTailPolicy : 1; bool HasMaskPolicy : 1; + bool HasFRMRoundModeOp : 1; bool IsTuple : 1; uint8_t UnMaskedPolicyScheme : 2; uint8_t MaskedPolicyScheme : 2; diff --git a/clang/lib/Sema/SemaRISCVVectorLookup.cpp b/clang/lib/Sema/SemaRISCVVectorLookup.cpp --- a/clang/lib/Sema/SemaRISCVVectorLookup.cpp +++ b/clang/lib/Sema/SemaRISCVVectorLookup.cpp @@ -349,7 +349,8 @@ std::string BuiltinName = "__builtin_rvv_" + std::string(Record.Name); RVVIntrinsic::updateNamesAndPolicy(IsMasked, HasPolicy, Name, BuiltinName, - OverloadedName, PolicyAttrs); + OverloadedName, PolicyAttrs, + Record.HasFRMRoundModeOp); // Put into IntrinsicList. size_t Index = IntrinsicList.size(); diff --git a/clang/lib/Support/RISCVVIntrinsicUtils.cpp b/clang/lib/Support/RISCVVIntrinsicUtils.cpp --- a/clang/lib/Support/RISCVVIntrinsicUtils.cpp +++ b/clang/lib/Support/RISCVVIntrinsicUtils.cpp @@ -870,20 +870,19 @@ //===----------------------------------------------------------------------===// // RVVIntrinsic implementation //===----------------------------------------------------------------------===// -RVVIntrinsic::RVVIntrinsic(StringRef NewName, StringRef Suffix, - StringRef NewOverloadedName, - StringRef OverloadedSuffix, StringRef IRName, - bool IsMasked, bool HasMaskedOffOperand, bool HasVL, - PolicyScheme Scheme, bool SupportOverloading, - bool HasBuiltinAlias, StringRef ManualCodegen, - const RVVTypes &OutInTypes, - const std::vector &NewIntrinsicTypes, - const std::vector &RequiredFeatures, - unsigned NF, Policy NewPolicyAttrs) +RVVIntrinsic::RVVIntrinsic( + StringRef NewName, StringRef Suffix, StringRef NewOverloadedName, + StringRef OverloadedSuffix, StringRef IRName, bool IsMasked, + bool HasMaskedOffOperand, bool HasVL, PolicyScheme Scheme, + bool SupportOverloading, bool HasBuiltinAlias, StringRef ManualCodegen, + const RVVTypes &OutInTypes, const std::vector &NewIntrinsicTypes, + const std::vector &RequiredFeatures, unsigned NF, + Policy NewPolicyAttrs, bool HasFRMRoundModeOp) : IRName(IRName), IsMasked(IsMasked), HasMaskedOffOperand(HasMaskedOffOperand), HasVL(HasVL), Scheme(Scheme), SupportOverloading(SupportOverloading), HasBuiltinAlias(HasBuiltinAlias), - ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs) { + ManualCodegen(ManualCodegen.str()), NF(NF), PolicyAttrs(NewPolicyAttrs), + HasFRMRoundModeOp(HasFRMRoundModeOp) { // Init BuiltinName, Name and OverloadedName BuiltinName = NewName.str(); @@ -898,7 +897,7 @@ OverloadedName += "_" + OverloadedSuffix.str(); updateNamesAndPolicy(IsMasked, hasPolicy(), Name, BuiltinName, OverloadedName, - PolicyAttrs); + PolicyAttrs, HasFRMRoundModeOp); // Init OutputType and InputTypes OutputType = OutInTypes[0]; @@ -1023,13 +1022,11 @@ "and mask policy"); } -void RVVIntrinsic::updateNamesAndPolicy(bool IsMasked, bool HasPolicy, - std::string &Name, - std::string &BuiltinName, - std::string &OverloadedName, - Policy &PolicyAttrs) { +void RVVIntrinsic::updateNamesAndPolicy( + bool IsMasked, bool HasPolicy, std::string &Name, std::string &BuiltinName, + std::string &OverloadedName, Policy &PolicyAttrs, bool HasFRMRoundModeOp) { - auto appendPolicySuffix = [&](const std::string &suffix) { + auto appendSuffix = [&](const std::string &suffix) { Name += suffix; BuiltinName += suffix; OverloadedName += suffix; @@ -1042,11 +1039,11 @@ if (IsMasked) { if (PolicyAttrs.isTUMUPolicy()) - appendPolicySuffix("_tumu"); + appendSuffix("_tumu"); else if (PolicyAttrs.isTUMAPolicy()) - appendPolicySuffix("_tum"); + appendSuffix("_tum"); else if (PolicyAttrs.isTAMUPolicy()) - appendPolicySuffix("_mu"); + appendSuffix("_mu"); else if (PolicyAttrs.isTAMAPolicy()) { Name += "_m"; if (HasPolicy) @@ -1057,13 +1054,16 @@ llvm_unreachable("Unhandled policy condition"); } else { if (PolicyAttrs.isTUPolicy()) - appendPolicySuffix("_tu"); + appendSuffix("_tu"); else if (PolicyAttrs.isTAPolicy()) { if (HasPolicy) BuiltinName += "_ta"; } else llvm_unreachable("Unhandled policy condition"); } + + if (HasFRMRoundModeOp) + appendSuffix("_rm"); } SmallVector parsePrototypes(StringRef Prototypes) { @@ -1110,6 +1110,7 @@ OS << (int)Record.HasMaskedOffOperand << ","; OS << (int)Record.HasTailPolicy << ","; OS << (int)Record.HasMaskPolicy << ","; + OS << (int)Record.HasFRMRoundModeOp << ","; OS << (int)Record.IsTuple << ","; OS << (int)Record.UnMaskedPolicyScheme << ","; OS << (int)Record.MaskedPolicyScheme << ","; diff --git a/clang/utils/TableGen/RISCVVEmitter.cpp b/clang/utils/TableGen/RISCVVEmitter.cpp --- a/clang/utils/TableGen/RISCVVEmitter.cpp +++ b/clang/utils/TableGen/RISCVVEmitter.cpp @@ -65,6 +65,7 @@ bool HasMaskedOffOperand :1; bool HasTailPolicy : 1; bool HasMaskPolicy : 1; + bool HasFRMRoundModeOp : 1; bool IsTuple : 1; uint8_t UnMaskedPolicyScheme : 2; uint8_t MaskedPolicyScheme : 2; @@ -512,6 +513,7 @@ StringRef MaskedIRName = R->getValueAsString("MaskedIRName"); unsigned NF = R->getValueAsInt("NF"); bool IsTuple = R->getValueAsBit("IsTuple"); + bool HasFRMRoundModeOp = R->getValueAsBit("HasFRMRoundModeOp"); const Policy DefaultPolicy; SmallVector SupportedUnMaskedPolicies = @@ -559,7 +561,7 @@ /*IsMasked=*/false, /*HasMaskedOffOperand=*/false, HasVL, UnMaskedPolicyScheme, SupportOverloading, HasBuiltinAlias, ManualCodegen, *Types, IntrinsicTypes, RequiredFeatures, NF, - DefaultPolicy)); + DefaultPolicy, HasFRMRoundModeOp)); if (UnMaskedPolicyScheme != PolicyScheme::SchemeNone) for (auto P : SupportedUnMaskedPolicies) { SmallVector PolicyPrototype = @@ -574,7 +576,7 @@ /*IsMask=*/false, /*HasMaskedOffOperand=*/false, HasVL, UnMaskedPolicyScheme, SupportOverloading, HasBuiltinAlias, ManualCodegen, *PolicyTypes, IntrinsicTypes, RequiredFeatures, - NF, P)); + NF, P, HasFRMRoundModeOp)); } if (!HasMasked) continue; @@ -585,7 +587,8 @@ Name, SuffixStr, OverloadedName, OverloadedSuffixStr, MaskedIRName, /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicyScheme, SupportOverloading, HasBuiltinAlias, ManualCodegen, *MaskTypes, - IntrinsicTypes, RequiredFeatures, NF, DefaultPolicy)); + IntrinsicTypes, RequiredFeatures, NF, DefaultPolicy, + HasFRMRoundModeOp)); if (MaskedPolicyScheme == PolicyScheme::SchemeNone) continue; for (auto P : SupportedMaskedPolicies) { @@ -600,7 +603,7 @@ MaskedIRName, /*IsMasked=*/true, HasMaskedOffOperand, HasVL, MaskedPolicyScheme, SupportOverloading, HasBuiltinAlias, ManualCodegen, *PolicyTypes, IntrinsicTypes, RequiredFeatures, NF, - P)); + P, HasFRMRoundModeOp)); } } // End for Log2LMULList } // End for TypeRange @@ -653,6 +656,7 @@ SR.Suffix = parsePrototypes(SuffixProto); SR.OverloadedSuffix = parsePrototypes(OverloadedSuffixProto); SR.IsTuple = IsTuple; + SR.HasFRMRoundModeOp = HasFRMRoundModeOp; SemaRecords->push_back(SR); } @@ -695,6 +699,7 @@ R.UnMaskedPolicyScheme = SR.UnMaskedPolicyScheme; R.MaskedPolicyScheme = SR.MaskedPolicyScheme; R.IsTuple = SR.IsTuple; + R.HasFRMRoundModeOp = SR.HasFRMRoundModeOp; assert(R.PrototypeIndex != static_cast(SemaSignatureTable::INVALID_INDEX)); diff --git a/llvm/include/llvm/IR/IntrinsicsRISCV.td b/llvm/include/llvm/IR/IntrinsicsRISCV.td --- a/llvm/include/llvm/IR/IntrinsicsRISCV.td +++ b/llvm/include/llvm/IR/IntrinsicsRISCV.td @@ -420,6 +420,27 @@ let ScalarOperand = 2; let VLOperand = 4; } + // For destination vector type is the same as first source vector. + // Input: (passthru, vector_in, vector_in/scalar_in, frm, vl) + class RISCVBinaryAAXUnMaskedRoundingMode + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, llvm_any_ty, + llvm_anyint_ty, LLVMMatchType<2>], + [ImmArg>, IntrNoMem]>, RISCVVIntrinsic { + let ScalarOperand = 2; + let VLOperand = 4; + } + // For destination vector type is the same as first source vector (with mask). + // Input: (maskedoff, vector_in, vector_in/scalar_in, mask, frm, vl, policy) + class RISCVBinaryAAXMaskedRoundingMode + : DefaultAttrsIntrinsic<[llvm_anyvector_ty], + [LLVMMatchType<0>, LLVMMatchType<0>, llvm_any_ty, + LLVMScalarOrSameVectorWidth<0, llvm_i1_ty>, llvm_anyint_ty, + LLVMMatchType<2>, LLVMMatchType<2>], + [ImmArg>, ImmArg>, IntrNoMem]>, RISCVVIntrinsic { + let ScalarOperand = 2; + let VLOperand = 5; + } // For destination vector type is the same as first source vector. The // second source operand must match the destination type or be an XLen scalar. // Input: (passthru, vector_in, vector_in/scalar_in, vl) @@ -1084,6 +1105,10 @@ def "int_riscv_" # NAME : RISCVBinaryAAXUnMasked; def "int_riscv_" # NAME # "_mask" : RISCVBinaryAAXMasked; } + multiclass RISCVBinaryAAXRoundingMode { + def "int_riscv_" # NAME : RISCVBinaryAAXUnMaskedRoundingMode; + def "int_riscv_" # NAME # "_mask" : RISCVBinaryAAXMaskedRoundingMode; + } // Like RISCVBinaryAAX, but the second operand is used a shift amount so it // must be a vector or an XLen scalar. multiclass RISCVBinaryAAShift { @@ -1292,7 +1317,7 @@ defm vwmaccus : RISCVTernaryWide; defm vwmaccsu : RISCVTernaryWide; - defm vfadd : RISCVBinaryAAX; + defm vfadd : RISCVBinaryAAXRoundingMode; defm vfsub : RISCVBinaryAAX; defm vfrsub : RISCVBinaryAAX; diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h @@ -115,6 +115,9 @@ HasRoundModeOpShift = IsSignExtendingOpWShift + 1, HasRoundModeOpMask = 1 << HasRoundModeOpShift, + + IsRVVFixedPointShift = HasRoundModeOpShift + 1, + IsRVVFixedPointMask = 1 << IsRVVFixedPointShift, }; enum VLMUL : uint8_t { @@ -181,6 +184,11 @@ return TSFlags & HasRoundModeOpMask; } +/// \returns true if this instruction is a RISC-V Vector fixed-point instruction +static inline bool isRVVFixedPoint(uint64_t TSFlags) { + return TSFlags & IsRVVFixedPointMask; +} + static inline unsigned getMergeOpNum(const MCInstrDesc &Desc) { assert(hasMergeOp(Desc.TSFlags)); assert(!Desc.isVariadic()); diff --git a/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp b/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp --- a/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp +++ b/llvm/lib/Target/RISCV/RISCVInsertReadWriteCSR.cpp @@ -13,6 +13,8 @@ // //===----------------------------------------------------------------------===// +#include "MCTargetDesc/RISCVBaseInfo.h" +#include "MCTargetDesc/RISCVMCTargetDesc.h" #include "RISCV.h" #include "RISCVSubtarget.h" #include "llvm/CodeGen/MachineFunctionPass.h" @@ -45,7 +47,7 @@ } private: - bool emitWriteVXRM(MachineBasicBlock &MBB); + bool emitWriteRoundingMode(MachineBasicBlock &MBB); std::optional getRoundModeIdx(const MachineInstr &MI); }; @@ -74,22 +76,38 @@ // This function inserts a write to vxrm when encountering an RVV fixed-point // instruction. -bool RISCVInsertReadWriteCSR::emitWriteVXRM(MachineBasicBlock &MBB) { +bool RISCVInsertReadWriteCSR::emitWriteRoundingMode(MachineBasicBlock &MBB) { bool Changed = false; for (MachineInstr &MI : MBB) { if (auto RoundModeIdx = getRoundModeIdx(MI)) { Changed = true; - - unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm(); - - // The value '4' is a hint to this pass to not alter the vxrm value. - if (VXRMImm == 4) - continue; - - BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm)) - .addImm(VXRMImm); - MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false, - /*IsImp*/ true)); + if (RISCVII::isRVVFixedPoint(MI.getDesc().TSFlags)) { + unsigned VXRMImm = MI.getOperand(*RoundModeIdx).getImm(); + + // The value '99' is a hint to this pass to not alter the vxrm value. + if (VXRMImm == 99) + continue; + + BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteVXRMImm)) + .addImm(VXRMImm); + MI.addOperand(MachineOperand::CreateReg(RISCV::VXRM, /*IsDef*/ false, + /*IsImp*/ true)); + // BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm)) + // .addImm(VXRMImm); + // MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, + // /*IsImp*/ true)); + } else { // FRM + unsigned FRMImm = MI.getOperand(*RoundModeIdx).getImm(); + + // The value '99' is a hint to this pass to not alter the frm value. + if (FRMImm == 99) + continue; + + BuildMI(MBB, MI, MI.getDebugLoc(), TII->get(RISCV::WriteFRMImm)) + .addImm(FRMImm); + MI.addOperand(MachineOperand::CreateReg(RISCV::FRM, /*IsDef*/ false, + /*IsImp*/ true)); + } } } return Changed; @@ -106,7 +124,7 @@ bool Changed = false; for (MachineBasicBlock &MBB : MF) - Changed |= emitWriteVXRM(MBB); + Changed |= emitWriteRoundingMode(MBB); return Changed; } diff --git a/llvm/lib/Target/RISCV/RISCVInstrFormats.td b/llvm/lib/Target/RISCV/RISCVInstrFormats.td --- a/llvm/lib/Target/RISCV/RISCVInstrFormats.td +++ b/llvm/lib/Target/RISCV/RISCVInstrFormats.td @@ -220,6 +220,14 @@ bit HasRoundModeOp = 0; let TSFlags{20} = HasRoundModeOp; + + // This is only valid when HasRoundModeOp is set to 1. HasRoundModeOp is set + // to 1 for vector fixed-point or floating-point intrinsics. This bit is + // processed under pass 'RISCVInsertReadWriteCSR' pass to distinguish between + // fixed-point / floating-point instructions and emit appropriate read/write + // to the correct CSR. + bit IsRVVFixedPoint = 0; + let TSFlags{21} = IsRVVFixedPoint; } // Pseudo instructions diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td @@ -1168,7 +1168,8 @@ VReg Op1Class, DAGOperand Op2Class, string Constraint, - int DummyMask = 1> : + int DummyMask = 1, + int RVVFixedPoint = 1> : Pseudo<(outs RetClass:$rd), (ins Op1Class:$rs2, Op2Class:$rs1, ixlenimm:$rm, AVL:$vl, ixlenimm:$sew), []>, RISCVVPseudo { @@ -1180,12 +1181,14 @@ let HasSEWOp = 1; let HasDummyMask = DummyMask; let HasRoundModeOp = 1; + let IsRVVFixedPoint = RVVFixedPoint; } class VPseudoBinaryNoMaskTURoundingMode : + string Constraint, + int RVVFixedPoint> : Pseudo<(outs RetClass:$rd), (ins RetClass:$merge, Op1Class:$rs2, Op2Class:$rs1, ixlenimm:$rm, AVL:$vl, ixlenimm:$sew), []>, @@ -1199,12 +1202,14 @@ let HasDummyMask = 1; let HasMergeOp = 1; let HasRoundModeOp = 1; + let IsRVVFixedPoint = RVVFixedPoint; } class VPseudoBinaryMaskPolicyRoundingMode : + string Constraint, + int RVVFixedPoint> : Pseudo<(outs GetVRegNoV0.R:$rd), (ins GetVRegNoV0.R:$merge, Op1Class:$rs2, Op2Class:$rs1, @@ -1221,6 +1226,7 @@ let HasVecPolicyOp = 1; let UsesMaskPolicy = 1; let HasRoundModeOp = 1; + let IsRVVFixedPoint = RVVFixedPoint; } // Special version of VPseudoBinaryNoMask where we pretend the first source is @@ -2036,16 +2042,18 @@ VReg Op1Class, DAGOperand Op2Class, LMULInfo MInfo, - string Constraint = ""> { + string Constraint = "", + int IsRVVFixedPoint = 1> { let VLMul = MInfo.value in { def "_" # MInfo.MX : - VPseudoBinaryNoMaskRoundingMode; + VPseudoBinaryNoMaskRoundingMode; def "_" # MInfo.MX # "_TU" : VPseudoBinaryNoMaskTURoundingMode; + Constraint, IsRVVFixedPoint>; def "_" # MInfo.MX # "_MASK" : VPseudoBinaryMaskPolicyRoundingMode, + Constraint, IsRVVFixedPoint>, RISCVMaskedPseudo; } } @@ -2109,6 +2117,11 @@ defm _VV : VPseudoBinary; } +multiclass VPseudoBinaryFV_VV_RM { + defm _VV : VPseudoBinaryRoundingMode; +} + multiclass VPseudoVGTR_VV_EEW { foreach m = MxList in { defvar mx = m.MX; @@ -2157,6 +2170,12 @@ f.fprclass, m, Constraint, sew>; } +multiclass VPseudoBinaryV_VF_RM { + defm "_V" # f.FX : VPseudoBinaryRoundingMode; +} + multiclass VPseudoVSLD1_VF { foreach f = FPList in { foreach m = f.MxList in { @@ -2891,6 +2910,28 @@ } } +multiclass VPseudoVALU_VV_VF_RM { + foreach m = MxListF in { + defvar mx = m.MX; + defvar WriteVFALUV_MX = !cast("WriteVFALUV_" # mx); + defvar ReadVFALUV_MX = !cast("ReadVFALUV_" # mx); + + defm "" : VPseudoBinaryFV_VV_RM, + Sched<[WriteVFALUV_MX, ReadVFALUV_MX, ReadVFALUV_MX, ReadVMask]>; + } + + foreach f = FPList in { + foreach m = f.MxList in { + defvar mx = m.MX; + defvar WriteVFALUF_MX = !cast("WriteVFALUF_" # mx); + defvar ReadVFALUV_MX = !cast("ReadVFALUV_" # mx); + defvar ReadVFALUF_MX = !cast("ReadVFALUF_" # mx); + defm "" : VPseudoBinaryV_VF_RM, + Sched<[WriteVFALUF_MX, ReadVFALUV_MX, ReadVFALUF_MX, ReadVMask]>; + } + } +} + multiclass VPseudoVALU_VF { foreach f = FPList in { foreach m = f.MxList in { @@ -6008,7 +6049,7 @@ // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions //===----------------------------------------------------------------------===// let Uses = [FRM], mayRaiseFPException = true in { -defm PseudoVFADD : VPseudoVALU_VV_VF; +defm PseudoVFADD : VPseudoVALU_VV_VF_RM; defm PseudoVFSUB : VPseudoVALU_VV_VF; defm PseudoVFRSUB : VPseudoVALU_VF; } @@ -6681,7 +6722,8 @@ //===----------------------------------------------------------------------===// // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions //===----------------------------------------------------------------------===// -defm : VPatBinaryV_VV_VX<"int_riscv_vfadd", "PseudoVFADD", AllFloatVectors>; +defm : VPatBinaryV_VV_VXRoundingMode<"int_riscv_vfadd", "PseudoVFADD", + AllFloatVectors>; defm : VPatBinaryV_VV_VX<"int_riscv_vfsub", "PseudoVFSUB", AllFloatVectors>; defm : VPatBinaryV_VX<"int_riscv_vfrsub", "PseudoVFRSUB", AllFloatVectors>; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVSDPatterns.td @@ -110,7 +110,9 @@ instruction_name#"_VV_"# vlmul.MX)) op_reg_class:$rs1, op_reg_class:$rs2, - (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), avl, log2sew)>; class VPatBinarySDNode_XI; multiclass VPatBinarySDNode_VV_VX; +class VPatBinarySDNode_VF_RM : + Pat<(result_type (vop (vop_type vop_reg_class:$rs1), + (vop_type (SplatFPOp xop_kind:$rs2)))), + (!cast( + !if(isSEWAware, + instruction_name#"_"#vlmul.MX#"_E"#!shl(1, log2sew), + instruction_name#"_"#vlmul.MX)) + vop_reg_class:$rs1, + (xop_type xop_kind:$rs2), + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), + avl, log2sew)>; + multiclass VPatBinaryFPSDNode_VV_VF { foreach vti = AllFloatVectors in { @@ -254,6 +282,21 @@ } } +multiclass VPatBinaryFPSDNode_VV_VF_RM { + foreach vti = AllFloatVectors in { + let Predicates = GetVTypePredicates.Predicates in { + def : VPatBinarySDNode_VV_RM; + def : VPatBinarySDNode_VF_RM; + } + } +} + multiclass VPatBinaryFPSDNode_R_VF { foreach fvti = AllFloatVectors in @@ -993,7 +1036,7 @@ // 13. Vector Floating-Point Instructions // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions -defm : VPatBinaryFPSDNode_VV_VF; +defm : VPatBinaryFPSDNode_VV_VF_RM; defm : VPatBinaryFPSDNode_VV_VF; defm : VPatBinaryFPSDNode_R_VF; diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td @@ -611,7 +611,9 @@ op1_reg_class:$rs1, op2_reg_class:$rs2, (mask_type V0), - (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), GPR:$vl, log2sew, TAIL_AGNOSTIC)>; @@ -706,7 +708,9 @@ vop_reg_class:$rs1, xop_kind:$rs2, (mask_type V0), - (XLenVT 4), // vxrm value for RISCVInertReadWriteCSR + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), GPR:$vl, log2sew, TAIL_AGNOSTIC)>; @@ -861,6 +865,36 @@ scalar_reg_class:$rs2, (mask_type V0), GPR:$vl, log2sew, TAIL_AGNOSTIC)>; +class VPatBinaryVL_VF_RM + : Pat<(result_type (vop (vop1_type vop_reg_class:$rs1), + (vop2_type (SplatFPOp scalar_reg_class:$rs2)), + (result_type result_reg_class:$merge), + (mask_type V0), + VLOpFrag)), + (!cast( + !if(isSEWAware, + instruction_name#"_"#vlmul.MX#"_E"#!shl(1, log2sew)#"_MASK", + instruction_name#"_"#vlmul.MX#"_MASK")) + result_reg_class:$merge, + vop_reg_class:$rs1, + scalar_reg_class:$rs2, + (mask_type V0), + // Value to indicate no rounding mode change in + // RISCVInertReadWriteCSR + (XLenVT 99), + GPR:$vl, log2sew, TAIL_AGNOSTIC)>; + multiclass VPatBinaryFPVL_VV_VF { foreach vti = AllFloatVectors in { @@ -877,6 +911,22 @@ } } +multiclass VPatBinaryFPVL_VV_VF_RM { + foreach vti = AllFloatVectors in { + let Predicates = GetVTypePredicates.Predicates in { + def : VPatBinaryVL_V_RM; + def : VPatBinaryVL_VF_RM; + } + } +} + multiclass VPatBinaryFPVL_R_VF { foreach fvti = AllFloatVectors in { @@ -1897,7 +1947,7 @@ // 13. Vector Floating-Point Instructions // 13.2. Vector Single-Width Floating-Point Add/Subtract Instructions -defm : VPatBinaryFPVL_VV_VF; +defm : VPatBinaryFPVL_VV_VF_RM; defm : VPatBinaryFPVL_VV_VF; defm : VPatBinaryFPVL_R_VF;