diff --git a/llvm/include/llvm/CodeGen/RegisterBank.h b/llvm/include/llvm/CodeGen/RegisterBank.h --- a/llvm/include/llvm/CodeGen/RegisterBank.h +++ b/llvm/include/llvm/CodeGen/RegisterBank.h @@ -29,7 +29,6 @@ private: unsigned ID; const char *Name; - unsigned Size; BitVector ContainedRegClasses; /// Sentinel value used to recognize register bank not properly @@ -40,8 +39,8 @@ friend RegisterBankInfo; public: - RegisterBank(unsigned ID, const char *Name, unsigned Size, - const uint32_t *CoveredClasses, unsigned NumRegClasses); + RegisterBank(unsigned ID, const char *Name, const uint32_t *CoveredClasses, + unsigned NumRegClasses); /// Get the identifier of this register bank. unsigned getID() const { return ID; } @@ -50,9 +49,6 @@ /// Should be used only for debugging purposes. const char *getName() const { return Name; } - /// Get the maximal size in bits that fits in this register bank. - unsigned getSize() const { return Size; } - /// Check whether this instance is ready to be used. bool isValid() const; @@ -62,7 +58,7 @@ /// \note This method does not check anything when assertions are disabled. /// /// \return True is the check was successful. - bool verify(const TargetRegisterInfo &TRI) const; + bool verify(const RegisterBankInfo &RBI, const TargetRegisterInfo &TRI) const; /// Check whether this register bank covers \p RC. /// In other words, check if this register bank fully covers diff --git a/llvm/include/llvm/CodeGen/RegisterBankInfo.h b/llvm/include/llvm/CodeGen/RegisterBankInfo.h --- a/llvm/include/llvm/CodeGen/RegisterBankInfo.h +++ b/llvm/include/llvm/CodeGen/RegisterBankInfo.h @@ -20,6 +20,7 @@ #include "llvm/ADT/iterator_range.h" #include "llvm/CodeGen/LowLevelType.h" #include "llvm/CodeGen/Register.h" +#include "llvm/CodeGen/RegisterBank.h" #include "llvm/Support/ErrorHandling.h" #include #include @@ -30,7 +31,6 @@ class MachineInstr; class MachineRegisterInfo; class raw_ostream; -class RegisterBank; class TargetInstrInfo; class TargetRegisterClass; class TargetRegisterInfo; @@ -83,7 +83,7 @@ /// \note This method does not check anything when assertions are disabled. /// /// \return True is the check was successful. - bool verify() const; + bool verify(const RegisterBankInfo &RBI) const; }; /// Helper struct that represents how a value is mapped through @@ -175,7 +175,7 @@ /// \note This method does not check anything when assertions are disabled. /// /// \return True is the check was successful. - bool verify(unsigned MeaningfulBitWidth) const; + bool verify(const RegisterBankInfo &RBI, unsigned MeaningfulBitWidth) const; /// Print this on dbgs() stream. void dump() const; @@ -384,11 +384,17 @@ protected: /// Hold the set of supported register banks. - RegisterBank **RegBanks; + const RegisterBank **RegBanks; /// Total number of register banks. unsigned NumRegBanks; + /// Hold the sizes of the register banks for all HwModes. + const unsigned *Sizes; + + /// Current HwMode for the target. + unsigned HwMode; + /// Keep dynamically allocated PartialMapping in a separate map. /// This shouldn't be needed when everything gets TableGen'ed. mutable DenseMap> @@ -415,7 +421,8 @@ /// Create a RegisterBankInfo that can accommodate up to \p NumRegBanks /// RegisterBank instances. - RegisterBankInfo(RegisterBank **RegBanks, unsigned NumRegBanks); + RegisterBankInfo(const RegisterBank **RegBanks, unsigned NumRegBanks, + const unsigned *Sizes, unsigned HwMode); /// This constructor is meaningless. /// It just provides a default constructor that can be used at link time @@ -428,7 +435,7 @@ } /// Get the register bank identified by \p ID. - RegisterBank &getRegBank(unsigned ID) { + const RegisterBank &getRegBank(unsigned ID) { assert(ID < getNumRegBanks() && "Accessing an unknown register bank"); return *RegBanks[ID]; } @@ -576,6 +583,11 @@ return const_cast(this)->getRegBank(ID); } + /// Get the maximum size in bits that fits in the given register bank. + unsigned getMaximumSize(unsigned RegBankID) const { + return Sizes[RegBankID + HwMode * NumRegBanks]; + } + /// Get the register bank of \p Reg. /// If Reg has not been assigned a register, a register class, /// or a register bank, then this returns nullptr. diff --git a/llvm/lib/CodeGen/MachineVerifier.cpp b/llvm/lib/CodeGen/MachineVerifier.cpp --- a/llvm/lib/CodeGen/MachineVerifier.cpp +++ b/llvm/lib/CodeGen/MachineVerifier.cpp @@ -2174,6 +2174,7 @@ } const RegisterBank *RegBank = MRI->getRegBankOrNull(Reg); + const RegisterBankInfo *RBI = MF->getSubtarget().getRegBankInfo(); // If we're post-RegBankSelect, the gvreg must have a bank. if (!RegBank && isFunctionRegBankSelected) { @@ -2185,12 +2186,12 @@ // Make sure the register fits into its register bank if any. if (RegBank && Ty.isValid() && - RegBank->getSize() < Ty.getSizeInBits()) { + RBI->getMaximumSize(RegBank->getID()) < Ty.getSizeInBits()) { report("Register bank is too small for virtual register", MO, MONum); errs() << "Register bank " << RegBank->getName() << " too small(" - << RegBank->getSize() << ") to fit " << Ty.getSizeInBits() - << "-bits\n"; + << RBI->getMaximumSize(RegBank->getID()) << ") to fit " + << Ty.getSizeInBits() << "-bits\n"; return; } } diff --git a/llvm/lib/CodeGen/RegisterBank.cpp b/llvm/lib/CodeGen/RegisterBank.cpp --- a/llvm/lib/CodeGen/RegisterBank.cpp +++ b/llvm/lib/CodeGen/RegisterBank.cpp @@ -11,6 +11,7 @@ #include "llvm/CodeGen/RegisterBank.h" #include "llvm/ADT/StringExtras.h" +#include "llvm/CodeGen/RegisterBankInfo.h" #include "llvm/CodeGen/TargetRegisterInfo.h" #include "llvm/Config/llvm-config.h" #include "llvm/Support/Debug.h" @@ -21,15 +22,16 @@ const unsigned RegisterBank::InvalidID = UINT_MAX; -RegisterBank::RegisterBank( - unsigned ID, const char *Name, unsigned Size, - const uint32_t *CoveredClasses, unsigned NumRegClasses) - : ID(ID), Name(Name), Size(Size) { +RegisterBank::RegisterBank(unsigned ID, const char *Name, + const uint32_t *CoveredClasses, + unsigned NumRegClasses) + : ID(ID), Name(Name) { ContainedRegClasses.resize(NumRegClasses); ContainedRegClasses.setBitsInMask(CoveredClasses); } -bool RegisterBank::verify(const TargetRegisterInfo &TRI) const { +bool RegisterBank::verify(const RegisterBankInfo &RBI, + const TargetRegisterInfo &TRI) const { assert(isValid() && "Invalid register bank"); for (unsigned RCId = 0, End = TRI.getNumRegClasses(); RCId != End; ++RCId) { const TargetRegisterClass &RC = *TRI.getRegClass(RCId); @@ -50,7 +52,7 @@ // Verify that the Size of the register bank is big enough to cover // all the register classes it covers. - assert(getSize() >= TRI.getRegSizeInBits(SubRC) && + assert(RBI.getMaximumSize(getID()) >= TRI.getRegSizeInBits(SubRC) && "Size is not big enough for all the subclasses!"); assert(covers(SubRC) && "Not all subclasses are covered"); } @@ -64,7 +66,7 @@ } bool RegisterBank::isValid() const { - return ID != InvalidID && Name != nullptr && Size != 0 && + return ID != InvalidID && Name != nullptr && // A register bank that does not cover anything is useless. !ContainedRegClasses.empty(); } @@ -89,7 +91,7 @@ OS << getName(); if (!IsForDebug) return; - OS << "(ID:" << getID() << ", Size:" << getSize() << ")\n" + OS << "(ID:" << getID() << ")\n" << "isValid:" << isValid() << '\n' << "Number of Covered register classes: " << ContainedRegClasses.count() << '\n'; diff --git a/llvm/lib/CodeGen/RegisterBankInfo.cpp b/llvm/lib/CodeGen/RegisterBankInfo.cpp --- a/llvm/lib/CodeGen/RegisterBankInfo.cpp +++ b/llvm/lib/CodeGen/RegisterBankInfo.cpp @@ -52,9 +52,11 @@ //------------------------------------------------------------------------------ // RegisterBankInfo implementation. //------------------------------------------------------------------------------ -RegisterBankInfo::RegisterBankInfo(RegisterBank **RegBanks, - unsigned NumRegBanks) - : RegBanks(RegBanks), NumRegBanks(NumRegBanks) { +RegisterBankInfo::RegisterBankInfo(const RegisterBank **RegBanks, + unsigned NumRegBanks, const unsigned *Sizes, + unsigned HwMode) + : RegBanks(RegBanks), NumRegBanks(NumRegBanks), Sizes(Sizes), + HwMode(HwMode) { #ifndef NDEBUG for (unsigned Idx = 0, End = getNumRegBanks(); Idx != End; ++Idx) { assert(RegBanks[Idx] != nullptr && "Invalid RegisterBank"); @@ -70,7 +72,7 @@ assert(Idx == RegBank.getID() && "ID does not match the index in the array"); LLVM_DEBUG(dbgs() << "Verify " << RegBank << '\n'); - assert(RegBank.verify(TRI) && "RegBank is invalid"); + assert(RegBank.verify(*this, TRI) && "RegBank is invalid"); } #endif // NDEBUG return true; @@ -516,12 +518,14 @@ } #endif -bool RegisterBankInfo::PartialMapping::verify() const { +bool RegisterBankInfo::PartialMapping::verify( + const RegisterBankInfo &RBI) const { assert(RegBank && "Register bank not set"); assert(Length && "Empty mapping"); assert((StartIdx <= getHighBitIdx()) && "Overflow, switch to APInt?"); // Check if the minimum width fits into RegBank. - assert(RegBank->getSize() >= Length && "Register bank too small for Mask"); + assert(RBI.getMaximumSize(RegBank->getID()) >= Length && + "Register bank too small for Mask"); return true; } @@ -546,13 +550,14 @@ return true; } -bool RegisterBankInfo::ValueMapping::verify(unsigned MeaningfulBitWidth) const { +bool RegisterBankInfo::ValueMapping::verify(const RegisterBankInfo &RBI, + unsigned MeaningfulBitWidth) const { assert(NumBreakDowns && "Value mapped nowhere?!"); unsigned OrigValueBitWidth = 0; for (const RegisterBankInfo::PartialMapping &PartMap : *this) { // Check that each register bank is big enough to hold the partial value: // this check is done by PartialMapping::verify - assert(PartMap.verify() && "Partial mapping is invalid"); + assert(PartMap.verify(RBI) && "Partial mapping is invalid"); // The original value should completely be mapped. // Thus the maximum accessed index + 1 is the size of the original value. OrigValueBitWidth = @@ -626,8 +631,9 @@ (void)MOMapping; // Register size in bits. // This size must match what the mapping expects. - assert(MOMapping.verify(RBI->getSizeInBits( - Reg, MF.getRegInfo(), *MF.getSubtarget().getRegisterInfo())) && + assert(MOMapping.verify(*RBI, RBI->getSizeInBits( + Reg, MF.getRegInfo(), + *MF.getSubtarget().getRegisterInfo())) && "Value mapping is invalid"); } return true; diff --git a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp --- a/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp +++ b/llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp @@ -71,7 +71,8 @@ // GR64all + its subclasses. assert(RBGPR.covers(*TRI.getRegClass(AArch64::GPR32RegClassID)) && "Subclass not added?"); - assert(RBGPR.getSize() == 128 && "GPRs should hold up to 128-bit"); + assert(getMaximumSize(RBGPR.getID()) == 128 && + "GPRs should hold up to 128-bit"); // The FPR register bank is fully defined by all the registers in // GR64all + its subclasses. @@ -79,12 +80,13 @@ "Subclass not added?"); assert(RBFPR.covers(*TRI.getRegClass(AArch64::FPR64RegClassID)) && "Subclass not added?"); - assert(RBFPR.getSize() == 512 && + assert(getMaximumSize(RBFPR.getID()) == 512 && "FPRs should hold up to 512-bit via QQQQ sequence"); assert(RBCCR.covers(*TRI.getRegClass(AArch64::CCRRegClassID)) && "Class not added?"); - assert(RBCCR.getSize() == 32 && "CCR should hold up to 32-bit"); + assert(getMaximumSize(RBCCR.getID()) == 32 && + "CCR should hold up to 32-bit"); // Check that the TableGen'ed like file is in sync we our expectations. // First, the Idx. diff --git a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp --- a/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp +++ b/llvm/lib/Target/ARM/ARMRegisterBankInfo.cpp @@ -162,7 +162,8 @@ "Subclass not added?"); assert(RBGPR.covers(*TRI.getRegClass(ARM::tGPROdd_and_tcGPRRegClassID)) && "Subclass not added?"); - assert(RBGPR.getSize() == 32 && "GPRs should hold up to 32-bit"); + assert(getMaximumSize(RBGPR.getID()) == 32 && + "GPRs should hold up to 32-bit"); #ifndef NDEBUG ARM::checkPartialMappings(); diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h --- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h +++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.h @@ -31,7 +31,7 @@ /// This class provides the information for the target register banks. class RISCVRegisterBankInfo final : public RISCVGenRegisterBankInfo { public: - RISCVRegisterBankInfo(const TargetRegisterInfo &TRI); + RISCVRegisterBankInfo(unsigned HwMode); }; } // end namespace llvm #endif diff --git a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp --- a/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp +++ b/llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp @@ -22,4 +22,5 @@ using namespace llvm; -RISCVRegisterBankInfo::RISCVRegisterBankInfo(const TargetRegisterInfo &TRI) {} +RISCVRegisterBankInfo::RISCVRegisterBankInfo(unsigned HwMode) + : RISCVGenRegisterBankInfo(HwMode) {} diff --git a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp --- a/llvm/lib/Target/RISCV/RISCVSubtarget.cpp +++ b/llvm/lib/Target/RISCV/RISCVSubtarget.cpp @@ -86,7 +86,7 @@ CallLoweringInfo.reset(new RISCVCallLowering(*getTargetLowering())); Legalizer.reset(new RISCVLegalizerInfo(*this)); - auto *RBI = new RISCVRegisterBankInfo(*getRegisterInfo()); + auto *RBI = new RISCVRegisterBankInfo(getHwMode()); RegBankInfo.reset(RBI); InstSelector.reset(createRISCVInstructionSelector( *static_cast(&TM), *this, *RBI)); diff --git a/llvm/lib/Target/X86/X86RegisterBankInfo.cpp b/llvm/lib/Target/X86/X86RegisterBankInfo.cpp --- a/llvm/lib/Target/X86/X86RegisterBankInfo.cpp +++ b/llvm/lib/Target/X86/X86RegisterBankInfo.cpp @@ -36,7 +36,8 @@ // GR64 + its subclasses. assert(RBGPR.covers(*TRI.getRegClass(X86::GR64RegClassID)) && "Subclass not added?"); - assert(RBGPR.getSize() == 64 && "GPRs should hold up to 64-bit"); + assert(getMaximumSize(RBGPR.getID()) == 64 && + "GPRs should hold up to 64-bit"); } const RegisterBank & diff --git a/llvm/utils/TableGen/RegisterBankEmitter.cpp b/llvm/utils/TableGen/RegisterBankEmitter.cpp --- a/llvm/utils/TableGen/RegisterBankEmitter.cpp +++ b/llvm/utils/TableGen/RegisterBankEmitter.cpp @@ -37,11 +37,11 @@ RegisterClassesTy RCs; /// The register class with the largest register size. - const CodeGenRegisterClass *RCWithLargestRegsSize; + std::vector RCsWithLargestRegSize; public: - RegisterBank(const Record &TheDef) - : TheDef(TheDef), RCWithLargestRegsSize(nullptr) {} + RegisterBank(const Record &TheDef, unsigned NumModeIds) + : TheDef(TheDef), RCsWithLargestRegSize(NumModeIds) {} /// Get the human-readable name for the bank. StringRef getName() const { return TheDef.getValueAsString("Name"); } @@ -79,18 +79,21 @@ // register size anywhere (we could sum the sizes of the subregisters // but there may be additional bits too) and we can't derive it from // the VT's reliably due to Untyped. - if (RCWithLargestRegsSize == nullptr) - RCWithLargestRegsSize = RC; - else if (RCWithLargestRegsSize->RSI.get(DefaultMode).SpillSize < - RC->RSI.get(DefaultMode).SpillSize) - RCWithLargestRegsSize = RC; - assert(RCWithLargestRegsSize && "RC was nullptr?"); + unsigned NumModeIds = RCsWithLargestRegSize.size(); + for (unsigned M = 0; M < NumModeIds; ++M) { + if (RCsWithLargestRegSize[M] == nullptr) + RCsWithLargestRegSize[M] = RC; + else if (RCsWithLargestRegSize[M]->RSI.get(M).SpillSize < + RC->RSI.get(M).SpillSize) + RCsWithLargestRegSize[M] = RC; + assert(RCsWithLargestRegSize[M] && "RC was nullptr?"); + } RCs.emplace_back(RC); } - const CodeGenRegisterClass *getRCWithLargestRegsSize() const { - return RCWithLargestRegsSize; + const CodeGenRegisterClass *getRCWithLargestRegSize(unsigned HwMode) const { + return RCsWithLargestRegSize[HwMode]; } iterator_range @@ -144,9 +147,10 @@ raw_ostream &OS, const StringRef TargetName, const std::vector &Banks) { OS << "private:\n" - << " static RegisterBank *RegBanks[];\n\n" + << " static const RegisterBank *RegBanks[];\n" + << " static const unsigned Sizes[];\n\n" << "protected:\n" - << " " << TargetName << "GenRegisterBankInfo();\n" + << " " << TargetName << "GenRegisterBankInfo(unsigned HwMode = 0);\n" << "\n"; } @@ -211,6 +215,7 @@ raw_ostream &OS, StringRef TargetName, std::vector &Banks) { const CodeGenRegBank &RegisterClassHierarchy = Target.getRegBank(); + const CodeGenHwModes &CGH = Target.getHwModes(); OS << "namespace llvm {\n" << "namespace " << TargetName << " {\n"; @@ -241,11 +246,8 @@ for (const auto &Bank : Banks) { std::string QualifiedBankID = (TargetName + "::" + Bank.getEnumeratorName()).str(); - const CodeGenRegisterClass &RC = *Bank.getRCWithLargestRegsSize(); - unsigned Size = RC.RSI.get(DefaultMode).SpillSize; - OS << "RegisterBank " << Bank.getInstanceVarName() << "(/* ID */ " - << QualifiedBankID << ", /* Name */ \"" << Bank.getName() - << "\", /* Size */ " << Size << ", " + OS << "const RegisterBank " << Bank.getInstanceVarName() << "(/* ID */ " + << QualifiedBankID << ", /* Name */ \"" << Bank.getName() << "\", " << "/* CoveredRegClasses */ " << Bank.getCoverageArrayName() << ", /* NumRegClasses */ " << RegisterClassHierarchy.getRegClasses().size() << ");\n"; @@ -253,16 +255,33 @@ OS << "} // end namespace " << TargetName << "\n" << "\n"; - OS << "RegisterBank *" << TargetName + OS << "const RegisterBank *" << TargetName << "GenRegisterBankInfo::RegBanks[] = {\n"; for (const auto &Bank : Banks) OS << " &" << TargetName << "::" << Bank.getInstanceVarName() << ",\n"; OS << "};\n\n"; + unsigned NumModeIds = CGH.getNumModeIds(); + OS << "const unsigned " << TargetName << "GenRegisterBankInfo::Sizes[] = {\n"; + for (unsigned M = 0; M < NumModeIds; ++M) { + OS << " // Mode = " << M << " ("; + if (M == DefaultMode) + OS << "Default"; + else + OS << CGH.getMode(M).Name; + OS << ")\n"; + for (const auto &Bank : Banks) { + const CodeGenRegisterClass &RC = *Bank.getRCWithLargestRegSize(M); + unsigned Size = RC.RSI.get(M).SpillSize; + OS << " " << Size << ",\n"; + } + } + OS << "};\n\n"; + OS << TargetName << "GenRegisterBankInfo::" << TargetName - << "GenRegisterBankInfo()\n" + << "GenRegisterBankInfo(unsigned HwMode)\n" << " : RegisterBankInfo(RegBanks, " << TargetName - << "::NumRegisterBanks) {\n" + << "::NumRegisterBanks, Sizes, HwMode) {\n" << " // Assert that RegBank indices match their ID's\n" << "#ifndef NDEBUG\n" << " for (auto RB : enumerate(RegBanks))\n" @@ -275,12 +294,13 @@ void RegisterBankEmitter::run(raw_ostream &OS) { StringRef TargetName = Target.getName(); const CodeGenRegBank &RegisterClassHierarchy = Target.getRegBank(); + const CodeGenHwModes &CGH = Target.getHwModes(); Records.startTimer("Analyze records"); std::vector Banks; for (const auto &V : Records.getAllDerivedDefinitions("RegisterBank")) { SmallPtrSet VisitedRCs; - RegisterBank Bank(*V); + RegisterBank Bank(*V, CGH.getNumModeIds()); for (const CodeGenRegisterClass *RC : Bank.getExplicitlySpecifiedRegisterClasses(RegisterClassHierarchy)) {