Skip to content

Commit c5c9587

Browse files
committed
chunking reverted
1 parent 5f28f8c commit c5c9587

1 file changed

Lines changed: 25 additions & 61 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 25 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -675,47 +675,6 @@ def load_transformer(
675675

676676

677677

678-
@staticmethod
679-
def _pack_text_embeds(
680-
text_hidden_states: jax.Array,
681-
sequence_lengths: jax.Array,
682-
padding_side: str = "left",
683-
scale_factor: int = 8,
684-
eps: float = 1e-6,
685-
) -> jax.Array:
686-
"""
687-
Packs and normalizes text encoder hidden states using JAX natively.
688-
"""
689-
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
690-
original_dtype = text_hidden_states.dtype
691-
692-
# Create padding mask
693-
token_indices = jnp.arange(seq_len)[None, :]
694-
if padding_side == "right":
695-
mask = token_indices < sequence_lengths[:, None]
696-
elif padding_side == "left":
697-
start_indices = seq_len - sequence_lengths[:, None]
698-
mask = token_indices >= start_indices
699-
else:
700-
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
701-
mask = mask[:, :, None, None]
702-
703-
masked_text_hidden_states = jnp.where(mask, text_hidden_states, 0.0)
704-
num_valid_positions = (sequence_lengths * hidden_dim).reshape(batch_size, 1, 1, 1)
705-
masked_mean = jnp.sum(masked_text_hidden_states, axis=(1, 2), keepdims=True) / (num_valid_positions + eps)
706-
707-
x_min = jnp.min(jnp.where(mask, text_hidden_states, jnp.inf), axis=(1, 2), keepdims=True)
708-
x_max = jnp.max(jnp.where(mask, text_hidden_states, -jnp.inf), axis=(1, 2), keepdims=True)
709-
710-
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
711-
normalized_hidden_states = normalized_hidden_states * scale_factor
712-
713-
normalized_hidden_states = normalized_hidden_states.reshape(batch_size, seq_len, -1)
714-
mask_flat = jnp.broadcast_to(mask.squeeze(-1), (batch_size, seq_len, hidden_dim * num_layers))
715-
normalized_hidden_states = jnp.where(mask_flat, normalized_hidden_states, 0.0)
716-
normalized_hidden_states = normalized_hidden_states.astype(original_dtype)
717-
return normalized_hidden_states
718-
719678
def _get_gemma_prompt_embeds(
720679
self,
721680
prompt: Union[str, List[str]],
@@ -755,33 +714,38 @@ def _get_gemma_prompt_embeds(
755714
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
756715
)
757716

758-
text_encoder_hidden_states = torch.stack(text_encoder_outputs.hidden_states, dim=-1)
759-
sequence_lengths = prompt_attention_mask.sum(dim=-1)
760-
761-
# Convert to JAX arrays to do native JAX math
762-
hidden_states_jax = jnp.array(text_encoder_hidden_states.cpu().to(torch.float32).numpy())
763-
sequence_lengths_jax = jnp.array(sequence_lengths.cpu().numpy())
764-
prompt_attention_mask_jax = jnp.array(prompt_attention_mask.cpu().numpy())
765-
717+
text_encoder_hidden_states = text_encoder_outputs.hidden_states
766718
del text_encoder_outputs # Free memory
719+
720+
prompt_embeds_list = []
721+
# Iterate instead of stacking eagerly to avoid 5.7+ GB HBM allocations outside JIT
722+
for state in text_encoder_hidden_states:
723+
state_np = state.cpu().to(torch.float32).numpy()
724+
prompt_embeds_list.append(jnp.array(state_np, dtype=jnp.bfloat16))
725+
726+
prompt_embeds = prompt_embeds_list
767727
del text_encoder_hidden_states # Free PyTorch tensor memory
768728

769-
prompt_embeds = self._pack_text_embeds(
770-
hidden_states_jax,
771-
sequence_lengths_jax,
772-
padding_side=self.tokenizer.padding_side,
773-
scale_factor=scale_factor,
774-
)
775-
prompt_attention_mask = prompt_attention_mask_jax
729+
prompt_attention_mask = jnp.array(prompt_attention_mask.cpu().to(torch.float32).numpy(), dtype=jnp.bool_)
776730
else:
777731
raise ValueError("`text_encoder` is required to encode prompts.")
778732

779733
if dtype is not None:
780-
prompt_embeds = prompt_embeds.astype(dtype)
781-
782-
_, seq_len, _ = prompt_embeds.shape
783-
prompt_embeds = jnp.repeat(prompt_embeds, num_videos_per_prompt, axis=0)
784-
prompt_embeds = prompt_embeds.reshape(batch_size * num_videos_per_prompt, seq_len, -1)
734+
if isinstance(prompt_embeds, list):
735+
prompt_embeds = [state.astype(dtype) for state in prompt_embeds]
736+
else:
737+
prompt_embeds = prompt_embeds.astype(dtype)
738+
739+
if isinstance(prompt_embeds, list):
740+
_, seq_len, _ = prompt_embeds[0].shape
741+
prompt_embeds = [
742+
jnp.repeat(state, num_videos_per_prompt, axis=0).reshape(batch_size * num_videos_per_prompt, seq_len, -1)
743+
for state in prompt_embeds
744+
]
745+
else:
746+
_, seq_len, _ = prompt_embeds.shape
747+
prompt_embeds = jnp.repeat(prompt_embeds, num_videos_per_prompt, axis=0)
748+
prompt_embeds = prompt_embeds.reshape(batch_size * num_videos_per_prompt, seq_len, -1)
785749

786750
prompt_attention_mask = prompt_attention_mask.reshape(batch_size, -1)
787751
prompt_attention_mask = jnp.repeat(prompt_attention_mask, num_videos_per_prompt, axis=0)

0 commit comments

Comments
 (0)