diff --git a/llvm/include/llvm/ADT/FloatingPointMode.h b/llvm/include/llvm/ADT/FloatingPointMode.h --- a/llvm/include/llvm/ADT/FloatingPointMode.h +++ b/llvm/include/llvm/ADT/FloatingPointMode.h @@ -38,6 +38,7 @@ TowardPositive = 2, ///< roundTowardPositive. TowardNegative = 3, ///< roundTowardNegative. NearestTiesToAway = 4, ///< roundTiesToAway. + MaxIEEEMode = NearestTiesToAway, // Special values. Dynamic = 7, ///< Denotes mode unknown at compile time. diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -554,6 +554,7 @@ SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const; SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; + SDValue lowerSET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; bool isEligibleForTailCallOptimization( CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -375,6 +375,7 @@ if (Subtarget.hasStdExtF()) { setOperationAction(ISD::FLT_ROUNDS_, (Subtarget.is64Bit() ? MVT::i64 : MVT::i32), Custom); + setOperationAction(ISD::SET_ROUNDING, MVT::Other, Custom); } setOperationAction(ISD::GlobalAddress, XLenVT, Custom); @@ -2168,6 +2169,8 @@ return lowerMSCATTER(Op, DAG); case ISD::FLT_ROUNDS_: return lowerGET_ROUNDING(Op, DAG); + case ISD::SET_ROUNDING: + return lowerSET_ROUNDING(Op, DAG); } } @@ -4139,6 +4142,36 @@ return DAG.getMergeValues({Masked, Chain}, DL); } +SDValue RISCVTargetLowering::lowerSET_ROUNDING(SDValue Op, + SelectionDAG &DAG) const { + const MVT IntTy = Subtarget.getXLenVT(); + SDLoc DL(Op); + SDValue Chain = Op->getOperand(0); + SDValue RMValue = Op->getOperand(1); + SDValue SysRegNo = + DAG.getConstant(static_cast(RISCV::SysReg::FRM), DL, IntTy); + + // Encoding used for rounding mode in RISCV differs from that used in + // FLT_ROUNDS. To convert it the C rounding mode is used as an index in + // a table, which consists of a sequence of 4-bit fields, each representing + // corresponding RISCV mode. + static const unsigned Table = + (RISCV::Rounding::RNE << 4 * int(RoundingMode::NearestTiesToEven)) | + (RISCV::Rounding::RTZ << 4 * int(RoundingMode::TowardZero)) | + (RISCV::Rounding::RDN << 4 * int(RoundingMode::TowardNegative)) | + (RISCV::Rounding::RUP << 4 * int(RoundingMode::TowardPositive)) | + (RISCV::Rounding::RMM << 4 * int(RoundingMode::NearestTiesToAway)); + + SDValue Shift = + DAG.getNode(ISD::SHL, DL, IntTy, RMValue, DAG.getConstant(2, DL, IntTy)); + SDValue Shifted = DAG.getNode(ISD::SRL, DL, IntTy, + DAG.getConstant(Table, DL, IntTy), Shift); + RMValue = DAG.getNode(ISD::AND, DL, IntTy, Shifted, + DAG.getConstant(0x7, DL, IntTy)); + return DAG.getNode(RISCVISD::WRITE_CSR, DL, MVT::Other, Chain, SysRegNo, + RMValue); +} + // Returns the opcode of the target-specific SDNode that implements the 32-bit // form of the given Opcode. static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) { diff --git a/llvm/test/CodeGen/RISCV/fpenv.ll b/llvm/test/CodeGen/RISCV/fpenv.ll --- a/llvm/test/CodeGen/RISCV/fpenv.ll +++ b/llvm/test/CodeGen/RISCV/fpenv.ll @@ -24,4 +24,80 @@ ; RV64IF: andi a0, a0, 7 ; RV64IF: ret +define void @func_02(i32 %rm) { + call void @llvm.set.rounding(i32 %rm) + ret void +} + +; RV32IF-LABEL: func_02: +; RV32IF slli a0, a0, 2 +; RV32IF lui a1, 66 +; RV32IF addi a1, a1, 769 +; RV32IF srl a0, a1, a0 +; RV32IF andi a0, a0, 7 +; RV32IF fsrm frm, a0 +; RV32IF ret + +; RV64IF-LABEL: func_02: +; RV64IF: slli a0, a0, 32 +; RV64IF: srli a0, a0, 30 +; RV64IF: lui a1, 66 +; RV64IF: addiw a1, a1, 769 +; RV64IF: srl a0, a1, a0 +; RV64IF: andi a0, a0, 7 +; RV64IF: fsrm a0 +; RV64IF: ret + + +define void @func_03() { + call void @llvm.set.rounding(i32 0) + ret void +} + +; COMMON-LABEL: func_03 +; COMMON: fsrmi 1 +; COMMON: ret + + +define void @func_04() { + call void @llvm.set.rounding(i32 1) + ret void +} + +; COMMON-LABEL: func_04 +; COMMON: fsrmi 0 +; COMMON: ret + + +define void @func_05() { + call void @llvm.set.rounding(i32 2) + ret void +} + +; COMMON-LABEL: func_05 +; COMMON: fsrmi 3 +; COMMON: ret + + +define void @func_06() { + call void @llvm.set.rounding(i32 3) + ret void +} + +; COMMON-LABEL: func_06 +; COMMON: fsrmi 2 +; COMMON: ret + + +define void @func_07() { + call void @llvm.set.rounding(i32 4) + ret void +} + +; COMMON-LABEL: func_07 +; COMMON: fsrmi 4 +; COMMON: ret + + +declare void @llvm.set.rounding(i32) declare i32 @llvm.flt.rounds()