@@ -88,14 +88,11 @@ def create_unet_state(self, pipeline, params, checkpoint_item_name, is_training)
8888 config = self .config ,
8989 mesh = self .mesh ,
9090 weights_init_fn = weights_init_fn ,
91- model_params = None ,
91+ model_params = None if self . config . train_new_unet else params . get ( "unet" , None ) ,
9292 checkpoint_manager = self .checkpoint_manager ,
9393 checkpoint_item = checkpoint_item_name ,
9494 training = is_training ,
9595 )
96- if not self .config .train_new_unet :
97- unet_state = unet_state .replace (params = params .get ("unet" , None ))
98- unet_state = jax .device_put (unet_state , state_mesh_shardings )
9996 return unet_state , state_mesh_shardings , learning_rate_scheduler
10097
10198 def create_vae_state (self , pipeline , params , checkpoint_item_name , is_training = False ):
@@ -153,20 +150,18 @@ def create_text_encoder_2_state(self, pipeline, params, checkpoint_item_name, is
153150 input_shape = (self .total_train_batch_size , pipeline .tokenizer .model_max_length ),
154151 )
155152
156- state , state_mesh_shardings = max_utils .setup_initial_state (
153+ # state, state_mesh_shardings =
154+ return max_utils .setup_initial_state (
157155 model = pipeline .text_encoder_2 ,
158156 tx = tx ,
159157 config = self .config ,
160158 mesh = self .mesh ,
161159 weights_init_fn = weights_init_fn ,
162- model_params = None ,
160+ model_params = params . get ( "text_encoder_2" , None ) ,
163161 checkpoint_manager = self .checkpoint_manager ,
164162 checkpoint_item = checkpoint_item_name ,
165163 training = is_training ,
166164 )
167- state = state .replace (params = params .get ("text_encoder_2" , None ))
168- state = jax .device_put (state , state_mesh_shardings )
169- return state , state_mesh_shardings
170165
171166 def restore_data_iterator_state (self , data_iterator ):
172167 if (
0 commit comments