diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h --- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h +++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h @@ -32,6 +32,12 @@ class RISCVRegisterBankInfo final : public RISCVGenRegisterBankInfo { public: RISCVRegisterBankInfo(unsigned HwMode); + + const RegisterBank &getRegBankFromRegClass(const TargetRegisterClass &RC, + LLT Ty) const override; + + const InstructionMapping & + getInstrMapping(const MachineInstr &MI) const override; }; } // end namespace llvm #endif diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp --- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp @@ -12,6 +12,7 @@ #include "RISCVRegisterBankInfo.h" #include "MCTargetDesc/RISCVMCTargetDesc.h" +#include "RISCVSubtarget.h" #include "llvm/CodeGen/MachineRegisterInfo.h" #include "llvm/CodeGen/RegisterBank.h" #include "llvm/CodeGen/RegisterBankInfo.h" @@ -20,7 +21,111 @@ #define GET_TARGET_REGBANK_IMPL #include "RISCVGenRegisterBank.inc" +namespace llvm { +namespace RISCV { + +RegisterBankInfo::PartialMapping PartMappings[] = { + {0, 32, GPRRegBank}, + {0, 64, GPRRegBank} +}; + +enum PartialMappingIdx { + PMI_GPR32 = 0, + PMI_GPR64 = 1 +}; + +RegisterBankInfo::ValueMapping ValueMappings[] = { + // Invalid value mapping. + {nullptr, 0}, + // Maximum 3 GPR operands; 32 bit. + {&PartMappings[PMI_GPR32], 1}, + {&PartMappings[PMI_GPR32], 1}, + {&PartMappings[PMI_GPR32], 1}, + // Maximum 3 GPR operands; 64 bit. + {&PartMappings[PMI_GPR64], 1}, + {&PartMappings[PMI_GPR64], 1}, + {&PartMappings[PMI_GPR64], 1} +}; + +enum ValueMappingsIdx { + InvalidIdx = 0, + GPR32Idx = 1, + GPR64Idx = 4 +}; +} // namespace RISCV +} // namespace llvm + using namespace llvm; RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode) : RISCVGenRegisterBankInfo(HwMode) {} + + +const RegisterBank & +RISCVRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC, + LLT Ty) const { + switch (RC.getID()) { + default: + llvm_unreachable("Register class not supported"); + case RISCV::GPRRegClassID: + case RISCV::GPRNoX0RegClassID: + case RISCV::GPRNoX0X2RegClassID: + case RISCV::GPRTCRegClassID: + case RISCV::GPRCRegClassID: + case RISCV::GPRC_and_GPRTCRegClassID: + case RISCV::GPRX0RegClassID: + case RISCV::SPRegClassID: + return getRegBank(RISCV::GPRRegBankID); + case RISCV::FPR32RegClassID: + case RISCV::FPR32CRegClassID: + case RISCV::FPR64RegClassID: + case RISCV::FPR64CRegClassID: + llvm_unreachable("Register class not supported"); + } +} + +const RegisterBankInfo::InstructionMapping & +RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { + const auto &Mapping = getInstrMappingImpl(MI); + if (Mapping.isValid()) + return Mapping; + + const MachineFunction &MF = *MI.getParent()->getParent(); + bool IsRV64 = MF.getSubtarget().is64Bit(); + + size_t NumOperands = MI.getNumOperands(); + const ValueMapping *GPRValueMapping = + &RISCV::ValueMappings[IsRV64 ? RISCV::GPR64Idx : RISCV::GPR32Idx]; + const ValueMapping *OperandsMapping = GPRValueMapping; + unsigned MappingID = DefaultMappingID; + + switch (MI.getOpcode()) { + case TargetOpcode::G_ADD: + case TargetOpcode::G_SUB: + case TargetOpcode::G_SHL: + case TargetOpcode::G_ASHR: + case TargetOpcode::G_LSHR: + case TargetOpcode::G_AND: + case TargetOpcode::G_OR: + case TargetOpcode::G_XOR: + case TargetOpcode::G_MUL: + case TargetOpcode::G_SDIV: + case TargetOpcode::G_SREM: + case TargetOpcode::G_UDIV: + case TargetOpcode::G_UREM: + case TargetOpcode::G_UMULH: + break; + case TargetOpcode::G_CONSTANT: + OperandsMapping = getOperandsMapping({GPRValueMapping, nullptr}); + break; + case TargetOpcode::G_ICMP: + OperandsMapping = getOperandsMapping( + {GPRValueMapping, nullptr, GPRValueMapping, GPRValueMapping}); + break; + default: + return getInvalidInstructionMapping(); + } + + return getInstructionMapping(MappingID, /*Cost=*/1, OperandsMapping, + NumOperands); +}