@@ -405,7 +405,7 @@ def test_multi_axis_sharding_pass(self):
405405 multi-dimensional mesh passes the assertion.
406406 """
407407 # Create a mesh shape for a 5D mesh.
408- devices = np .array (jax .devices ()).reshape ((4 , 1 , 1 , 1 , 1 ))
408+ devices = np .array (jax .devices ()).reshape ((jax . device_count () , 1 , 1 , 1 , 1 ))
409409 mesh = Mesh (devices , self .mesh_axes )
410410
411411 # Shard across multiple axes, including the valid 'fsdp' axis.
@@ -420,7 +420,7 @@ def test_multi_axis_not_sharded_fails(self):
420420 Tests that a tensor on a complex mesh fails if it's not sharded along any
421421 of the primary valid axes (like 'fsdp').
422422 """
423- devices = np .array (jax .devices ()).reshape ((4 , 1 , 1 , 1 , 1 ))
423+ devices = np .array (jax .devices ()).reshape ((jax . device_count () , 1 , 1 , 1 , 1 ))
424424 mesh = Mesh (devices , self .mesh_axes )
425425 pspec = PartitionSpec (("sequence" , "context" ), "stage" , "tensor" , None )
426426 params = {"complex_layer" : jax .device_put (jnp .ones ((8 , 8 , 2 , 2 )), NamedSharding (mesh , pspec ))}
@@ -432,7 +432,7 @@ def test_multi_axis_mixed_sharding_fails(self):
432432 """
433433 Tests that a mix of sharded (correctly) and unsharded tensors on a complex mesh fails.
434434 """
435- devices = np .array (jax .devices ()).reshape ((4 , 1 , 1 , 1 , 1 ))
435+ devices = np .array (jax .devices ()).reshape ((jax . device_count () , 1 , 1 , 1 , 1 ))
436436 mesh = Mesh (devices , self .mesh_axes )
437437 sharded_pspec = PartitionSpec (("fsdp" , "sequence" ), "stage" , ("tensor" ), None )
438438 sharded_param = jax .device_put (jnp .ones ((8 , 8 , 2 , 2 )), NamedSharding (mesh , sharded_pspec ))
@@ -459,7 +459,7 @@ def setUp(self):
459459 self .skipTest ("This test suite requires at least 4 TPU devices" )
460460
461461 self .mesh_axes = ("fsdp" , "sequence" , "tensor" , "stage" , "context" )
462- devices = np .array (jax .devices ()).reshape ((4 , 1 , 1 , 1 , 1 ))
462+ devices = np .array (jax .devices ()).reshape ((jax . device_count () , 1 , 1 , 1 , 1 ))
463463 self .mesh = Mesh (devices , self .mesh_axes )
464464
465465 def test_multi_axis_mixed_formating (self ):
0 commit comments