@@ -165,7 +165,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
165165
166166 state = state .to_pure_dict ()
167167 p_train_step = jax .jit (
168- functools .partial (train_step , scheduler = pipeline .scheduler ),
168+ functools .partial (train_step , scheduler = pipeline .scheduler , config = self . config ),
169169 donate_argnums = (0 ,),
170170 )
171171 rng = jax .random .key (self .config .seed )
@@ -219,16 +219,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
219219 return pipeline
220220
221221
222- def train_step (state , graphdef , scheduler_state , data , rng , scheduler ):
223- return step_optimizer (graphdef , state , scheduler , scheduler_state , data , rng )
222+ def train_step (state , graphdef , scheduler_state , data , rng , scheduler , config ):
223+ return step_optimizer (graphdef , state , scheduler , scheduler_state , data , rng , config )
224224
225225
226- def step_optimizer (graphdef , state , scheduler , scheduler_state , data , rng ):
226+ def step_optimizer (graphdef , state , scheduler , scheduler_state , data , rng , config ):
227227 _ , new_rng , timestep_rng = jax .random .split (rng , num = 3 )
228228
229229 def loss_fn (model ):
230- latents = data ["latents" ]
231- encoder_hidden_states = data ["encoder_hidden_states" ]
230+ latents = data ["latents" ].astype (config .weights_dtype )
231+ encoder_hidden_states = data ["encoder_hidden_states" ].astype (config .weights_dtype )
232+ # TODO - fix tf record conversion.
233+ encoder_hidden_states = jax .numpy .squeeze (encoder_hidden_states , axis = 1 )
232234 bsz = latents .shape [0 ]
233235 timesteps = jax .random .randint (
234236 timestep_rng ,
0 commit comments