diff --git a/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py b/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py --- a/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py +++ b/lldb/test/API/commands/register/register/aarch64_dynamic_regset/TestArm64DynamicRegsets.py @@ -70,15 +70,13 @@ self.runCmd("register write ffr " + "'" + p_regs_value + "'") self.expect("register read ffr", substrs=[p_regs_value]) - @no_debug_info_test - @skipIf(archs=no_match(["aarch64"])) - @skipIf(oslist=no_match(["linux"])) - def test_aarch64_dynamic_regset_config(self): - """Test AArch64 Dynamic Register sets configuration.""" + def setup_register_config_test(self, run_args=None): self.build() self.line = line_number("main.c", "// Set a break point here.") exe = self.getBuildArtifact("a.out") + if run_args is not None: + self.runCmd("settings set target.run-args " + run_args) self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) lldbutil.run_break_set_by_file_and_line( @@ -92,12 +90,16 @@ substrs=["stop reason = breakpoint 1."], ) - target = self.dbg.GetSelectedTarget() - process = target.GetProcess() - thread = process.GetThreadAtIndex(0) - currentFrame = thread.GetFrameAtIndex(0) + return self.thread().GetSelectedFrame().GetRegisters() + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_aarch64_dynamic_regset_config(self): + """Test AArch64 Dynamic Register sets configuration.""" + register_sets = self.setup_register_config_test() - for registerSet in currentFrame.GetRegisters(): + for registerSet in register_sets: if "Scalable Vector Extension Registers" in registerSet.GetName(): self.assertTrue( self.isAArch64SVE(), @@ -120,6 +122,19 @@ ) self.expect("register read data_mask", substrs=["data_mask = 0x"]) self.expect("register read code_mask", substrs=["code_mask = 0x"]) + if "Scalable Matrix Extension Registers" in registerSet.GetName(): + self.assertTrue( + self.isAArch64SME(), + "LLDB Enabled SME register set when it was disabled by target", + ) + + def make_za_value(self, vl, generator): + # Generate a vector value string "{0x00 0x01....}". + rows = [] + for row in range(vl): + byte = "0x{:02x}".format(generator(row)) + rows.append(" ".join([byte] * vl)) + return "{" + " ".join(rows) + "}" @no_debug_info_test @skipIf(archs=no_match(["aarch64"])) @@ -130,27 +145,57 @@ if not self.isAArch64SME(): self.skipTest("SME must be present.") - self.build() - self.line = line_number("main.c", "// Set a break point here.") + register_sets = self.setup_register_config_test("sme") - exe = self.getBuildArtifact("a.out") - self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) + ssve_registers = register_sets.GetFirstValueByName( + "Scalable Vector Extension Registers") + self.assertTrue(ssve_registers.IsValid()) + self.sve_regs_read_dynamic(ssve_registers) - lldbutil.run_break_set_by_file_and_line( - self, "main.c", self.line, num_expected_locations=1 + sme_registers = register_sets.GetFirstValueByName( + "Scalable Matrix Extension Registers" ) - self.runCmd("settings set target.run-args sme") - self.runCmd("run", RUN_SUCCEEDED) + self.assertTrue(sme_registers.IsValid()) - self.expect( - "thread backtrace", - STOPPED_DUE_TO_BREAKPOINT, - substrs=["stop reason = breakpoint 1."], - ) + vg = ssve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned() + vl = vg * 8 + # When first enabled it is all 0s. + self.expect("register read za", substrs=[self.make_za_value(vl, lambda r: 0)]) + za_value = self.make_za_value(vl, lambda r: r + 1) + self.runCmd("register write za '{}'".format(za_value)) + self.expect("register read za", substrs=[za_value]) - register_sets = self.thread().GetSelectedFrame().GetRegisters() + # SVG should match VG because we're in streaming mode. - ssve_registers = register_sets.GetFirstValueByName( - "Scalable Vector Extension Registers") - self.assertTrue(ssve_registers.IsValid()) - self.sve_regs_read_dynamic(ssve_registers) + self.assertTrue(sme_registers.IsValid()) + svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned() + self.assertEqual(vg, svg) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_aarch64_dynamic_regset_config_sme_za_disabled(self): + """Test that ZA shows as 0s when disabled and can be enabled by writing + to it.""" + if not self.isAArch64SME(): + self.skipTest("SME must be present.") + + # No argument, so ZA will be disabled when we break. + register_sets = self.setup_register_config_test() + + # vg is the non-streaming vg as we are in non-streaming mode, so we need + # to use svg. + sme_registers = register_sets.GetFirstValueByName( + "Scalable Matrix Extension Registers" + ) + self.assertTrue(sme_registers.IsValid()) + svg = sme_registers.GetChildMemberWithName("svg").GetValueAsUnsigned() + + svl = svg * 8 + # A disabled ZA is shown as all 0s. + self.expect("register read za", substrs=[self.make_za_value(svl, lambda r: 0)]) + za_value = self.make_za_value(svl, lambda r: r + 1) + # Writing to it enables ZA, so the value should be there when we read + # it back. + self.runCmd("register write za '{}'".format(za_value)) + self.expect("register read za", substrs=[za_value]) diff --git a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py b/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py --- a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py +++ b/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py @@ -98,6 +98,12 @@ self.expect("register read ffr", substrs=[p_regs_value]) + def build_for_mode(self, mode): + cflags = "-march=armv8-a+sve -lpthread" + if mode == Mode.SSVE: + cflags += " -DUSE_SSVE" + self.build(dictionary={"CFLAGS_EXTRAS": cflags}) + def run_sve_test(self, mode): if (mode == Mode.SVE) and not self.isAArch64SVE(): self.skipTest("SVE registers must be supported.") @@ -105,12 +111,8 @@ if (mode == Mode.SSVE) and not self.isAArch64SME(): self.skipTest("Streaming SVE registers must be supported.") - cflags = "-march=armv8-a+sve -lpthread" - if mode == Mode.SSVE: - cflags += " -DUSE_SSVE" - self.build(dictionary={"CFLAGS_EXTRAS": cflags}) + self.build_for_mode(mode) - self.build() supported_vg = self.get_supported_vg() if not (2 in supported_vg and 4 in supported_vg): @@ -196,3 +198,94 @@ def test_ssve_registers_dynamic_config(self): """Test AArch64 SSVE registers multi-threaded dynamic resize.""" self.run_sve_test(Mode.SSVE) + + def setup_svg_test(self, mode): + # Even when running in SVE mode, we need access to SVG for these tests. + if not self.isAArch64SME(): + self.skipTest("Streaming SVE registers must be present.") + + self.build_for_mode(mode) + + supported_vg = self.get_supported_vg() + + main_thread_stop_line = line_number("main.c", "// Break in main thread") + lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line) + + self.runCmd("run", RUN_SUCCEEDED) + + self.expect( + "thread info 1", + STOPPED_DUE_TO_BREAKPOINT, + substrs=["stop reason = breakpoint"], + ) + + target = self.dbg.GetSelectedTarget() + process = target.GetProcess() + + return process, supported_vg + + def read_reg(self, process, regset, reg): + registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters() + sve_registers = registerSets.GetFirstValueByName(regset) + return sve_registers.GetChildMemberWithName(reg).GetValueAsUnsigned() + + def read_vg(self, process): + return self.read_reg(process, "Scalable Vector Extension Registers", "vg") + + def read_svg(self, process): + return self.read_reg(process, "Scalable Matrix Extension Registers", "svg") + + def do_svg_test(self, process, vgs, expected_svgs): + for vg, svg in zip(vgs, expected_svgs): + self.runCmd("register write vg {}".format(vg)) + self.assertEqual(svg, self.read_svg(process)) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_svg_sve_mode(self): + """When in SVE mode, svg should remain constant as we change vg.""" + process, supported_vg = self.setup_svg_test(Mode.SVE) + svg = self.read_svg(process) + self.do_svg_test(process, supported_vg, [svg] * len(supported_vg)) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_svg_ssve_mode(self): + """When in SSVE mode, changing vg should change svg to the same value.""" + process, supported_vg = self.setup_svg_test(Mode.SSVE) + self.do_svg_test(process, supported_vg, supported_vg) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_sme_not_present(self): + """When there is no SME, we should not show the SME register sets.""" + if self.isAArch64SME(): + self.skipTest("Streaming SVE registers must not be present.") + + self.build_for_mode(Mode.SVE) + + exe = self.getBuildArtifact("a.out") + self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) + + # This test may run on a non-sve system, but we'll stop before any + # SVE instruction would be run. + self.runCmd("b main") + self.runCmd("run", RUN_SUCCEEDED) + + self.expect( + "thread info 1", + STOPPED_DUE_TO_BREAKPOINT, + substrs=["stop reason = breakpoint"], + ) + + target = self.dbg.GetSelectedTarget() + process = target.GetProcess() + + registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters() + sme_registers = registerSets.GetFirstValueByName( + "Scalable Matrix Extension Registers" + ) + self.assertFalse(sme_registers.IsValid()) diff --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/Makefile @@ -0,0 +1,5 @@ +C_SOURCES := main.c + +CFLAGS_EXTRAS := -march=armv8-a+sve+sme -lpthread + +include Makefile.rules diff --git a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py copy from lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py copy to lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py --- a/lldb/test/API/commands/register/register/aarch64_sve_registers/rw_access_dynamic_resize/TestSVEThreadedDynamic.py +++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/TestZAThreadedDynamic.py @@ -1,11 +1,6 @@ """ -Test the AArch64 SVE and Streaming SVE (SSVE) registers dynamic resize with +Test the AArch64 SME Array Storage (ZA) register dynamic resize with multiple threads. - -This test assumes a minimum supported vector length (VL) of 256 bits -and will test 512 bits if possible. We refer to "vg" which is the -register shown in lldb. This is in units of 64 bits. 256 bit VL is -the same as a vg of 4. """ from enum import Enum @@ -15,21 +10,15 @@ from lldbsuite.test import lldbutil -class Mode(Enum): - SVE = 0 - SSVE = 1 - - -class RegisterCommandsTestCase(TestBase): +class AArch64ZAThreadedTestCase(TestBase): def get_supported_vg(self): - # Changing VL trashes the register state, so we need to run the program - # just to test this. Then run it again for the test. exe = self.getBuildArtifact("a.out") self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) main_thread_stop_line = line_number("main.c", "// Break in main thread") lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line) + self.runCmd("settings set target.run-args 0") self.runCmd("run", RUN_SUCCEEDED) self.expect( @@ -38,7 +27,6 @@ substrs=["stop reason = breakpoint"], ) - # Write back the current vg to confirm read/write works at all. current_vg = self.match("register read vg", ["(0x[0-9]+)"]) self.assertTrue(current_vg is not None) self.expect("register write vg {}".format(current_vg.group())) @@ -57,64 +45,36 @@ return supported_vg - def check_sve_registers(self, vg_test_value): - z_reg_size = vg_test_value * 8 - p_reg_size = int(z_reg_size / 8) - - p_value_bytes = ["0xff", "0x55", "0x11", "0x01", "0x00"] - - for i in range(32): - s_reg_value = "s%i = 0x" % (i) + "".join( - "{:02x}".format(i + 1) for _ in range(4) - ) - - d_reg_value = "d%i = 0x" % (i) + "".join( - "{:02x}".format(i + 1) for _ in range(8) - ) - - v_reg_value = "v%i = 0x" % (i) + "".join( - "{:02x}".format(i + 1) for _ in range(16) - ) - - z_reg_value = ( - "{" - + " ".join("0x{:02x}".format(i + 1) for _ in range(z_reg_size)) - + "}" - ) - - self.expect("register read -f hex " + "s%i" % (i), substrs=[s_reg_value]) + def gen_za_value(self, svg, value_generator): + svl = svg * 8 - self.expect("register read -f hex " + "d%i" % (i), substrs=[d_reg_value]) + rows = [] + for row in range(svl): + byte = "0x{:02x}".format(value_generator(row)) + rows.append(" ".join([byte] * svl)) - self.expect("register read -f hex " + "v%i" % (i), substrs=[v_reg_value]) + return "{" + " ".join(rows) + "}" - self.expect("register read " + "z%i" % (i), substrs=[z_reg_value]) - - for i in range(16): - p_regs_value = ( - "{" + " ".join(p_value_bytes[i % 5] for _ in range(p_reg_size)) + "}" - ) - self.expect("register read " + "p%i" % (i), substrs=[p_regs_value]) - - self.expect("register read ffr", substrs=[p_regs_value]) - - def run_sve_test(self, mode): - if (mode == Mode.SVE) and not self.isAArch64SVE(): - self.skipTest("SVE registers must be supported.") + def check_za_register(self, svg, value_offset): + self.expect( + "register read za", + substrs=[self.gen_za_value(svg, lambda r: r + value_offset)], + ) - if (mode == Mode.SSVE) and not self.isAArch64SME(): - self.skipTest("Streaming SVE registers must be supported.") + def check_disabled_za_register(self, svg): + self.expect("register read za", substrs=[self.gen_za_value(svg, lambda r: 0)]) - cflags = "-march=armv8-a+sve -lpthread" - if mode == Mode.SSVE: - cflags += " -DUSE_SSVE" - self.build(dictionary={"CFLAGS_EXTRAS": cflags}) + def za_test_impl(self, enable_za): + if not self.isAArch64SME(): + self.skipTest("SME must be present.") self.build() supported_vg = self.get_supported_vg() + self.runCmd("settings set target.run-args {}".format("1" if enable_za else "0")) + if not (2 in supported_vg and 4 in supported_vg): - self.skipTest("Not all required SVE vector lengths are supported.") + self.skipTest("Not all required streaming vector lengths are supported.") main_thread_stop_line = line_number("main.c", "// Break in main thread") lldbutil.run_break_set_by_file_and_line(self, "main.c", main_thread_stop_line) @@ -133,8 +93,6 @@ self.runCmd("run", RUN_SUCCEEDED) - process = self.dbg.GetSelectedTarget().GetProcess() - self.expect( "thread info 1", STOPPED_DUE_TO_BREAKPOINT, @@ -142,12 +100,19 @@ ) if 8 in supported_vg: - self.check_sve_registers(8) + if enable_za: + self.check_za_register(8, 1) + else: + self.check_disabled_za_register(8) else: - self.check_sve_registers(4) + if enable_za: + self.check_za_register(4, 1) + else: + self.check_disabled_za_register(4) self.runCmd("process continue", RUN_SUCCEEDED) + process = self.dbg.GetSelectedTarget().GetProcess() for idx in range(1, process.GetNumThreads()): thread = process.GetThreadAtIndex(idx) if thread.GetStopReason() != lldb.eStopReasonBreakpoint: @@ -158,12 +123,12 @@ if stopped_at_line_number == thX_break_line1: self.runCmd("thread select %d" % (idx + 1)) - self.check_sve_registers(4) + self.check_za_register(4, 2) self.runCmd("register write vg 2") elif stopped_at_line_number == thY_break_line1: self.runCmd("thread select %d" % (idx + 1)) - self.check_sve_registers(2) + self.check_za_register(2, 3) self.runCmd("register write vg 4") self.runCmd("thread continue 2") @@ -177,22 +142,24 @@ if stopped_at_line_number == thX_break_line2: self.runCmd("thread select %d" % (idx + 1)) - self.check_sve_registers(2) + self.check_za_register(2, 2) elif stopped_at_line_number == thY_break_line2: self.runCmd("thread select %d" % (idx + 1)) - self.check_sve_registers(4) + self.check_za_register(4, 3) @no_debug_info_test @skipIf(archs=no_match(["aarch64"])) @skipIf(oslist=no_match(["linux"])) - def test_sve_registers_dynamic_config(self): - """Test AArch64 SVE registers multi-threaded dynamic resize.""" - self.run_sve_test(Mode.SVE) + def test_za_register_dynamic_config_main_enabled(self): + """Test multiple threads resizing ZA, with the main thread's ZA + enabled.""" + self.za_test_impl(True) @no_debug_info_test @skipIf(archs=no_match(["aarch64"])) @skipIf(oslist=no_match(["linux"])) - def test_ssve_registers_dynamic_config(self): - """Test AArch64 SSVE registers multi-threaded dynamic resize.""" - self.run_sve_test(Mode.SSVE) + def test_za_register_dynamic_config_main_disabled(self): + """Test multiple threads resizing ZA, with the main thread's ZA + disabled.""" + self.za_test_impl(False) diff --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_dynamic_resize/main.c @@ -0,0 +1,103 @@ +#include +#include +#include +#include +#include +#include + +// Important notes for this test: +// * Making a syscall will disable streaming mode. +// * LLDB writing to vg while in streaming mode will disable ZA +// (this is just how ptrace works). +// * Writing to an inactive ZA produces a SIGILL. + +#ifndef PR_SME_SET_VL +#define PR_SME_SET_VL 63 +#endif + +#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr") +#define SMSTART_SM SM_INST(3) +#define SMSTART_ZA SM_INST(5) + +void set_za_register(int svl, int value_offset) { +#define MAX_VL_BYTES 256 + uint8_t data[MAX_VL_BYTES]; + + // ldr za will actually wrap the selected vector row, by the number of rows + // you have. So setting one that didn't exist would actually set one that did. + // That's why we need the streaming vector length here. + for (int i = 0; i < svl; ++i) { + memset(data, i + value_offset, MAX_VL_BYTES); + // Each one of these loads a VL sized row of ZA. + asm volatile("mov w12, %w0\n\t" + "ldr za[w12, 0], [%1]\n\t" ::"r"(i), + "r"(&data) + : "w12"); + } +} + +// These are used to make sure we only break in each thread once both of the +// threads have been started. Otherwise when the test does "process continue" +// it could stop in one thread and wait forever for the other one to start. +atomic_bool threadX_ready = false; +atomic_bool threadY_ready = false; + +void *threadX_func(void *x_arg) { + threadX_ready = true; + while (!threadY_ready) { + } + + prctl(PR_SME_SET_VL, 8 * 4); + SMSTART_SM; + SMSTART_ZA; + set_za_register(8 * 4, 2); + SMSTART_ZA; // Thread X breakpoint 1 + set_za_register(8 * 2, 2); + return NULL; // Thread X breakpoint 2 +} + +void *threadY_func(void *y_arg) { + threadY_ready = true; + while (!threadX_ready) { + } + + prctl(PR_SME_SET_VL, 8 * 2); + SMSTART_SM; + SMSTART_ZA; + set_za_register(8 * 2, 3); + SMSTART_ZA; // Thread Y breakpoint 1 + set_za_register(8 * 4, 3); + return NULL; // Thread Y breakpoint 2 +} + +int main(int argc, char *argv[]) { + // Expecting argument to tell us whether to enable ZA on the main thread. + if (argc != 2) + return 1; + + prctl(PR_SME_SET_VL, 8 * 8); + SMSTART_SM; + + if (argv[1][0] == '1') { + SMSTART_ZA; + set_za_register(8 * 8, 1); + } + // else we do not enable ZA and lldb will show 0s for it. + + pthread_t x_thread; + if (pthread_create(&x_thread, NULL, threadX_func, 0)) // Break in main thread + return 1; + + pthread_t y_thread; + if (pthread_create(&y_thread, NULL, threadY_func, 0)) + return 1; + + if (pthread_join(x_thread, NULL)) + return 2; + + if (pthread_join(y_thread, NULL)) + return 2; + + return 0; +} + diff --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/Makefile @@ -0,0 +1,5 @@ +C_SOURCES := main.c + +CFLAGS_EXTRAS := -march=armv8-a+sve+sme + +include Makefile.rules diff --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/TestZARegisterSaveRestore.py @@ -0,0 +1,252 @@ +""" +Test the AArch64 SME ZA register is saved and restored around expressions. + +This attempts to cover expressions that change the following: +* ZA enabled or not. +* Streaming mode or not. +* Streaming vector length (increasing and decreasing). +* Some combintations of the above. +""" + +from enum import IntEnum +import lldb +from lldbsuite.test.decorators import * +from lldbsuite.test.lldbtest import * +from lldbsuite.test import lldbutil + + +# These enum values match the flag values used in the test program. +class Mode(IntEnum): + SVE = 0 + SSVE = 1 + + +class ZA(IntEnum): + Disabled = 0 + Enabled = 1 + + +class AArch64ZATestCase(TestBase): + def get_supported_svg(self): + # Always build this probe program to start as streaming SVE. + # We will read/write "vg" here but since we are in streaming mode "svg" + # is really what we are writing ("svg" is a read only pseudo). + self.build() + + exe = self.getBuildArtifact("a.out") + self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) + # Enter streaming mode, don't enable ZA, start_vl and other_vl don't + # matter here. + self.runCmd("settings set target.run-args 1 0 0 0") + + stop_line = line_number("main.c", "// Set a break point here.") + lldbutil.run_break_set_by_file_and_line( + self, "main.c", stop_line, num_expected_locations=1 + ) + + self.runCmd("run", RUN_SUCCEEDED) + + self.expect( + "thread info 1", + STOPPED_DUE_TO_BREAKPOINT, + substrs=["stop reason = breakpoint"], + ) + + # Write back the current vg to confirm read/write works at all. + current_svg = self.match("register read vg", ["(0x[0-9]+)"]) + self.assertTrue(current_svg is not None) + self.expect("register write vg {}".format(current_svg.group())) + + # Aka 128, 256 and 512 bit. + supported_svg = [] + for svg in [2, 4, 8]: + # This could mask other errors but writing vg is tested elsewhere + # so we assume the hardware rejected the value. + self.runCmd("register write vg {}".format(svg), check=False) + if not self.res.GetError(): + supported_svg.append(svg) + + self.runCmd("breakpoint delete 1") + self.runCmd("continue") + + return supported_svg + + def read_vg(self): + process = self.dbg.GetSelectedTarget().GetProcess() + registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters() + sve_registers = registerSets.GetFirstValueByName( + "Scalable Vector Extension Registers" + ) + return sve_registers.GetChildMemberWithName("vg").GetValueAsUnsigned() + + def read_svg(self): + process = self.dbg.GetSelectedTarget().GetProcess() + registerSets = process.GetThreadAtIndex(0).GetFrameAtIndex(0).GetRegisters() + sve_registers = registerSets.GetFirstValueByName( + "Scalable Matrix Extension Registers" + ) + return sve_registers.GetChildMemberWithName("svg").GetValueAsUnsigned() + + def make_za_value(self, vl, generator): + # Generate a vector value string "{0x00 0x01....}". + rows = [] + for row in range(vl): + byte = "0x{:02x}".format(generator(row)) + rows.append(" ".join([byte] * vl)) + return "{" + " ".join(rows) + "}" + + def check_za(self, vl): + # We expect an increasing value starting at 1. Row 0=1, row 1 = 2, etc. + self.expect( + "register read za", substrs=[self.make_za_value(vl, lambda row: row + 1)] + ) + + def check_za_disabled(self, vl): + # When ZA is disabled, lldb will show ZA as all 0s. + self.expect("register read za", substrs=[self.make_za_value(vl, lambda row: 0)]) + + def za_expr_test_impl(self, sve_mode, za_state, swap_start_vl): + if not self.isAArch64SME(): + self.skipTest("SME must be present.") + + supported_svg = self.get_supported_svg() + if len(supported_svg) < 2: + self.skipTest("Target must support at least 2 streaming vector lengths.") + + # vg is in units of 8 bytes. + start_vl = supported_svg[1] * 8 + other_vl = supported_svg[2] * 8 + + if swap_start_vl: + start_vl, other_vl = other_vl, start_vl + + self.line = line_number("main.c", "// Set a break point here.") + + exe = self.getBuildArtifact("a.out") + self.runCmd("file " + exe, CURRENT_EXECUTABLE_SET) + self.runCmd( + "settings set target.run-args {} {} {} {}".format( + sve_mode, za_state, start_vl, other_vl + ) + ) + + lldbutil.run_break_set_by_file_and_line( + self, "main.c", self.line, num_expected_locations=1 + ) + self.runCmd("run", RUN_SUCCEEDED) + + self.expect( + "thread backtrace", + STOPPED_DUE_TO_BREAKPOINT, + substrs=["stop reason = breakpoint 1."], + ) + + exprs = [ + "expr_disable_za", + "expr_enable_za", + "expr_start_vl", + "expr_other_vl", + "expr_enable_sm", + "expr_disable_sm", + ] + + # This may be the streaming or non-streaming vg. All that matters is + # that it is saved and restored, remaining constant throughout. + start_vg = self.read_vg() + + # Check SVE registers to make sure that combination of scaling SVE + # and scaling ZA works properly. This is a brittle check, but failures + # are likely to be catastrophic when they do happen anyway. + sve_reg_names = "ffr {} {}".format( + " ".join(["z{}".format(n) for n in range(32)]), + " ".join(["p{}".format(n) for n in range(16)]), + ) + self.runCmd("register read " + sve_reg_names) + sve_values = self.res.GetOutput() + + def check_regs(): + if za_state == ZA.Enabled: + self.check_za(start_vl) + else: + self.check_za_disabled(start_vl) + + # svg and vg are in units of 8 bytes. + self.assertEqual(start_vl, self.read_svg() * 8) + self.assertEqual(start_vg, self.read_vg()) + + self.expect("register read " + sve_reg_names, substrs=[sve_values]) + + for expr in exprs: + expr_cmd = "expression {}()".format(expr) + + # We do this twice because there were issues in development where + # using data stored by a previous WriteAllRegisterValues would crash + # the second time around. + self.runCmd(expr_cmd) + check_regs() + self.runCmd(expr_cmd) + check_regs() + + # Run them in sequence to make sure there is no state lingering between + # them after a restore. + for expr in exprs: + self.runCmd("expression {}()".format(expr)) + check_regs() + + for expr in reversed(exprs): + self.runCmd("expression {}()".format(expr)) + check_regs() + + # These tests start with the 1st supported SVL and change to the 2nd + # supported SVL as needed. + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_ssve_za_enabled(self): + self.za_expr_test_impl(Mode.SSVE, ZA.Enabled, False) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_ssve_za_disabled(self): + self.za_expr_test_impl(Mode.SSVE, ZA.Disabled, False) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_sve_za_enabled(self): + self.za_expr_test_impl(Mode.SVE, ZA.Enabled, False) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_sve_za_disabled(self): + self.za_expr_test_impl(Mode.SVE, ZA.Disabled, False) + + # These tests start in the 2nd supported SVL and change to the 1st supported + # SVL as needed. + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_ssve_za_enabled_different_vl(self): + self.za_expr_test_impl(Mode.SSVE, ZA.Enabled, True) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_ssve_za_disabled_different_vl(self): + self.za_expr_test_impl(Mode.SSVE, ZA.Disabled, True) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_sve_za_enabled_different_vl(self): + self.za_expr_test_impl(Mode.SVE, ZA.Enabled, True) + + @no_debug_info_test + @skipIf(archs=no_match(["aarch64"])) + @skipIf(oslist=no_match(["linux"])) + def test_za_expr_sve_za_disabled_different_vl(self): + self.za_expr_test_impl(Mode.SVE, ZA.Disabled, True) diff --git a/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c new file mode 100644 --- /dev/null +++ b/lldb/test/API/commands/register/register/aarch64_za_register/za_save_restore/main.c @@ -0,0 +1,226 @@ +#include +#include +#include +#include +#include + +// Important details for this program: +// * Making a syscall will disable streaming mode if it is active. +// * Changing the vector length will make streaming mode and ZA inactive. +// * ZA can be active independent of streaming mode. +// * ZA's size is the streaming vector length squared. + +#ifndef PR_SME_SET_VL +#define PR_SME_SET_VL 63 +#endif + +#ifndef PR_SME_GET_VL +#define PR_SME_GET_VL 64 +#endif + +#ifndef PR_SME_VL_LEN_MASK +#define PR_SME_VL_LEN_MASK 0xffff +#endif + +#define SM_INST(c) asm volatile("msr s0_3_c4_c" #c "_3, xzr") +#define SMSTART SM_INST(7) +#define SMSTART_SM SM_INST(3) +#define SMSTART_ZA SM_INST(5) +#define SMSTOP SM_INST(6) +#define SMSTOP_SM SM_INST(2) +#define SMSTOP_ZA SM_INST(4) + +int start_vl = 0; +int other_vl = 0; + +void write_sve_regs() { + // We assume the smefa64 feature is present, which allows ffr access + // in streaming mode. + asm volatile("setffr\n\t"); + asm volatile("ptrue p0.b\n\t"); + asm volatile("ptrue p1.h\n\t"); + asm volatile("ptrue p2.s\n\t"); + asm volatile("ptrue p3.d\n\t"); + asm volatile("pfalse p4.b\n\t"); + asm volatile("ptrue p5.b\n\t"); + asm volatile("ptrue p6.h\n\t"); + asm volatile("ptrue p7.s\n\t"); + asm volatile("ptrue p8.d\n\t"); + asm volatile("pfalse p9.b\n\t"); + asm volatile("ptrue p10.b\n\t"); + asm volatile("ptrue p11.h\n\t"); + asm volatile("ptrue p12.s\n\t"); + asm volatile("ptrue p13.d\n\t"); + asm volatile("pfalse p14.b\n\t"); + asm volatile("ptrue p15.b\n\t"); + + asm volatile("cpy z0.b, p0/z, #1\n\t"); + asm volatile("cpy z1.b, p5/z, #2\n\t"); + asm volatile("cpy z2.b, p10/z, #3\n\t"); + asm volatile("cpy z3.b, p15/z, #4\n\t"); + asm volatile("cpy z4.b, p0/z, #5\n\t"); + asm volatile("cpy z5.b, p5/z, #6\n\t"); + asm volatile("cpy z6.b, p10/z, #7\n\t"); + asm volatile("cpy z7.b, p15/z, #8\n\t"); + asm volatile("cpy z8.b, p0/z, #9\n\t"); + asm volatile("cpy z9.b, p5/z, #10\n\t"); + asm volatile("cpy z10.b, p10/z, #11\n\t"); + asm volatile("cpy z11.b, p15/z, #12\n\t"); + asm volatile("cpy z12.b, p0/z, #13\n\t"); + asm volatile("cpy z13.b, p5/z, #14\n\t"); + asm volatile("cpy z14.b, p10/z, #15\n\t"); + asm volatile("cpy z15.b, p15/z, #16\n\t"); + asm volatile("cpy z16.b, p0/z, #17\n\t"); + asm volatile("cpy z17.b, p5/z, #18\n\t"); + asm volatile("cpy z18.b, p10/z, #19\n\t"); + asm volatile("cpy z19.b, p15/z, #20\n\t"); + asm volatile("cpy z20.b, p0/z, #21\n\t"); + asm volatile("cpy z21.b, p5/z, #22\n\t"); + asm volatile("cpy z22.b, p10/z, #23\n\t"); + asm volatile("cpy z23.b, p15/z, #24\n\t"); + asm volatile("cpy z24.b, p0/z, #25\n\t"); + asm volatile("cpy z25.b, p5/z, #26\n\t"); + asm volatile("cpy z26.b, p10/z, #27\n\t"); + asm volatile("cpy z27.b, p15/z, #28\n\t"); + asm volatile("cpy z28.b, p0/z, #29\n\t"); + asm volatile("cpy z29.b, p5/z, #30\n\t"); + asm volatile("cpy z30.b, p10/z, #31\n\t"); + asm volatile("cpy z31.b, p15/z, #32\n\t"); +} + +// Write something different so we will know if we didn't restore them +// correctly. +void write_sve_regs_expr() { + asm volatile("pfalse p0.b\n\t"); + asm volatile("wrffr p0.b\n\t"); + asm volatile("pfalse p1.b\n\t"); + asm volatile("pfalse p2.b\n\t"); + asm volatile("pfalse p3.b\n\t"); + asm volatile("ptrue p4.b\n\t"); + asm volatile("pfalse p5.b\n\t"); + asm volatile("pfalse p6.b\n\t"); + asm volatile("pfalse p7.b\n\t"); + asm volatile("pfalse p8.b\n\t"); + asm volatile("ptrue p9.b\n\t"); + asm volatile("pfalse p10.b\n\t"); + asm volatile("pfalse p11.b\n\t"); + asm volatile("pfalse p12.b\n\t"); + asm volatile("pfalse p13.b\n\t"); + asm volatile("ptrue p14.b\n\t"); + asm volatile("pfalse p15.b\n\t"); + + asm volatile("cpy z0.b, p0/z, #2\n\t"); + asm volatile("cpy z1.b, p5/z, #3\n\t"); + asm volatile("cpy z2.b, p10/z, #4\n\t"); + asm volatile("cpy z3.b, p15/z, #5\n\t"); + asm volatile("cpy z4.b, p0/z, #6\n\t"); + asm volatile("cpy z5.b, p5/z, #7\n\t"); + asm volatile("cpy z6.b, p10/z, #8\n\t"); + asm volatile("cpy z7.b, p15/z, #9\n\t"); + asm volatile("cpy z8.b, p0/z, #10\n\t"); + asm volatile("cpy z9.b, p5/z, #11\n\t"); + asm volatile("cpy z10.b, p10/z, #12\n\t"); + asm volatile("cpy z11.b, p15/z, #13\n\t"); + asm volatile("cpy z12.b, p0/z, #14\n\t"); + asm volatile("cpy z13.b, p5/z, #15\n\t"); + asm volatile("cpy z14.b, p10/z, #16\n\t"); + asm volatile("cpy z15.b, p15/z, #17\n\t"); + asm volatile("cpy z16.b, p0/z, #18\n\t"); + asm volatile("cpy z17.b, p5/z, #19\n\t"); + asm volatile("cpy z18.b, p10/z, #20\n\t"); + asm volatile("cpy z19.b, p15/z, #21\n\t"); + asm volatile("cpy z20.b, p0/z, #22\n\t"); + asm volatile("cpy z21.b, p5/z, #23\n\t"); + asm volatile("cpy z22.b, p10/z, #24\n\t"); + asm volatile("cpy z23.b, p15/z, #25\n\t"); + asm volatile("cpy z24.b, p0/z, #26\n\t"); + asm volatile("cpy z25.b, p5/z, #27\n\t"); + asm volatile("cpy z26.b, p10/z, #28\n\t"); + asm volatile("cpy z27.b, p15/z, #29\n\t"); + asm volatile("cpy z28.b, p0/z, #30\n\t"); + asm volatile("cpy z29.b, p5/z, #31\n\t"); + asm volatile("cpy z30.b, p10/z, #32\n\t"); + asm volatile("cpy z31.b, p15/z, #33\n\t"); +} + +void set_za_register(int svl, int value_offset) { +#define MAX_VL_BYTES 256 + uint8_t data[MAX_VL_BYTES]; + + // ldr za will actually wrap the selected vector row, by the number of rows + // you have. So setting one that didn't exist would actually set one that did. + // That's why we need the streaming vector length here. + for (int i = 0; i < svl; ++i) { + memset(data, i + value_offset, MAX_VL_BYTES); + // Each one of these loads a VL sized row of ZA. + asm volatile("mov w12, %w0\n\t" + "ldr za[w12, 0], [%1]\n\t" ::"r"(i), + "r"(&data) + : "w12"); + } +} + +void expr_disable_za() { + SMSTOP_ZA; + write_sve_regs_expr(); +} + +void expr_enable_za() { + SMSTART_ZA; + set_za_register(start_vl, 2); + write_sve_regs_expr(); +} + +void expr_start_vl() { + prctl(PR_SME_SET_VL, start_vl); + SMSTART_ZA; + set_za_register(start_vl, 4); + write_sve_regs_expr(); +} + +void expr_other_vl() { + prctl(PR_SME_SET_VL, other_vl); + SMSTART_ZA; + set_za_register(other_vl, 5); + write_sve_regs_expr(); +} + +void expr_enable_sm() { + SMSTART_SM; + write_sve_regs_expr(); +} + +void expr_disable_sm() { + SMSTOP_SM; + write_sve_regs_expr(); +} + +int main(int argc, char *argv[]) { + // We expect to get: + // * whether to enable streaming mode + // * whether to enable ZA + // * what the starting VL should be + // * what the other VL should be + if (argc != 5) + return 1; + + bool ssve = argv[1][0] == '1'; + bool za = argv[2][0] == '1'; + start_vl = atoi(argv[3]); + other_vl = atoi(argv[4]); + + prctl(PR_SME_SET_VL, start_vl); + + if (ssve) + SMSTART_SM; + + if (za) { + SMSTART_ZA; + set_za_register(start_vl, 1); + } + + write_sve_regs(); + + return 0; // Set a break point here. +} +