Skip to content

Commit c10f702

Browse files
committed
Trim unit test time
1 parent e8cbb57 commit c10f702

2 files changed

Lines changed: 33 additions & 47 deletions

File tree

tests/unit/moe_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -460,6 +460,7 @@ def test_megablox(self):
460460
megablox=True,
461461
sparse_matmul=True,
462462
per_device_batch_size=1,
463+
max_target_length=128,
463464
)
464465

465466
rng = jax.random.PRNGKey(1234)
@@ -488,6 +489,7 @@ def test_ragged_dot(self):
488489
megablox=False,
489490
sparse_matmul=True,
490491
per_device_batch_size=1,
492+
max_target_length=128,
491493
)
492494

493495
rng = jax.random.PRNGKey(1234)
@@ -516,6 +518,7 @@ def test_dense(self):
516518
megablox=False,
517519
sparse_matmul=False,
518520
per_device_batch_size=1,
521+
max_target_length=128,
519522
)
520523

521524
rng = jax.random.PRNGKey(2345)
@@ -545,6 +548,7 @@ def test_megablox_expert_parallelism(self):
545548
sparse_matmul=True,
546549
per_device_batch_size=4, # TODO(b/450900273): sharding error if pdbs=1
547550
ici_expert_parallelism=4,
551+
max_target_length=128,
548552
)
549553

550554
rng = jax.random.PRNGKey(2345)
@@ -577,6 +581,7 @@ def test_moe_fsdp_two_stage_parallelism_tpu_only(self):
577581
ici_fsdp_parallelism=2,
578582
ici_fsdp_transpose_parallelism=2,
579583
moe_fsdp_use_two_stage_all_gather=True,
584+
max_target_length=128,
580585
)
581586

582587
rng = jax.random.PRNGKey(2345)
@@ -652,6 +657,7 @@ def test_megablox_context_parallelism(self):
652657
sparse_matmul=True,
653658
per_device_batch_size=1,
654659
ici_context_parallelism=4,
660+
max_target_length=128,
655661
)
656662

657663
rng = jax.random.PRNGKey(2345)
@@ -684,6 +690,7 @@ def test_megablox_expert_context_parallelism(self):
684690
ici_context_parallelism=2,
685691
ici_expert_parallelism=2,
686692
packing=False,
693+
max_target_length=128,
687694
)
688695

689696
rng = jax.random.PRNGKey(2345)
@@ -715,6 +722,7 @@ def test_megablox_expert_tensor_parallelism(self):
715722
per_device_batch_size=4,
716723
ici_tensor_parallelism=2,
717724
ici_expert_parallelism=2,
725+
max_target_length=128,
718726
)
719727

720728
rng = jax.random.PRNGKey(2345)

tests/unit/train_compile_test.py

Lines changed: 24 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -188,7 +188,7 @@ def test_sequence_parallelism(self):
188188
"",
189189
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
190190
f"compiled_trainstep_file={compiled_trainstep_file}",
191-
"compile_topology=v5e-256",
191+
"compile_topology=v5p-64",
192192
"use_iota_embed=true",
193193
"compile_topology_num_slices=1",
194194
"ici_sequence_parallelism=16",
@@ -276,12 +276,12 @@ def test_remat_full(self):
276276
"",
277277
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
278278
f"compiled_trainstep_file={compiled_trainstep_file}",
279-
"compile_topology=v5e-256",
279+
"compile_topology=v6e-256",
280280
"compile_topology_num_slices=1",
281281
"per_device_batch_size=1",
282282
"ici_fsdp_parallelism=16",
283283
"ici_tensor_parallelism=16",
284-
"max_target_length=2048",
284+
"max_target_length=1024",
285285
"fused_qkv=true",
286286
"fused_mlp=true",
287287
"remat_policy=full",
@@ -366,7 +366,7 @@ def test_moe_dropping_bf16(self):
366366
"",
367367
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
368368
f"compiled_trainstep_file={compiled_trainstep_file}",
369-
"compile_topology=v6e-256",
369+
"compile_topology=v5p-64",
370370
"use_iota_embed=true",
371371
"compile_topology_num_slices=1",
372372
"model_name=mixtral-8x7b",
@@ -457,7 +457,7 @@ def test_moe_dense_bf16(self):
457457
"",
458458
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
459459
f"compiled_trainstep_file={compiled_trainstep_file}",
460-
"compile_topology=v6e-256",
460+
"compile_topology=v5p-64",
461461
"use_iota_embed=true",
462462
"compile_topology_num_slices=1",
463463
"model_name=mixtral-8x7b",
@@ -503,7 +503,7 @@ def test_moe_pp_bf16(self):
503503
"",
504504
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
505505
f"compiled_trainstep_file={compiled_trainstep_file}",
506-
"compile_topology=v6e-256",
506+
"compile_topology=v5p-64",
507507
"use_iota_embed=true",
508508
"compile_topology_num_slices=2",
509509
"model_name=mixtral-8x7b",
@@ -527,10 +527,10 @@ def test_moe_deepseek_scanned_bf16(self):
527527
"",
528528
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
529529
f"compiled_trainstep_file={compiled_trainstep_file}",
530-
"compile_topology=v5p-256",
530+
"compile_topology=v5p-64",
531531
"use_iota_embed=true",
532532
"compile_topology_num_slices=1",
533-
"model_name=deepseek3-671b",
533+
"model_name=deepseek3-test",
534534
"sparse_matmul=True",
535535
"megablox=False",
536536
"per_device_batch_size=2",
@@ -552,10 +552,10 @@ def test_moe_deepseek_unscanned_bf16(self):
552552
"",
553553
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
554554
f"compiled_trainstep_file={compiled_trainstep_file}",
555-
"compile_topology=v5p-256",
555+
"compile_topology=v5p-64",
556556
"use_iota_embed=true",
557557
"compile_topology_num_slices=1",
558-
"model_name=deepseek3-671b",
558+
"model_name=deepseek3-test",
559559
"sparse_matmul=True",
560560
"megablox=False",
561561
"per_device_batch_size=1",
@@ -575,10 +575,10 @@ def test_moe_deepseek_with_device_limit(self):
575575
"",
576576
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
577577
f"compiled_trainstep_file={compiled_trainstep_file}",
578-
"compile_topology=v5p-256",
578+
"compile_topology=v5p-64",
579579
"use_iota_embed=true",
580580
"compile_topology_num_slices=1",
581-
"model_name=deepseek3-671b",
581+
"model_name=deepseek3-test",
582582
"sparse_matmul=True",
583583
"megablox=False",
584584
"per_device_batch_size=1",
@@ -591,30 +591,6 @@ def test_moe_deepseek_with_device_limit(self):
591591
)
592592
)
593593

