Skip to content

Commit a0f63d0

Browse files
shralexNuojCheng
authored andcommitted
[DRAFT] v7x AOT support
This is still blocked by downstream issues.
1 parent f8aeead commit a0f63d0

2 files changed

Lines changed: 135 additions & 2 deletions

File tree

src/MaxText/accelerator_to_spec_map.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,107 @@ class SystemCharacteristics:
3232

3333

3434
UserFacingNameToSystemCharacteristics = {
35+
# tpu7x: two cores per chip, 192 GiB HBM per chip
36+
# TODO: chip_config_name might need to change for SparseCore
37+
"tpu7x-2": SystemCharacteristics("tpu", "tpu7x:1x1x1", "default", (1, 1, 1), 1, (False, False, False)),
38+
"tpu7x-8": SystemCharacteristics("tpu", "tpu7x:2x2x1", "default", (2, 2, 1), 4, (False, False, False)),
39+
"tpu7x-16": SystemCharacteristics("tpu", "tpu7x:2x2x2", "default", (2, 2, 1), 8, (False, False, False)),
40+
"tpu7x-32": SystemCharacteristics("tpu", "tpu7x:2x2x4", "default", (2, 2, 1), 16, (False, False, False)),
41+
"tpu7x-64": SystemCharacteristics("tpu", "tpu7x:2x4x4", "default", (2, 2, 1), 32, (False, False, False)),
42+
"tpu7x-128": SystemCharacteristics("tpu", "tpu7x:4x4x4", "default", (2, 2, 1), 64, (True, True, True)),
43+
"tpu7x-256": SystemCharacteristics("tpu", "tpu7x:4x4x8", "default", (2, 2, 1), 128, (True, True, True)),
44+
"tpu7x-384": SystemCharacteristics("tpu", "tpu7x:4x4x12", "default", (2, 2, 1), 192, (True, True, True)),
45+
"tpu7x-512": SystemCharacteristics("tpu", "tpu7x:4x8x8", "default", (2, 2, 1), 256, (True, True, True)),
46+
"tpu7x-640": SystemCharacteristics("tpu", "tpu7x:4x4x20", "default", (2, 2, 1), 320, (True, True, True)),
47+
"tpu7x-768": SystemCharacteristics("tpu", "tpu7x:4x8x12", "default", (2, 2, 1), 384, (True, True, True)),
48+
"tpu7x-896": SystemCharacteristics("tpu", "tpu7x:4x4x28", "default", (2, 2, 1), 448, (True, True, True)),
49+
"tpu7x-1024": SystemCharacteristics("tpu", "tpu7x:8x8x8", "default", (2, 2, 1), 512, (True, True, True)),
50+
"tpu7x-1152": SystemCharacteristics("tpu", "tpu7x:4x12x12", "default", (2, 2, 1), 576, (True, True, True)),
51+
"tpu7x-1280": SystemCharacteristics("tpu", "tpu7x:4x8x20", "default", (2, 2, 1), 640, (True, True, True)),
52+
"tpu7x-1408": SystemCharacteristics("tpu", "tpu7x:4x4x44", "default", (2, 2, 1), 704, (True, True, True)),
53+
"tpu7x-1536": SystemCharacteristics("tpu", "tpu7x:8x8x12", "default", (2, 2, 1), 768, (True, True, True)),
54+
"tpu7x-1664": SystemCharacteristics("tpu", "tpu7x:4x4x52", "default", (2, 2, 1), 832, (True, True, True)),
55+
"tpu7x-1792": SystemCharacteristics("tpu", "tpu7x:4x8x28", "default", (2, 2, 1), 896, (True, True, True)),
56+
"tpu7x-1920": SystemCharacteristics("tpu", "tpu7x:4x12x20", "default", (2, 2, 1), 960, (True, True, True)),
57+
"tpu7x-2048": SystemCharacteristics("tpu", "tpu7x:8x8x16", "default", (2, 2, 1), 1024, (True, True, True)),
58+
"tpu7x-2176": SystemCharacteristics("tpu", "tpu7x:4x4x68", "default", (2, 2, 1), 1088, (True, True, True)),
59+
"tpu7x-2304": SystemCharacteristics("tpu", "tpu7x:8x12x12", "default", (2, 2, 1), 1152, (True, True, True)),
60+
"tpu7x-2432": SystemCharacteristics("tpu", "tpu7x:4x4x76", "default", (2, 2, 1), 1216, (True, True, True)),
61+
"tpu7x-2560": SystemCharacteristics("tpu", "tpu7x:8x8x20", "default", (2, 2, 1), 1280, (True, True, True)),
62+
"tpu7x-2688": SystemCharacteristics("tpu", "tpu7x:4x12x28", "default", (2, 2, 1), 1344, (True, True, True)),
63+
"tpu7x-2816": SystemCharacteristics("tpu", "tpu7x:4x8x44", "default", (2, 2, 1), 1408, (True, True, True)),
64+
"tpu7x-2944": SystemCharacteristics("tpu", "tpu7x:4x4x92", "default", (2, 2, 1), 1472, (True, True, True)),
65+
"tpu7x-3072": SystemCharacteristics("tpu", "tpu7x:8x12x16", "default", (2, 2, 1), 1536, (True, True, True)),
66+
"tpu7x-3200": SystemCharacteristics("tpu", "tpu7x:4x20x20", "default", (2, 2, 1), 1600, (True, True, True)),
67+
"tpu7x-3328": SystemCharacteristics("tpu", "tpu7x:4x8x52", "default", (2, 2, 1), 1664, (True, True, True)),
68+
"tpu7x-3456": SystemCharacteristics("tpu", "tpu7x:12x12x12", "default", (2, 2, 1), 1728, (True, True, True)),
69+
"tpu7x-3584": SystemCharacteristics("tpu", "tpu7x:8x8x28", "default", (2, 2, 1), 1792, (True, True, True)),
70+
"tpu7x-3712": SystemCharacteristics("tpu", "tpu7x:4x4x116", "default", (2, 2, 1), 1856, (True, True, True)),
71+
"tpu7x-3840": SystemCharacteristics("tpu", "tpu7x:8x12x20", "default", (2, 2, 1), 1920, (True, True, True)),
72+
"tpu7x-3968": SystemCharacteristics("tpu", "tpu7x:4x4x124", "default", (2, 2, 1), 1984, (True, True, True)),
73+
"tpu7x-4096": SystemCharacteristics("tpu", "tpu7x:8x16x16", "default", (2, 2, 1), 2048, (True, True, True)),
74+
"tpu7x-4224": SystemCharacteristics("tpu", "tpu7x:4x12x44", "default", (2, 2, 1), 2112, (True, True, True)),
75+
"tpu7x-4352": SystemCharacteristics("tpu", "tpu7x:4x8x68", "default", (2, 2, 1), 2176, (True, True, True)),
76+
"tpu7x-4480": SystemCharacteristics("tpu", "tpu7x:4x20x28", "default", (2, 2, 1), 2240, (True, True, True)),
77+
"tpu7x-4608": SystemCharacteristics("tpu", "tpu7x:12x12x16", "default", (2, 2, 1), 2304, (True, True, True)),
78+
"tpu7x-4736": SystemCharacteristics("tpu", "tpu7x:4x4x148", "default", (2, 2, 1), 2368, (True, True, True)),
79+
"tpu7x-4864": SystemCharacteristics("tpu", "tpu7x:4x8x76", "default", (2, 2, 1), 2432, (True, True, True)),
80+
"tpu7x-4992": SystemCharacteristics("tpu", "tpu7x:4x12x52", "default", (2, 2, 1), 2496, (True, True, True)),
81+
"tpu7x-5120": SystemCharacteristics("tpu", "tpu7x:8x16x20", "default", (2, 2, 1), 2560, (True, True, True)),
82+
"tpu7x-5248": SystemCharacteristics("tpu", "tpu7x:4x4x164", "default", (2, 2, 1), 2624, (True, True, True)),
83+
"tpu7x-5376": SystemCharacteristics("tpu", "tpu7x:8x12x28", "default", (2, 2, 1), 2688, (True, True, True)),
84+
"tpu7x-5504": SystemCharacteristics("tpu", "tpu7x:4x4x172", "default", (2, 2, 1), 2752, (True, True, True)),
85+
"tpu7x-5632": SystemCharacteristics("tpu", "tpu7x:8x8x44", "default", (2, 2, 1), 2816, (True, True, True)),
86+
"tpu7x-5760": SystemCharacteristics("tpu", "tpu7x:12x12x20", "default", (2, 2, 1), 2880, (True, True, True)),
87+
"tpu7x-5888": SystemCharacteristics("tpu", "tpu7x:4x8x92", "default", (2, 2, 1), 2944, (True, True, True)),
88+
"tpu7x-6016": SystemCharacteristics("tpu", "tpu7x:4x4x188", "default", (2, 2, 1), 3008, (True, True, True)),
89+
"tpu7x-6144": SystemCharacteristics("tpu", "tpu7x:12x16x16", "default", (2, 2, 1), 3072, (True, True, True)),
90+
"tpu7x-6272": SystemCharacteristics("tpu", "tpu7x:4x28x28", "default", (2, 2, 1), 3136, (True, True, True)),
91+
"tpu7x-6400": SystemCharacteristics("tpu", "tpu7x:8x20x20", "default", (2, 2, 1), 3200, (True, True, True)),
92+
"tpu7x-6528": SystemCharacteristics("tpu", "tpu7x:4x12x68", "default", (2, 2, 1), 3264, (True, True, True)),
93+
"tpu7x-6656": SystemCharacteristics("tpu", "tpu7x:8x8x52", "default", (2, 2, 1), 3328, (True, True, True)),
94+
"tpu7x-6784": SystemCharacteristics("tpu", "tpu7x:4x4x212", "default", (2, 2, 1), 3392, (True, True, True)),
95+
"tpu7x-6912": SystemCharacteristics("tpu", "tpu7x:12x12x24", "default", (2, 2, 1), 3456, (True, True, True)),
96+
"tpu7x-7040": SystemCharacteristics("tpu", "tpu7x:4x20x44", "default", (2, 2, 1), 3520, (True, True, True)),
97+
"tpu7x-7168": SystemCharacteristics("tpu", "tpu7x:8x16x28", "default", (2, 2, 1), 3584, (True, True, True)),
98+
"tpu7x-7296": SystemCharacteristics("tpu", "tpu7x:4x12x76", "default", (2, 2, 1), 3648, (True, True, True)),
99+
"tpu7x-7424": SystemCharacteristics("tpu", "tpu7x:4x8x116", "default", (2, 2, 1), 3712, (True, True, True)),
100+
"tpu7x-7552": SystemCharacteristics("tpu", "tpu7x:4x4x236", "default", (2, 2, 1), 3776, (True, True, True)),
101+
"tpu7x-7680": SystemCharacteristics("tpu", "tpu7x:12x16x20", "default", (2, 2, 1), 3840, (True, True, True)),
102+
"tpu7x-7808": SystemCharacteristics("tpu", "tpu7x:4x4x244", "default", (2, 2, 1), 3904, (True, True, True)),
103+
"tpu7x-7936": SystemCharacteristics("tpu", "tpu7x:4x8x124", "default", (2, 2, 1), 3968, (True, True, True)),
104+
"tpu7x-8064": SystemCharacteristics("tpu", "tpu7x:12x12x28", "default", (2, 2, 1), 4032, (True, True, True)),
105+
"tpu7x-8192": SystemCharacteristics("tpu", "tpu7x:16x16x16", "default", (2, 2, 1), 4096, (True, True, True)),
106+
"tpu7x-8320": SystemCharacteristics("tpu", "tpu7x:4x20x52", "default", (2, 2, 1), 4160, (True, True, True)),
107+
"tpu7x-8448": SystemCharacteristics("tpu", "tpu7x:8x12x44", "default", (2, 2, 1), 4224, (True, True, True)),
108+
"tpu7x-8704": SystemCharacteristics("tpu", "tpu7x:8x8x68", "default", (2, 2, 1), 4352, (True, True, True)),
109+
"tpu7x-8832": SystemCharacteristics("tpu", "tpu7x:4x12x92", "default", (2, 2, 1), 4416, (True, True, True)),
110+
"tpu7x-8960": SystemCharacteristics("tpu", "tpu7x:8x20x28", "default", (2, 2, 1), 4480, (True, True, True)),
111+
"tpu7x-9216": SystemCharacteristics("tpu", "tpu7x:12x16x24", "default", (2, 2, 1), 4608, (True, True, True)),
112+
"tpu7x-9472": SystemCharacteristics("tpu", "tpu7x:4x8x148", "default", (2, 2, 1), 4736, (True, True, True)),
113+
"tpu7x-9600": SystemCharacteristics("tpu", "tpu7x:12x20x20", "default", (2, 2, 1), 4800, (True, True, True)),
114+
"tpu7x-9728": SystemCharacteristics("tpu", "tpu7x:8x8x76", "default", (2, 2, 1), 4864, (True, True, True)),
115+
"tpu7x-9856": SystemCharacteristics("tpu", "tpu7x:4x28x44", "default", (2, 2, 1), 4928, (True, True, True)),
116+
"tpu7x-9984": SystemCharacteristics("tpu", "tpu7x:8x12x52", "default", (2, 2, 1), 4992, (True, True, True)),
117+
"tpu7x-10240": SystemCharacteristics("tpu", "tpu7x:16x16x20", "default", (2, 2, 1), 5120, (True, True, True)),
118+
"tpu7x-10368": SystemCharacteristics("tpu", "tpu7x:12x12x36", "default", (2, 2, 1), 5184, (True, True, True)),
119+
"tpu7x-10496": SystemCharacteristics("tpu", "tpu7x:4x8x164", "default", (2, 2, 1), 5248, (True, True, True)),
120+
"tpu7x-10752": SystemCharacteristics("tpu", "tpu7x:12x16x28", "default", (2, 2, 1), 5376, (True, True, True)),
121+
"tpu7x-10880": SystemCharacteristics("tpu", "tpu7x:4x20x68", "default", (2, 2, 1), 5440, (True, True, True)),
122+
"tpu7x-11008": SystemCharacteristics("tpu", "tpu7x:4x8x172", "default", (2, 2, 1), 5504, (True, True, True)),
123+
"tpu7x-11136": SystemCharacteristics("tpu", "tpu7x:4x12x116", "default", (2, 2, 1), 5568, (True, True, True)),
124+
"tpu7x-11264": SystemCharacteristics("tpu", "tpu7x:8x16x44", "default", (2, 2, 1), 5632, (True, True, True)),
125+
"tpu7x-11520": SystemCharacteristics("tpu", "tpu7x:12x20x24", "default", (2, 2, 1), 5760, (True, True, True)),
126+
"tpu7x-11648": SystemCharacteristics("tpu", "tpu7x:4x28x52", "default", (2, 2, 1), 5824, (True, True, True)),
127+
"tpu7x-11776": SystemCharacteristics("tpu", "tpu7x:8x8x92", "default", (2, 2, 1), 5888, (True, True, True)),
128+
"tpu7x-11904": SystemCharacteristics("tpu", "tpu7x:4x12x124", "default", (2, 2, 1), 5952, (True, True, True)),
129+
"tpu7x-12032": SystemCharacteristics("tpu", "tpu7x:4x8x188", "default", (2, 2, 1), 6016, (True, True, True)),
130+
"tpu7x-12160": SystemCharacteristics("tpu", "tpu7x:4x20x76", "default", (2, 2, 1), 6080, (True, True, True)),
131+
"tpu7x-12288": SystemCharacteristics("tpu", "tpu7x:16x16x24", "default", (2, 2, 1), 6144, (True, True, True)),
132+
"tpu7x-13824": SystemCharacteristics("tpu", "tpu7x:12x24x24", "default", (2, 2, 1), 6912, (True, True, True)),
133+
"tpu7x-16384": SystemCharacteristics("tpu", "tpu7x:16x16x32", "default", (2, 2, 1), 8192, (True, True, True)),
134+
"tpu7x-17920": SystemCharacteristics("tpu", "tpu7x:16x20x28", "default", (2, 2, 1), 8960, (True, True, True)),
135+
"tpu7x-18432": SystemCharacteristics("tpu", "tpu7x:16x24x24", "default", (2, 2, 1), 9216, (True, True, True)),
35136
# v6e: one core per chip with 32 GB HBM
36137
"v6e-1": SystemCharacteristics("tpu", "v6e:1x1", "default", (1, 1, 1), 1, (False, False, False)),
37138
"v6e-4": SystemCharacteristics("tpu", "v6e:2x2", "default", (2, 2, 1), 4, (False, False, False)),

tests/train_compile_test.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,6 @@ def test_save_compiled_v5p_two_slices(self):
124124
)
125125
)
126126

