Skip to content

Commit 3a44f61

Browse files
committed
add test which exist wrong for bs should be 128
1 parent 9e74521 commit 3a44f61

1 file changed

Lines changed: 56 additions & 33 deletions

File tree

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import tensorflow as tf
2424
import jax.numpy as jnp
2525
import jax
26-
from jax.sharding import PartitionSpec as P
26+
from jax.sharding import Mesh, PartitionSpec as P
2727
from flax import nnx
2828
from maxdiffusion.schedulers import FlaxFlowMatchScheduler
2929
from flax.linen import partitioning as nn_partitioning
@@ -39,7 +39,11 @@
3939
from flax.training import train_state
4040
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
4141
from jax.experimental import multihost_utils
42+
from maxdiffusion.max_utils import create_device_mesh
43+
import copy
4244

45+
class EvalConfig:
46+
pass
4347

4448
class TrainState(train_state.TrainState):
4549
graphdef: nnx.GraphDef
@@ -212,7 +216,7 @@ def start_training(self):
212216

213217
pipeline = self.load_checkpoint()
214218
# Generate a sample before training to compare against generated sample after training.
215-
pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
219+
# pretrained_video_path = generate_sample(self.config, pipeline, filename_prefix="pre-training-")
216220

217221
if self.config.eval_every == -1 or (not self.config.enable_generate_video_for_eval):
218222
# save some memory.
@@ -230,8 +234,8 @@ def start_training(self):
230234
# Returns pipeline with trained transformer state
231235
pipeline = self.training_loop(pipeline, optimizer, learning_rate_scheduler, train_data_iterator)
232236

233-
posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
234-
print_ssim(pretrained_video_path, posttrained_video_path)
237+
# posttrained_video_path = generate_sample(self.config, pipeline, filename_prefix="post-training-")
238+
# print_ssim(pretrained_video_path, posttrained_video_path)
235239

236240
def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data_iterator):
237241
mesh = pipeline.mesh
@@ -246,7 +250,19 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
246250
state = jax.lax.with_sharding_constraint(state, state_spec)
247251
state_shardings = nnx.get_named_sharding(state, mesh)
248252
data_shardings = self.get_data_shardings(mesh)
249-
eval_data_shardings = self.get_eval_data_shardings(mesh)
253+
254+
single_batch_size = min(self.config.eval_max_processed_batch_size, self.config.global_batch_size_to_train_on)
255+
eval_config = EvalConfig()
256+
eval_config.dcn_data_parallelism = self.config.dcn_data_parallelism
257+
eval_config.dcn_fsdp_parallelism = self.config.dcn_fsdp_parallelism
258+
eval_config.dcn_tensor_parallelism = self.config.dcn_tensor_parallelism
259+
eval_config.ici_data_parallelism = single_batch_size
260+
eval_config.ici_fsdp_parallelism = 1
261+
eval_config.ici_tensor_parallelism = 1
262+
eval_config.allow_split_physical_axes = self.config.allow_split_physical_axes
263+
eval_devices_array = create_device_mesh(eval_config)
264+
eval_mesh = Mesh(eval_devices_array, self.config.mesh_axes)
265+
eval_data_shardings = self.get_eval_data_shardings(eval_mesh)
250266

