4040 load_next_batch ,
4141 record_scalar_metrics ,
4242 write_metrics ,
43- _metrics_queue
43+ _metrics_queue ,
4444)
4545
4646from maxdiffusion .checkpointing .base_stable_diffusion_checkpointer import (STABLE_DIFFUSION_XL_CHECKPOINT )
@@ -67,7 +67,7 @@ def get_shaped_batch(self, config, pipeline):
6767 total_train_batch_size = config .total_train_batch_size
6868 shaped_batch = {}
6969
70- if self .config .dataset_type in ["tf" ,"tfrecord" ] and self .config .cache_latents_text_encoder_outputs :
70+ if self .config .dataset_type in ["tf" , "tfrecord" ] and self .config .cache_latents_text_encoder_outputs :
7171 batch_image_shape = (
7272 total_train_batch_size ,
7373 pipeline .unet .config .in_channels ,
@@ -92,7 +92,7 @@ def get_shaped_batch(self, config, pipeline):
9292
9393 def get_data_shardings (self ):
9494 data_sharding = jax .sharding .NamedSharding (self .mesh , P (* self .config .data_sharding ))
95- if self .config .dataset_type in ["tf" ,"tfrecord" ] and self .config .cache_latents_text_encoder_outputs :
95+ if self .config .dataset_type in ["tf" , "tfrecord" ] and self .config .cache_latents_text_encoder_outputs :
9696 data_sharding = {
9797 "input_ids" : data_sharding ,
9898 "pixel_values" : data_sharding ,
@@ -188,11 +188,7 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da
188188 def training_loop (self , p_train_step , pipeline , params , train_states , data_iterator , unet_learning_rate_scheduler ):
189189
190190 writer = max_utils .initialize_summary_writer (self .config )
191- writer_thread = threading .Thread (
192- target = _tensorboard_writer_worker ,
193- args = (writer , self .config ),
194- daemon = True
195- )
191+ writer_thread = threading .Thread (target = _tensorboard_writer_worker , args = (writer , self .config ), daemon = True )
196192 writer_thread .start ()
197193 unet_state = train_states ["unet_state" ]
198194 vae_state = train_states ["vae_state" ]
@@ -228,14 +224,14 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
228224 for step in np .arange (start_step , self .config .max_train_steps ):
229225 if self .config .enable_profiler and step == first_profiling_step :
230226 max_utils .activate_profiler (self .config )
231-
227+
232228 next_batch_future = executor .submit (load_next_batch , data_iterator , example_batch , self .config )
233229 start_step_time = datetime .datetime .now ()
234230 with jax .profiler .StepTraceAnnotation ("train-new" , step_num = step ):
235231 (unet_state , train_metric , train_rngs ) = p_train_step (
236232 unet_state , vae_state , text_encoder_state , text_encoder_2_state , example_batch , train_rngs
237233 )
238- train_metric [' scalar' ][ ' learning/loss' ].block_until_ready ()
234+ train_metric [" scalar" ][ " learning/loss" ].block_until_ready ()
239235 samples_count = self .total_train_batch_size * (step + 1 )
240236 last_step_completion = datetime .datetime .now ()
241237 time_difference = last_step_completion - start_step_time
@@ -247,7 +243,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
247243 if self .config .write_metrics :
248244 write_metrics (writer , local_metrics_file , running_gcs_metrics , train_metric , step , self .config )
249245 example_batch = next_batch_future .result ()
250-
246+
251247 if step != 0 and self .config .checkpoint_every != - 1 and samples_count % self .config .checkpoint_every == 0 :
252248 train_states ["unet_state" ] = unet_state
253249 train_states ["vae_state" ] = vae_state
@@ -265,7 +261,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
265261 _metrics_queue .put (None )
266262 writer_thread .join ()
267263 if writer :
268- writer .flush ()
264+ writer .flush ()
269265 train_states ["unet_state" ] = unet_state
270266 train_states ["text_encoder_state" ] = text_encoder_state
271267 train_states ["text_encoder_2_state" ] = text_encoder_2_state
@@ -369,5 +365,5 @@ def compute_loss(state_params):
369365 new_state = unet_state .apply_gradients (grads = grad ["unet" ])
370366
371367 metrics = {"scalar" : {"learning/loss" : loss }, "scalars" : {}}
372-
368+
373369 return new_state , metrics , new_train_rng
0 commit comments