@@ -109,15 +109,19 @@ def load_dataset(self, pipeline, params, train_states):
109109 p_vae_apply = None
110110 rng = None
111111 if config .dataset_type == "tf" and config .cache_latents_text_encoder_outputs :
112- p_encode = jax .jit (
113- partial (
112+ text_encoder_partial = partial (
114113 maxdiffusion_utils .encode_xl ,
115114 text_encoders = [pipeline .text_encoder , pipeline .text_encoder_2 ],
116115 text_encoder_params = [train_states ["text_encoder_state" ].params , train_states ["text_encoder_2_state" ].params ],
117116 )
117+ text_encoder_partial .__name__ = "Text encoder"
118+ p_encode = jax .jit (
119+ text_encoder_partial
118120 )
121+ vae_partial = partial (maxdiffusion_utils .vae_apply , vae = pipeline .vae , vae_params = train_states ["vae_state" ].params )
122+ vae_partial .__name__ = "VAE Partial"
119123 p_vae_apply = jax .jit (
120- partial ( maxdiffusion_utils . vae_apply , vae = pipeline . vae , vae_params = train_states [ "vae_state" ]. params )
124+ vae_partial
121125 )
122126 rng = self .rng
123127
@@ -152,8 +156,10 @@ def compile_train_step(self, pipeline, params, train_states, state_shardings, da
152156
153157 self .rng , train_rngs = jax .random .split (self .rng )
154158 with self .mesh , nn_partitioning .axis_rules (self .config .logical_axis_rules ):
159+ train_step_partial = partial (_train_step , pipeline = pipeline , params = params , config = self .config )
160+ train_step_partial .__name__ = "Train Step"
155161 p_train_step = jax .jit (
156- partial ( _train_step , pipeline = pipeline , params = params , config = self . config ) ,
162+ train_step_partial ,
157163 in_shardings = (
158164 state_shardings ["unet_state_shardings" ],
159165 state_shardings ["vae_state_shardings" ],
0 commit comments