@@ -82,7 +82,6 @@ def setUpClass(cls):
8282 "attention=flash" ,
8383 "remat_policy=FULL" ,
8484 "allow_split_physical_axes=True" ,
85- "skip_jax_distributed_system=True" ,
8685 "weights_dtype=bfloat16" ,
8786 "activations_dtype=bfloat16" ,
8887 "per_device_batch_size=0.25" ,
@@ -158,7 +157,6 @@ def setUpClass(cls):
158157 "attention=flash" ,
159158 "remat_policy=FULL" ,
160159 "allow_split_physical_axes=True" ,
161- "skip_jax_distributed_system=True" ,
162160 "weights_dtype=bfloat16" ,
163161 "activations_dtype=bfloat16" ,
164162 "per_device_batch_size=0.25" ,
@@ -238,7 +236,6 @@ def setUpClass(cls):
238236 "attention=flash" ,
239237 "remat_policy=FULL" ,
240238 "allow_split_physical_axes=True" ,
241- "skip_jax_distributed_system=True" ,
242239 "weights_dtype=bfloat16" ,
243240 "activations_dtype=bfloat16" ,
244241 "per_device_batch_size=0.25" ,
@@ -319,7 +316,6 @@ def setUpClass(cls):
319316 "attention=flash" ,
320317 "remat_policy=FULL" ,
321318 "allow_split_physical_axes=True" ,
322- "skip_jax_distributed_system=True" ,
323319 "weights_dtype=bfloat16" ,
324320 "activations_dtype=bfloat16" ,
325321 "per_device_batch_size=0.25" ,
0 commit comments