2323import tensorflow as tf
2424import jax .numpy as jnp
2525import jax
26- from jax .sharding import PartitionSpec as P
26+ from jax .sharding import Mesh , PartitionSpec as P
2727from flax import nnx
2828from maxdiffusion .schedulers import FlaxFlowMatchScheduler
2929from flax .linen import partitioning as nn_partitioning
3939from flax .training import train_state
4040from maxdiffusion .pipelines .wan .wan_pipeline import WanPipeline
4141from jax .experimental import multihost_utils
42+ from maxdiffusion .max_utils import create_device_mesh
43+ import copy
4244
45+ class EvalConfig :
46+ pass
4347
4448class TrainState (train_state .TrainState ):
4549 graphdef : nnx .GraphDef
@@ -212,7 +216,7 @@ def start_training(self):
212216
213217 pipeline = self .load_checkpoint ()
214218 # Generate a sample before training to compare against generated sample after training.
215- pretrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "pre-training-" )
219+ # pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
216220
217221 if self .config .eval_every == - 1 or (not self .config .enable_generate_video_for_eval ):
218222 # save some memory.
@@ -230,8 +234,8 @@ def start_training(self):
230234 # Returns pipeline with trained transformer state
231235 pipeline = self .training_loop (pipeline , optimizer , learning_rate_scheduler , train_data_iterator )
232236
233- posttrained_video_path = generate_sample (self .config , pipeline , filename_prefix = "post-training-" )
234- print_ssim (pretrained_video_path , posttrained_video_path )
237+ # posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
238+ # print_ssim(pretrained_video_path, posttrained_video_path)
235239
236240 def training_loop (self , pipeline , optimizer , learning_rate_scheduler , train_data_iterator ):
237241 mesh = pipeline .mesh
@@ -246,7 +250,19 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
246250 state = jax .lax .with_sharding_constraint (state , state_spec )
247251 state_shardings = nnx .get_named_sharding (state , mesh )
248252 data_shardings = self .get_data_shardings (mesh )
249- eval_data_shardings = self .get_eval_data_shardings (mesh )
253+
254+ single_batch_size = min (self .config .eval_max_processed_batch_size , self .config .global_batch_size_to_train_on )
255+ eval_config = EvalConfig ()
256+ eval_config .dcn_data_parallelism = self .config .dcn_data_parallelism
257+ eval_config .dcn_fsdp_parallelism = self .config .dcn_fsdp_parallelism
258+ eval_config .dcn_tensor_parallelism = self .config .dcn_tensor_parallelism
259+ eval_config .ici_data_parallelism = single_batch_size
260+ eval_config .ici_fsdp_parallelism = 1
261+ eval_config .ici_tensor_parallelism = 1
262+ eval_config .allow_split_physical_axes = self .config .allow_split_physical_axes
263+ eval_devices_array = create_device_mesh (eval_config )
264+ eval_mesh = Mesh (eval_devices_array , self .config .mesh_axes )
265+ eval_data_shardings = self .get_eval_data_shardings (eval_mesh )
250266
251267 writer = max_utils .initialize_summary_writer (self .config )
252268 writer_thread = threading .Thread (target = _tensorboard_writer_worker , args = (writer , self .config ), daemon = True )
@@ -327,25 +343,39 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
327343 # Loop indefinitely until the iterator is exhausted
328344 while True :
329345 try :
330- with mesh :
331- eval_start_time = datetime .datetime .now ()
332- eval_batch = load_next_batch (eval_data_iterator , None , self .config )
333- metrics , eval_rng = p_eval_step (state , eval_batch , eval_rng , scheduler_state )
334- losses = metrics ["scalar" ]["learning/eval_loss" ]
335- timesteps = eval_batch ["timesteps" ]
336- gathered_timesteps_on_device = multihost_utils .process_allgather (timesteps )
337- gathered_timesteps = jax .device_get (gathered_timesteps_on_device )
338- gathered_losses_on_device = multihost_utils .process_allgather (losses )
339- gathered_losses = jax .device_get (gathered_losses_on_device )
340- for t , l in zip (gathered_timesteps .flatten (), gathered_losses .flatten ()):
346+ eval_start_time = datetime .datetime .now ()
347+ eval_batch = load_next_batch (eval_data_iterator , None , self .config )
348+ bs = len (eval_batch ["latents" ])
349+ for i in range (0 , bs , single_batch_size ):
350+ eval_step_start_time = datetime .datetime .now ()
351+ start = i
352+ end = min (i + single_batch_size , bs )
353+ timesteps = eval_batch ["timesteps" ][start :end ]
354+ chunk_eval_branch = {
355+ "latents" : eval_batch ["latents" ][start :end , :],
356+ "encoder_hidden_states" : eval_batch ["encoder_hidden_states" ][start :end , :],
357+ "timesteps" : timesteps ,
358+ }
359+ with eval_mesh :
360+ metrics , eval_rng = p_eval_step (state , chunk_eval_branch , eval_rng , scheduler_state )
361+ losses = metrics ["scalar" ]["learning/eval_loss" ]
362+ # gathered_timesteps_on_device = multihost_utils.process_allgather(timesteps)
363+ # gathered_timesteps = jax.device_get(gathered_timesteps_on_device)
364+ gathered_losses_on_device = multihost_utils .process_allgather (losses )
365+ gathered_losses = jax .device_get (gathered_losses_on_device )
366+ for t , l in zip (timesteps .flatten (), gathered_losses .flatten ()):
341367 timestep = int (t )
342368 if timestep not in eval_losses_by_timestep :
343369 eval_losses_by_timestep [timestep ] = []
344370 eval_losses_by_timestep [timestep ].append (l )
345- eval_end_time = datetime .datetime .now ()
346- eval_duration = eval_end_time - eval_start_time
371+ eval_step_end_time = datetime .datetime .now ()
372+ eval_step_duration = eval_step_end_time - eval_step_start_time
347373 if jax .process_index () == 0 :
348- max_logging .log (f" Eval step time { eval_duration .total_seconds ():.2f} seconds." )
374+ max_logging .log (f" Eval step processed batch { end } : { start } in { eval_step_duration .total_seconds ():.2f} seconds." )
375+ eval_end_time = datetime .datetime .now ()
376+ eval_duration = eval_end_time - eval_start_time
377+ if jax .process_index () == 0 :
378+ max_logging .log (f" Eval step time { eval_duration .total_seconds ():.2f} seconds." )
349379 except StopIteration :
350380 # This block is executed when the iterator has no more data
351381 break
@@ -440,11 +470,14 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
440470
441471 # The loss function logic is identical to training. We are evaluating the model's
442472 # ability to perform its core training objective (e.g., denoising).
443- @jax .jit
444- def loss_fn (params , latents , encoder_hidden_states , timesteps , rng ):
473+ def loss_fn (params ):
445474 # Reconstruct the model from its definition and parameters
446475 model = nnx .merge (state .graphdef , params , state .rest_of_state )
447476
477+ latents = data ["latents" ].astype (config .weights_dtype )
478+ encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
479+ timesteps = data ["timesteps" ].astype ("int64" )
480+
448481 noise = jax .random .normal (key = rng , shape = latents .shape , dtype = latents .dtype )
449482 noisy_latents = scheduler .add_noise (scheduler_state , latents , noise , timesteps )
450483
@@ -468,18 +501,8 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
468501 # --- Key Difference from train_step ---
469502 # Directly compute the loss without calculating gradients.
470503 # The model's state.params are used but not updated.
471- bs = len (data ["latents" ])
472- single_batch_size = min (config .eval_max_processed_batch_size , config .global_batch_size_to_train_on )
473- losses = jnp .zeros (bs )
474- for i in range (0 , bs , single_batch_size ):
475- start = i
476- end = min (i + single_batch_size , bs )
477- latents = data ["latents" ][start :end , :].astype (config .weights_dtype )
478- encoder_hidden_states = data ["encoder_hidden_states" ][start :end , :].astype (config .weights_dtype )
479- timesteps = data ["timesteps" ][start :end ].astype ("int64" )
480- _ , new_rng = jax .random .split (rng , num = 2 )
481- loss = loss_fn (state .params , latents , encoder_hidden_states , timesteps , new_rng )
482- losses = losses .at [start :end ].set (loss )
504+ _ , new_rng = jax .random .split (rng , num = 2 )
505+ losses = loss_fn (state .params )
483506
484507 # Structure the metrics for logging and aggregation
485508 metrics = {"scalar" : {"learning/eval_loss" : losses }}
0 commit comments