diff --git a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h --- a/llvm/include/llvm/CodeGen/TargetRegisterInfo.h +++ b/llvm/include/llvm/CodeGen/TargetRegisterInfo.h @@ -693,6 +693,11 @@ static void dumpReg(Register Reg, unsigned SubRegIndex = 0, const TargetRegisterInfo *TRI = nullptr); + /// Return target defined base register class for a physical register. + virtual const TargetRegisterClass *getPhysRegBaseClass(MCRegister Reg) const { + return nullptr; + } + protected: /// Overridden by TableGen in targets that have sub-registers. virtual unsigned composeSubRegIndicesImpl(unsigned, unsigned) const { diff --git a/llvm/include/llvm/Target/Target.td b/llvm/include/llvm/Target/Target.td --- a/llvm/include/llvm/Target/Target.td +++ b/llvm/include/llvm/Target/Target.td @@ -318,6 +318,12 @@ // Target-specific flags. This becomes the TSFlags field in TargetRegisterClass. bits<8> TSFlags = 0; + + // If set then consider this register class to be the base class for registers in + // its MemberList. The base class for registers present in multiple base register + // classes will be resolved in the order defined by this value, with lower values + // taking precedence over higher ones. Ties are resolved by enumeration order. + int BaseClassOrder = ?; } // The memberList in a RegisterClass is a dag of set operations. TableGen diff --git a/llvm/utils/TableGen/CodeGenRegisters.h b/llvm/utils/TableGen/CodeGenRegisters.h --- a/llvm/utils/TableGen/CodeGenRegisters.h +++ b/llvm/utils/TableGen/CodeGenRegisters.h @@ -472,6 +472,13 @@ // Called by CodeGenRegBank::CodeGenRegBank(). static void computeSubClasses(CodeGenRegBank&); + + // Get ordering value among register base classes. + std::optional getBaseClassOrder() const { + if (TheDef && !TheDef->isValueUnset("BaseClassOrder")) + return TheDef->getValueAsInt("BaseClassOrder"); + return {}; + } }; // Register categories are used when we need to deterine the category a diff --git a/llvm/utils/TableGen/RegisterInfoEmitter.cpp b/llvm/utils/TableGen/RegisterInfoEmitter.cpp --- a/llvm/utils/TableGen/RegisterInfoEmitter.cpp +++ b/llvm/utils/TableGen/RegisterInfoEmitter.cpp @@ -1195,10 +1195,15 @@ << " bool isConstantPhysReg(MCRegister PhysReg) const override final;\n" << " /// Devirtualized TargetFrameLowering.\n" << " static const " << TargetName << "FrameLowering *getFrameLowering(\n" - << " const MachineFunction &MF);\n" - << "};\n\n"; + << " const MachineFunction &MF);\n"; const auto &RegisterClasses = RegBank.getRegClasses(); + if (llvm::any_of(RegisterClasses, [](const auto &RC) { + return RC.getBaseClassOrder() != std::nullopt; })) { + OS << " const TargetRegisterClass *getPhysRegBaseClass(MCRegister Reg) const override;\n"; + } + + OS << "};\n\n"; if (!RegisterClasses.empty()) { OS << "namespace " << RegisterClasses.front().Namespace @@ -1595,6 +1600,59 @@ EmitRegUnitPressure(OS, RegBank, ClassName); + // Emit register base class mapper + if (!RegisterClasses.empty()) { + // Collect base classes + SmallVector BaseClasses; + for (const auto &RC : RegisterClasses) { + if (auto Order = RC.getBaseClassOrder()) + BaseClasses.push_back(&RC); + } + if (!BaseClasses.empty()) { + assert(BaseClasses.size() < (UINT8_MAX - 1) && "Too many base register classes"); + + // Apply order + struct BaseClassOrdering { + bool operator()(const CodeGenRegisterClass *LHS, const CodeGenRegisterClass *RHS) const { + auto LHSOrder = LHS->getBaseClassOrder(); + auto RHSOrder = RHS->getBaseClassOrder(); + if (LHSOrder == RHSOrder) + return LHS->EnumValue < RHS->EnumValue; + else + return LHSOrder < RHSOrder; + } + }; + llvm::stable_sort(BaseClasses, BaseClassOrdering()); + + // Build mapping + std::vector Mapping; + Mapping.resize(Regs.size() + 1, 0); + for (int RCIdx = BaseClasses.size() - 1; RCIdx >= 0; --RCIdx) { + for (const auto Reg : BaseClasses[RCIdx]->getMembers()) + Mapping[Reg->EnumValue] = RCIdx + 1; + } + + OS << "\n// Register to base register class mapping\n\n"; + OS << "\n"; + OS << "const TargetRegisterClass *" << ClassName + << "::getPhysRegBaseClass(MCRegister Reg)" + << " const {\n"; + OS << " static const TargetRegisterClass *BaseClasses[" << BaseClasses.size() << "] = {\n"; + for (const auto RC : BaseClasses) + OS << " &" << RC->getQualifiedName() << "RegClass,\n"; + OS << " };\n"; + OS << " static const uint8_t Mapping[" << Mapping.size() << "] = {\n"; + for (const uint8_t Value : Mapping) + OS << " " << (int)Value << ","; + OS << " };\n\n"; + OS << " if (Reg >= sizeof(Mapping))\n return nullptr;\n"; + OS << " if (!Mapping[Reg])\n return nullptr;\n"; + OS << " return BaseClasses[Mapping[Reg] - 1];\n"; + OS << "}\n"; + } + } + + // Emit the constructor of the class... OS << "extern const MCRegisterDesc " << TargetName << "RegDesc[];\n"; OS << "extern const MCPhysReg " << TargetName << "RegDiffLists[];\n";