251267
writer = max_utils.initialize_summary_writer(self.config)
252268
writer_thread = threading.Thread(target=_tensorboard_writer_worker, args=(writer, self.config), daemon=True)
@@ -327,25 +343,39 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
327343
# Loop indefinitely until the iterator is exhausted
328344
while True:
329345
try:
330-
with mesh:
331-
eval_start_time = datetime.datetime.now()
332-
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
333-
metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state)
334-
losses = metrics["scalar"]["learning/eval_loss"]
335-
timesteps = eval_batch["timesteps"]
336-
gathered_timesteps_on_device = multihost_utils.process_allgather(timesteps)
337-
gathered_timesteps = jax.device_get(gathered_timesteps_on_device)
338-
gathered_losses_on_device = multihost_utils.process_allgather(losses)
339-
gathered_losses = jax.device_get(gathered_losses_on_device)
340-
for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()):
346+
eval_start_time = datetime.datetime.now()
347+
eval_batch = load_next_batch(eval_data_iterator, None, self.config)
348+
bs = len(eval_batch["latents"])
349+
for i in range(0, bs, single_batch_size):
350+
eval_step_start_time = datetime.datetime.now()
351+
start = i
352+
end = min(i + single_batch_size, bs)
353+
timesteps = eval_batch["timesteps"][start:end]
354+
chunk_eval_branch = {
355+
"latents": eval_batch["latents"][start:end, :],
356+
"encoder_hidden_states": eval_batch["encoder_hidden_states"][start:end, :],
357+
"timesteps": timesteps,
358+
}
359+
with eval_mesh:
360+
metrics, eval_rng = p_eval_step(state, chunk_eval_branch, eval_rng, scheduler_state)
361+
losses = metrics["scalar"]["learning/eval_loss"]
362+
# gathered_timesteps_on_device = multihost_utils.process_allgather(timesteps)
363+
# gathered_timesteps = jax.device_get(gathered_timesteps_on_device)
364+
gathered_losses_on_device = multihost_utils.process_allgather(losses)
365+
gathered_losses = jax.device_get(gathered_losses_on_device)
366+
for t, l in zip(timesteps.flatten(), gathered_losses.flatten()):
341367
timestep = int(t)
342368
if timestep not in eval_losses_by_timestep:
343369
eval_losses_by_timestep[timestep] = []
344370
eval_losses_by_timestep[timestep].append(l)
345-
eval_end_time = datetime.datetime.now()
346-
eval_duration = eval_end_time - eval_start_time
371+
eval_step_end_time = datetime.datetime.now()
372+
eval_step_duration = eval_step_end_time - eval_step_start_time
347373
if jax.process_index() == 0:
348-
max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.")
374+
max_logging.log(f" Eval step processed batch {end} : {start} in {eval_step_duration.total_seconds():.2f} seconds.")
375+
eval_end_time = datetime.datetime.now()
376+
eval_duration = eval_end_time - eval_start_time
377+
if jax.process_index() == 0:
378+
max_logging.log(f" Eval step time {eval_duration.total_seconds():.2f} seconds.")
349379
except StopIteration:
350380
# This block is executed when the iterator has no more data
351381
break
@@ -440,11 +470,14 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
440470

441471
# The loss function logic is identical to training. We are evaluating the model's
442472
# ability to perform its core training objective (e.g., denoising).
443-
@jax.jit
444-
def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
473+
def loss_fn(params):
445474
# Reconstruct the model from its definition and parameters
446475
model = nnx.merge(state.graphdef, params, state.rest_of_state)
447476

477+
latents = data["latents"].astype(config.weights_dtype)
478+
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
479+
timesteps = data["timesteps"].astype("int64")
480+
448481
noise = jax.random.normal(key=rng, shape=latents.shape, dtype=latents.dtype)
449482
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
450483

@@ -468,18 +501,8 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng):
468501
# --- Key Difference from train_step ---
469502
# Directly compute the loss without calculating gradients.
470503
# The model's state.params are used but not updated.
471-
bs = len(data["latents"])
472-
single_batch_size = min(config.eval_max_processed_batch_size, config.global_batch_size_to_train_on)
473-
losses = jnp.zeros(bs)
474-
for i in range(0, bs, single_batch_size):
475-
start = i
476-
end = min(i + single_batch_size, bs)
477-
latents= data["latents"][start:end, :].astype(config.weights_dtype)
478-
encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype)
479-
timesteps = data["timesteps"][start:end].astype("int64")
480-
_, new_rng = jax.random.split(rng, num=2)
481-
loss = loss_fn(state.params, latents, encoder_hidden_states, timesteps, new_rng)
482-
losses = losses.at[start:end].set(loss)
504+
_, new_rng = jax.random.split(rng, num=2)
505+
losses = loss_fn(state.params)
483506

484507
# Structure the metrics for logging and aggregation
485508
metrics = {"scalar": {"learning/eval_loss": losses}}

0 commit comments

Comments
 (0)