diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -533,6 +533,7 @@ bool HasMask = true) const; SDValue lowerFixedLengthVectorExtendToRVV(SDValue Op, SelectionDAG &DAG, unsigned ExtendOpc) const; + SDValue lowerGET_ROUNDING(SDValue Op, SelectionDAG &DAG) const; bool isEligibleForTailCallOptimization( CCState &CCInfo, CallLoweringInfo &CLI, MachineFunction &MF, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -372,6 +372,10 @@ setOperationAction(ISD::STRICT_FP_TO_SINT, MVT::i32, Custom); } + if (Subtarget.hasStdExtF()) { + setOperationAction(ISD::FLT_ROUNDS_, XLenVT, Custom); + } + setOperationAction(ISD::GlobalAddress, XLenVT, Custom); setOperationAction(ISD::BlockAddress, XLenVT, Custom); setOperationAction(ISD::ConstantPool, XLenVT, Custom); @@ -2161,6 +2165,8 @@ return lowerMGATHER(Op, DAG); case ISD::MSCATTER: return lowerMSCATTER(Op, DAG); + case ISD::FLT_ROUNDS_: + return lowerGET_ROUNDING(Op, DAG); } } @@ -4107,6 +4113,37 @@ MSN->getMemoryVT(), MSN->getMemOperand()); } +SDValue RISCVTargetLowering::lowerGET_ROUNDING(SDValue Op, + SelectionDAG &DAG) const { + const MVT XLenVT = Subtarget.getXLenVT(); + SDLoc DL(Op); + SDValue Chain = Op->getOperand(0); + SDValue SysRegNo = DAG.getConstant( + RISCVSysReg::lookupSysRegByName("FRM")->Encoding, DL, XLenVT); + SDVTList VTs = DAG.getVTList(XLenVT, MVT::Other); + SDValue RM = DAG.getNode(RISCVISD::READ_CSR, DL, VTs, Chain, SysRegNo); + + // Encoding used for rounding mode in RISCV differs from that used in + // FLT_ROUNDS. To convert it the RISCV rounding mode is used as an index in a + // table, which consists of a sequence of 4-bit fields, each representing + // corresponding FLT_ROUNDS mode. + static const int Table = + (int(RoundingMode::NearestTiesToEven) << 4 * RISCVFPRndMode::RNE) | + (int(RoundingMode::TowardZero) << 4 * RISCVFPRndMode::RTZ) | + (int(RoundingMode::TowardNegative) << 4 * RISCVFPRndMode::RDN) | + (int(RoundingMode::TowardPositive) << 4 * RISCVFPRndMode::RUP) | + (int(RoundingMode::NearestTiesToAway) << 4 * RISCVFPRndMode::RMM); + + SDValue Shift = + DAG.getNode(ISD::SHL, DL, XLenVT, RM, DAG.getConstant(2, DL, XLenVT)); + SDValue Shifted = DAG.getNode(ISD::SRL, DL, XLenVT, + DAG.getConstant(Table, DL, XLenVT), Shift); + SDValue Masked = DAG.getNode(ISD::AND, DL, XLenVT, Shifted, + DAG.getConstant(7, DL, XLenVT)); + + return DAG.getMergeValues({Masked, Chain}, DL); +} + // Returns the opcode of the target-specific SDNode that implements the 32-bit // form of the given Opcode. static RISCVISD::NodeType getRISCVWOpcode(unsigned Opcode) { @@ -4584,6 +4621,13 @@ if (SDValue V = lowerVECREDUCE(SDValue(N, 0), DAG)) Results.push_back(V); break; + case ISD::FLT_ROUNDS_: { + SDVTList VTs = DAG.getVTList(Subtarget.getXLenVT(), MVT::Other); + SDValue Res = DAG.getNode(ISD::FLT_ROUNDS_, DL, VTs, N->getOperand(0)); + Results.push_back(Res.getValue(0)); + Results.push_back(Res.getValue(1)); + break; + } } } diff --git a/llvm/test/CodeGen/RISCV/fpenv.ll b/llvm/test/CodeGen/RISCV/fpenv.ll new file mode 100644 --- /dev/null +++ b/llvm/test/CodeGen/RISCV/fpenv.ll @@ -0,0 +1,29 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py +; RUN: llc -mtriple=riscv32 -mattr=+f -verify-machineinstrs < %s | FileCheck -check-prefix=RV32IF %s +; RUN: llc -mtriple=riscv64 -mattr=+f -verify-machineinstrs < %s | FileCheck -check-prefix=RV64IF %s + +define i32 @func_01() { +; RV32IF-LABEL: func_01: +; RV32IF: # %bb.0: +; RV32IF-NEXT: frrm a0 +; RV32IF-NEXT: slli a0, a0, 2 +; RV32IF-NEXT: lui a1, 66 +; RV32IF-NEXT: addi a1, a1, 769 +; RV32IF-NEXT: srl a0, a1, a0 +; RV32IF-NEXT: andi a0, a0, 7 +; RV32IF-NEXT: ret +; +; RV64IF-LABEL: func_01: +; RV64IF: # %bb.0: +; RV64IF-NEXT: frrm a0 +; RV64IF-NEXT: slli a0, a0, 2 +; RV64IF-NEXT: lui a1, 66 +; RV64IF-NEXT: addiw a1, a1, 769 +; RV64IF-NEXT: srl a0, a1, a0 +; RV64IF-NEXT: andi a0, a0, 7 +; RV64IF-NEXT: ret + %rm = call i32 @llvm.flt.rounds() + ret i32 %rm +} + +declare i32 @llvm.flt.rounds()