@@ -60,22 +60,27 @@ def validate_config(config):
6060
6161def get_topology_mesh (config ):
6262 """Get the target hardware devices, and create configured mesh with them"""
63- target_hardware = accelerator_to_spec_map .get_system_characteristics (config .compile_topology )
64- if target_hardware .platform == "gpu" :
65- # Disable sharded autotuning. This is an optimization to distribute
66- # autotuning across the fleet, but can cause hangs with AoT compilation.
67- os .environ ["XLA_FLAGS" ] = os .environ .get ("XLA_FLAGS" , "" ) + " --xla_gpu_shard_autotuning=false"
68- jax .config .update ("mock_num_gpu_processes" , config .compile_topology_num_slices )
69- topology_devices = jax .devices ()
70- else :
63+ if config .internal_compile :
7164 topology_devices = get_topology_desc (
72- platform = target_hardware .platform ,
73- topology_name = target_hardware .topology_name ,
74- chip_config_name = target_hardware .chip_config_name ,
75- chips_per_host_bounds = target_hardware .chips_per_host_bounds ,
76- num_slices = config .compile_topology_num_slices ,
77- wrap = target_hardware .wrap ,
65+ platform = "tpu" , topology_name = config .compile_topology , num_slices = config .compile_topology_num_slices
7866 ).devices
67+ else :
68+ target_hardware = accelerator_to_spec_map .get_system_characteristics (config .compile_topology )
69+ if target_hardware .platform == "gpu" :
70+ # Disable sharded autotuning. This is an optimization to distribute
71+ # autotuning across the fleet, but can cause hangs with AoT compilation.
72+ os .environ ["XLA_FLAGS" ] = os .environ .get ("XLA_FLAGS" , "" ) + " --xla_gpu_shard_autotuning=false"
73+ jax .config .update ("mock_num_gpu_processes" , config .compile_topology_num_slices )
74+ topology_devices = jax .devices ()
75+ else :
76+ topology_devices = get_topology_desc (
77+ platform = target_hardware .platform ,
78+ topology_name = target_hardware .topology_name ,
79+ chip_config_name = target_hardware .chip_config_name ,
80+ chips_per_host_bounds = target_hardware .chips_per_host_bounds ,
81+ num_slices = config .compile_topology_num_slices ,
82+ wrap = target_hardware .wrap ,
83+ ).devices
7984 if config .shard_mode == ShardMode .EXPLICIT :
8085 jax .config .update ("jax_remove_size_one_mesh_axis_from_type" , True )
8186 topology_device_mesh = maxtext_utils .create_device_mesh (config , topology_devices )
@@ -174,10 +179,14 @@ def is_oom(argv: Sequence[str]) -> bool:
174179 data_sharding = sharding .get_input_data_sharding (config , topology_mesh )
175180
176181 # Get function to compile and shardings
177- func_to_compile , in_shard , out_shard , static_argnums , donate_argnums = (
178- maxtext_utils .get_functional_train_with_signature (
179- train .train_step , data_sharding , state_mesh_shardings , model , config
180- )
182+ (
183+ func_to_compile ,
184+ in_shard ,
185+ out_shard ,
186+ static_argnums ,
187+ donate_argnums ,
188+ ) = maxtext_utils .get_functional_train_with_signature (
189+ train .train_step , data_sharding , state_mesh_shardings , model , config
181190 )
182191
183192 try :
@@ -255,10 +264,14 @@ def main(argv: Sequence[str]) -> None:
255264 donate_argnums = 0
256265 else :
257266 # Get function to compile and shardings
258- func_to_compile , in_shard , out_shard , static_argnums , donate_argnums = (
259- maxtext_utils .get_functional_train_with_signature (
260- train .train_step , data_sharding , state_mesh_shardings , model , config
261- )
267+ (
268+ func_to_compile ,
269+ in_shard ,
270+ out_shard ,
271+ static_argnums ,
272+ donate_argnums ,
273+ ) = maxtext_utils .get_functional_train_with_signature (
274+ train .train_step , data_sharding , state_mesh_shardings , model , config
262275 )
263276
264277 # print weights sharding info under debug sharding mode
0 commit comments