Skip to content

Commit edcd3c2

Browse files
committed
modify pusav1 generation
1 parent 47c8305 commit edcd3c2

2 files changed

Lines changed: 10 additions & 7 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ global_batch_size: 0
234234
# For creating tfrecords from dataset
235235
tfrecords_dir: ''
236236
no_records_per_shard: 0
237+
enable_eval_timesteps: False
237238

238239
warmup_steps_fraction: 0.1
239240
learning_rate_schedule_steps: -1 # By default the length of the schedule is set to the number of steps.

src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,14 +112,16 @@ def generate_dataset(config):
112112
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
113113
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32)
114114

115-
# Determine the timestep for the first 420 samples
116115
current_timestep = None
117-
if global_record_count < num_samples_to_process:
118-
bucket_index = global_record_count // bucket_size
119-
current_timestep = timesteps_list[bucket_index]
120-
else:
121-
print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}")
122-
return
116+
# Determine the timestep for the first 420 samples
117+
if config.enable_eval_timesteps:
118+
if global_record_count < num_samples_to_process:
119+
print(f"global_record_count: {global_record_count}")
120+
bucket_index = global_record_count // bucket_size
121+
current_timestep = timesteps_list[bucket_index]
122+
else:
123+
print(f"value {global_record_count} is greater than or equal to {num_samples_to_process}")
124+
return
123125

124126
# Write the example, including the timestep if applicable
125127
writer.write(create_example(latent, prompt_embeds, timestep=current_timestep))

0 commit comments

Comments
 (0)