127-
# TODO (b/374764692) : Enable when v6e AOT test when stable Jax supports v6e AOT.
128-
@pytest.mark.skip(reason="Enable when downstream v6e AOT support reaches stable Jax.")
129127
@pytest.mark.cpu_only
130128
def test_save_compiled_v6e(self):
131129
temp_dir = gettempdir()
@@ -143,6 +141,40 @@ def test_save_compiled_v6e(self):
143141
)
144142
)
145143

144+
@pytest.mark.cpu_only
145+
def test_save_compiled_tpu7x(self):
146+
temp_dir = gettempdir()
147+
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_tpu7x.pickle")
148+
train_compile_main(
149+
(
150+
None,
151+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
152+
f"compiled_trainstep_file={compiled_trainstep_file}",
153+
"compile_topology=tpu7x-16",
154+
"compile_topology_num_slices=1",
155+
"base_emb_dim=256",
156+
"base_mlp_dim=256",
157+
"base_num_decoder_layers=2",
158+
)
159+
)
160+
161+
@pytest.mark.cpu_only
162+
def test_save_compiled_tpu7x_two_slices(self):
163+
temp_dir = gettempdir()
164+
compiled_trainstep_file = os.path.join(temp_dir, "test_compiled_tpu7x_two_slices.pickle")
165+
train_compile_main(
166+
(
167+
None,
168+
os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"),
169+
f"compiled_trainstep_file={compiled_trainstep_file}",
170+
"compile_topology=tpu7x-8",
171+
"compile_topology_num_slices=2",
172+
"base_emb_dim=256",
173+
"base_mlp_dim=256",
174+
"base_num_decoder_layers=2",
175+
)
176+
)
177+
146178
@pytest.mark.cpu_only
147179
def test_sequence_parallelism(self):
148180
temp_dir = gettempdir()

0 commit comments

Comments
 (0)