diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -842,8 +842,29 @@ } case ISD::ConstantFP: { const APFloat &APF = cast(Node)->getValueAPF(); - if (static_cast(TLI)->isLegalZfaFPImm(APF, VT)) - break; + int FPImm = static_cast(TLI)->getLegalZfaFPImm( + APF, VT); + if (FPImm >= 0) { + unsigned Opc; + switch (VT.SimpleTy) { + default: + llvm_unreachable("Unexpected size"); + case MVT::f16: + Opc = RISCV::FLI_H; + break; + case MVT::f32: + Opc = RISCV::FLI_S; + break; + case MVT::f64: + Opc = RISCV::FLI_D; + break; + } + + SDNode *Res = CurDAG->getMachineNode( + Opc, DL, VT, CurDAG->getTargetConstant(FPImm, DL, XLenVT)); + ReplaceNode(Node, Res); + return; + } bool NegZeroF64 = APF.isNegZero() && VT == MVT::f64; SDValue Imm; @@ -2967,7 +2988,8 @@ MVT VT = CFP->getSimpleValueType(0); - if (static_cast(TLI)->isLegalZfaFPImm(APF, VT)) + if (static_cast(TLI)->getLegalZfaFPImm(APF, + VT) >= 0) return false; MVT XLenVT = Subtarget->getXLenVT(); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.h b/llvm/lib/Target/RISCV/RISCVISelLowering.h --- a/llvm/lib/Target/RISCV/RISCVISelLowering.h +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.h @@ -393,7 +393,7 @@ SmallVectorImpl &Ops) const override; bool shouldScalarizeBinop(SDValue VecOp) const override; bool isOffsetFoldingLegal(const GlobalAddressSDNode *GA) const override; - bool isLegalZfaFPImm(const APFloat &Imm, EVT VT) const; + int getLegalZfaFPImm(const APFloat &Imm, EVT VT) const; bool isFPImmLegal(const APFloat &Imm, EVT VT, bool ForCodeSize) const override; bool isExtractSubvectorCheap(EVT ResVT, EVT SrcVT, diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1545,9 +1545,11 @@ return false; } -bool RISCVTargetLowering::isLegalZfaFPImm(const APFloat &Imm, EVT VT) const { +// Returns 0-31 if the fli instruction is available for the type and this is +// legal FP immediate for the type. Returns -1 otherwise. +int RISCVTargetLowering::getLegalZfaFPImm(const APFloat &Imm, EVT VT) const { if (!Subtarget.hasStdExtZfa()) - return false; + return -1; bool IsSupportedVT = false; if (VT == MVT::f16) { @@ -1559,7 +1561,10 @@ IsSupportedVT = true; } - return IsSupportedVT && RISCVLoadFPImm::getLoadFPImm(Imm) != -1; + if (!IsSupportedVT) + return -1; + + return RISCVLoadFPImm::getLoadFPImm(Imm); } bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, @@ -1575,7 +1580,7 @@ if (!IsLegalVT) return false; - if (isLegalZfaFPImm(Imm, VT)) + if (getLegalZfaFPImm(Imm, VT) >= 0) return true; // Cannot create a 64 bit floating-point immediate value for rv32. diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td @@ -63,7 +63,7 @@ : RVInstR; -let hasSideEffects = 0, mayLoad = 0, mayStore = 0, mayRaiseFPException = 1 in +let hasSideEffects = 0, mayLoad = 0, mayStore = 0 in class FPUnaryOp_imm funct7, bits<5> rs2val, bits<3> funct3, RISCVOpcode opcode, dag outs, dag ins, string opcodestr, string argstr> : RVInst { @@ -182,14 +182,7 @@ // Codegen patterns //===----------------------------------------------------------------------===// -def fpimm_to_loadfpimm : SDNodeXFormgetTargetConstant(RISCVLoadFPImm::getLoadFPImm(N->getValueAPF()), - SDLoc(N), Subtarget->getXLenVT());}]>; - - let Predicates = [HasStdExtZfa] in { -def : Pat<(f32 fpimm:$imm), (FLI_S (fpimm_to_loadfpimm fpimm:$imm))>; - def: PatFprFpr; def: PatFprFpr; @@ -212,8 +205,6 @@ } // Predicates = [HasStdExtZfa] let Predicates = [HasStdExtZfa, HasStdExtD] in { -def : Pat<(f64 fpimm:$imm), (FLI_D (fpimm_to_loadfpimm fpimm:$imm))>; - def: PatFprFpr; def: PatFprFpr; @@ -242,8 +233,6 @@ } let Predicates = [HasStdExtZfa, HasStdExtZfh] in { -def : Pat<(f16 fpimm:$imm), (FLI_H (fpimm_to_loadfpimm fpimm:$imm))>; - def: PatFprFpr; def: PatFprFpr;