diff --git a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp --- a/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp +++ b/llvm/lib/Target/RISCV/AsmParser/RISCVAsmParser.cpp @@ -496,7 +496,7 @@ return isUImm5(); if (Kind != KindTy::FPImmediate) return false; - int Idx = RISCVLoadFPImm::getLoadFP64Imm( + int Idx = RISCVLoadFPImm::getLoadFPImm( APFloat(APFloat::IEEEdouble(), APInt(64, getFPConst()))); // Don't allow decimal version of the minimum value. It is a different value // for each supported data type. @@ -985,7 +985,7 @@ return; } - int Imm = RISCVLoadFPImm::getLoadFP64Imm( + int Imm = RISCVLoadFPImm::getLoadFPImm( APFloat(APFloat::IEEEdouble(), APInt(64, getFPConst()))); Inst.addOperand(MCOperand::createImm(Imm)); } diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.h @@ -349,20 +349,10 @@ namespace RISCVLoadFPImm { float getFPImm(unsigned Imm); -/// getLoadFP32Imm - Return a 5-bit binary encoding of the 32-bit -/// floating-point immediate value. If the value cannot be represented as a -/// 5-bit binary encoding, then return -1. -int getLoadFP32Imm(const APFloat &FPImm); - -/// getLoadFP64Imm - Return a 5-bit binary encoding of the 64-bit -/// floating-point immediate value. If the value cannot be represented as a -/// 5-bit binary encoding, then return -1. -int getLoadFP64Imm(const APFloat &FPImm); - -/// getLoadFP16Imm - Return a 5-bit binary encoding of the 16-bit -/// floating-point immediate value. If the value cannot be represented as a -/// 5-bit binary encoding, then return -1. -int getLoadFP16Imm(const APFloat &FPImm); +/// getLoadFPImm - Return a 5-bit binary encoding of the floating-point +/// immediate value. If the value cannot be represented as a 5-bit binary +/// encoding, then return -1. +int getLoadFPImm(APFloat FPImm); } // namespace RISCVLoadFPImm namespace RISCVSysReg { diff --git a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp --- a/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp +++ b/llvm/lib/Target/RISCV/MCTargetDesc/RISCVBaseInfo.cpp @@ -214,87 +214,37 @@ return uncompressInst(OutInst, MI, STI); } -// Lookup table for fli.h for entries 1-31. Entry 0(-1.0) is handled separately. -// NOTE: The exponent for entry 1 is larger than entry 2 and 3 because they -// are denormals. -static constexpr std::pair LoadFP16ImmArr[] = { - {0b00001, 0b00}, {0b00000, 0b01}, {0b00000, 0b10}, {0b00111, 0b00}, - {0b01000, 0b00}, {0b01011, 0b00}, {0b01100, 0b00}, {0b01101, 0b00}, - {0b01101, 0b01}, {0b01101, 0b10}, {0b01101, 0b11}, {0b01110, 0b00}, - {0b01110, 0b01}, {0b01110, 0b10}, {0b01110, 0b11}, {0b01111, 0b00}, - {0b01111, 0b01}, {0b01111, 0b10}, {0b01111, 0b11}, {0b10000, 0b00}, - {0b10000, 0b01}, {0b10000, 0b10}, {0b10001, 0b00}, {0b10010, 0b00}, - {0b10011, 0b00}, {0b10110, 0b00}, {0b10111, 0b00}, {0b11110, 0b00}, - {0b11111, 0b00}, {0b11111, 0b00}, {0b11111, 0b10}, -}; - -// Lookup table for fli.s for entries 1-31. +// Lookup table for fli.s for entries 2-31. static constexpr std::pair LoadFP32ImmArr[] = { - {0b00000001, 0b00}, {0b01101111, 0b00}, {0b01110000, 0b00}, - {0b01110111, 0b00}, {0b01111000, 0b00}, {0b01111011, 0b00}, - {0b01111100, 0b00}, {0b01111101, 0b00}, {0b01111101, 0b01}, - {0b01111101, 0b10}, {0b01111101, 0b11}, {0b01111110, 0b00}, - {0b01111110, 0b01}, {0b01111110, 0b10}, {0b01111110, 0b11}, - {0b01111111, 0b00}, {0b01111111, 0b01}, {0b01111111, 0b10}, - {0b01111111, 0b11}, {0b10000000, 0b00}, {0b10000000, 0b01}, - {0b10000000, 0b10}, {0b10000001, 0b00}, {0b10000010, 0b00}, - {0b10000011, 0b00}, {0b10000110, 0b00}, {0b10000111, 0b00}, - {0b10001110, 0b00}, {0b10001111, 0b00}, {0b11111111, 0b00}, - {0b11111111, 0b10}, -}; - -// Lookup table for fli.d for entries 1-31. -static constexpr std::pair LoadFP64ImmArr[] = { - {0b00000000001, 0b00}, {0b01111101111, 0b00}, {0b01111110000, 0b00}, - {0b01111110111, 0b00}, {0b01111111000, 0b00}, {0b01111111011, 0b00}, - {0b01111111100, 0b00}, {0b01111111101, 0b00}, {0b01111111101, 0b01}, - {0b01111111101, 0b10}, {0b01111111101, 0b11}, {0b01111111110, 0b00}, - {0b01111111110, 0b01}, {0b01111111110, 0b10}, {0b01111111110, 0b11}, - {0b01111111111, 0b00}, {0b01111111111, 0b01}, {0b01111111111, 0b10}, - {0b01111111111, 0b11}, {0b10000000000, 0b00}, {0b10000000000, 0b01}, - {0b10000000000, 0b10}, {0b10000000001, 0b00}, {0b10000000010, 0b00}, - {0b10000000011, 0b00}, {0b10000000110, 0b00}, {0b10000000111, 0b00}, - {0b10000001110, 0b00}, {0b10000001111, 0b00}, {0b11111111111, 0b00}, - {0b11111111111, 0b10}, + {0b01101111, 0b00}, {0b01110000, 0b00}, {0b01110111, 0b00}, + {0b01111000, 0b00}, {0b01111011, 0b00}, {0b01111100, 0b00}, + {0b01111101, 0b00}, {0b01111101, 0b01}, {0b01111101, 0b10}, + {0b01111101, 0b11}, {0b01111110, 0b00}, {0b01111110, 0b01}, + {0b01111110, 0b10}, {0b01111110, 0b11}, {0b01111111, 0b00}, + {0b01111111, 0b01}, {0b01111111, 0b10}, {0b01111111, 0b11}, + {0b10000000, 0b00}, {0b10000000, 0b01}, {0b10000000, 0b10}, + {0b10000001, 0b00}, {0b10000010, 0b00}, {0b10000011, 0b00}, + {0b10000110, 0b00}, {0b10000111, 0b00}, {0b10001110, 0b00}, + {0b10001111, 0b00}, {0b11111111, 0b00}, {0b11111111, 0b10}, }; -int RISCVLoadFPImm::getLoadFP16Imm(const APFloat &FPImm) { - assert(&FPImm.getSemantics() == &APFloat::IEEEhalf()); - - APInt Imm = FPImm.bitcastToAPInt(); - - if (Imm.extractBitsAsZExtValue(8, 0) != 0) +int RISCVLoadFPImm::getLoadFPImm(APFloat FPImm) { + assert((&FPImm.getSemantics() == &APFloat::IEEEsingle() || + &FPImm.getSemantics() == &APFloat::IEEEdouble() || + &FPImm.getSemantics() == &APFloat::IEEEhalf()) && + "Unexpected semantics"); + + // Handle the minimum normalized value which is different for each type. + if (FPImm.isSmallestNormalized()) + return 1; + + // Convert to single precision to use its lookup table. + bool LosesInfo; + APFloat::opStatus Status = FPImm.convert( + APFloat::IEEEsingle(), APFloat::rmNearestTiesToEven, &LosesInfo); + if (Status != APFloat::opOK || LosesInfo) return -1; - bool Sign = Imm.extractBitsAsZExtValue(1, 15); - uint8_t Mantissa = Imm.extractBitsAsZExtValue(2, 8); - uint8_t Exp = Imm.extractBitsAsZExtValue(5, 10); - - // The array isn't sorted so we must use std::find unlike fp32 and fp64. - auto EMI = llvm::find(LoadFP16ImmArr, std::make_pair(Exp, Mantissa)); - if (EMI == std::end(LoadFP16ImmArr)) - return -1; - - // Table doesn't have entry 0. - int Entry = std::distance(std::begin(LoadFP16ImmArr), EMI) + 1; - - // The only legal negative value is -1.0(entry 0). 1.0 is entry 16. - if (Sign) { - if (Entry == 16) - return 0; - return false; - } - - // Entry 29 and 30 are both infinity, but 30 is the real infinity. - if (Entry == 29) - ++Entry; - - return Entry; -} - -int RISCVLoadFPImm::getLoadFP32Imm(const APFloat &FPImm) { - assert(&FPImm.getSemantics() == &APFloat::IEEEsingle()); - APInt Imm = FPImm.bitcastToAPInt(); if (Imm.extractBitsAsZExtValue(21, 0) != 0) @@ -309,38 +259,8 @@ EMI->second != Mantissa) return -1; - // Table doesn't have entry 0. - int Entry = std::distance(std::begin(LoadFP32ImmArr), EMI) + 1; - - // The only legal negative value is -1.0(entry 0). 1.0 is entry 16. - if (Sign) { - if (Entry == 16) - return 0; - return false; - } - - return Entry; -} - -int RISCVLoadFPImm::getLoadFP64Imm(const APFloat &FPImm) { - assert(&FPImm.getSemantics() == &APFloat::IEEEdouble()); - - APInt Imm = FPImm.bitcastToAPInt(); - - if (Imm.extractBitsAsZExtValue(50, 0) != 0) - return -1; - - bool Sign = Imm.extractBitsAsZExtValue(1, 63); - uint8_t Mantissa = Imm.extractBitsAsZExtValue(2, 50); - uint16_t Exp = Imm.extractBitsAsZExtValue(11, 52); - - auto EMI = llvm::lower_bound(LoadFP64ImmArr, std::make_pair(Exp, Mantissa)); - if (EMI == std::end(LoadFP64ImmArr) || EMI->first != Exp || - EMI->second != Mantissa) - return -1; - - // Table doesn't have entry 0. - int Entry = std::distance(std::begin(LoadFP64ImmArr), EMI) + 1; + // Table doesn't have entry 0 or 1. + int Entry = std::distance(std::begin(LoadFP32ImmArr), EMI) + 2; // The only legal negative value is -1.0(entry 0). 1.0 is entry 16. if (Sign) { @@ -362,8 +282,8 @@ Imm = 16; } - uint32_t Exp = LoadFP32ImmArr[Imm - 1].first; - uint32_t Mantissa = LoadFP32ImmArr[Imm - 1].second; + uint32_t Exp = LoadFP32ImmArr[Imm - 2].first; + uint32_t Mantissa = LoadFP32ImmArr[Imm - 2].second; uint32_t I = Sign << 31 | Exp << 23 | Mantissa << 21; return bit_cast(I); diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1540,21 +1540,20 @@ } bool RISCVTargetLowering::isLegalZfaFPImm(const APFloat &Imm, EVT VT) const { - if (!Subtarget.hasStdExtZfa() || !VT.isSimple()) + if (!Subtarget.hasStdExtZfa()) return false; - switch (VT.getSimpleVT().SimpleTy) { - default: - return false; - case MVT::f16: - return (Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZvfh()) && - RISCVLoadFPImm::getLoadFP16Imm(Imm) != -1; - case MVT::f32: - return RISCVLoadFPImm::getLoadFP32Imm(Imm) != -1; - case MVT::f64: + bool IsSupportedVT = false; + if (VT == MVT::f16) { + IsSupportedVT = Subtarget.hasStdExtZfh() || Subtarget.hasStdExtZvfh(); + } else if (VT == MVT::f32) { + IsSupportedVT = true; + } else if (VT == MVT::f64) { assert(Subtarget.hasStdExtD() && "Expect D extension"); - return RISCVLoadFPImm::getLoadFP64Imm(Imm) != -1; + IsSupportedVT = true; } + + return IsSupportedVT && RISCVLoadFPImm::getLoadFPImm(Imm) != -1; } bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT, diff --git a/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td b/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td --- a/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td +++ b/llvm/lib/Target/RISCV/RISCVInstrInfoZfa.td @@ -179,20 +179,13 @@ // Codegen patterns //===----------------------------------------------------------------------===// -def fp32imm_to_loadfpimm : SDNodeXFormgetTargetConstant(RISCVLoadFPImm::getLoadFP32Imm(N->getValueAPF()), +def fpimm_to_loadfpimm : SDNodeXFormgetTargetConstant(RISCVLoadFPImm::getLoadFPImm(N->getValueAPF()), SDLoc(N), Subtarget->getXLenVT());}]>; -def fp64imm_to_loadfpimm : SDNodeXFormgetTargetConstant(RISCVLoadFPImm::getLoadFP64Imm(N->getValueAPF()), - SDLoc(N), Subtarget->getXLenVT());}]>; - -def fp16imm_to_loadfpimm : SDNodeXFormgetTargetConstant(RISCVLoadFPImm::getLoadFP16Imm(N->getValueAPF()), - SDLoc(N), Subtarget->getXLenVT());}]>; let Predicates = [HasStdExtZfa] in { -def : Pat<(f32 fpimm:$imm), (FLI_S (fp32imm_to_loadfpimm fpimm:$imm))>; +def : Pat<(f32 fpimm:$imm), (FLI_S (fpimm_to_loadfpimm fpimm:$imm))>; def: PatFprFpr; def: PatFprFpr; @@ -216,7 +209,7 @@ } // Predicates = [HasStdExtZfa] let Predicates = [HasStdExtZfa, HasStdExtD] in { -def : Pat<(f64 fpimm:$imm), (FLI_D (fp64imm_to_loadfpimm fpimm:$imm))>; +def : Pat<(f64 fpimm:$imm), (FLI_D (fpimm_to_loadfpimm fpimm:$imm))>; def: PatFprFpr; def: PatFprFpr; @@ -246,7 +239,7 @@ } let Predicates = [HasStdExtZfa, HasStdExtZfh] in { -def : Pat<(f16 fpimm:$imm), (FLI_H (fp16imm_to_loadfpimm fpimm:$imm))>; +def : Pat<(f16 fpimm:$imm), (FLI_H (fpimm_to_loadfpimm fpimm:$imm))>; def: PatFprFpr; def: PatFprFpr;