|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 | """Profiler tests.""" |
| 16 | +import os |
16 | 17 | import sys |
17 | 18 | import unittest |
| 19 | +from unittest.mock import MagicMock, patch |
18 | 20 |
|
19 | 21 | import pytest |
20 | 22 |
|
21 | 23 | from MaxText import pyconfig |
| 24 | +from MaxText.globals import MAXTEXT_PKG_DIR |
| 25 | +from maxtext.configs import types |
22 | 26 | from maxtext.common import profiler |
23 | 27 | from tests.utils.test_helpers import get_test_config_path |
24 | 28 |
|
25 | 29 |
|
26 | 30 | class ProfilerTest(unittest.TestCase): |
27 | 31 | """Test for profiler.""" |
28 | 32 |
|
| 33 | + def setUp(self): |
| 34 | + super().setUp() |
| 35 | + # Mock jax.devices() to be deterministic and avoid runtime initialization errors |
| 36 | + self.mock_devices = [MagicMock(slice_index=0) for _ in range(1)] |
| 37 | + self.jax_patcher = patch("jax.devices", return_value=self.mock_devices) |
| 38 | + self.jax_patcher.start() |
| 39 | + self.jax_process_index_patcher = patch("jax.process_index", return_value=0) |
| 40 | + self.jax_process_index_patcher.start() |
| 41 | + |
| 42 | + def tearDown(self): |
| 43 | + self.jax_patcher.stop() |
| 44 | + self.jax_process_index_patcher.stop() |
| 45 | + super().tearDown() |
| 46 | + |
| 47 | + @pytest.mark.tpu_only |
| 48 | + def test_profiler_options_populated_from_config(self): |
| 49 | + """Verifies that Profiler initializes jax.profiler.ProfileOptions from config.""" |
| 50 | + config = pyconfig.initialize( |
| 51 | + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], |
| 52 | + enable_checkpointing=False, |
| 53 | + run_name="test_profiler_options", |
| 54 | + base_output_directory="/tmp", |
| 55 | + profiler="xplane", |
| 56 | + xprof_tpu_power_trace_level=1, |
| 57 | + xprof_e2e_enable_fw_throttle_event=True, |
| 58 | + xprof_e2e_enable_fw_power_level_event=True, |
| 59 | + xprof_e2e_enable_fw_thermal_event=True, |
| 60 | + ) |
| 61 | + |
| 62 | + with patch("jax.profiler.ProfileOptions") as mock_options_cls: |
| 63 | + # Setup mock return value |
| 64 | + mock_options_instance = MagicMock() |
| 65 | + mock_options_cls.return_value = mock_options_instance |
| 66 | + |
| 67 | + prof = profiler.Profiler(config) |
| 68 | + |
| 69 | + # Check if ProfileOptions was instantiated |
| 70 | + mock_options_cls.assert_called_once() |
| 71 | + |
| 72 | + # Verify advanced_configuration was populated |
| 73 | + expected_advanced_config = { |
| 74 | + "tpu_power_trace_level": types.XProfTPUPowerTraceMode.POWER_TRACE_NORMAL, |
| 75 | + "e2e_enable_fw_throttle_event": True, |
| 76 | + "e2e_enable_fw_power_level_event": True, |
| 77 | + "e2e_enable_fw_thermal_event": True, |
| 78 | + } |
| 79 | + self.assertEqual(prof.profiling_options.advanced_configuration, expected_advanced_config) |
| 80 | + |
| 81 | + @pytest.mark.tpu_only |
| 82 | + @patch("jax.profiler.start_trace") |
| 83 | + def test_profiler_activate_passes_options(self, mock_start_trace): |
| 84 | + """Verifies that activate() passes the profiling_options to jax.profiler.start_trace.""" |
| 85 | + config = pyconfig.initialize( |
| 86 | + [sys.argv[0], os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml")], |
| 87 | + enable_checkpointing=False, |
| 88 | + run_name="test_profiler_options", |
| 89 | + base_output_directory="/tmp", |
| 90 | + profiler="xplane", |
| 91 | + xprof_tpu_power_trace_level=2, |
| 92 | + ) |
| 93 | + |
| 94 | + # We need to mock ProfileOptions as well to check identity or value |
| 95 | + with patch("jax.profiler.ProfileOptions"): |
| 96 | + prof = profiler.Profiler(config) |
| 97 | + prof.activate() |
| 98 | + |
| 99 | + # Verify start_trace was called with profiler_options |
| 100 | + mock_start_trace.assert_called_once() |
| 101 | + _, kwargs = mock_start_trace.call_args |
| 102 | + self.assertIn("profiler_options", kwargs) |
| 103 | + self.assertEqual(kwargs["profiler_options"], prof.profiling_options) |
| 104 | + self.assertEqual( |
| 105 | + prof.profiling_options.advanced_configuration["tpu_power_trace_level"], |
| 106 | + types.XProfTPUPowerTraceMode.POWER_TRACE_SPI, |
| 107 | + ) |
| 108 | + |
29 | 109 | # These periodic proilfer tests can run on any platform (cpu, gpu or tpu) |
30 | 110 | @pytest.mark.tpu_only |
31 | 111 | def test_periodic_profiler_third_period_starts(self): |
|
0 commit comments