diff --git a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h --- a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h +++ b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h @@ -693,6 +693,14 @@ static void dumpReg(Register Reg, unsigned SubRegIndex = 0, const TargetRegisterInfo *TRI = nullptr); + /// Return target defined base register class for a physical register. + /// This is the register class with the lowest BaseClassOrder containing the + /// register. + /// Will be nullptr if the register is not in any base register class. + virtual const TargetRegisterClass *getPhysRegBaseClass(MCRegister Reg) const { + return nullptr; + } + protected: /// Overridden by TableGen in targets that have sub-registers. virtual unsigned composeSubRegIndicesImpl(unsigned, unsigned) const { diff --git a/llvm/include/llvm/Target/Target.td b/llvm/include/llvm/Target/Target.td --- a/llvm/include/llvm/Target/Target.td +++ b/llvm/include/llvm/Target/Target.td @@ -318,6 +318,12 @@ // Target-specific flags. This becomes the TSFlags field in TargetRegisterClass. bits<8> TSFlags = 0; + + // If set then consider this register class to be the base class for registers in + // its MemberList. The base class for registers present in multiple base register + // classes will be resolved in the order defined by this value, with lower values + // taking precedence over higher ones. Ties are resolved by enumeration order. + int BaseClassOrder = ?; } // The memberList in a RegisterClass is a dag of set operations. TableGen diff --git a/llvm/test/TableGen/RegisterInfoEmitter-BaseClassOrder.td b/llvm/test/TableGen/RegisterInfoEmitter-BaseClassOrder.td new file mode 100644 --- /dev/null +++ b/llvm/test/TableGen/RegisterInfoEmitter-BaseClassOrder.td @@ -0,0 +1,38 @@ +// RUN: llvm-tblgen -gen-register-info -I %p/../../include -I %p/Common %s | FileCheck %s + +include "llvm/Target/Target.td" + +let Namespace = "MyTarget" in { + def R0 : Register<"r0">; // base class BaseA + def R1 : Register<"r1">; // base class BaseA + def R2 : Register<"r2">; // base class BaseC + def R3 : Register<"r3">; // base class BaseC + def R4 : Register<"r4">; // base class BaseB + def R5 : Register<"r5">; // base class BaseB + def R6 : Register<"r6">; // no base class +} // Namespace = "MyTarget" + + +// BaseA and BaseB are equal ordered so enumeration order determines base class for overlaps +def BaseA : RegisterClass<"MyTarget", [i32], 32, (sequence "R%u", 0, 3)> { + let BaseClassOrder = 1; +} +def BaseB : RegisterClass<"MyTarget", [i32], 32, (sequence "R%u", 3, 5)> { + let BaseClassOrder = 1; +} + +// BaseC defined order overrides BaseA and BaseB +def BaseC : RegisterClass<"MyTarget", [i32], 32, (sequence "R%u", 2, 3)> { + let BaseClassOrder = 0; +} + +def MyTarget : Target; + +// CHECK: static const TargetRegisterClass *BaseClasses[4] = { +// CHECK-NEXT: nullptr, +// CHECK-NEXT: &MyTarget::BaseCRegClass, +// CHECK-NEXT: &MyTarget::BaseARegClass, +// CHECK-NEXT: &MyTarget::BaseBRegClass, +// CHECK-NEXT: } +// CHECK-NEXT: static const uint8_t Mapping[8] = { +// CHECK-NEXT: 0,2,2,1,1,3,3,0, }; diff --git a/llvm/utils/TableGen/CodeGenRegisters.h b/llvm/utils/TableGen/CodeGenRegisters.h --- a/llvm/utils/TableGen/CodeGenRegisters.h +++ b/llvm/utils/TableGen/CodeGenRegisters.h @@ -472,6 +472,13 @@ // Called by CodeGenRegBank::CodeGenRegBank(). static void computeSubClasses(CodeGenRegBank&); + + // Get ordering value among register base classes. + std::optional getBaseClassOrder() const { + if (TheDef && !TheDef->isValueUnset("BaseClassOrder")) + return TheDef->getValueAsInt("BaseClassOrder"); + return {}; + } }; // Register categories are used when we need to deterine the category a diff --git a/llvm/utils/TableGen/RegisterInfoEmitter.cpp b/llvm/utils/TableGen/RegisterInfoEmitter.cpp --- a/llvm/utils/TableGen/RegisterInfoEmitter.cpp +++ b/llvm/utils/TableGen/RegisterInfoEmitter.cpp @@ -1195,10 +1195,14 @@ << " bool isConstantPhysReg(MCRegister PhysReg) const override final;\n" << " /// Devirtualized TargetFrameLowering.\n" << " static const " << TargetName << "FrameLowering *getFrameLowering(\n" - << " const MachineFunction &MF);\n" - << "};\n\n"; + << " const MachineFunction &MF);\n"; const auto &RegisterClasses = RegBank.getRegClasses(); + if (llvm::any_of(RegisterClasses, [](const auto &RC) { return RC.getBaseClassOrder(); })) { + OS << " const TargetRegisterClass *getPhysRegBaseClass(MCRegister Reg) const override;\n"; + } + + OS << "};\n\n"; if (!RegisterClasses.empty()) { OS << "namespace " << RegisterClasses.front().Namespace @@ -1595,6 +1599,54 @@ EmitRegUnitPressure(OS, RegBank, ClassName); + // Emit register base class mapper + if (!RegisterClasses.empty()) { + // Collect base classes + SmallVector BaseClasses; + for (const auto &RC : RegisterClasses) { + if (RC.getBaseClassOrder()) + BaseClasses.push_back(&RC); + } + if (!BaseClasses.empty()) { + // Represent class indexes with uint8_t and allocate one index for nullptr + assert(BaseClasses.size() < UINT8_MAX && "Too many base register classes"); + + // Apply order + struct BaseClassOrdering { + bool operator()(const CodeGenRegisterClass *LHS, const CodeGenRegisterClass *RHS) const { + return std::pair(*LHS->getBaseClassOrder(), LHS->EnumValue) + < std::pair(*RHS->getBaseClassOrder(), RHS->EnumValue); + } + }; + llvm::stable_sort(BaseClasses, BaseClassOrdering()); + + // Build mapping for Regs (+1 for NoRegister) + std::vector Mapping(Regs.size() + 1, 0); + for (int RCIdx = BaseClasses.size() - 1; RCIdx >= 0; --RCIdx) { + for (const auto Reg : BaseClasses[RCIdx]->getMembers()) + Mapping[Reg->EnumValue] = RCIdx + 1; + } + + OS << "\n// Register to base register class mapping\n\n"; + OS << "\n"; + OS << "const TargetRegisterClass *" << ClassName + << "::getPhysRegBaseClass(MCRegister Reg)" + << " const {\n"; + OS << " static const TargetRegisterClass *BaseClasses[" << (BaseClasses.size() + 1) << "] = {\n"; + OS << " nullptr,\n"; + for (const auto RC : BaseClasses) + OS << " &" << RC->getQualifiedName() << "RegClass,\n"; + OS << " };\n"; + OS << " static const uint8_t Mapping[" << Mapping.size() << "] = {\n "; + for (const uint8_t Value : Mapping) + OS << (unsigned)Value << ","; + OS << " };\n\n"; + OS << " assert(Reg < sizeof(Mapping));\n"; + OS << " return BaseClasses[Mapping[Reg]];\n"; + OS << "}\n"; + } + } + // Emit the constructor of the class... OS << "extern const MCRegisterDesc " << TargetName << "RegDesc[];\n"; OS << "extern const MCPhysReg " << TargetName << "RegDiffLists[];\n";