2222from maxtext .common .gcloud_stub import is_decoupled
2323from maxtext .trainers .pre_train .train import main as train_main
2424from maxtext .utils .globals import MAXTEXT_ASSETS_ROOT
25- from tests .utils .test_helpers import get_test_config_path , get_test_dataset_path , get_test_base_output_directory
25+ from tests .utils .test_helpers import (
26+ get_test_config_path ,
27+ get_test_dataset_path ,
28+ get_test_base_output_directory ,
29+ get_decoupled_parallelism_overrides ,
30+ is_rocm_backend ,
31+ )
2632
2733
2834class TrainTests (unittest .TestCase ):
@@ -37,9 +43,9 @@ class TrainTests(unittest.TestCase):
3743 _fsdp_tp4_override = []
3844 if decoupled :
3945 if dev_count >= 4 and dev_count % 4 == 0 :
40- _fsdp_tp4_override = [ f"ici_fsdp_parallelism= { dev_count // 4 } " ]
46+ _fsdp_tp4_override = get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count // 4 , as_argv = True )
4147 elif dev_count < 4 :
42- _fsdp_tp4_override = [ f"ici_fsdp_parallelism= { dev_count } " ]
48+ _fsdp_tp4_override = get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True )
4349
4450 CONFIGS = {
4551 "base" : [ # short test for train.py with TFDS c4
@@ -53,7 +59,7 @@ class TrainTests(unittest.TestCase):
5359 "enable_goodput_recording=False" ,
5460 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
5561 ]
56- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
62+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
5763 "synthetic" : [ # tests base config with synthetic dataset
5864 None ,
5965 get_test_config_path (),
@@ -66,7 +72,7 @@ class TrainTests(unittest.TestCase):
6672 "dataset_type=synthetic" ,
6773 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
6874 ]
69- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
75+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
7076 "pdb_lt_1" : [ # tests base config with per_device_batch_size < 1
7177 None ,
7278 get_test_config_path (),
@@ -80,7 +86,7 @@ class TrainTests(unittest.TestCase):
8086 "ici_tensor_parallelism=4" ,
8187 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
8288 ]
83- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
89+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
8490 "tp_transpose" : [ # tests base config with ici_tensor_transpose_parallelism=4
8591 None ,
8692 get_test_config_path (),
@@ -92,7 +98,7 @@ class TrainTests(unittest.TestCase):
9298 "enable_goodput_recording=False" ,
9399 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
94100 ]
95- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
101+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
96102 "int8" : [ # tests base config with int8
97103 None ,
98104 get_test_config_path (),
@@ -105,7 +111,7 @@ class TrainTests(unittest.TestCase):
105111 "enable_goodput_recording=False" ,
106112 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
107113 ]
108- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
114+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
109115 "fp8" : [ # tests base config with fp8
110116 None ,
111117 get_test_config_path (),
@@ -118,7 +124,7 @@ class TrainTests(unittest.TestCase):
118124 "enable_goodput_recording=False" ,
119125 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
120126 ]
121- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
127+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
122128 "nanoo_fp8" : [ # tests base config with nanoo_fp8
123129 None ,
124130 get_test_config_path (),
@@ -131,7 +137,7 @@ class TrainTests(unittest.TestCase):
131137 "enable_goodput_recording=False" ,
132138 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
133139 ]
134- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
140+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
135141 "te_fp8_delayedscaling" : [ # tests base config with te_fp8_delayedscaling
136142 None ,
137143 get_test_config_path (),
@@ -144,7 +150,7 @@ class TrainTests(unittest.TestCase):
144150 "enable_goodput_recording=False" ,
145151 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
146152 ]
147- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
153+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
148154 "te_fp8_currentscaling" : [ # tests base config with te_fp8_currentscaling
149155 None ,
150156 get_test_config_path (),
@@ -157,7 +163,7 @@ class TrainTests(unittest.TestCase):
157163 "enable_goodput_recording=False" ,
158164 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
159165 ]
160- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
166+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
161167 "te_mxfp8" : [ # tests base config with te_mxfp8
162168 None ,
163169 get_test_config_path (),
@@ -170,7 +176,7 @@ class TrainTests(unittest.TestCase):
170176 "enable_goodput_recording=False" ,
171177 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
172178 ]
173- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
179+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
174180 "dropout" : [ # tests base config with dropout
175181 None ,
176182 get_test_config_path (),
@@ -185,7 +191,7 @@ class TrainTests(unittest.TestCase):
185191 "dropout_rate=0.02" ,
186192 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
187193 ]
188- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
194+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
189195 "hf_input_pipeline" : [ # test for train.py with TFDS c4, using HF input pipeline
190196 None ,
191197 get_test_config_path (),
@@ -199,7 +205,7 @@ class TrainTests(unittest.TestCase):
199205 f"hf_train_files={ dataset_path } /hf/c4/c4-train-00000-of-01637.parquet" ,
200206 "tokenizer_path=google-t5/t5-large" ,
201207 ]
202- + ([ f"ici_fsdp_parallelism= { dev_count } " ] if decoupled else [] ),
208+ + get_decoupled_parallelism_overrides ( fsdp_parallelism = dev_count , as_argv = True ),
203209 }
204210
205211 @pytest .mark .integration_test
@@ -427,7 +433,7 @@ def test_gpu_optimizer_offload(self):
427433 "enable_goodput_recording=False" ,
428434 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
429435 ]
430- train_main (optimizer_offload + ([ f"ici_fsdp_parallelism= { self .dev_count } " ] if self . decoupled else [] ))
436+ train_main (optimizer_offload + get_decoupled_parallelism_overrides ( fsdp_parallelism = self .dev_count , as_argv = True ))
431437
432438 @pytest .mark .integration_test
433439 @pytest .mark .gpu_only
@@ -448,7 +454,7 @@ def test_gpu_parameter_offload(self):
448454 "enable_goodput_recording=False" ,
449455 rf"tokenizer_path={ os .path .join (MAXTEXT_ASSETS_ROOT , 'tokenizers' , 'tokenizer.llama2' )} " ,
450456 ]
451- train_main (parameter_offload + ([ f"ici_fsdp_parallelism= { self .dev_count } " ] if self . decoupled else [] ))
457+ train_main (parameter_offload + get_decoupled_parallelism_overrides ( fsdp_parallelism = self .dev_count , as_argv = True ))
452458
453459 @pytest .mark .gpu_only
454460 def test_gpu_cudnn_flash_jax (self ):
@@ -567,6 +573,8 @@ def test_gpu_packed_attention(self):
567573 @pytest .mark .gpu_only
568574 @pytest .mark .skip (reason = "b/489133823. Previously transient in b/462548581." )
569575 def test_gpu_ring_attention (self ):
576+ if is_rocm_backend ():
577+ pytest .skip ("TE ring attention context parallelism not supported on ROCm." )
570578 os .environ ["NVTE_FUSED_ATTN" ] = "1" # Enable fused attention
571579 os .environ ["NVTE_FUSED_RING_ATTENTION_USE_SCAN" ] = "0" # Disable scan for ring attention
572580 ring_attention = [ # tests base config on GPU with ring attention
0 commit comments