Skip to content

Commit 0f6ef50

Browse files
prishajain1Perseus14
authored andcommitted
Explicit Activation Sharding in forward pass
parity check test fixed
1 parent cbb8cfd commit 0f6ef50

2 files changed

Lines changed: 26 additions & 15 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import jax
1818
import jax.numpy as jnp
1919
from flax import nnx
20+
import flax.linen as nn
2021

2122
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
2223
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward
@@ -321,6 +322,15 @@ def __call__(
321322
) -> Tuple[jax.Array, jax.Array]:
322323
batch_size = hidden_states.shape[0]
323324

325+
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
326+
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
327+
audio_hidden_states = jax.lax.with_sharding_constraint(audio_hidden_states, axis_names)
328+
329+
if encoder_hidden_states is not None:
330+
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, axis_names)
331+
if audio_encoder_hidden_states is not None:
332+
audio_encoder_hidden_states = jax.lax.with_sharding_constraint(audio_encoder_hidden_states, axis_names)
333+
324334
# 1. Video and Audio Self-Attention
325335
norm_hidden_states = self.norm1(hidden_states)
326336

src/maxdiffusion/tests/ltx2_parity_test.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -570,21 +570,22 @@ def convert_weight(pt_key_base, jax_key):
570570

571571
# 4. Run Forward
572572
print("Running MaxDiffusion forward pass...")
573-
output = model(
574-
hidden_states=jax_inputs["hidden_states"],
575-
audio_hidden_states=jax_inputs["audio_hidden_states"],
576-
encoder_hidden_states=jax_inputs["encoder_hidden_states"],
577-
audio_encoder_hidden_states=jax_inputs["audio_encoder_hidden_states"],
578-
timestep=jax_inputs["timestep"],
579-
encoder_attention_mask=jax_inputs["encoder_attention_mask"],
580-
audio_encoder_attention_mask=jax_inputs["audio_encoder_attention_mask"],
581-
num_frames=config["num_frames"] if "num_frames" in config else 4,
582-
height=config["height"] if "height" in config else 32,
583-
width=config["width"] if "width" in config else 32,
584-
audio_num_frames=128,
585-
fps=24.0,
586-
return_dict=True,
587-
)
573+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
574+
output = model(
575+
hidden_states=jax_inputs["hidden_states"],
576+
audio_hidden_states=jax_inputs["audio_hidden_states"],
577+
encoder_hidden_states=jax_inputs["encoder_hidden_states"],
578+
audio_encoder_hidden_states=jax_inputs["audio_encoder_hidden_states"],
579+
timestep=jax_inputs["timestep"],
580+
encoder_attention_mask=jax_inputs["encoder_attention_mask"],
581+
audio_encoder_attention_mask=jax_inputs["audio_encoder_attention_mask"],
582+
num_frames=config["num_frames"] if "num_frames" in config else 4,
583+
height=config["height"] if "height" in config else 32,
584+
width=config["width"] if "width" in config else 32,
585+
audio_num_frames=128,
586+
fps=24.0,
587+
return_dict=True,
588+
)
588589

589590
max_sample = output["sample"]
590591
max_audio_sample = output["audio_sample"]

0 commit comments

Comments
 (0)