Index: llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h +++ llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h @@ -599,7 +599,6 @@ /// that are used in the description of instruction. In other words, /// there are just a handful of them and we do not want to waste space. /// - /// \todo This should be TableGen'ed. virtual const RegisterBank & getRegBankFromRegClass(const TargetRegisterClass &RC, LLT Ty) const { llvm_unreachable("The target must override this method"); Index: llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp =================================================================== --- llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp +++ llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp @@ -228,48 +228,13 @@ const RegisterBank & AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC, - LLT) const { + LLT Ty) const { switch (RC.getID()) { - case AArch64::FPR8RegClassID: - case AArch64::FPR16RegClassID: - case AArch64::FPR16_loRegClassID: - case AArch64::FPR32_with_hsub_in_FPR16_loRegClassID: - case AArch64::FPR32RegClassID: - case AArch64::FPR64RegClassID: - case AArch64::FPR64_loRegClassID: - case AArch64::FPR128RegClassID: - case AArch64::FPR128_loRegClassID: - case AArch64::DDRegClassID: - case AArch64::DDDRegClassID: - case AArch64::DDDDRegClassID: - case AArch64::QQRegClassID: - case AArch64::QQQRegClassID: - case AArch64::QQQQRegClassID: - return getRegBank(AArch64::FPRRegBankID); - case AArch64::GPR32commonRegClassID: - case AArch64::GPR32RegClassID: - case AArch64::GPR32spRegClassID: - case AArch64::GPR32sponlyRegClassID: - case AArch64::GPR32argRegClassID: - case AArch64::GPR32allRegClassID: - case AArch64::GPR64commonRegClassID: - case AArch64::GPR64RegClassID: - case AArch64::GPR64spRegClassID: - case AArch64::GPR64sponlyRegClassID: - case AArch64::GPR64argRegClassID: - case AArch64::GPR64allRegClassID: - case AArch64::GPR64noipRegClassID: - case AArch64::GPR64common_and_GPR64noipRegClassID: - case AArch64::GPR64noip_and_tcGPR64RegClassID: - case AArch64::tcGPR64RegClassID: - case AArch64::rtcGPR64RegClassID: case AArch64::WSeqPairsClassRegClassID: case AArch64::XSeqPairsClassRegClassID: return getRegBank(AArch64::GPRRegBankID); - case AArch64::CCRRegClassID: - return getRegBank(AArch64::CCRegBankID); default: - llvm_unreachable("Register class not supported"); + return AArch64GenRegisterBankInfo::getRegBankFromRegClass(RC, Ty); } } Index: llvm/lib/Target/ARM/ARMRegisterBankInfo.h =================================================================== --- llvm/lib/Target/ARM/ARMRegisterBankInfo.h +++ llvm/lib/Target/ARM/ARMRegisterBankInfo.h @@ -32,9 +32,6 @@ public: ARMRegisterBankInfo(const TargetRegisterInfo &TRI); - const RegisterBank &getRegBankFromRegClass(const TargetRegisterClass &RC, - LLT) const override; - const InstructionMapping & getInstrMapping(const MachineInstr &MI) const override; }; Index: llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp =================================================================== --- llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp +++ llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp @@ -174,44 +174,6 @@ llvm::call_once(InitializeRegisterBankFlag, InitializeRegisterBankOnce); } -const RegisterBank & -ARMRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC, - LLT) const { - using namespace ARM; - - switch (RC.getID()) { - case GPRRegClassID: - case GPRwithAPSRRegClassID: - case GPRnoipRegClassID: - case GPRnopcRegClassID: - case GPRnoip_and_GPRnopcRegClassID: - case rGPRRegClassID: - case GPRspRegClassID: - case GPRnoip_and_tcGPRRegClassID: - case tcGPRRegClassID: - case tGPRRegClassID: - case tGPREvenRegClassID: - case tGPROddRegClassID: - case tGPR_and_tGPREvenRegClassID: - case tGPR_and_tGPROddRegClassID: - case tGPREven_and_tcGPRRegClassID: - case tGPREven_and_GPRnoip_and_tcGPRRegClassID: - case tGPROdd_and_tcGPRRegClassID: - return getRegBank(ARM::GPRRegBankID); - case HPRRegClassID: - case SPR_8RegClassID: - case SPRRegClassID: - case DPR_8RegClassID: - case DPRRegClassID: - case QPRRegClassID: - return getRegBank(ARM::FPRRegBankID); - default: - llvm_unreachable("Unsupported register kind"); - } - - llvm_unreachable("Switch should handle all register classes"); -} - const RegisterBankInfo::InstructionMapping & ARMRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const { auto Opc = MI.getOpcode(); Index: llvm/lib/Target/Mips/MipsRegisterBankInfo.h =================================================================== --- llvm/lib/Target/Mips/MipsRegisterBankInfo.h +++ llvm/lib/Target/Mips/MipsRegisterBankInfo.h @@ -32,9 +32,6 @@ public: MipsRegisterBankInfo(const TargetRegisterInfo &TRI); - const RegisterBank &getRegBankFromRegClass(const TargetRegisterClass &RC, - LLT) const override; - const InstructionMapping & getInstrMapping(const MachineInstr &MI) const override; Index: llvm/lib/Target/Mips/MipsRegisterBankInfo.cpp =================================================================== --- llvm/lib/Target/Mips/MipsRegisterBankInfo.cpp +++ llvm/lib/Target/Mips/MipsRegisterBankInfo.cpp @@ -76,35 +76,6 @@ MipsRegisterBankInfo::MipsRegisterBankInfo(const TargetRegisterInfo &TRI) : MipsGenRegisterBankInfo() {} -const RegisterBank & -MipsRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC, - LLT) const { - using namespace Mips; - - switch (RC.getID()) { - case Mips::GPR32RegClassID: - case Mips::CPU16Regs_and_GPRMM16ZeroRegClassID: - case Mips::GPRMM16MovePPairFirstRegClassID: - case Mips::CPU16Regs_and_GPRMM16MovePPairSecondRegClassID: - case Mips::GPRMM16MoveP_and_CPU16Regs_and_GPRMM16ZeroRegClassID: - case Mips::GPRMM16MovePPairFirst_and_GPRMM16MovePPairSecondRegClassID: - case Mips::SP32RegClassID: - case Mips::GP32RegClassID: - return getRegBank(Mips::GPRBRegBankID); - case Mips::FGRCCRegClassID: - case Mips::FGR32RegClassID: - case Mips::FGR64RegClassID: - case Mips::AFGR64RegClassID: - case Mips::MSA128BRegClassID: - case Mips::MSA128HRegClassID: - case Mips::MSA128WRegClassID: - case Mips::MSA128DRegClassID: - return getRegBank(Mips::FPRBRegBankID); - default: - llvm_unreachable("Register class not supported"); - } -} - // Instructions where all register operands are floating point. static bool isFloatingPointOpcode(unsigned Opc) { switch (Opc) { Index: llvm/test/TableGen/RegisterBankEmitter.td =================================================================== --- llvm/test/TableGen/RegisterBankEmitter.td +++ llvm/test/TableGen/RegisterBankEmitter.td @@ -4,12 +4,22 @@ def MyTarget : Target; def R0 : Register<"r0">; +def F1 : Register<"f1">; + let Size = 32 in { def ClassA : RegisterClass<"MyTarget", [i32], 32, (add R0)>; def ClassB : RegisterClass<"MyTarget", [i1], 32, (add ClassA)>; + + def ClassC : RegisterClass<"MyTarget", [i32], 32, (add F1, ClassA)>; } // CHECK: GPRRegBankCoverageData // CHECK: MyTarget::ClassARegClassID // CHECK: MyTarget::ClassBRegClassID + +// CHECK: MyTargetGenRegisterBankInfo::getRegBankFromRegClass +// CHECK: case MyTarget::ClassARegClassID: +// CHECK: return getRegBank(MyTarget::FPRRegBankID); + def GPRRegBank : RegisterBank<"GPR", [ClassA]>; +def FPRRegBank : RegisterBank<"FPR", [ClassC]>; Index: llvm/utils/TableGen/RegisterBankEmitter.cpp =================================================================== --- llvm/utils/TableGen/RegisterBankEmitter.cpp +++ llvm/utils/TableGen/RegisterBankEmitter.cpp @@ -13,6 +13,7 @@ #include "llvm/ADT/BitVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -149,6 +150,11 @@ << "protected:\n" << " " << TargetName << "GenRegisterBankInfo();\n" << "\n"; + + OS << " virtual const RegisterBank &\n" + << " getRegBankFromRegClass(const TargetRegisterClass &RC, LLT Ty) const " + "override;\n" + << "\n"; } /// Visit each register class belonging to the given register bank. @@ -271,8 +277,40 @@ << " for (const auto &RB : RegBanks)\n" << " assert(Index++ == RB->getID() && \"Index != ID\");\n" << "#endif // NDEBUG\n" - << "}\n" - << "} // end namespace llvm\n"; + << "}\n\n"; + + OS << "const RegisterBank &\n" + << TargetName + << "GenRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass " + "&RC, LLT Ty) const {\n" + << " switch (RC.getID()) {\n"; + + std::map RCMap; + for (const auto &Bank : Banks) { + for (const auto &RC : Bank.register_classes()) { + std::string QualifiedRegClassID = + (Twine(RC->Namespace) + "::" + RC->getName() + "RegClassID").str(); + if (RCMap.find(QualifiedRegClassID) == RCMap.end()) { + OS << " case " << QualifiedRegClassID << ":\n"; + RCMap.insert({QualifiedRegClassID, Bank.getName()}); + } else { + std::string WarnStr = formatv( + "Register class '{0}' is already mapped to register bank '{1}'.", + RC->getName(), RCMap[QualifiedRegClassID]); + PrintWarning(Bank.getDef().getLoc(), WarnStr); + } + } + OS << " return getRegBank(" << TargetName + << "::" << Bank.getEnumeratorName() << ");\n"; + } + + OS << "" + << " default:\n" + << " llvm_unreachable(\"Register class not supported\");\n" + << " }\n" + << "}\n"; + + OS << "} // end namespace llvm\n"; } void RegisterBankEmitter::run(raw_ostream &OS) {