Skip to content

Commit 05bf24c

Browse files
Merge pull request #2938 from AI-Hypercomputer:xibin/nnx
PiperOrigin-RevId: 856729133
2 parents d4fe93a + 1cdddd3 commit 05bf24c

3 files changed

Lines changed: 19 additions & 10 deletions

File tree

src/MaxText/maxtext_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,12 @@ def init_initial_state(model, tx, config, is_training, key):
764764
image_shape = multimodal_utils.get_dummy_image_shape_for_init(
765765
config.model_name, batch_size=config.micro_batch_size_to_train_on
766766
)
767+
# Split the master key into independent keys for each RNG collection
768+
# Reference: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/rng_guide.html
769+
params_key, dropout_key, aqt_key = jax.random.split(key, 3)
770+
767771
model_vars = model.init(
768-
{"params": key, "dropout": key, "aqt": key},
772+
{"params": params_key, "dropout": dropout_key, "aqt": aqt_key},
769773
np.ones(input_shape, dtype=jnp.int32),
770774
np.ones(input_shape, dtype=jnp.int32),
771775
encoder_images=np.ones(image_shape, dtype=jnp.int32) if config.use_multimodal else None,

tests/data_loader_test.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def setUp(self):
4343
enable_rampup_batch_size=True,
4444
per_device_batch_size_start=1.0,
4545
per_device_batch_size_increment=1.0,
46-
global_rampup_samples=60,
46+
# global_rampup_samples: (rampup increment number) * (Samples for initial 5 steps)
47+
global_rampup_samples=3 * (1 * jax.device_count() * 5),
4748
)
4849
self.mesh = Mesh(create_device_mesh(self.config), self.config.mesh_axes)
4950
self.mock_data_iterator = MagicMock()
@@ -140,8 +141,10 @@ def test_rampup_data_loader(self):
140141

141142
# Expected batch sizes based on test config.
142143
# The end global batch size is self.num_devices * per_device_batch_size
143-
# The rampup should be: 5 steps of size 4, 3 steps of size 8, 2 steps of size 12, then size 16.
144-
expected_batch_sizes = [4, 4, 4, 4, 4, 8, 8, 8, 12, 12, 16, 16]
144+
# The rampup of per_device_batch_size should be:
145+
# 5 steps of size 1, 3 steps of size 2, 2 steps of size 3, then size 4.
146+
multipliers = [1] * 5 + [2] * 3 + [3] * 2 + [4] * 2
147+
expected_batch_sizes = [m * self.config_rampup.num_target_devices for m in multipliers]
145148
for i, expected_size in enumerate(expected_batch_sizes):
146149
batch = data_loader.load_next_batch(rampup_manager=rampup_manager)
147150
expected_shape = (expected_size, self.config_rampup.max_target_length)
@@ -168,8 +171,10 @@ def test_rampup_data_loader_from_checkpointing(self):
168171

169172
# Expected batch sizes based on test config.
170173
# The end global batch size is self.num_devices * per_device_batch_size
171-
# The rampup should be: 3 steps of size 8, 2 steps of size 12, then size 16.
172-
expected_batch_sizes = [8, 8, 8, 12, 12, 16, 16]
174+
# The rampup of per_device_batch_size should be:
175+
# 3 steps of size 2, 2 steps of size 3, then size 4.
176+
multipliers = [2] * 3 + [3] * 2 + [4] * 2
177+
expected_batch_sizes = [m * self.config_rampup.num_target_devices for m in multipliers]
173178
for i, expected_size in enumerate(expected_batch_sizes):
174179
batch = data_loader.load_next_batch(rampup_manager=rampup_manager)
175180
expected_shape = (expected_size, self.config_rampup.max_target_length)

tests/maxtext_utils_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ def test_multi_axis_sharding_pass(self):
405405
multi-dimensional mesh passes the assertion.
406406
"""
407407
# Create a mesh shape for a 5D mesh.
408-
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
408+
devices = np.array(jax.devices()).reshape((jax.device_count(), 1, 1, 1, 1))
409409
mesh = Mesh(devices, self.mesh_axes)
410410

411411
# Shard across multiple axes, including the valid 'fsdp' axis.
@@ -420,7 +420,7 @@ def test_multi_axis_not_sharded_fails(self):
420420
Tests that a tensor on a complex mesh fails if it's not sharded along any
421421
of the primary valid axes (like 'fsdp').
422422
"""
423-
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
423+
devices = np.array(jax.devices()).reshape((jax.device_count(), 1, 1, 1, 1))
424424
mesh = Mesh(devices, self.mesh_axes)
425425
pspec = PartitionSpec(("sequence", "context"), "stage", "tensor", None)
426426
params = {"complex_layer": jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, pspec))}
@@ -432,7 +432,7 @@ def test_multi_axis_mixed_sharding_fails(self):
432432
"""
433433
Tests that a mix of sharded (correctly) and unsharded tensors on a complex mesh fails.
434434
"""
435-
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
435+
devices = np.array(jax.devices()).reshape((jax.device_count(), 1, 1, 1, 1))
436436
mesh = Mesh(devices, self.mesh_axes)
437437
sharded_pspec = PartitionSpec(("fsdp", "sequence"), "stage", ("tensor"), None)
438438
sharded_param = jax.device_put(jnp.ones((8, 8, 2, 2)), NamedSharding(mesh, sharded_pspec))
@@ -459,7 +459,7 @@ def setUp(self):
459459
self.skipTest("This test suite requires at least 4 TPU devices")
460460

461461
self.mesh_axes = ("fsdp", "sequence", "tensor", "stage", "context")
462-
devices = np.array(jax.devices()).reshape((4, 1, 1, 1, 1))
462+
devices = np.array(jax.devices()).reshape((jax.device_count(), 1, 1, 1, 1))
463463
self.mesh = Mesh(devices, self.mesh_axes)
464464

465465
def test_multi_axis_mixed_formating(self):

0 commit comments

Comments
 (0)