Skip to content

Commit 5503f9c

Browse files
committed
add hyper
1 parent 9edefc3 commit 5503f9c

3 files changed

Lines changed: 10 additions & 6 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ global_batch_size: 0
235235
tfrecords_dir: ''
236236
no_records_per_shard: 0
237237
enable_eval_timesteps: False
238+
considered_timesteps_list: [125, 250, 375, 500, 625, 750, 875]
239+
num_eval_samples: 420
238240

239241
warmup_steps_fraction: 0.1
240242
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.
@@ -316,3 +318,4 @@ quantization_calibration_method: "absmax"
316318
eval_every: -1
317319
eval_data_dir: ""
318320
enable_generate_video_for_eval: False # This will increase the used TPU memory.
321+
eval_max_number_of_samples_in_bucket: 60

src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,10 @@ def generate_dataset(config):
8585
shard_record_count = 0
8686

8787
# Define timesteps and bucket configuration
88-
timesteps_list = [125, 250, 375, 500, 625, 750, 875]
89-
bucket_size = 60
90-
num_samples_to_process = 420
88+
num_eval_samples = config.num_eval_samples
89+
timesteps_list = config.timesteps_list
90+
assert num_eval_samples % len(timesteps_list) == 0
91+
bucket_size = num_eval_samples // len(timesteps_list)
9192

9293
# Load dataset
9394
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
@@ -115,12 +116,12 @@ def generate_dataset(config):
115116
current_timestep = None
116117
# Determine the timestep for the first 420 samples
117118
if config.enable_eval_timesteps:
118-
if global_record_count < num_samples_to_process:
119+
if global_record_count < num_eval_samples:
119120
print(f"global_record_count: {global_record_count}")
120121
bucket_index = global_record_count // bucket_size
121122
current_timestep = timesteps_list[bucket_index]
122123
else:
123-
print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}")
124+
print(f"value {global_record_count} is greater than or equal to {num_eval_samples}")
124125
return
125126

126127
# Write the example, including the timestep if applicable

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
345345
max_logging.log(f"Step {step}, calculating mean loss per timestep...")
346346
for timestep, losses in sorted(eval_losses_by_timestep.items()):
347347
losses = jnp.array(losses)
348-
losses = losses[: min(60, len(losses))]
348+
losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))]
349349
mean_loss = jnp.mean(losses)
350350
max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}, num of losses: {len(losses)}")
351351
mean_per_timestep.append(mean_loss)

0 commit comments

Comments
 (0)