@@ -87,7 +87,7 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training)
8787 rngs = self .rng , max_sequence_length = self .config .max_sequence_length , eval_only = True
8888 )
8989
90- transformer_params = load_flow_model (self .config .flux_name , transformer_eval_params , "cpu" )
90+ # transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
9191
9292 weights_init_fn = functools .partial (
9393 pipeline .flux .init_weights , rngs = self .rng , max_sequence_length = self .config .max_sequence_length
@@ -103,9 +103,9 @@ def create_flux_state(self, pipeline, params, checkpoint_item_name, is_training)
103103 checkpoint_item = checkpoint_item_name ,
104104 training = is_training ,
105105 )
106- if not self .config .train_new_flux :
107- flux_state = flux_state .replace (params = transformer_params )
108- flux_state = jax .device_put (flux_state , state_mesh_shardings )
106+ # if not self.config.train_new_flux:
107+ # flux_state = flux_state.replace(params=transformer_params)
108+ # flux_state = jax.device_put(flux_state, state_mesh_shardings)
109109 return flux_state , state_mesh_shardings , learning_rate_scheduler
110110
111111 def create_vae_state (self , pipeline , params , checkpoint_item_name , is_training = False ):
@@ -217,12 +217,13 @@ def load_diffusers_checkpoint(self):
217217 dtype = self .config .activations_dtype ,
218218 weights_dtype = self .config .weights_dtype ,
219219 precision = max_utils .get_precision (self .config ),
220+ num_layers = 1
220221 )
221- transformer_eval_params = transformer .init_weights (
222- rngs = self .rng , max_sequence_length = self .config .max_sequence_length , eval_only = True
222+ transformer_params = transformer .init_weights (
223+ rngs = self .rng , max_sequence_length = self .config .max_sequence_length , eval_only = False
223224 )
224225
225- transformer_params = load_flow_model (self .config .flux_name , transformer_eval_params , "cpu" )
226+ # transformer_params = load_flow_model(self.config.flux_name, transformer_eval_params, "cpu")
226227
227228 pipeline = FluxPipeline (
228229 t5_encoder ,
0 commit comments