2424
2525# Define a filename for logging
2626
27- def _log_to_file ( message : str , log_file : str = "" ):
28- """Appends a message to the global log file with a timestamp."""
29- timestamp = time . strftime ( '%Y-%m-%d %H:%M:%S %Z' , time . localtime ())
30- full_message = f"[ { timestamp } ] { message } \n "
31- if log_file :
32- with open ( log_file , 'a' ) as f :
33- f . write ( full_message )
34- max_logging . log (full_message . strip () )
35-
27+
28+ def _log_to_file ( message : str , log_file : str = "" ):
29+ """Appends a message to the global log file with a timestamp."""
30+ timestamp = time . strftime ( "%Y-%m-%d %H:%M:%S %Z" , time . localtime ())
31+ full_message = f"[ { timestamp } ] { message } \n "
32+ if log_file :
33+ with open ( log_file , "a" ) as f :
34+ f . write (full_message )
35+ max_logging . log ( full_message . strip ())
3636
3737
3838class BaseStableDiffusionTrainer (BaseStableDiffusionCheckpointer ):
@@ -80,32 +80,28 @@ def get_data_shardings(self):
8080 @abstractmethod
8181 def create_scheduler (self , pipeline , params ):
8282 pass
83-
83+
8484 def _time_and_log_call (
85- self ,
86- func_obj : Callable [..., Any ],
87- * func_args : Any ,
88- description : str = "" ,
89- ** func_kwargs : Any
90- ) -> Any :
85+ self , func_obj : Callable [..., Any ], * func_args : Any , description : str = "" , ** func_kwargs : Any
86+ ) -> Any :
9187 """
9288 Times a function call, logs its duration, and returns its result.
9389 """
9490 if not description :
95- if hasattr (func_obj , ' __name__' ):
91+ if hasattr (func_obj , " __name__" ):
9692 description = func_obj .__name__
97- elif hasattr (func_obj , ' __call__' ) and hasattr (type (func_obj ), ' __name__' ):
93+ elif hasattr (func_obj , " __call__" ) and hasattr (type (func_obj ), " __name__" ):
9894 description = type (func_obj ).__name__
9995 log_file = ""
100-
96+
10197 if self .config .write_timing_metrics and self .config .timing_metrics_file :
10298 log_file = self .config .get .timing_metrics_file
10399 _log_to_file (f"Starting: { description } ..." , log_file = log_file )
104- start_time = time .perf_counter () # Use perf_counter for more precise duration measurement
100+ start_time = time .perf_counter () # Use perf_counter for more precise duration measurement
105101 result = func_obj (* func_args , ** func_kwargs )
106102 end_time = time .perf_counter ()
107103 duration = end_time - start_time
108- _log_to_file (f"Finished: { description } - Duration: { duration :.4f} seconds" ,log_file = log_file )
104+ _log_to_file (f"Finished: { description } - Duration: { duration :.4f} seconds" , log_file = log_file )
109105 return result
110106
111107 def calculate_tflops (self , pipeline , params ):
@@ -129,7 +125,7 @@ def start_training(self):
129125 pipeline = pipeline ,
130126 params = params ,
131127 checkpoint_item_name = "vae_state" ,
132- is_training = False
128+ is_training = False ,
133129 )
134130
135131 train_states ["vae_state" ] = vae_state
@@ -147,13 +143,13 @@ def start_training(self):
147143 state_shardings ["text_encoder_state_shardings" ] = text_encoder_state_mesh_shardings
148144 if hasattr (pipeline , "text_encoder_2" ):
149145 text_encoder_2_state , text_encoder_2_state_mesh_shardings = self ._time_and_log_call (
150- self .create_text_encoder_2_state ,
151- # Arguments for create_text_encoder_2_state
152- pipeline = pipeline ,
153- params = params ,
154- checkpoint_item_name = "text_encoder_2_state" ,
155- is_training = self .config .train_text_encoder ,
156- )
146+ self .create_text_encoder_2_state ,
147+ # Arguments for create_text_encoder_2_state
148+ pipeline = pipeline ,
149+ params = params ,
150+ checkpoint_item_name = "text_encoder_2_state" ,
151+ is_training = self .config .train_text_encoder ,
152+ )
157153 train_states ["text_encoder_2_state" ] = text_encoder_2_state
158154 state_shardings ["text_encoder_2_state_shardings" ] = text_encoder_2_state_mesh_shardings
159155
@@ -167,17 +163,9 @@ def start_training(self):
167163 self .per_device_tflops = per_device_tflops
168164
169165 # Load dataset
170- data_iterator = self ._time_and_log_call (
171- self .load_dataset ,
172- pipeline ,
173- params ,
174- train_states
175- )
166+ data_iterator = self ._time_and_log_call (self .load_dataset , pipeline , params , train_states )
176167 if self .config .dataset_type == "grain" :
177- data_iterator = self ._time_and_log_call (
178- self .restore_data_iterator_state ,
179- data_iterator = data_iterator
180- )
168+ data_iterator = self ._time_and_log_call (self .restore_data_iterator_state , data_iterator = data_iterator )
181169
182170 unet_state , unet_state_mesh_shardings , unet_learning_rate_scheduler = self ._time_and_log_call (
183171 self .create_unet_state ,
@@ -196,13 +184,12 @@ def start_training(self):
196184 data_shardings = self .get_data_shardings ()
197185 # Compile train_step
198186 p_train_step = self ._time_and_log_call (
199- self .compile_train_step ,
200- pipeline , params , train_states , state_shardings , data_shardings
201- )
187+ self .compile_train_step , pipeline , params , train_states , state_shardings , data_shardings
188+ )
202189 # Start training
203- train_states = self ._time_and_log_call (self . training_loop ,
204- p_train_step , pipeline , params , train_states , data_iterator , unet_learning_rate_scheduler
190+ train_states = self ._time_and_log_call (
191+ self . training_loop , p_train_step , pipeline , params , train_states , data_iterator , unet_learning_rate_scheduler
205192 )
206193 # 6. save final checkpoint
207194 # Hook
208- self ._time_and_log_call (self .post_training_steps ,pipeline , params , train_states )
195+ self ._time_and_log_call (self .post_training_steps , pipeline , params , train_states )
0 commit comments