diff --git a/lldb/include/lldb/Utility/RegisterValue.h b/lldb/include/lldb/Utility/RegisterValue.h --- a/lldb/include/lldb/Utility/RegisterValue.h +++ b/lldb/include/lldb/Utility/RegisterValue.h @@ -33,7 +33,9 @@ // byte AArch64 SVE. kTypicalRegisterByteSize = 256u, // Anything else we'll heap allocate storage for it. - kMaxRegisterByteSize = kTypicalRegisterByteSize, + // 256x256 to support 256 byte AArch64 SME's array storage (ZA) register. + // Which is a square of vector length x vector length. + kMaxRegisterByteSize = 256u * 256u, }; typedef llvm::SmallVector BytesContainer; diff --git a/lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.h b/lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.h --- a/lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.h +++ b/lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.h @@ -85,6 +85,8 @@ bool m_mte_ctrl_is_valid; bool m_sve_header_is_valid; + bool m_za_buffer_is_valid; + bool m_za_header_is_valid; bool m_pac_mask_is_valid; bool m_tls_is_valid; size_t m_tls_size; @@ -98,6 +100,9 @@ struct sve::user_sve_header m_sve_header; std::vector m_sve_ptrace_payload; + sve::user_za_header m_za_header; + std::vector m_za_ptrace_payload; + bool m_refresh_hwdebug_info; struct user_pac_mask { @@ -109,6 +114,12 @@ uint64_t m_mte_ctrl_reg; + struct sme_regs { + uint64_t svg_reg; + }; + + struct sme_regs m_sme_regs; + struct tls_regs { uint64_t tpidr_reg; // Only valid when SME is present. @@ -139,10 +150,24 @@ Status WriteTLS(); + Status ReadSMESVG(); + + Status ReadZAHeader(); + + Status ReadZA(); + + Status WriteZA(); + + // No WriteZAHeader because writing only the header will disable ZA. + // Instead use WriteZA and ensure you have the correct ZA buffer size set + // beforehand if you wish to disable it. + bool IsSVE(unsigned reg) const; + bool IsZA(unsigned reg) const; bool IsPAuth(unsigned reg) const; bool IsMTE(unsigned reg) const; bool IsTLS(unsigned reg) const; + bool IsSME(unsigned reg) const; uint64_t GetSVERegVG() { return m_sve_header.vl / 8; } @@ -150,12 +175,18 @@ void *GetSVEHeader() { return &m_sve_header; } + void *GetZAHeader() { return &m_za_header; } + + size_t GetZAHeaderSize() { return sizeof(m_za_header); } + void *GetPACMask() { return &m_pac_mask; } void *GetMTEControl() { return &m_mte_ctrl_reg; } void *GetTLSBuffer() { return &m_tls_regs; } + void *GetSMEBuffer() { return &m_sme_regs; } + void *GetSVEBuffer() { return m_sve_ptrace_payload.data(); } size_t GetSVEHeaderSize() { return sizeof(m_sve_header); } @@ -166,10 +197,16 @@ unsigned GetSVERegSet(); + void *GetZABuffer() { return m_za_ptrace_payload.data(); }; + + size_t GetZABufferSize() { return m_za_ptrace_payload.size(); } + size_t GetMTEControlSize() { return sizeof(m_mte_ctrl_reg); } size_t GetTLSBufferSize() { return m_tls_size; } + size_t GetSMEBufferSize() { return sizeof(m_sme_regs); } + llvm::Error ReadHardwareDebugInfo() override; llvm::Error WriteHardwareDebugRegs(DREGType hwbType) override; diff --git a/lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.cpp b/lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.cpp --- a/lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.cpp +++ b/lldb/source/Plugins/Process/Linux/NativeRegisterContextLinux_arm64.cpp @@ -41,6 +41,10 @@ 0x40b /* ARM Scalable Matrix Extension, Streaming SVE mode */ #endif +#ifndef NT_ARM_ZA +#define NT_ARM_ZA 0x40c /* ARM Scalable Matrix Extension, Array Storage */ +#endif + #ifndef NT_ARM_PAC_MASK #define NT_ARM_PAC_MASK 0x406 /* Pointer authentication code masks */ #endif @@ -90,6 +94,16 @@ opt_regsets.Set(RegisterInfoPOSIX_arm64::eRegsetMaskSSVE); } + sve::user_za_header za_header; + ioVec.iov_base = &za_header; + ioVec.iov_len = sizeof(za_header); + regset = NT_ARM_ZA; + if (NativeProcessLinux::PtraceWrapper(PTRACE_GETREGSET, + native_thread.GetID(), ®set, + &ioVec, sizeof(za_header)) + .Success()) + opt_regsets.Set(RegisterInfoPOSIX_arm64::eRegsetMaskZA); + NativeProcessLinux &process = native_thread.GetProcess(); std::optional auxv_at_hwcap = @@ -133,6 +147,7 @@ ::memset(&m_sve_header, 0, sizeof(m_sve_header)); ::memset(&m_pac_mask, 0, sizeof(m_pac_mask)); ::memset(&m_tls_regs, 0, sizeof(m_tls_regs)); + ::memset(&m_sme_regs, 0, sizeof(m_sme_regs)); m_mte_ctrl_reg = 0; @@ -314,6 +329,39 @@ offset = reg_info->byte_offset - GetRegisterInfo().GetMTEOffset(); assert(offset < GetMTEControlSize()); src = (uint8_t *)GetMTEControl() + offset; + } else if (IsZA(reg)) { + error = ReadZAHeader(); + if (error.Fail()) + return error; + + // If there is only a header and no registers, ZA is inactive. Read as 0 + // in this case. + if (m_za_header.size == sizeof(m_za_header)) { + // This will get reconfigured/reset later, so we are safe to use it. + // ZA is a square of VL * VL and the ptrace buffer also includes the + // header itself. + m_za_ptrace_payload.resize(((m_za_header.vl) * (m_za_header.vl)) + + GetZAHeaderSize()); + std::fill(m_za_ptrace_payload.begin(), m_za_ptrace_payload.end(), 0); + } else { + // ZA is active, read the real register. + error = ReadZA(); + if (error.Fail()) + return error; + } + + offset = reg_info->byte_offset - GetRegisterInfo().GetZAOffset() + + GetZAHeaderSize(); + assert(offset < GetZABufferSize()); + src = (uint8_t *)GetZABuffer() + offset; + } else if (IsSME(reg)) { + error = ReadSMESVG(); + if (error.Fail()) + return error; + + offset = reg_info->byte_offset - GetRegisterInfo().GetSMEOffset(); + assert(offset < GetSMEBufferSize()); + src = (uint8_t *)GetSMEBuffer() + offset; } else return Status("failed - register wasn't recognized to be a GPR or an FPR, " "write strategy unknown"); @@ -420,8 +468,12 @@ SetSVERegVG(vg_value); error = WriteSVEHeader(); - if (error.Success()) + if (error.Success()) { + // Changing VG during streaming mode also changes the size of ZA. + if (m_sve_state == SVEState::Streaming) + m_za_header_is_valid = false; ConfigureRegisterContext(); + } if (m_sve_header_is_valid && vg_value == GetSVERegVG()) return error; @@ -494,6 +546,23 @@ ::memcpy(dst, reg_value.GetBytes(), reg_info->byte_size); return WriteTLS(); + } else if (IsZA(reg)) { + error = ReadZA(); + if (error.Fail()) + return error; + + offset = reg_info->byte_offset - GetRegisterInfo().GetZAOffset() + + GetZAHeaderSize(); + assert(offset < GetZABufferSize()); + dst = (uint8_t *)GetZABuffer() + offset; + ::memcpy(dst, reg_value.GetBytes(), reg_info->byte_size); + + // While this is writing a header that contains a vector length, the only + // way to change that is via the vg register. So here we assume the length + // will always be the current length and no reconfigure is needed. + return WriteZA(); + } else if (IsSME(reg)) { + return Status("Writing to SVG is not supported."); } return Status("Failed to write register value"); @@ -503,8 +572,11 @@ GPR, SVE, // Used for SVE and SSVE. FPR, // When there is no SVE, or SVE in FPSIMD mode. + // Pointer authentication registers are read only, so not included here. MTE, TLS, + ZA, + // SME pseudo registers are read only. }; static uint8_t *AddSavedRegistersKind(uint8_t *dst, SavedRegistersKind kind) { @@ -527,8 +599,9 @@ lldb::WritableDataBufferSP &data_sp) { // AArch64 register data must contain GPRs and either FPR or SVE registers. // SVE registers can be non-streaming (aka SVE) or streaming (aka SSVE). - // Finally an optional MTE register. Pointer Authentication (PAC) registers - // are read-only and will be skipped. + // Followed optionally by MTE, TLS and ZA register(s). SME pseudo registers + // are derived from other data, and Pointer Authentication (PAC) registers + // are read-only, so they are all skipped. // In order to create register data checkpoint we first read all register // values if not done already and calculate total size of register set data. @@ -541,6 +614,22 @@ if (error.Fail()) return error; + // Here this means, does the system have ZA, not whether it is active. + if (GetRegisterInfo().IsZAEnabled()) { + error = ReadZAHeader(); + if (error.Fail()) + return error; + // Use header size here because the buffer may contain fake data when ZA is + // disabled. + reg_data_byte_size += sizeof(SavedRegistersKind) + m_za_header.size; + // For the same reason, we need to force it to be re-read so that it will + // always contain the real header. + m_za_buffer_is_valid = false; + error = ReadZA(); + if (error.Fail()) + return error; + } + // If SVE is enabled we need not copy FPR separately. if (GetRegisterInfo().IsSVEEnabled() || GetRegisterInfo().IsSSVEEnabled()) { // Store mode and register data. @@ -573,6 +662,45 @@ dst = AddSavedRegisters(dst, SavedRegistersKind::GPR, GetGPRBuffer(), GetGPRBufferSize()); + // Streaming SVE and the ZA register both use the streaming vector length. + // When you change this, the kernel will invalidate parts of the process + // state. Therefore we need a specific order of restoration for each mode, if + // we also have ZA to restore. + // + // Streaming mode enabled, ZA enabled: + // * Write streaming registers. This sets SVCR.SM and clears SVCR.ZA. + // * Write ZA, this set SVCR.ZA. The register data we provide is written to + // ZA. + // * Result is SVCR.SM and SVCR.ZA set, with the expected data in both + // register sets. + // + // Streaming mode disabled, ZA enabled: + // * Write ZA. This sets SVCR.ZA, and the ZA content. In the majority of cases + // the streaming vector length is changing, so the thread is converted into + // an FPSIMD thread if it is not already one. This also clears SVCR.SM. + // * Write SVE registers, which also clears SVCR.SM but most importantly, puts + // us into full SVE mode instead of FPSIMD mode (where the registers are + // actually the 128 bit Neon registers). + // * Result is we have SVCR.SM = 0, SVCR.ZA = 1 and the expected register + // state. + // + // Restoring in different orders leads to things like the SVE registers being + // truncated due to the FPSIMD mode and ZA being disabled or filled with 0s + // (disabled and 0s looks the same from inside lldb since we fake the value + // when it's disabled). + // + // For more information on this, look up the uses of the relevant NT_ARM_ + // constants and the functions vec_set_vector_length, sve_set_common and + // za_set in the Linux Kernel. + + if ((m_sve_state != SVEState::Streaming) && GetRegisterInfo().IsZAEnabled()) { + // Use the header size not the buffer size, as we may be using the buffer + // for fake data, which we do not want to write out. + assert(m_za_header.size <= GetZABufferSize()); + dst = AddSavedRegisters(dst, SavedRegistersKind::ZA, GetZABuffer(), + m_za_header.size); + } + if (GetRegisterInfo().IsSVEEnabled() || GetRegisterInfo().IsSSVEEnabled()) { dst = AddSavedRegistersKind(dst, SavedRegistersKind::SVE); *(reinterpret_cast(dst)) = m_sve_state; @@ -583,6 +711,12 @@ GetFPRSize()); } + if ((m_sve_state == SVEState::Streaming) && GetRegisterInfo().IsZAEnabled()) { + assert(m_za_header.size <= GetZABufferSize()); + dst = AddSavedRegisters(dst, SavedRegistersKind::ZA, GetZABuffer(), + m_za_header.size); + } + if (GetRegisterInfo().IsMTEEnabled()) { dst = AddSavedRegisters(dst, SavedRegistersKind::MTE, GetMTEControl(), GetMTEControlSize()); @@ -675,6 +809,8 @@ return error; // SVE header has been written configure SVE vector length if needed. + // This could change ZA data too, but that will be restored again later + // anyway. ConfigureRegisterContext(); // Write header and register data, incrementing src this time. @@ -697,6 +833,33 @@ GetTLSBuffer(), &src, GetTLSBufferSize(), m_tls_is_valid, std::bind(&NativeRegisterContextLinux_arm64::WriteTLS, this)); break; + case SavedRegistersKind::ZA: + // To enable or disable ZA you write the regset with or without register + // data. The kernel detects this by looking at the ioVec's length, not the + // ZA header size you pass in. Therefore we must write header and register + // data (if present) in one go every time. Read the header only first just + // to get the size. + ::memcpy(GetZAHeader(), src, GetZAHeaderSize()); + // Read the header and register data. Can't use the buffer size here, it + // may be incorrect due to being filled with dummy data previously. Resize + // this so WriteZA uses the correct size. + m_za_ptrace_payload.resize(m_za_header.size); + ::memcpy(GetZABuffer(), src, GetZABufferSize()); + m_za_buffer_is_valid = true; + + error = WriteZA(); + if (error.Fail()) + return error; + + // Update size of ZA, which resizes the ptrace payload potentially + // trashing our copy of the data we just wrote. + ConfigureRegisterContext(); + + // ZA buffer now has proper size, read back the data we wrote above, from + // ptrace. + error = ReadZA(); + src += GetZABufferSize(); + break; } if (error.Fail()) @@ -724,6 +887,10 @@ return GetRegisterInfo().IsSVEReg(reg); } +bool NativeRegisterContextLinux_arm64::IsZA(unsigned reg) const { + return GetRegisterInfo().IsZAReg(reg); +} + bool NativeRegisterContextLinux_arm64::IsPAuth(unsigned reg) const { return GetRegisterInfo().IsPAuthReg(reg); } @@ -736,6 +903,10 @@ return GetRegisterInfo().IsTLSReg(reg); } +bool NativeRegisterContextLinux_arm64::IsSME(unsigned reg) const { + return GetRegisterInfo().IsSMEReg(reg); +} + llvm::Error NativeRegisterContextLinux_arm64::ReadHardwareDebugInfo() { if (!m_refresh_hwdebug_info) { return llvm::Error::success(); @@ -877,11 +1048,13 @@ m_fpu_is_valid = false; m_sve_buffer_is_valid = false; m_sve_header_is_valid = false; + m_za_buffer_is_valid = false; + m_za_header_is_valid = false; m_pac_mask_is_valid = false; m_mte_ctrl_is_valid = false; m_tls_is_valid = false; - // Update SVE registers in case there is change in configuration. + // Update SVE and ZA registers in case there is change in configuration. ConfigureRegisterContext(); } @@ -1047,6 +1220,62 @@ return WriteRegisterSet(&ioVec, GetTLSBufferSize(), NT_ARM_TLS); } +Status NativeRegisterContextLinux_arm64::ReadZAHeader() { + Status error; + + if (m_za_header_is_valid) + return error; + + struct iovec ioVec; + ioVec.iov_base = GetZAHeader(); + ioVec.iov_len = GetZAHeaderSize(); + + error = ReadRegisterSet(&ioVec, GetZAHeaderSize(), NT_ARM_ZA); + + if (error.Success()) + m_za_header_is_valid = true; + + return error; +} + +Status NativeRegisterContextLinux_arm64::ReadZA() { + Status error; + + if (m_za_buffer_is_valid) + return error; + + struct iovec ioVec; + ioVec.iov_base = GetZABuffer(); + ioVec.iov_len = GetZABufferSize(); + + error = ReadRegisterSet(&ioVec, GetZABufferSize(), NT_ARM_ZA); + + if (error.Success()) + m_za_buffer_is_valid = true; + + return error; +} + +Status NativeRegisterContextLinux_arm64::WriteZA() { + // Note that because the ZA ptrace payload contains the header also, this + // method will write both. This is done because writing only the header + // will disable ZA, even if .size in the header is correct for an enabled ZA. + Status error; + + error = ReadZA(); + if (error.Fail()) + return error; + + struct iovec ioVec; + ioVec.iov_base = GetZABuffer(); + ioVec.iov_len = GetZABufferSize(); + + m_za_buffer_is_valid = false; + m_za_header_is_valid = false; + + return WriteRegisterSet(&ioVec, GetZABufferSize(), NT_ARM_ZA); +} + void NativeRegisterContextLinux_arm64::ConfigureRegisterContext() { // ConfigureRegisterContext gets called from InvalidateAllRegisters // on every stop and configures SVE vector length and whether we are in @@ -1096,6 +1325,19 @@ m_sve_ptrace_payload.resize(sve::PTraceSize(vq, sve::ptrace_regs_sve)); } } + + if (!m_za_header_is_valid) { + Status error = ReadZAHeader(); + if (error.Success()) { + uint32_t vq = RegisterInfoPOSIX_arm64::eVectorQuadwordAArch64SVE; + if (sve::vl_valid(m_za_header.vl)) + vq = sve::vq_from_vl(m_za_header.vl); + + GetRegisterInfo().ConfigureVectorLengthZA(vq); + m_za_ptrace_payload.resize(m_za_header.size); + m_za_buffer_is_valid = false; + } + } } uint32_t NativeRegisterContextLinux_arm64::CalculateFprOffset( @@ -1121,12 +1363,27 @@ return sve_reg_offset; } +Status NativeRegisterContextLinux_arm64::ReadSMESVG() { + // This register is the streaming vector length, so we will get it from + // NT_ARM_ZA regardless of the current streaming mode. + Status error = ReadZAHeader(); + if (error.Success()) + m_sme_regs.svg_reg = m_za_header.vl / 8; + + return error; +} + std::vector NativeRegisterContextLinux_arm64::GetExpeditedRegisters( ExpeditedRegs expType) const { std::vector expedited_reg_nums = NativeRegisterContext::GetExpeditedRegisters(expType); + // SVE, non-streaming vector length. if (m_sve_state == SVEState::FPSIMD || m_sve_state == SVEState::Full) expedited_reg_nums.push_back(GetRegisterInfo().GetRegNumSVEVG()); + // SME, streaming vector length. This is used by the ZA register which is + // present even when streaming mode is not enabled. + if (GetRegisterInfo().IsSSVEEnabled()) + expedited_reg_nums.push_back(GetRegisterInfo().GetRegNumSMEVG()); return expedited_reg_nums; } diff --git a/lldb/source/Plugins/Process/Utility/LinuxPTraceDefines_arm64sve.h b/lldb/source/Plugins/Process/Utility/LinuxPTraceDefines_arm64sve.h --- a/lldb/source/Plugins/Process/Utility/LinuxPTraceDefines_arm64sve.h +++ b/lldb/source/Plugins/Process/Utility/LinuxPTraceDefines_arm64sve.h @@ -152,6 +152,8 @@ uint16_t reserved; }; +using user_za_header = user_sve_header; + /* Definitions for user_sve_header.flags: */ const uint16_t ptrace_regs_mask = 1 << 0; const uint16_t ptrace_regs_fpsimd = 0; diff --git a/lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.h b/lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.h --- a/lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.h +++ b/lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.h @@ -54,6 +54,7 @@ size_t GetFPUSize() { return sizeof(RegisterInfoPOSIX_arm64::FPU); } bool IsSVE(unsigned reg) const; + bool IsZA(unsigned reg) const; bool IsPAuth(unsigned reg) const; bool IsTLS(unsigned reg) const; diff --git a/lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.cpp b/lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.cpp --- a/lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.cpp +++ b/lldb/source/Plugins/Process/Utility/RegisterContextPOSIX_arm64.cpp @@ -43,6 +43,10 @@ return m_register_info_up->IsSVEReg(reg); } +bool RegisterContextPOSIX_arm64::IsZA(unsigned reg) const { + return m_register_info_up->IsZAReg(reg); +} + bool RegisterContextPOSIX_arm64::IsPAuth(unsigned reg) const { return m_register_info_up->IsPAuthReg(reg); } diff --git a/lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.h b/lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.h --- a/lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.h +++ b/lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.h @@ -30,6 +30,7 @@ eRegsetMaskPAuth = 4, eRegsetMaskMTE = 8, eRegsetMaskTLS = 16, + eRegsetMaskZA = 32, eRegsetMaskDynamic = ~1, }; @@ -106,8 +107,14 @@ void AddRegSetTLS(bool has_tpidr2); + void AddRegSetZA(); + + void AddRegSetSME(); + uint32_t ConfigureVectorLength(uint32_t sve_vq); + void ConfigureVectorLengthZA(uint32_t za_vq); + bool VectorSizeIsValid(uint32_t vq) { // coverity[unsigned_compare] if (vq >= eVectorQuadwordAArch64 && vq <= eVectorQuadwordAArch64SVEMax) @@ -117,6 +124,7 @@ bool IsSVEEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskSVE); } bool IsSSVEEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskSSVE); } + bool IsZAEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskZA); } bool IsPAuthEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskPAuth); } bool IsMTEEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskMTE); } bool IsTLSEnabled() const { return m_opt_regsets.AnySet(eRegsetMaskTLS); } @@ -128,15 +136,20 @@ bool IsPAuthReg(unsigned reg) const; bool IsMTEReg(unsigned reg) const; bool IsTLSReg(unsigned reg) const; + bool IsZAReg(unsigned reg) const; + bool IsSMEReg(unsigned reg) const; uint32_t GetRegNumSVEZ0() const; uint32_t GetRegNumSVEFFR() const; uint32_t GetRegNumFPCR() const; uint32_t GetRegNumFPSR() const; uint32_t GetRegNumSVEVG() const; + uint32_t GetRegNumSMEVG() const; uint32_t GetPAuthOffset() const; uint32_t GetMTEOffset() const; uint32_t GetTLSOffset() const; + uint32_t GetZAOffset() const; + uint32_t GetSMEOffset() const; private: typedef std::map> @@ -145,7 +158,10 @@ per_vq_register_infos m_per_vq_reg_infos; uint32_t m_vector_reg_vq = eVectorQuadwordAArch64; + uint32_t m_za_reg_vq = eVectorQuadwordAArch64; + // In normal operation this is const. Only when SVE or SME registers change + // size is it either replaced or the content modified. const lldb_private::RegisterInfo *m_register_info_p; uint32_t m_register_info_count; @@ -164,6 +180,8 @@ std::vector pauth_regnum_collection; std::vector m_mte_regnum_collection; std::vector m_tls_regnum_collection; + std::vector m_za_regnum_collection; + std::vector m_sme_regnum_collection; }; #endif diff --git a/lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.cpp b/lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.cpp --- a/lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.cpp +++ b/lldb/source/Plugins/Process/Utility/RegisterInfoPOSIX_arm64.cpp @@ -83,6 +83,14 @@ // Only present when SME is present DEFINE_EXTENSION_REG(tpidr2)}; +static lldb_private::RegisterInfo g_register_infos_za[] = + // 16 is a default size we will change later. + {{"za", nullptr, 16, 0, lldb::eEncodingVector, lldb::eFormatVectorOfUInt8, + KIND_ALL_INVALID, nullptr, nullptr, nullptr}}; + +static lldb_private::RegisterInfo g_register_infos_sme[] = { + DEFINE_EXTENSION_REG(svg)}; + // Number of register sets provided by this context. enum { k_num_gpr_registers = gpr_w28 - gpr_x0 + 1, @@ -91,6 +99,8 @@ k_num_mte_register = 1, // Number of TLS registers is dynamic so it is not listed here. k_num_pauth_register = 2, + k_num_za_register = 1, + k_num_sme_register = 1, k_num_register_sets_default = 2, k_num_register_sets = 3 }; @@ -197,6 +207,13 @@ // The size of the TLS set is dynamic, so not listed here. +static const lldb_private::RegisterSet g_reg_set_za_arm64 = { + "Scalable Matrix Array Storage Registers", "za", k_num_za_register, + nullptr}; + +static const lldb_private::RegisterSet g_reg_set_sme_arm64 = { + "Scalable Matrix Extension Registers", "sme", k_num_sme_register, nullptr}; + RegisterInfoPOSIX_arm64::RegisterInfoPOSIX_arm64( const lldb_private::ArchSpec &target_arch, lldb_private::Flags opt_regsets) : lldb_private::RegisterInfoAndSetInterface(target_arch), @@ -241,6 +258,11 @@ // present. AddRegSetTLS(m_opt_regsets.AllSet(eRegsetMaskSSVE)); + if (m_opt_regsets.AnySet(eRegsetMaskSSVE)) { + AddRegSetZA(); + AddRegSetSME(); + } + m_register_info_count = m_dynamic_reg_infos.size(); m_register_info_p = m_dynamic_reg_infos.data(); m_register_set_p = m_dynamic_reg_sets.data(); @@ -344,6 +366,40 @@ m_dynamic_reg_sets.back().registers = m_tls_regnum_collection.data(); } +void RegisterInfoPOSIX_arm64::AddRegSetZA() { + uint32_t za_regnum = m_dynamic_reg_infos.size(); + m_za_regnum_collection.push_back(za_regnum); + + m_dynamic_reg_infos.push_back(g_register_infos_za[0]); + m_dynamic_reg_infos[za_regnum].byte_offset = + m_dynamic_reg_infos[za_regnum - 1].byte_offset + + m_dynamic_reg_infos[za_regnum - 1].byte_size; + m_dynamic_reg_infos[za_regnum].kinds[lldb::eRegisterKindLLDB] = za_regnum; + + m_per_regset_regnum_range[m_register_set_count] = + std::make_pair(za_regnum, za_regnum + 1); + m_dynamic_reg_sets.push_back(g_reg_set_za_arm64); + m_dynamic_reg_sets.back().registers = m_za_regnum_collection.data(); +} + +void RegisterInfoPOSIX_arm64::AddRegSetSME() { + uint32_t sme_regnum = m_dynamic_reg_infos.size(); + for (uint32_t i = 0; i < k_num_sme_register; i++) { + m_sme_regnum_collection.push_back(sme_regnum + i); + m_dynamic_reg_infos.push_back(g_register_infos_sme[i]); + m_dynamic_reg_infos[sme_regnum + i].byte_offset = + m_dynamic_reg_infos[sme_regnum + i - 1].byte_offset + + m_dynamic_reg_infos[sme_regnum + i - 1].byte_size; + m_dynamic_reg_infos[sme_regnum + i].kinds[lldb::eRegisterKindLLDB] = + sme_regnum + i; + } + + m_per_regset_regnum_range[m_register_set_count] = + std::make_pair(sme_regnum, m_dynamic_reg_infos.size()); + m_dynamic_reg_sets.push_back(g_reg_set_sme_arm64); + m_dynamic_reg_sets.back().registers = m_sme_regnum_collection.data(); +} + uint32_t RegisterInfoPOSIX_arm64::ConfigureVectorLength(uint32_t sve_vq) { // sve_vq contains SVE Quad vector length in context of AArch64 SVE. // SVE register infos if enabled cannot be disabled by selecting sve_vq = 0. @@ -408,6 +464,20 @@ return m_vector_reg_vq; } +void RegisterInfoPOSIX_arm64::ConfigureVectorLengthZA(uint32_t za_vq) { + if (!VectorSizeIsValid(za_vq) || m_za_reg_vq == za_vq) + return; + + m_za_reg_vq = za_vq; + + // For SVE changes, we replace m_register_info_p completely. ZA is in a + // dynamic set and is just 1 register so we make an exception to const here. + lldb_private::RegisterInfo *non_const_reginfo = + const_cast(m_register_info_p); + non_const_reginfo[m_za_regnum_collection[0]].byte_size = + (za_vq * 16) * (za_vq * 16); +} + bool RegisterInfoPOSIX_arm64::IsSVEReg(unsigned reg) const { if (m_vector_reg_vq > eVectorQuadwordAArch64) return (sve_vg <= reg && reg <= sve_ffr); @@ -439,6 +509,14 @@ return llvm::is_contained(m_tls_regnum_collection, reg); } +bool RegisterInfoPOSIX_arm64::IsZAReg(unsigned reg) const { + return llvm::is_contained(m_za_regnum_collection, reg); +} + +bool RegisterInfoPOSIX_arm64::IsSMEReg(unsigned reg) const { + return llvm::is_contained(m_sme_regnum_collection, reg); +} + uint32_t RegisterInfoPOSIX_arm64::GetRegNumSVEZ0() const { return sve_z0; } uint32_t RegisterInfoPOSIX_arm64::GetRegNumSVEFFR() const { return sve_ffr; } @@ -449,6 +527,10 @@ uint32_t RegisterInfoPOSIX_arm64::GetRegNumSVEVG() const { return sve_vg; } +uint32_t RegisterInfoPOSIX_arm64::GetRegNumSMEVG() const { + return m_sme_regnum_collection[0]; +} + uint32_t RegisterInfoPOSIX_arm64::GetPAuthOffset() const { return m_register_info_p[pauth_regnum_collection[0]].byte_offset; } @@ -460,3 +542,11 @@ uint32_t RegisterInfoPOSIX_arm64::GetTLSOffset() const { return m_register_info_p[m_tls_regnum_collection[0]].byte_offset; } + +uint32_t RegisterInfoPOSIX_arm64::GetZAOffset() const { + return m_register_info_p[m_za_regnum_collection[0]].byte_offset; +} + +uint32_t RegisterInfoPOSIX_arm64::GetSMEOffset() const { + return m_register_info_p[m_sme_regnum_collection[0]].byte_offset; +} diff --git a/lldb/source/Plugins/Process/elf-core/RegisterUtilities.h b/lldb/source/Plugins/Process/elf-core/RegisterUtilities.h --- a/lldb/source/Plugins/Process/elf-core/RegisterUtilities.h +++ b/lldb/source/Plugins/Process/elf-core/RegisterUtilities.h @@ -119,6 +119,10 @@ {llvm::Triple::Linux, llvm::Triple::aarch64, llvm::ELF::NT_ARM_SVE}, }; +constexpr RegsetDesc AARCH64_ZA_Desc[] = { + {llvm::Triple::Linux, llvm::Triple::aarch64, llvm::ELF::NT_ARM_ZA}, +}; + constexpr RegsetDesc AARCH64_PAC_Desc[] = { {llvm::Triple::Linux, llvm::Triple::aarch64, llvm::ELF::NT_ARM_PAC_MASK}, }; diff --git a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteRegisterContext.h b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteRegisterContext.h --- a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteRegisterContext.h +++ b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteRegisterContext.h @@ -39,6 +39,7 @@ ~GDBRemoteDynamicRegisterInfo() override = default; void UpdateARM64SVERegistersInfos(uint64_t vg); + void UpdateARM64SMERegistersInfos(uint64_t vg); }; class GDBRemoteRegisterContext : public RegisterContext { @@ -77,7 +78,9 @@ uint32_t ConvertRegisterKindToRegisterNumber(lldb::RegisterKind kind, uint32_t num) override; - bool AArch64SVEReconfigure(); + void AArch64SVEReconfigure(); + + void AArch64SMEReconfigure(); protected: friend class ThreadGDBRemote; diff --git a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteRegisterContext.cpp b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteRegisterContext.cpp --- a/lldb/source/Plugins/Process/gdb-remote/GDBRemoteRegisterContext.cpp +++ b/lldb/source/Plugins/Process/gdb-remote/GDBRemoteRegisterContext.cpp @@ -373,14 +373,14 @@ if (dst == nullptr) return false; - // Code below is specific to AArch64 target in SVE state + // Code below is specific to AArch64 target in SVE or SMEstate // If vector granule (vg) register is being written then thread's // register context reconfiguration is triggered on success. - bool do_reconfigure_arm64_sve = false; + // We do not allow writes to SVG so it is not mentioned here. const ArchSpec &arch = process->GetTarget().GetArchitecture(); - if (arch.IsValid() && arch.GetTriple().isAArch64()) - if (strcmp(reg_info->name, "vg") == 0) - do_reconfigure_arm64_sve = true; + bool do_reconfigure_arm64_sve = arch.IsValid() && + arch.GetTriple().isAArch64() && + (strcmp(reg_info->name, "vg") == 0); if (data.CopyByteOrderedData(data_offset, // src offset reg_info->byte_size, // src length @@ -400,10 +400,12 @@ {m_reg_data.GetDataStart(), size_t(m_reg_data.GetByteSize())})) { - SetAllRegisterValid(false); - - if (do_reconfigure_arm64_sve) + if (do_reconfigure_arm64_sve) { AArch64SVEReconfigure(); + AArch64SMEReconfigure(); + } + + InvalidateAllRegisters(); return true; } @@ -435,8 +437,11 @@ // This is an actual register, write it success = SetPrimordialRegister(reg_info, gdb_comm); - if (success && do_reconfigure_arm64_sve) + if (success && do_reconfigure_arm64_sve) { AArch64SVEReconfigure(); + AArch64SMEReconfigure(); + InvalidateAllRegisters(); + } } // Check if writing this register will invalidate any other register @@ -760,37 +765,52 @@ return m_reg_info_sp->ConvertRegisterKindToRegisterNumber(kind, num); } -bool GDBRemoteRegisterContext::AArch64SVEReconfigure() { - if (!m_reg_info_sp) - return false; - +void GDBRemoteRegisterContext::AArch64SVEReconfigure() { + assert(m_reg_info_sp); const RegisterInfo *reg_info = m_reg_info_sp->GetRegisterInfo("vg"); if (!reg_info) - return false; + return; uint64_t fail_value = LLDB_INVALID_ADDRESS; uint32_t vg_reg_num = reg_info->kinds[eRegisterKindLLDB]; uint64_t vg_reg_value = ReadRegisterAsUnsigned(vg_reg_num, fail_value); if (vg_reg_value == fail_value || vg_reg_value > 32) - return false; + return; reg_info = m_reg_info_sp->GetRegisterInfo("p0"); // Predicate registers have 1 bit per byte in the vector so their size is // VL / 8. VG is in units of 8 bytes already, so if the size of p0 == VG // already, we do not have to reconfigure. if (!reg_info || vg_reg_value == reg_info->byte_size) - return false; + return; m_reg_info_sp->UpdateARM64SVERegistersInfos(vg_reg_value); // Make a heap based buffer that is big enough to store all registers m_reg_data.SetData(std::make_shared( m_reg_info_sp->GetRegisterDataByteSize(), 0)); m_reg_data.SetByteOrder(GetByteOrder()); +} - InvalidateAllRegisters(); +void GDBRemoteRegisterContext::AArch64SMEReconfigure() { + assert(m_reg_info_sp); + const RegisterInfo *reg_info = m_reg_info_sp->GetRegisterInfo("svg"); + // Target does not have SME, nothing for us to reconfigure. + if (!reg_info) + return; - return true; + uint64_t fail_value = LLDB_INVALID_ADDRESS; + uint32_t svg_reg_num = reg_info->kinds[eRegisterKindLLDB]; + uint64_t svg_reg_value = ReadRegisterAsUnsigned(svg_reg_num, fail_value); + + if (svg_reg_value == LLDB_INVALID_ADDRESS || svg_reg_value > 32) + return; + + m_reg_info_sp->UpdateARM64SMERegistersInfos(svg_reg_value); + // Make a heap based buffer that is big enough to store all registers + m_reg_data.SetData(std::make_shared( + m_reg_info_sp->GetRegisterDataByteSize(), 0)); + m_reg_data.SetByteOrder(GetByteOrder()); } void GDBRemoteDynamicRegisterInfo::UpdateARM64SVERegistersInfos(uint64_t vg) { @@ -815,3 +835,15 @@ // Re-calculate register offsets ConfigureOffsets(); } + +void GDBRemoteDynamicRegisterInfo::UpdateARM64SMERegistersInfos(uint64_t svg) { + for (auto ® : m_regs) { + if (strcmp(reg.name, "za") == 0) { + // ZA is a register with size (svg*8) * (svg*8). A square essentially. + reg.byte_size = (svg * 8) * (svg * 8); + } + reg.byte_offset = LLDB_INVALID_INDEX32; + } + + ConfigureOffsets(); +} diff --git a/lldb/source/Plugins/Process/gdb-remote/ProcessGDBRemote.cpp b/lldb/source/Plugins/Process/gdb-remote/ProcessGDBRemote.cpp --- a/lldb/source/Plugins/Process/gdb-remote/ProcessGDBRemote.cpp +++ b/lldb/source/Plugins/Process/gdb-remote/ProcessGDBRemote.cpp @@ -1660,17 +1660,19 @@ gdb_thread->PrivateSetRegisterValue(lldb_regnum, buffer_sp->GetData()); } - // AArch64 SVE specific code below calls AArch64SVEReconfigure to update - // SVE register sizes and offsets if value of VG register has changed - // since last stop. + // AArch64 SVE/SME specific code below updates SVE and ZA register sizes and + // offsets if value of VG or SVG registers has changed since last stop. const ArchSpec &arch = GetTarget().GetArchitecture(); if (arch.IsValid() && arch.GetTriple().isAArch64()) { GDBRemoteRegisterContext *reg_ctx_sp = static_cast( gdb_thread->GetRegisterContext().get()); - if (reg_ctx_sp) + if (reg_ctx_sp) { reg_ctx_sp->AArch64SVEReconfigure(); + reg_ctx_sp->AArch64SMEReconfigure(); + reg_ctx_sp->InvalidateAllRegisters(); + } } thread_sp->SetName(thread_name.empty() ? nullptr : thread_name.c_str()); diff --git a/lldb/source/Target/DynamicRegisterInfo.cpp b/lldb/source/Target/DynamicRegisterInfo.cpp --- a/lldb/source/Target/DynamicRegisterInfo.cpp +++ b/lldb/source/Target/DynamicRegisterInfo.cpp @@ -614,10 +614,11 @@ ConfigureOffsets(); // Check if register info is reconfigurable - // AArch64 SVE register set has configurable register sizes + // AArch64 SVE register set has configurable register sizes, as does the ZA + // register that SME added (the streaming state of SME reuses the SVE state). if (arch.GetTriple().isAArch64()) { for (const auto ® : m_regs) { - if (strcmp(reg.name, "vg") == 0) { + if ((strcmp(reg.name, "vg") == 0) || (strcmp(reg.name, "svg") == 0)) { m_is_reconfigurable = true; break; } diff --git a/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py b/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py --- a/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py +++ b/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py @@ -70,15 +70,14 @@ self.runCmd("register write ffr " + "'" + p_regs_value + "'") self.expect("register read ffr", substrs=[p_regs_value]) - @no_debug_info_test - @skipIf(archs=no_match(["aarch64"])) - @skipIf(oslist=no_match(["linux"])) - def test_aarch64_dynamic_regset_config(self): - """Test AArch64 Dynamic Register sets configuration.""" + + def setup_register_config_test(self, run_args=None): self.build() self.line = line_number("main.c", "// Set a break point here.") exe = self.getBuildArtifact("a.out") + if run_args is not None: + self.runCmd("settings set target.run-args " + run_args) self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) lldbutil.run_break_set_by_file_and_line( @@ -97,7 +96,16 @@ thread = process.GetThreadAtIndex(0) currentFrame = thread.GetFrameAtIndex(0) - for registerSet in currentFrame.GetRegisters(): + return currentFrame.GetRegisters() + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_aarch64_dynamic_regset_config(self): + """Test AArch64 Dynamic Register sets configuration.""" + register_sets = self.setup_register_config_test() + + for registerSet in register_sets: if "Scalable Vector Extension Registers" in registerSet.GetName(): self.assertTrue( self.isAArch64SVE(), @@ -120,6 +128,20 @@ ) self.expect("register read data_mask", substrs=["data_mask = 0x"]) self.expect("register read code_mask", substrs=["code_mask = 0x"]) + if "Scalable Matrix Extension Registers" in registerSet.GetName(): + self.assertTrue(self.isAArch64SME(), + "LLDB Enabled SME register set when it was disabled by target") + if "Scalable Matrix Array Storage Registers" in registerSet.GetName(): + self.assertTrue(self.isAArch64SME(), + "LLDB Enabled SME array storage register set when it was disabled by target.") + + def make_za_value(self, vl, generator): + # Generate a vector value string "{0x00 0x01....}". + rows = [] + for row in range(vl): + byte = "0x{:02x}".format(generator(row)) + rows.append(" ".join([byte]*vl)) + return "{" + " ".join(rows) + "}" @no_debug_info_test @skipIf(archs=no_match(["aarch64"])) @@ -130,32 +152,58 @@ if not self.isAArch64SME(): self.skipTest("SME must be present.") - self.build() - self.line = line_number("main.c", "// Set a break point here.") - - exe = self.getBuildArtifact("a.out") - self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) - - lldbutil.run_break_set_by_file_and_line( - self, "main.c", self.line, num_expected_locations=1 - ) - self.runCmd("settings set target.run-args sme") - self.runCmd("run", RUN_SUCCEEDED) - - self.expect( - "thread backtrace", - STOPPED_DUE_TO_BREAKPOINT, - substrs=["stop reason = breakpoint 1."], - ) - - target = self.dbg.GetSelectedTarget() - process = target.GetProcess() - thread = process.GetThreadAtIndex(0) - currentFrame = thread.GetFrameAtIndex(0) - - register_sets = currentFrame.GetRegisters() + register_sets = self.setup_register_config_test("sme") ssve_registers = register_sets.GetFirstValueByName( "Scalable Vector Extension Registers") self.assertTrue(ssve_registers.IsValid()) self.sve_regs_read_dynamic(ssve_registers) + + za_register = register_sets.GetFirstValueByName( + "Scalable Matrix Array Storage Registers") + self.assertTrue(za_register.IsValid()) + vg = ssve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned() + vl = vg * 8 + # When first enabled it is all 0s. + self.expect("register read za", substrs=[self.make_za_value(vl, lambda r: 0)]) + za_value = self.make_za_value(vl, lambda r:r+1) + self.runCmd("register write za '{}'".format(za_value)) + self.expect("register read za", substrs=[za_value]) + + # SVG should match VG because we're in streaming mode. + sme_registers = register_sets.GetFirstValueByName( + "Scalable Matrix Extension Registers") + self.assertTrue(sme_registers.IsValid()) + svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned() + self.assertEqual(vg, svg) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_aarch64_dynamic_regset_config_sme_za_disabled(self): + """Test that ZA shows as 0s when disabled and can be enabled by writing + to it.""" + if not self.isAArch64SME(): + self.skipTest("SME must be present.") + + # No argument, so ZA will be disabled when we break. + register_sets = self.setup_register_config_test() + + # vg is the non-streaming vg as we are in non-streaming mode, so we need + # to use svg. + sme_registers = register_sets.GetFirstValueByName( + "Scalable Matrix Extension Registers") + self.assertTrue(sme_registers.IsValid()) + svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned() + + za_register = register_sets.GetFirstValueByName( + "Scalable Matrix Array Storage Registers") + self.assertTrue(za_register.IsValid()) + svl = svg * 8 + # A disabled ZA is shown as all 0s. + self.expect("register read za", substrs=[self.make_za_value(svl, lambda r: 0)]) + za_value = self.make_za_value(svl, lambda r:r+1) + # Writing to it enables ZA, so the value should be there when we read + # it back. + self.runCmd("register write za '{}'".format(za_value)) + self.expect("register read za", substrs=[za_value]) diff --git a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py b/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py --- a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py +++ b/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py @@ -98,6 +98,12 @@ self.expect("register read ffr", substrs=[p_regs_value]) + def build_for_mode(self, mode): + cflags = "-march=armv8-a+sve -lpthread" + if mode == Mode.SSVE: + cflags += " -DUSE_SSVE" + self.build(dictionary={"CFLAGS_EXTRAS": cflags}) + def run_sve_test(self, mode): if (mode == Mode.SVE) and not self.isAArch64SVE(): self.skipTest("SVE registers must be supported.") @@ -105,12 +111,8 @@ if (mode == Mode.SSVE) and not self.isAArch64SME(): self.skipTest("Streaming SVE registers must be supported.") - cflags = "-march=armv8-a+sve -lpthread" - if mode == Mode.SSVE: - cflags += " -DUSE_SSVE" - self.build(dictionary={"CFLAGS_EXTRAS": cflags}) + self.build_for_mode(mode) - self.build() supported_vg = self.get_supported_vg() if not (2 in supported_vg and 4 in supported_vg): @@ -200,3 +202,95 @@ def test_ssve_registers_dynamic_config(self): """Test AArch64 SSVE registers multi-threaded dynamic resize.""" self.run_sve_test(Mode.SSVE) + + def setup_svg_test(self, mode): + # Even when running in SVE mode, we need access to SVG for these tests. + if not self.isAArch64SME(): + self.skipTest("Streaming SVE registers must be present.") + + self.build_for_mode(mode) + + supported_vg = self.get_supported_vg() + + main_thread_stop_line = line_number("main.c", "// Break in main thread") + lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line) + + self.runCmd("run", RUN_SUCCEEDED) + + self.expect( + "thread info 1", + STOPPED_DUE_TO_BREAKPOINT, + substrs=["stop reason = breakpoint"], + ) + + target = self.dbg.GetSelectedTarget() + process = target.GetProcess() + + return process, supported_vg + + def read_reg(self, process, regset, reg): + registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters() + sve_registers = registerSets.GetFirstValueByName(regset) + return sve_registers.GetChildMemberWithName(reg).GetValueAsUnsigned() + + def read_vg(self, process): + return self.read_reg(process, "Scalable Vector Extension Registers", "vg") + + def read_svg(self, process): + return self.read_reg(process, "Scalable Matrix Extension Registers", "svg") + + def do_svg_test(self, process, vgs, expected_svgs): + for vg, svg in zip(vgs, expected_svgs): + self.runCmd("register write vg {}".format(vg)) + self.assertEqual(svg, self.read_svg(process)) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_svg_sve_mode(self): + """ When in SVE mode, svg should remain constant as we change vg. """ + process, supported_vg = self.setup_svg_test(Mode.SVE) + svg = self.read_svg(process) + self.do_svg_test(process, supported_vg, [svg]*len(supported_vg)) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_svg_ssve_mode(self): + """ When in SSVE mode, changing vg should change svg to the same value. """ + process, supported_vg = self.setup_svg_test(Mode.SSVE) + self.do_svg_test(process, supported_vg, supported_vg) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_sme_not_present(self): + """ When there is no SME, we should not show the SME register sets.""" + if self.isAArch64SME(): + self.skipTest("Streaming SVE registers must not be present.") + + self.build_for_mode(Mode.SVE) + + exe = self.getBuildArtifact("a.out") + self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) + + # This test may run on a non-sve system, but we'll stop before any + # SVE instruction would be run. + self.runCmd("b main") + self.runCmd("run", RUN_SUCCEEDED) + + self.expect( + "thread info 1", + STOPPED_DUE_TO_BREAKPOINT, + substrs=["stop reason = breakpoint"], + ) + + target = self.dbg.GetSelectedTarget() + process = target.GetProcess() + + registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters() + sme_registers = registerSets.GetFirstValueByName("Scalable Matrix Extension Registers") + self.assertFalse(sme_registers.IsValid()) + + za = registerSets.GetFirstValueByName("Scalable Matrix Array Storage Registers") + self.assertFalse(za.IsValid()) diff --git a/lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/Makefile b/lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/Makefile new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/Makefile @@ -0,0 +1,5 @@ +C_SOURCES := main.c + +CFLAGS_EXTRAS := -march=armv8-a+sve+sme -lpthread + +include Makefile.rules diff --git a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py b/lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/TestZAThreadedDynamic.py copy from lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py copy to lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/TestZAThreadedDynamic.py --- a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py +++ b/lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/TestZAThreadedDynamic.py @@ -1,11 +1,6 @@ """ -Test the AArch64 SVE and Streaming SVE (SSVE) registers dynamic resize with +Test the AArch64 SME Array Storage (ZA) register dynamic resize with multiple threads. - -This test assumes a minimum supported vector length (VL) of 256 bits -and will test 512 bits if possible. We refer to "vg" which is the -register shown in lldb. This is in units of 64 bits. 256 bit VL is -the same as a vg of 4. """ from enum import Enum @@ -15,21 +10,15 @@ from lldbsuite.test import lldbutil -class Mode(Enum): - SVE = 0 - SSVE = 1 - - -class RegisterCommandsTestCase(TestBase): +class AArch64ZAThreadedTestCase(TestBase): def get_supported_vg(self): - # Changing VL trashes the register state, so we need to run the program - # just to test this. Then run it again for the test. exe = self.getBuildArtifact("a.out") self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) main_thread_stop_line = line_number("main.c", "// Break in main thread") lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line) + self.runCmd("settings set target.run-args 0") self.runCmd("run", RUN_SUCCEEDED) self.expect( @@ -38,7 +27,6 @@ substrs=["stop reason = breakpoint"], ) - # Write back the current vg to confirm read/write works at all. current_vg = self.match("register read vg", ["(0x[0-9]+)"]) self.assertTrue(current_vg is not None) self.expect("register write vg {}".format(current_vg.group())) @@ -57,64 +45,36 @@ return supported_vg - def check_sve_registers(self, vg_test_value): - z_reg_size = vg_test_value * 8 - p_reg_size = int(z_reg_size / 8) - - p_value_bytes = ["0xff", "0x55", "0x11", "0x01", "0x00"] - - for i in range(32): - s_reg_value = "s%i = 0x" % (i) + "".join( - "{:02x}".format(i + 1) for _ in range(4) - ) - - d_reg_value = "d%i = 0x" % (i) + "".join( - "{:02x}".format(i + 1) for _ in range(8) - ) - - v_reg_value = "v%i = 0x" % (i) + "".join( - "{:02x}".format(i + 1) for _ in range(16) - ) - - z_reg_value = ( - "{" - + " ".join("0x{:02x}".format(i + 1) for _ in range(z_reg_size)) - + "}" - ) - - self.expect("register read -f hex " + "s%i" % (i), substrs=[s_reg_value]) + def gen_za_value(self, svg, value_generator): + svl = svg*8 - self.expect("register read -f hex " + "d%i" % (i), substrs=[d_reg_value]) + rows = [] + for row in range(svl): + byte = "0x{:02x}".format(value_generator(row)) + rows.append(" ".join([byte]*svl)) - self.expect("register read -f hex " + "v%i" % (i), substrs=[v_reg_value]) + return "{" + " ".join(rows) + "}" - self.expect("register read " + "z%i" % (i), substrs=[z_reg_value]) + def check_za_register(self, svg, value_offset): + self.expect("register read za", substrs=[ + self.gen_za_value(svg, lambda r: r+value_offset)]) - for i in range(16): - p_regs_value = ( - "{" + " ".join(p_value_bytes[i % 5] for _ in range(p_reg_size)) + "}" - ) - self.expect("register read " + "p%i" % (i), substrs=[p_regs_value]) + def check_disabled_za_register(self, svg): + self.expect("register read za", substrs=[ + self.gen_za_value(svg, lambda r: 0)]) - self.expect("register read ffr", substrs=[p_regs_value]) - - def run_sve_test(self, mode): - if (mode == Mode.SVE) and not self.isAArch64SVE(): - self.skipTest("SVE registers must be supported.") - - if (mode == Mode.SSVE) and not self.isAArch64SME(): - self.skipTest("Streaming SVE registers must be supported.") - - cflags = "-march=armv8-a+sve -lpthread" - if mode == Mode.SSVE: - cflags += " -DUSE_SSVE" - self.build(dictionary={"CFLAGS_EXTRAS": cflags}) + def za_test_impl(self, enable_za): + if not self.isAArch64SME(): + self.skipTest("SME must be present.") self.build() supported_vg = self.get_supported_vg() + self.runCmd("settings set target.run-args {}".format( + '1' if enable_za else '0')) + if not (2 in supported_vg and 4 in supported_vg): - self.skipTest("Not all required SVE vector lengths are supported.") + self.skipTest("Not all required streaming vector lengths are supported.") main_thread_stop_line = line_number("main.c", "// Break in main thread") lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line) @@ -133,8 +93,6 @@ self.runCmd("run", RUN_SUCCEEDED) - process = self.dbg.GetSelectedTarget().GetProcess() - self.expect( "thread info 1", STOPPED_DUE_TO_BREAKPOINT, @@ -142,16 +100,19 @@ ) if 8 in supported_vg: - self.check_sve_registers(8) + if enable_za: + self.check_za_register(8, 1) + else: + self.check_disabled_za_register(8) else: - self.check_sve_registers(4) + if enable_za: + self.check_za_register(4, 1) + else: + self.check_disabled_za_register(4) self.runCmd("process continue", RUN_SUCCEEDED) - # If we start the checks too quickly, thread 3 may not have started. - while process.GetNumThreads() < 3: - pass - + process = self.dbg.GetSelectedTarget().GetProcess() for idx in range(1, process.GetNumThreads()): thread = process.GetThreadAtIndex(idx) if thread.GetStopReason() != lldb.eStopReasonBreakpoint: @@ -162,12 +123,12 @@ if stopped_at_line_number == thX_break_line1: self.runCmd("thread select %d" % (idx + 1)) - self.check_sve_registers(4) + self.check_za_register(4, 2) self.runCmd("register write vg 2") elif stopped_at_line_number == thY_break_line1: self.runCmd("thread select %d" % (idx + 1)) - self.check_sve_registers(2) + self.check_za_register(2, 3) self.runCmd("register write vg 4") self.runCmd("thread continue 2") @@ -181,22 +142,24 @@ if stopped_at_line_number == thX_break_line2: self.runCmd("thread select %d" % (idx + 1)) - self.check_sve_registers(2) + self.check_za_register(2, 2) elif stopped_at_line_number == thY_break_line2: self.runCmd("thread select %d" % (idx + 1)) - self.check_sve_registers(4) + self.check_za_register(4, 3) @no_debug_info_test @skipIf(archs=no_match(["aarch64"])) @skipIf(oslist=no_match(["linux"])) - def test_sve_registers_dynamic_config(self): - """Test AArch64 SVE registers multi-threaded dynamic resize.""" - self.run_sve_test(Mode.SVE) + def test_za_register_dynamic_config_main_enabled(self): + """ Test multiple threads resizing ZA, with the main thread's ZA + enabled.""" + self.za_test_impl(True) @no_debug_info_test @skipIf(archs=no_match(["aarch64"])) @skipIf(oslist=no_match(["linux"])) - def test_ssve_registers_dynamic_config(self): - """Test AArch64 SSVE registers multi-threaded dynamic resize.""" - self.run_sve_test(Mode.SSVE) + def test_za_register_dynamic_config_main_disabled(self): + """ Test multiple threads resizing ZA, with the main thread's ZA + disabled.""" + self.za_test_impl(False) \ No newline at end of file diff --git a/lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/main.c b/lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/main.c new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_reg/za_dynamic_resize/main.c @@ -0,0 +1,102 @@ +#include +#include +#include +#include +#include +#include + +// Important notes for this test: +// * Making a syscall will disable streaming mode. +// * LLDB writing to vg while in streaming mode will disable ZA +// (this is just how ptrace works). +// * Writing to an inactive ZA produces a SIGILL. + +#ifndef PR_SME_SET_VL +#define PR_SME_SET_VL 63 +#endif + +#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr") +#define SMSTART_SM SM_INST(3) +#define SMSTART_ZA SM_INST(5) + +void set_za_register(int svl, int value_offset) { +#define MAX_VL_BYTES 256 + uint8_t data[MAX_VL_BYTES]; + + // ldr za will actually wrap the selected vector row, by the number of rows + // you have. So setting one that didn't exist would actually set one that did. + // That's why we need the streaming vector length here. + for (int i = 0; i < svl; ++i) { + memset(data, i + value_offset, MAX_VL_BYTES); + // Each one of these loads a VL sized row of ZA. + asm volatile("mov w12, %w0\n\t" + "ldr za[w12, 0], [%1]\n\t" ::"r"(i), + "r"(&data) + : "w12"); + } +} + +// These are used to make sure we only break in each thread once both of the +// threads have been started. Otherwise when the test does "process continue" +// it could stop in one thread and wait forever for the other one to start. +atomic_bool threadX_ready = false; +atomic_bool threadY_ready = false; + +void *threadX_func(void *x_arg) { + threadX_ready = true; + while (!threadY_ready) { + } + + prctl(PR_SME_SET_VL, 8 * 4); + SMSTART_SM; + SMSTART_ZA; + set_za_register(8 * 4, 2); + SMSTART_ZA; // Thread X breakpoint 1 + set_za_register(8 * 2, 2); + return NULL; // Thread X breakpoint 2 +} + +void *threadY_func(void *y_arg) { + threadY_ready = true; + while (!threadX_ready) { + } + + prctl(PR_SME_SET_VL, 8 * 2); + SMSTART_SM; + SMSTART_ZA; + set_za_register(8 * 2, 3); + SMSTART_ZA; // Thread Y breakpoint 1 + set_za_register(8 * 4, 3); + return NULL; // Thread Y breakpoint 2 +} + +int main(int argc, char *argv[]) { + // Expecting argument to tell us whether to enable ZA on the main thread. + if (argc != 2) + return 1; + + prctl(PR_SME_SET_VL, 8 * 8); + SMSTART_SM; + + if (argv[1][0] == '1') { + SMSTART_ZA; + set_za_register(8 * 8, 1); + } + // else we do not enable ZA and lldb will show 0s for it. + + pthread_t x_thread; + if (pthread_create(&x_thread, NULL, threadX_func, 0)) // Break in main thread + return 1; + + pthread_t y_thread; + if (pthread_create(&y_thread, NULL, threadY_func, 0)) + return 1; + + if (pthread_join(x_thread, NULL)) + return 2; + + if (pthread_join(y_thread, NULL)) + return 2; + + return 0; +} diff --git a/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/Makefile b/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/Makefile new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/Makefile @@ -0,0 +1,5 @@ +C_SOURCES := main.c + +CFLAGS_EXTRAS := -march=armv8-a+sve+sme + +include Makefile.rules diff --git a/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/TestZARegisterSaveRestore.py b/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/TestZARegisterSaveRestore.py new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/TestZARegisterSaveRestore.py @@ -0,0 +1,237 @@ +""" +Test the AArch64 SME ZA register is saved and restored around expressions. + +This attempts to cover expressions that change the following: +* ZA enabled or not. +* Streaming mode or not. +* Streaming vector length (increasing and decreasing). +* Some combintations of the above. +""" + +from enum import IntEnum +import lldb +from lldbsuite.test.decorators import * +from lldbsuite.test.lldbtest import * +from lldbsuite.test import lldbutil + + +# These enum values match the flag values used in the test program. +class Mode(IntEnum): + SVE = 0 + SSVE = 1 + + +class ZA(IntEnum): + Disabled = 0 + Enabled = 1 + + +class AArch64ZATestCase(TestBase): + def get_supported_svg(self): + # Always build this probe program to start as streaming SVE. + # We will read/write "vg" here but since we are in streaming mode "svg" + # is really what we are writing ("svg" is a read only pseudo). + self.build() + + exe = self.getBuildArtifact("a.out") + self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) + # Enter streaming mode, don't enable ZA, start_vl and other_vl don't + # matter here. + self.runCmd("settings set target.run-args 1 0 0 0") + + stop_line = line_number("main.c", "// Set a break point here.") + lldbutil.run_break_set_by_file_and_line(self, "main.c", stop_line, + num_expected_locations=1) + + self.runCmd("run", RUN_SUCCEEDED) + + self.expect( + "thread info 1", + STOPPED_DUE_TO_BREAKPOINT, + substrs=["stop reason = breakpoint"], + ) + + # Write back the current vg to confirm read/write works at all. + current_svg = self.match("register read vg", ["(0x[0-9]+)"]) + self.assertTrue(current_svg is not None) + self.expect("register write vg {}".format(current_svg.group())) + + # Aka 128, 256 and 512 bit. + supported_svg = [] + for svg in [2, 4, 8]: + # This could mask other errors but writing vg is tested elsewhere + # so we assume the hardware rejected the value. + self.runCmd("register write vg {}".format(svg), check=False) + if not self.res.GetError(): + supported_svg.append(svg) + + self.runCmd("breakpoint delete 1") + self.runCmd("continue") + + return supported_svg + + def read_vg(self): + process = self.dbg.GetSelectedTarget().GetProcess() + registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters() + sve_registers = registerSets.GetFirstValueByName("Scalable Vector Extension Registers") + return sve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned() + + def read_svg(self): + process = self.dbg.GetSelectedTarget().GetProcess() + registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters() + sve_registers = registerSets.GetFirstValueByName("Scalable Matrix Extension Registers") + return sve_registers.GetChildMemberWithName("svg").GetValueAsUnsigned() + + def make_za_value(self, vl, generator): + # Generate a vector value string "{0x00 0x01....}". + rows = [] + for row in range(vl): + byte = "0x{:02x}".format(generator(row)) + rows.append(" ".join([byte]*vl)) + return "{" + " ".join(rows) + "}" + + def check_za(self, vl): + # We expect an increasing value starting at 1. Row 0=1, row 1 = 2, etc. + self.expect("register read za", substrs=[ + self.make_za_value(vl, lambda row: row+1)]) + + def check_za_disabled(self, vl): + # When ZA is disabled, lldb will show ZA as all 0s. + self.expect("register read za", substrs=[ + self.make_za_value(vl, lambda row: 0)]) + + def za_expr_test_impl(self, sve_mode, za_state, swap_start_vl): + if not self.isAArch64SME(): + self.skipTest("SME must be present.") + + supported_svg = self.get_supported_svg() + if len(supported_svg) < 2: + self.skipTest("Target must support at least 2 streaming vector lengths.") + + # vg is in units of 8 bytes. + start_vl = supported_svg[1] * 8 + other_vl = supported_svg[2] * 8 + + if swap_start_vl: + start_vl, other_vl = other_vl, start_vl + + self.line = line_number("main.c", "// Set a break point here.") + + exe = self.getBuildArtifact("a.out") + self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) + self.runCmd("settings set target.run-args {} {} {} {}".format(sve_mode, + za_state, start_vl, other_vl)) + + lldbutil.run_break_set_by_file_and_line( + self, "main.c", self.line, num_expected_locations=1 + ) + self.runCmd("run", RUN_SUCCEEDED) + + self.expect( + "thread backtrace", + STOPPED_DUE_TO_BREAKPOINT, + substrs=["stop reason = breakpoint 1."], + ) + + exprs = ["expr_disable_za", "expr_enable_za", "expr_start_vl", + "expr_other_vl", "expr_enable_sm", "expr_disable_sm"] + + # This may be the streaming or non-streaming vg. All that matters is + # that it is saved and restored, remaining constant throughout. + start_vg = self.read_vg() + + # Check SVE registers to make sure that combination of scaling SVE + # and scaling ZA works properly. This is a brittle check, but failures + # are likely to be catastrophic when they do happen anyway. + sve_reg_names = "ffr {} {}".format( + " ".join(["z{}".format(n) for n in range(32)]), + " ".join(["p{}".format(n) for n in range(16)])) + self.runCmd("register read " + sve_reg_names) + sve_values = self.res.GetOutput() + + def check_regs(): + if za_state == ZA.Enabled: + self.check_za(start_vl) + else: + self.check_za_disabled(start_vl) + + # svg and vg are in units of 8 bytes. + self.assertEqual(start_vl, self.read_svg()*8) + self.assertEqual(start_vg, self.read_vg()) + + self.expect("register read " + sve_reg_names, substrs=[sve_values]) + + for expr in exprs: + expr_cmd = "expression {}()".format(expr) + + # We do this twice because there were issues in development where + # using data stored by a previous WriteAllRegisterValues would crash + # the second time around. + self.runCmd(expr_cmd) + check_regs() + self.runCmd(expr_cmd) + check_regs() + + # Run them in sequence to make sure there is no state lingering between + # them after a restore. + for expr in exprs: + self.runCmd("expression {}()".format(expr)) + check_regs() + + for expr in reversed(exprs): + self.runCmd("expression {}()".format(expr)) + check_regs() + + # These tests start with the 1st supported SVL and change to the 2nd + # supported SVL as needed. + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_ssve_za_enabled(self): + self.za_expr_test_impl(Mode.SSVE, ZA.Enabled, False) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_ssve_za_disabled(self): + self.za_expr_test_impl(Mode.SSVE, ZA.Disabled, False) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_sve_za_enabled(self): + self.za_expr_test_impl(Mode.SVE, ZA.Enabled, False) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_sve_za_disabled(self): + self.za_expr_test_impl(Mode.SVE, ZA.Disabled, False) + + # These tests start in the 2nd supported SVL and change to the 1st supported + # SVL as needed. + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_ssve_za_enabled_different_vl(self): + self.za_expr_test_impl(Mode.SSVE, ZA.Enabled, True) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_ssve_za_disabled_different_vl(self): + self.za_expr_test_impl(Mode.SSVE, ZA.Disabled, True) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_sve_za_enabled_different_vl(self): + self.za_expr_test_impl(Mode.SVE, ZA.Enabled, True) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_sve_za_disabled_different_vl(self): + self.za_expr_test_impl(Mode.SVE, ZA.Disabled, True) diff --git a/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/main.c b/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/main.c new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_reg/za_save_restore/main.c @@ -0,0 +1,225 @@ +#include +#include +#include +#include +#include + +// Important details for this program: +// * Making a syscall will disable streaming mode if it is active. +// * Changing the vector length will make streaming mode and ZA inactive. +// * ZA can be active independent of streaming mode. +// * ZA's size is the streaming vector length squared. + +#ifndef PR_SME_SET_VL +#define PR_SME_SET_VL 63 +#endif + +#ifndef PR_SME_GET_VL +#define PR_SME_GET_VL 64 +#endif + +#ifndef PR_SME_VL_LEN_MASK +#define PR_SME_VL_LEN_MASK 0xffff +#endif + +#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr") +#define SMSTART SM_INST(7) +#define SMSTART_SM SM_INST(3) +#define SMSTART_ZA SM_INST(5) +#define SMSTOP SM_INST(6) +#define SMSTOP_SM SM_INST(2) +#define SMSTOP_ZA SM_INST(4) + +int start_vl = 0; +int other_vl = 0; + +void write_sve_regs() { + // We assume the smefa64 feature is present, which allows ffr access + // in streaming mode. + asm volatile("setffr\n\t"); + asm volatile("ptrue p0.b\n\t"); + asm volatile("ptrue p1.h\n\t"); + asm volatile("ptrue p2.s\n\t"); + asm volatile("ptrue p3.d\n\t"); + asm volatile("pfalse p4.b\n\t"); + asm volatile("ptrue p5.b\n\t"); + asm volatile("ptrue p6.h\n\t"); + asm volatile("ptrue p7.s\n\t"); + asm volatile("ptrue p8.d\n\t"); + asm volatile("pfalse p9.b\n\t"); + asm volatile("ptrue p10.b\n\t"); + asm volatile("ptrue p11.h\n\t"); + asm volatile("ptrue p12.s\n\t"); + asm volatile("ptrue p13.d\n\t"); + asm volatile("pfalse p14.b\n\t"); + asm volatile("ptrue p15.b\n\t"); + + asm volatile("cpy z0.b, p0/z, #1\n\t"); + asm volatile("cpy z1.b, p5/z, #2\n\t"); + asm volatile("cpy z2.b, p10/z, #3\n\t"); + asm volatile("cpy z3.b, p15/z, #4\n\t"); + asm volatile("cpy z4.b, p0/z, #5\n\t"); + asm volatile("cpy z5.b, p5/z, #6\n\t"); + asm volatile("cpy z6.b, p10/z, #7\n\t"); + asm volatile("cpy z7.b, p15/z, #8\n\t"); + asm volatile("cpy z8.b, p0/z, #9\n\t"); + asm volatile("cpy z9.b, p5/z, #10\n\t"); + asm volatile("cpy z10.b, p10/z, #11\n\t"); + asm volatile("cpy z11.b, p15/z, #12\n\t"); + asm volatile("cpy z12.b, p0/z, #13\n\t"); + asm volatile("cpy z13.b, p5/z, #14\n\t"); + asm volatile("cpy z14.b, p10/z, #15\n\t"); + asm volatile("cpy z15.b, p15/z, #16\n\t"); + asm volatile("cpy z16.b, p0/z, #17\n\t"); + asm volatile("cpy z17.b, p5/z, #18\n\t"); + asm volatile("cpy z18.b, p10/z, #19\n\t"); + asm volatile("cpy z19.b, p15/z, #20\n\t"); + asm volatile("cpy z20.b, p0/z, #21\n\t"); + asm volatile("cpy z21.b, p5/z, #22\n\t"); + asm volatile("cpy z22.b, p10/z, #23\n\t"); + asm volatile("cpy z23.b, p15/z, #24\n\t"); + asm volatile("cpy z24.b, p0/z, #25\n\t"); + asm volatile("cpy z25.b, p5/z, #26\n\t"); + asm volatile("cpy z26.b, p10/z, #27\n\t"); + asm volatile("cpy z27.b, p15/z, #28\n\t"); + asm volatile("cpy z28.b, p0/z, #29\n\t"); + asm volatile("cpy z29.b, p5/z, #30\n\t"); + asm volatile("cpy z30.b, p10/z, #31\n\t"); + asm volatile("cpy z31.b, p15/z, #32\n\t"); +} + +// Write something different so we will know if we didn't restore them +// correctly. +void write_sve_regs_expr() { + asm volatile("pfalse p0.b\n\t"); + asm volatile("wrffr p0.b\n\t"); + asm volatile("pfalse p1.b\n\t"); + asm volatile("pfalse p2.b\n\t"); + asm volatile("pfalse p3.b\n\t"); + asm volatile("ptrue p4.b\n\t"); + asm volatile("pfalse p5.b\n\t"); + asm volatile("pfalse p6.b\n\t"); + asm volatile("pfalse p7.b\n\t"); + asm volatile("pfalse p8.b\n\t"); + asm volatile("ptrue p9.b\n\t"); + asm volatile("pfalse p10.b\n\t"); + asm volatile("pfalse p11.b\n\t"); + asm volatile("pfalse p12.b\n\t"); + asm volatile("pfalse p13.b\n\t"); + asm volatile("ptrue p14.b\n\t"); + asm volatile("pfalse p15.b\n\t"); + + asm volatile("cpy z0.b, p0/z, #2\n\t"); + asm volatile("cpy z1.b, p5/z, #3\n\t"); + asm volatile("cpy z2.b, p10/z, #4\n\t"); + asm volatile("cpy z3.b, p15/z, #5\n\t"); + asm volatile("cpy z4.b, p0/z, #6\n\t"); + asm volatile("cpy z5.b, p5/z, #7\n\t"); + asm volatile("cpy z6.b, p10/z, #8\n\t"); + asm volatile("cpy z7.b, p15/z, #9\n\t"); + asm volatile("cpy z8.b, p0/z, #10\n\t"); + asm volatile("cpy z9.b, p5/z, #11\n\t"); + asm volatile("cpy z10.b, p10/z, #12\n\t"); + asm volatile("cpy z11.b, p15/z, #13\n\t"); + asm volatile("cpy z12.b, p0/z, #14\n\t"); + asm volatile("cpy z13.b, p5/z, #15\n\t"); + asm volatile("cpy z14.b, p10/z, #16\n\t"); + asm volatile("cpy z15.b, p15/z, #17\n\t"); + asm volatile("cpy z16.b, p0/z, #18\n\t"); + asm volatile("cpy z17.b, p5/z, #19\n\t"); + asm volatile("cpy z18.b, p10/z, #20\n\t"); + asm volatile("cpy z19.b, p15/z, #21\n\t"); + asm volatile("cpy z20.b, p0/z, #22\n\t"); + asm volatile("cpy z21.b, p5/z, #23\n\t"); + asm volatile("cpy z22.b, p10/z, #24\n\t"); + asm volatile("cpy z23.b, p15/z, #25\n\t"); + asm volatile("cpy z24.b, p0/z, #26\n\t"); + asm volatile("cpy z25.b, p5/z, #27\n\t"); + asm volatile("cpy z26.b, p10/z, #28\n\t"); + asm volatile("cpy z27.b, p15/z, #29\n\t"); + asm volatile("cpy z28.b, p0/z, #30\n\t"); + asm volatile("cpy z29.b, p5/z, #31\n\t"); + asm volatile("cpy z30.b, p10/z, #32\n\t"); + asm volatile("cpy z31.b, p15/z, #33\n\t"); +} + +void set_za_register(int svl, int value_offset) { +#define MAX_VL_BYTES 256 + uint8_t data[MAX_VL_BYTES]; + + // ldr za will actually wrap the selected vector row, by the number of rows + // you have. So setting one that didn't exist would actually set one that did. + // That's why we need the streaming vector length here. + for (int i = 0; i < svl; ++i) { + memset(data, i + value_offset, MAX_VL_BYTES); + // Each one of these loads a VL sized row of ZA. + asm volatile("mov w12, %w0\n\t" + "ldr za[w12, 0], [%1]\n\t" ::"r"(i), + "r"(&data) + : "w12"); + } +} + +void expr_disable_za() { + SMSTOP_ZA; + write_sve_regs_expr(); +} + +void expr_enable_za() { + SMSTART_ZA; + set_za_register(start_vl, 2); + write_sve_regs_expr(); +} + +void expr_start_vl() { + prctl(PR_SME_SET_VL, start_vl); + SMSTART_ZA; + set_za_register(start_vl, 4); + write_sve_regs_expr(); +} + +void expr_other_vl() { + prctl(PR_SME_SET_VL, other_vl); + SMSTART_ZA; + set_za_register(other_vl, 5); + write_sve_regs_expr(); +} + +void expr_enable_sm() { + SMSTART_SM; + write_sve_regs_expr(); +} + +void expr_disable_sm() { + SMSTOP_SM; + write_sve_regs_expr(); +} + +int main(int argc, char *argv[]) { + // We expect to get: + // * whether to enable streaming mode + // * whether to enable ZA + // * what the starting VL should be + // * what the other VL should be + if (argc != 5) + return 1; + + bool ssve = argv[1][0] == '1'; + bool za = argv[2][0] == '1'; + start_vl = atoi(argv[3]); + other_vl = atoi(argv[4]); + + prctl(PR_SME_SET_VL, start_vl); + + if (ssve) + SMSTART_SM; + + if (za) { + SMSTART_ZA; + set_za_register(start_vl, 1); + } + + write_sve_regs(); + + return 0; // Set a break point here. +}