|
27 | 27 | import flax.linen as nn |
28 | 28 | import flax.traverse_util |
29 | 29 | from flax import nnx |
| 30 | +from flax.linen import partitioning as nn_partitioning |
30 | 31 | from transformers import AutoTokenizer, GemmaTokenizer, GemmaTokenizerFast, Gemma3ForConditionalGeneration |
31 | 32 | from tqdm.auto import tqdm |
32 | 33 | import qwix |
@@ -1222,59 +1223,69 @@ def __call__( |
1222 | 1223 | graphdef, state = nnx.split(self.transformer) |
1223 | 1224 |
|
1224 | 1225 | # 7. Denoising Loop |
1225 | | - connectors_graphdef, connectors_state = nnx.split(self.connectors) |
1226 | | - |
1227 | | - @jax.jit |
1228 | | - def run_connectors(graphdef, state, hidden_states, attention_mask): |
1229 | | - model = nnx.merge(graphdef, state) |
1230 | | - return model(hidden_states, attention_mask) |
1231 | | - |
1232 | | - video_embeds, audio_embeds = run_connectors( |
1233 | | - connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_) |
| 1226 | + import contextlib |
| 1227 | + context_manager = ( |
| 1228 | + self.mesh if hasattr(self, "mesh") and self.mesh is not None else contextlib.nullcontext() |
| 1229 | + ) |
| 1230 | + axis_rules_context = ( |
| 1231 | + nn_partitioning.axis_rules(self.config.logical_axis_rules) |
| 1232 | + if hasattr(self, "config") and hasattr(self.config, "logical_axis_rules") else contextlib.nullcontext() |
1234 | 1233 | ) |
1235 | 1234 |
|
1236 | | - for i, t in enumerate(timesteps): |
1237 | | - noise_pred, noise_pred_audio = transformer_forward_pass( |
1238 | | - graphdef, state, |
1239 | | - latents_jax, |
1240 | | - audio_latents_jax, |
1241 | | - t, |
1242 | | - video_embeds, |
1243 | | - audio_embeds, |
1244 | | - prompt_attention_mask_jax, |
1245 | | - prompt_attention_mask_jax, |
1246 | | - guidance_scale > 1.0, |
1247 | | - guidance_scale, |
1248 | | - num_frames, |
1249 | | - height, |
1250 | | - width, |
1251 | | - audio_num_frames, |
1252 | | - frame_rate, |
| 1235 | + with context_manager, axis_rules_context: |
| 1236 | + connectors_graphdef, connectors_state = nnx.split(self.connectors) |
| 1237 | + |
| 1238 | + @jax.jit |
| 1239 | + def run_connectors(graphdef, state, hidden_states, attention_mask): |
| 1240 | + model = nnx.merge(graphdef, state) |
| 1241 | + return model(hidden_states, attention_mask) |
| 1242 | + |
| 1243 | + video_embeds, audio_embeds = run_connectors( |
| 1244 | + connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_) |
1253 | 1245 | ) |
1254 | | - |
1255 | | - if guidance_scale > 1.0: |
1256 | | - noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) |
1257 | | - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
1258 | | - # Audio guidance |
1259 | | - noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0) |
1260 | | - noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) |
1261 | | - |
1262 | | - latents_step = latents_jax[batch_size:] |
1263 | | - audio_latents_step = audio_latents_jax[batch_size:] |
1264 | | - else: |
1265 | | - latents_step = latents_jax |
1266 | | - audio_latents_step = audio_latents_jax |
1267 | | - |
1268 | | - # Step |
1269 | | - latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False) |
1270 | | - audio_latents_step, _ = self.scheduler.step(scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False) |
1271 | | - |
1272 | | - if guidance_scale > 1.0: |
1273 | | - latents_jax = jnp.concatenate([latents_step] * 2, axis=0) |
1274 | | - audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0) |
1275 | | - else: |
1276 | | - latents_jax = latents_step |
1277 | | - audio_latents_jax = audio_latents_step |
| 1246 | + |
| 1247 | + for i, t in enumerate(timesteps): |
| 1248 | + noise_pred, noise_pred_audio = transformer_forward_pass( |
| 1249 | + graphdef, state, |
| 1250 | + latents_jax, |
| 1251 | + audio_latents_jax, |
| 1252 | + t, |
| 1253 | + video_embeds, |
| 1254 | + audio_embeds, |
| 1255 | + prompt_attention_mask_jax, |
| 1256 | + prompt_attention_mask_jax, |
| 1257 | + guidance_scale > 1.0, |
| 1258 | + guidance_scale, |
| 1259 | + num_frames, |
| 1260 | + height, |
| 1261 | + width, |
| 1262 | + audio_num_frames, |
| 1263 | + frame_rate, |
| 1264 | + ) |
| 1265 | + |
| 1266 | + if guidance_scale > 1.0: |
| 1267 | + noise_pred_uncond, noise_pred_text = jnp.split(noise_pred, 2, axis=0) |
| 1268 | + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) |
| 1269 | + # Audio guidance |
| 1270 | + noise_pred_audio_uncond, noise_pred_audio_text = jnp.split(noise_pred_audio, 2, axis=0) |
| 1271 | + noise_pred_audio = noise_pred_audio_uncond + guidance_scale * (noise_pred_audio_text - noise_pred_audio_uncond) |
| 1272 | + |
| 1273 | + latents_step = latents_jax[batch_size:] |
| 1274 | + audio_latents_step = audio_latents_jax[batch_size:] |
| 1275 | + else: |
| 1276 | + latents_step = latents_jax |
| 1277 | + audio_latents_step = audio_latents_jax |
| 1278 | + |
| 1279 | + # Step |
| 1280 | + latents_step, _ = self.scheduler.step(scheduler_state, noise_pred, t, latents_step, return_dict=False) |
| 1281 | + audio_latents_step, _ = self.scheduler.step(scheduler_state, noise_pred_audio, t, audio_latents_step, return_dict=False) |
| 1282 | + |
| 1283 | + if guidance_scale > 1.0: |
| 1284 | + latents_jax = jnp.concatenate([latents_step] * 2, axis=0) |
| 1285 | + audio_latents_jax = jnp.concatenate([audio_latents_step] * 2, axis=0) |
| 1286 | + else: |
| 1287 | + latents_jax = latents_step |
| 1288 | + audio_latents_jax = audio_latents_step |
1278 | 1289 |
|
1279 | 1290 | # 8. Decode Latents |
1280 | 1291 | if guidance_scale > 1.0: |
|
0 commit comments