Index: llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h =================================================================== --- llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h +++ llvm/include/llvm/CodeGen/GlobalISel/RegisterBankInfo.h @@ -599,7 +599,6 @@ /// that are used in the description of instruction. In other words, /// there are just a handful of them and we do not want to waste space. /// - /// \todo This should be TableGen'ed. virtual const RegisterBank & getRegBankFromRegClass(const TargetRegisterClass &RC, LLT Ty) const { llvm_unreachable("The target must override this method"); Index: llvm/test/TableGen/RegisterBankEmitter.td =================================================================== --- llvm/test/TableGen/RegisterBankEmitter.td +++ llvm/test/TableGen/RegisterBankEmitter.td @@ -4,12 +4,22 @@ def MyTarget : Target; def R0 : Register<"r0">; +def F1 : Register<"f1">; + let Size = 32 in { def ClassA : RegisterClass<"MyTarget", [i32], 32, (add R0)>; def ClassB : RegisterClass<"MyTarget", [i1], 32, (add ClassA)>; + + def ClassC : RegisterClass<"MyTarget", [i32], 32, (add F1, ClassA)>; } // CHECK: GPRRegBankCoverageData // CHECK: MyTarget::ClassARegClassID // CHECK: MyTarget::ClassBRegClassID + +// CHECK: MyTargetGenRegisterBankInfo::getRegBankFromRegClass +// CHECK: case MyTarget::ClassARegClassID: +// CHECK: return getRegBank(MyTarget::FPRRegBankID); + def GPRRegBank : RegisterBank<"GPR", [ClassA]>; +def FPRRegBank : RegisterBank<"FPR", [ClassC]>; Index: llvm/utils/TableGen/RegisterBankEmitter.cpp =================================================================== --- llvm/utils/TableGen/RegisterBankEmitter.cpp +++ llvm/utils/TableGen/RegisterBankEmitter.cpp @@ -13,6 +13,7 @@ #include "llvm/ADT/BitVector.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/FormatVariadic.h" #include "llvm/TableGen/Error.h" #include "llvm/TableGen/Record.h" #include "llvm/TableGen/TableGenBackend.h" @@ -149,6 +150,11 @@ << "protected:\n" << " " << TargetName << "GenRegisterBankInfo();\n" << "\n"; + + OS << " virtual const RegisterBank &\n" + << " getRegBankFromRegClass(const TargetRegisterClass &RC, LLT Ty) const " + "override;\n" + << "\n"; } /// Visit each register class belonging to the given register bank. @@ -271,8 +277,41 @@ << " for (const auto &RB : RegBanks)\n" << " assert(Index++ == RB->getID() && \"Index != ID\");\n" << "#endif // NDEBUG\n" - << "}\n" - << "} // end namespace llvm\n"; + << "}\n\n"; + + OS << "const RegisterBank &\n" + << TargetName + << "GenRegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass " + "&RC, LLT Ty) const {\n" + << " switch (RC.getID()) {\n"; + + std::map RCMap; + for (const auto &Bank : Banks) { + std::string QualifiedBankID = + (TargetName + "::" + Bank.getEnumeratorName()).str(); + for (const auto &RC : Bank.register_classes()) { + std::string QualifiedRegClassID = + (Twine(RC->Namespace) + "::" + RC->getName() + "RegClassID").str(); + if (RCMap.find(QualifiedRegClassID) == RCMap.end()) { + OS << " case " << QualifiedRegClassID << ":\n"; + RCMap.insert({QualifiedRegClassID, Bank.getName()}); + } else { + std::string WarnStr = formatv( + "Register class '{0}' is already mapped to register bank '{1}'.", + RC->getName(), RCMap[QualifiedRegClassID]); + PrintWarning(Bank.getDef().getLoc(), WarnStr); + } + } + OS << " return getRegBank(" << QualifiedBankID << ");\n"; + } + + OS << "" + << " default:\n" + << " llvm_unreachable(\"Register class not supported\");\n" + << " }\n" + << "}\n"; + + OS << "} // end namespace llvm\n"; } void RegisterBankEmitter::run(raw_ostream &OS) {