diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -245,6 +245,26 @@ }; } // namespace RISCVISD +namespace RISCV { +// Must agree with definitions in RISCVSystemOperands.td +enum class SysReg { + // User Floating-Point CSRs + FFLAGS = 0x001, + FRM = 0x002, + FCSR = 0x003 +}; + +enum Rounding { + RNE = 0, ///< roundTiesToEven. + RTZ = 1, ///< roundTowardZero. + RDN = 2, ///< roundTowardPositive. + RUP = 3, ///< roundTowardPositive. + RMM = 4, ///< roundTiesToAway. + MaxRounding = RMM, + DYN = 7 +}; +} // namespace RISCV + class RISCVTargetLowering : public TargetLowering { const RISCVSubtarget &Subtarget; @@ -508,6 +528,7 @@ bool HasMask = true) const; SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const; + SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; bool isEligibleForTailCallOptimization( CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -359,6 +359,11 @@ setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom); } + if (Subtarget.hasStdExtF()) { + setOperationAction(ISD::FLT_ROUNDS_, + (Subtarget.is64Bit() ? MVT::i64 : MVT::i32), Custom); + } + setOperationAction(ISD::GlobalAddress, XLenVT, Custom); setOperationAction(ISD::BlockAddress, XLenVT, Custom); setOperationAction(ISD::ConstantPool, XLenVT, Custom); @@ -1726,6 +1731,8 @@ case ISD::MGATHER: case ISD::MSCATTER: return lowerMGATHERMSCATTER(Op, DAG); + case ISD::FLT_ROUNDS_: + return lowerGET_ROUNDING(Op, DAG); } } @@ -3537,6 +3544,37 @@ Ops, N->getMemoryVT(), N->getMemOperand()); } +SDValue RISCVTargetLowering::lowerGET_ROUNDING(SDValue Op, + SelectionDAG &DAG) const { + const MVT IntTy = Subtarget.getXLenVT(); + SDLoc DL(Op); + SDValue Chain = Op->getOperand(0); + SDValue SysRegNo = DAG.getConstant(static_cast(RISCV::SysReg::FRM), DL, + IntTy); + SDVTList VTs = DAG.getVTList(IntTy, MVT::Other); + SDValue RM = DAG.getNode(RISCVISD::READ_CSR, DL, VTs, Chain, SysRegNo); + + // Encoding used for rounding mode in RISCV differs from that used in + // FLT_ROUNDS. To convert it the RISCV rounding mode is used as an index in a + // table, which consists of a sequence of 4-bit fields, each representing + // corresponding FLT_ROUNDS mode. + static const int Table = + (int(RoundingMode::NearestTiesToEven) << 4 * RISCV::Rounding::RNE) | + (int(RoundingMode::TowardZero) << 4 * RISCV::Rounding::RTZ) | + (int(RoundingMode::TowardNegative) << 4 * RISCV::Rounding::RDN) | + (int(RoundingMode::TowardPositive) << 4 * RISCV::Rounding::RUP) | + (int(RoundingMode::NearestTiesToAway) << 4 * RISCV::Rounding::RMM); + + SDValue Shift = + DAG.getNode(ISD::SHL, DL, IntTy, RM, DAG.getConstant(2, DL, IntTy)); + SDValue Shifted = DAG.getNode(ISD::SRL, DL, IntTy, + DAG.getConstant(Table, DL, IntTy), Shift); + SDValue Masked = + DAG.getNode(ISD::AND, DL, IntTy, Shifted, DAG.getConstant(7, DL, IntTy)); + + return DAG.getMergeValues({Masked, Chain}, DL); +} + // Returns the opcode of the target-specific SDNode that implements the 32-bit // form of the given Opcode. static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) { @@ -3945,6 +3983,15 @@ if (SDValue V = lowerVECREDUCE(SDValue(N, 0), DAG)) Results.push_back(V); break; + case ISD::FLT_ROUNDS_: { + assert(N->getValueType(0) == MVT::i32 && Subtarget.is64Bit() && + Subtarget.hasStdExtF() && "Unexpected custom legalization"); + SDVTList VTs = DAG.getVTList(MVT::i64, MVT::Other); + SDValue Res = DAG.getNode(ISD::FLT_ROUNDS_, DL, VTs, N->getOperand(0)); + Results.push_back(SDValue(Res.getNode(), 0)); + Results.push_back(SDValue(Res.getNode(), 1)); + break; + } } } diff --git a/llvm/test/CodeGen/RISCV/fpenv.ll b/llvm/test/CodeGen/RISCV/fpenv.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/fpenv.ll @@ -0,0 +1,27 @@ +; RUN: llc -mtriple=riscv32 -mattr=+f -verify-machineinstrs < %s | FileCheck -check-prefix=RV32IF %s +; RUN: llc -mtriple=riscv64 -mattr=+f -verify-machineinstrs < %s | FileCheck -check-prefix=RV64IF %s + +define i32 @func_01() { + %rm = call i32 @llvm.flt.rounds() + ret i32 %rm +} + +; RV32IF-LABEL: func_01: +; RV32IF: csrrc a0, frm, zero +; RV32IF: slli a0, a0, 2 +; RV32IF: lui a1, 66 +; RV32IF: addi a1, a1, 769 +; RV32IF: srl a0, a1, a0 +; RV32IF: andi a0, a0, 7 +; RV32IF: ret + +; RV64IF-LABEL: func_01: +; RV64IF: csrrc a0, frm, zero +; RV64IF: slli a0, a0, 2 +; RV64IF: lui a1, 66 +; RV64IF: addiw a1, a1, 769 +; RV64IF: srl a0, a1, a0 +; RV64IF: andi a0, a0, 7 +; RV64IF: ret + +declare i32 @llvm.flt.rounds()