2222import jax
2323import optax
2424import jax .numpy as jnp
25- from jax .sharding import PartitionSpec as P
25+ from jax .sharding import PositionalSharding , PartitionSpec as P
2626from flax .linen import partitioning as nn_partitioning
27- from maxdiffusion .checkpointing .flux_checkpointer import (FluxCheckpointer , FLUX_CHECKPOINT )
27+ from maxdiffusion .checkpointing .flux_checkpointer import (
28+ FluxCheckpointer ,
29+ FLUX_CHECKPOINT ,
30+ FLUX_TRANSFORMER_PARAMS_KEY ,
31+ FLUX_STATE_KEY ,
32+ FLUX_STATE_SHARDINGS_KEY ,
33+ FLUX_VAE_PARAMS_KEY ,
34+ VAE_STATE_KEY ,
35+ VAE_STATE_SHARDINGS_KEY )
2836
2937from maxdiffusion .input_pipeline .input_pipeline_interface import (make_data_iterator )
3038
@@ -57,7 +65,7 @@ def __init__(self, config):
5765 raise ValueError ("this script currently doesn't support training text_encoders" )
5866
5967 def post_training_steps (self , pipeline , params , train_states , msg = "" ):
60- imgs = pipeline (flux_params = train_states ["flux_state" ],
68+ imgs = pipeline (flux_params = train_states [FLUX_STATE_KEY ],
6169 timesteps = 50 ,
6270 vae_params = train_states ["vae_state" ])
6371 imgs = np .array (imgs )
@@ -94,11 +102,21 @@ def start_training(self):
94102 # create train states
95103 train_states = {}
96104 state_shardings = {}
105+
106+ # move params to accelerator
107+ encoders_sharding = PositionalSharding (self .devices_array ).replicate ()
108+ partial_device_put_replicated = partial (max_utils .device_put_replicated , sharding = encoders_sharding )
109+ pipeline .clip_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), pipeline .clip_encoder .params )
110+ pipeline .clip_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , pipeline .clip_encoder .params )
111+ pipeline .t5_encoder .params = jax .tree_util .tree_map (lambda x : x .astype (jnp .bfloat16 ), pipeline .t5_encoder .params )
112+ pipeline .t5_encoder .params = jax .tree_util .tree_map (partial_device_put_replicated , pipeline .t5_encoder .params )
113+
114+
97115 vae_state , vae_state_mesh_shardings = self .create_vae_state (
98- pipeline = pipeline , params = params , checkpoint_item_name = "vae_state" , is_training = False
116+ pipeline = pipeline , params = params [ FLUX_VAE_PARAMS_KEY ] , checkpoint_item_name = VAE_STATE_KEY , is_training = False
99117 )
100- train_states ["vae_state" ] = vae_state
101- state_shardings ["vae_state_shardings" ] = vae_state_mesh_shardings
118+ train_states [VAE_STATE_KEY ] = vae_state
119+ state_shardings [VAE_STATE_SHARDINGS_KEY ] = vae_state_mesh_shardings
102120
103121 # Load dataset
104122 data_iterator = self .load_dataset (pipeline , params , train_states )
@@ -107,18 +125,23 @@ def start_training(self):
107125
108126 # don't need this anymore, clear some memory.
109127 del pipeline .t5_encoder
128+
129+ # evaluate shapes
130+
110131 flux_state , flux_state_mesh_shardings , flux_learning_rate_scheduler = self .create_flux_state (
111- # ambiguous here, but if self. params.get("unet") doesn't exist
132+ # ambiguous here, but if params=None
112133 # Then its 1 of 2 scenarios:
113134 # 1. unet state will be loaded directly from orbax
114135 # 2. a new unet is being trained from scratch.
115136 pipeline = pipeline ,
116137 params = None , # Params are loaded inside create_flux_state
117- checkpoint_item_name = "flux_state" ,
138+ checkpoint_item_name = FLUX_STATE_KEY ,
118139 is_training = True ,
119140 )
120- train_states ["flux_state" ] = flux_state
121- state_shardings ["flux_state_shardings" ] = flux_state_mesh_shardings
141+ flux_state = flux_state .replace (params = params [FLUX_TRANSFORMER_PARAMS_KEY ])
142+ flux_state = jax .device_put (flux_state , flux_state_mesh_shardings )
143+ train_states [FLUX_STATE_KEY ] = flux_state
144+ state_shardings [FLUX_STATE_SHARDINGS_KEY ] = flux_state_mesh_shardings
122145 #self.post_training_steps(pipeline, params, train_states, msg="before_training")
123146
124147 # Create scheduler
@@ -320,15 +343,15 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da
320343 max_logging .log ("Precompiling..." )
321344 s = time .time ()
322345 dummy_batch = self .get_shaped_batch (self .config , pipeline )
323- p_train_step = p_train_step .lower (train_states ["flux_state" ], dummy_batch , train_rngs )
346+ p_train_step = p_train_step .lower (train_states [FLUX_STATE_KEY ], dummy_batch , train_rngs )
324347 p_train_step = p_train_step .compile ()
325348 max_logging .log (f"Compile time: { (time .time () - s )} " )
326349 return p_train_step
327350
328351 def training_loop (self , p_train_step , pipeline , params , train_states , data_iterator , unet_learning_rate_scheduler ):
329352
330353 writer = max_utils .initialize_summary_writer (self .config )
331- flux_state = train_states ["flux_state" ]
354+ flux_state = train_states [FLUX_STATE_KEY ]
332355 num_model_parameters = max_utils .calculate_num_params_from_pytree (flux_state .params )
333356
334357 max_utils .add_text_to_summary_writer ("number_model_parameters" , str (num_model_parameters ), writer )
@@ -352,7 +375,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
352375 last_profiling_step = np .clip (
353376 first_profiling_step + self .config .profiler_steps - 1 , first_profiling_step , self .config .max_train_steps - 1
354377 )
355- start_step = get_first_step (train_states ["flux_state" ])
378+ start_step = get_first_step (train_states [FLUX_STATE_KEY ])
356379 _ , train_rngs = jax .random .split (self .rng )
357380 times = []
358381 for step in np .arange (start_step , self .config .max_train_steps ):
@@ -379,7 +402,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
379402
380403 if step != 0 and self .config .checkpoint_every != - 1 and samples_count % self .config .checkpoint_every == 0 :
381404 max_logging .log (f"Saving checkpoint for step { step } " )
382- train_states ["flux_state" ] = flux_state
405+ train_states [FLUX_STATE_KEY ] = flux_state
383406 self .save_checkpoint (step , pipeline , train_states )
384407
385408 if self .config .enable_profiler and step == last_profiling_step :
@@ -390,7 +413,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
390413 writer , local_metrics_file , running_gcs_metrics , train_metric , self .config .max_train_steps - 1 , self .config
391414 )
392415
393- train_states ["flux_state" ] = flux_state
416+ train_states [FLUX_STATE_KEY ] = flux_state
394417 max_logging .log (f"Average time per step: { sum (times [2 :], datetime .timedelta (0 )) / len (times [2 :])} " )
395418 if self .config .save_final_checkpoint :
396419 max_logging .log (f"Saving checkpoint for step { step } " )
@@ -402,7 +425,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
402425def _train_step (flux_state , batch , train_rng , guidance_vec , pipeline , scheduler , config ):
403426 _ , gen_dummy_rng = jax .random .split (train_rng )
404427 sample_rng , timestep_bias_rng , new_train_rng = jax .random .split (gen_dummy_rng , 3 )
405- state_params = {"flux_state" : flux_state .params }
428+ state_params = {FLUX_STATE_KEY : flux_state .params }
406429
407430 def compute_loss (state_params ):
408431 latents = batch ["pixel_values" ]
@@ -424,7 +447,7 @@ def compute_loss(state_params):
424447 noisy_latents = pipeline .scheduler .add_noise (scheduler , latents , noise , timesteps , flux = True )
425448
426449 model_pred = pipeline .flux .apply (
427- {"params" : state_params ["flux_state" ]},
450+ {"params" : state_params [FLUX_STATE_KEY ]},
428451 hidden_states = noisy_latents ,
429452 img_ids = img_ids ,
430453 encoder_hidden_states = text_embeds ,
@@ -444,7 +467,7 @@ def compute_loss(state_params):
444467 grad_fn = jax .value_and_grad (compute_loss )
445468 loss , grad = grad_fn (state_params )
446469
447- new_state = flux_state .apply_gradients (grads = grad ["flux_state" ])
470+ new_state = flux_state .apply_gradients (grads = grad [FLUX_STATE_KEY ])
448471
449472 metrics = {"scalar" : {"learning/loss" : loss }, "scalars" : {}}
450473
0 commit comments