|
15 | 15 | """Smoke test for MoE using ragged_dot.""" |
16 | 16 |
|
17 | 17 | import os |
18 | | -import unittest |
19 | | -from tempfile import gettempdir |
| 18 | +import tempfile |
20 | 19 |
|
21 | 20 | from absl.testing import absltest |
| 21 | +from absl.testing import parameterized |
| 22 | +from MaxText import globals as maxtext_globals |
| 23 | +from MaxText import train |
22 | 24 |
|
23 | | -from MaxText.globals import MAXTEXT_PKG_DIR |
24 | | -from MaxText.train import main as train_main |
| 25 | +train_main = train.main |
| 26 | +MAXTEXT_PKG_DIR = maxtext_globals.MAXTEXT_PKG_DIR |
| 27 | +gettempdir = tempfile.gettempdir |
25 | 28 |
|
26 | 29 |
|
27 | | -class Train(unittest.TestCase): |
| 30 | +class Train(parameterized.TestCase): |
28 | 31 | """Smoke test for MoE using ragged_dot in G3 only.""" |
29 | 32 |
|
30 | | - def test_tiny_config(self): |
| 33 | + @parameterized.named_parameters( |
| 34 | + {"testcase_name": "not_quantized", "quantization": ""}, |
| 35 | + {"testcase_name": "fp8_full", "quantization": "fp8_full"}, |
| 36 | + ) |
| 37 | + def test_tiny_config(self, quantization: str): |
31 | 38 | test_tmpdir = os.environ.get("TEST_TMPDIR", gettempdir()) |
32 | 39 | outputs_dir = os.environ.get("TEST_UNDECLARED_OUTPUTS_DIR", test_tmpdir) |
33 | | - train_main( |
34 | | - [ |
35 | | - None, |
36 | | - os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), |
37 | | - f"base_output_directory={test_tmpdir}", |
38 | | - "run_name=ragged_dot_smoke_test", |
39 | | - "base_emb_dim=128", |
40 | | - "base_num_query_heads=4", |
41 | | - "base_num_kv_heads=4", |
42 | | - "base_mlp_dim=128", |
43 | | - "base_moe_mlp_dim=128", |
44 | | - "base_num_decoder_layers=8", |
45 | | - "head_dim=128", |
46 | | - # TODO(b/441100085): When changing the decoder_block we might |
47 | | - # need to adjust the tiling. |
48 | | - "decoder_block=deepseek", |
49 | | - "attention_type=mla", |
50 | | - "num_experts=2", |
51 | | - # Enable sparse_matmul. |
52 | | - "sparse_matmul=True", |
53 | | - # Enable ragged_dot. |
54 | | - "megablox=False", |
55 | | - "per_device_batch_size=2", |
56 | | - "max_target_length=1024", |
57 | | - "dataset_type=synthetic", |
58 | | - "steps=10", |
59 | | - "enable_checkpointing=False", |
60 | | - "enable_goodput_recording=False", |
61 | | - "enable_checkpoint_cloud_logger=False", |
62 | | - "monitor_goodput=False", |
63 | | - f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}", |
64 | | - ] |
65 | | - ) |
| 40 | + train_main([ |
| 41 | + None, |
| 42 | + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), |
| 43 | + f"base_output_directory={test_tmpdir}", |
| 44 | + "run_name=ragged_dot_smoke_test", |
| 45 | + "base_emb_dim=128", |
| 46 | + "base_num_query_heads=4", |
| 47 | + "base_num_kv_heads=4", |
| 48 | + "base_mlp_dim=128", |
| 49 | + "base_moe_mlp_dim=128", |
| 50 | + "base_num_decoder_layers=8", |
| 51 | + "head_dim=128", |
| 52 | + # TODO(b/441100085): When changing the decoder_block we might |
| 53 | + # need to adjust the tiling. |
| 54 | + "decoder_block=deepseek", |
| 55 | + "attention_type=mla", |
| 56 | + "num_experts=2", |
| 57 | + # Enable sparse_matmul. |
| 58 | + "sparse_matmul=True", |
| 59 | + # Enable ragged_dot. |
| 60 | + "megablox=False", |
| 61 | + f'quantization="{quantization}"', |
| 62 | + "use_qwix_quantization=True", |
| 63 | + "per_device_batch_size=2", |
| 64 | + "max_target_length=1024", |
| 65 | + "dataset_type=synthetic", |
| 66 | + "steps=10", |
| 67 | + "enable_checkpointing=False", |
| 68 | + "enable_goodput_recording=False", |
| 69 | + "enable_checkpoint_cloud_logger=False", |
| 70 | + "monitor_goodput=False", |
| 71 | + f"metrics_file={os.path.join(outputs_dir, 'metrics.json')}", |
| 72 | + ]) |
66 | 73 |
|
67 | 74 |
|
68 | 75 | if __name__ == "__main__": |
|
0 commit comments