594-
@pytest.mark.cpu_only
595-
def test_moe_deepseek_without_device_limit(self):
596-
compiled_trainstep_file = "/tmp/test_moe_deepseek_without_device_limit.pickle"
597-
train_compile_main(
598-
(
599-
"",
600-
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
601-
f"compiled_trainstep_file={compiled_trainstep_file}",
602-
"compile_topology=v5p-256",
603-
"use_iota_embed=true",
604-
"compile_topology_num_slices=1",
605-
"model_name=deepseek3-671b",
606-
"sparse_matmul=True",
607-
"megablox=False",
608-
"per_device_batch_size=1",
609-
"max_target_length=1024",
610-
"attention=flash",
611-
"dtype=bfloat16",
612-
"weight_dtype=bfloat16",
613-
"n_routing_groups=-1",
614-
"topk_routing_group=-1",
615-
)
616-
)
617-
618594
@pytest.mark.cpu_only
619595
def test_moe_deepseek_pipeline_subset(self):
620596
compiled_trainstep_file = "/tmp/test_moe_deepseek_pipeline_subset.pickle"
@@ -623,15 +599,15 @@ def test_moe_deepseek_pipeline_subset(self):
623599
"",
624600
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
625601
f"compiled_trainstep_file={compiled_trainstep_file}",
626-
"compile_topology=v6e-256",
602+
"compile_topology=v5p-64",
627603
"compile_topology_num_slices=8",
628604
"use_iota_embed=true",
629-
"model_name=deepseek3-671b",
605+
"model_name=deepseek3-test",
630606
"megablox=True",
631607
"sparse_matmul=False",
632608
"capacity_factor=1",
633609
"per_device_batch_size=1",
634-
"max_target_length=2048",
610+
"max_target_length=1024",
635611
"pipeline_parallel_layers=56",
636612
"ici_expert_parallelism=16",
637613
"dcn_pipeline_parallelism=8",
@@ -646,11 +622,11 @@ def test_pipeline_subset(self):
646622
"",
647623
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
648624
f"compiled_trainstep_file={compiled_trainstep_file}",
649-
"compile_topology=v6e-256",
625+
"compile_topology=v5p-128",
650626
"compile_topology_num_slices=8",
651627
"use_iota_embed=true",
652628
"per_device_batch_size=1",
653-
"max_target_length=2048",
629+
"max_target_length=1024",
654630
"pipeline_parallel_layers=56",
655631
"base_num_decoder_layers=61", # Remainder of 5 will fail when sharded incorrectly.
656632
"ici_expert_parallelism=16",
@@ -666,15 +642,15 @@ def test_moe_llama4_17b_16e(self):
666642
"",
667643
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
668644
f"compiled_trainstep_file={compiled_trainstep_file}",
669-
"compile_topology=v5p-256",
645+
"compile_topology=v5p-128",
670646
"compile_topology_num_slices=1",
671647
"model_name=llama4-17b-16e",
672648
"per_device_batch_size=1",
673649
"max_target_length=1024",
674650
"dtype=bfloat16",
675651
"weight_dtype=bfloat16",
676652
"scan_layers=True",
677-
"ici_fsdp_parallelism=32",
653+
"ici_fsdp_parallelism=16",
678654
"ici_tensor_parallelism=4",
679655
)
680656
)
@@ -687,7 +663,7 @@ def test_moe_gpt_oss_20b_sparse_matmul(self):
687663
"",
688664
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
689665
f"compiled_trainstep_file={compiled_trainstep_file}",
690-
"compile_topology=v5p-64",
666+
"compile_topology=v5p-16",
691667
"compile_topology_num_slices=1",
692668
"model_name=gpt-oss-20b",
693669
"per_device_batch_size=1",
@@ -709,7 +685,7 @@ def test_moe_gpt_oss_20b_dense_matmul(self):
709685
"",
710686
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
711687
f"compiled_trainstep_file={compiled_trainstep_file}",
712-
"compile_topology=v5p-64",
688+
"compile_topology=v5p-16",
713689
"compile_topology_num_slices=1",
714690
"model_name=gpt-oss-20b",
715691
"per_device_batch_size=1",
@@ -767,6 +743,7 @@ def test_qwen3_next(self):
767743
"compile_topology_num_slices=1",
768744
"model_name=qwen3-next-80b-a3b",
769745
"per_device_batch_size=1",
746+
"max_target_length=1024",
770747
)
771748
)
772749

@@ -811,5 +788,6 @@ def test_olmo3_7b(self):
811788
"model_name=olmo3_7b",
812789
"per_device_batch_size=1",
813790
"scan_layers=True",
791+
"max_target_length=1024",
814792
)
815793
)

0 commit comments

Comments
 (0)