File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -234,6 +234,7 @@ global_batch_size: 0
234234# For creating tfrecords from dataset
235235tfrecords_dir : ' '
236236no_records_per_shard : 0
237+ enable_eval_timesteps : False
237238
238239warmup_steps_fraction : 0.1
239240learning_rate_schedule_steps : -1 # By default the length of the schedule is set to the number of steps.
Original file line number Diff line number Diff line change @@ -112,14 +112,16 @@ def generate_dataset(config):
112112 latent = jnp .array (latent .float ().numpy (), dtype = jnp .float32 )
113113 prompt_embeds = jnp .array (prompt_embeds .float ().numpy (), dtype = jnp .float32 )
114114
115- # Determine the timestep for the first 420 samples
116115 current_timestep = None
117- if global_record_count < num_samples_to_process :
118- bucket_index = global_record_count // bucket_size
119- current_timestep = timesteps_list [bucket_index ]
120- else :
121- print (f"value { global_record_count } is greater than or equal to { num_samples_to_process } " )
122- return
116+ # Determine the timestep for the first 420 samples
117+ if config .enable_eval_timesteps :
118+ if global_record_count < num_samples_to_process :
119+ print (f"global_record_count: { global_record_count } " )
120+ bucket_index = global_record_count // bucket_size
121+ current_timestep = timesteps_list [bucket_index ]
122+ else :
123+ print (f"value { global_record_count } is greater than or equal to { num_samples_to_process } " )
124+ return
123125
124126 # Write the example, including the timestep if applicable
125127 writer .write (create_example (latent , prompt_embeds , timestep = current_timestep ))
You can’t perform that action at this time.
0 commit comments