Skip to content

Commit 5f28f8c

Browse files
committed
minor fixes for ltx2 in config and pipeline, reverting back %128 assertion change
1 parent ae27bdd commit 5f28f8c

4 files changed

Lines changed: 40 additions & 60 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,19 +37,10 @@ negative_prompt: "Bright tones, overexposed, static, blurred details, subtitles,
3737
height: 512
3838
width: 768
3939
num_frames: 121
40-
flow_shift: 5.0
41-
downscale_factor: 0.6666666
42-
spatial_upscaler_model_path: "ltxv-spatial-upscaler-0.9.7.safetensors"
43-
prompt_enhancement_words_threshold: 120
44-
stg_mode: "attention_values"
4540
decode_timestep: 0.05
4641
decode_noise_scale: 0.025
4742
quantization: "int8"
4843
seed: 10
49-
conditioning_media_paths: None #["IMAGE_PATH"]
50-
conditioning_start_frames: [0]
51-
52-
5344
#parallelism
5445
mesh_axes: ['data', 'fsdp', 'context', 'tensor']
5546
logical_axis_rules: [

src/maxdiffusion/maxdiffusion_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,10 @@ def get_dummy_flux_inputs(config, pipeline, batch_size):
287287

288288

289289
def get_dummy_ltx2_inputs(config, pipeline, batch_size):
290-
height = 32
291-
width = 32
292-
num_frames = 1
293290
raw_keys = config.get_keys() if hasattr(config, "get_keys") else {}
291+
height = raw_keys.get("height", 512) if raw_keys.get("height") else 512
292+
width = raw_keys.get("width", 768) if raw_keys.get("width") else 768
293+
num_frames = raw_keys.get("num_frames", 121) if raw_keys.get("num_frames") else 121
294294
fps = raw_keys.get("fps", 24.0) if raw_keys.get("fps") else 24.0
295295
duration_s = num_frames / fps
296296
audio_latents_per_second = (

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -235,10 +235,8 @@ def _tpu_flash_attention(
235235
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
236236
# This is the case for cross-attn.
237237
if key.shape[1] != query.shape[1]:
238-
if key.shape[1] % 128 != 0:
239-
kv_max_block_size = ((key.shape[1] + 127) // 128) * 128
240-
else:
241-
kv_max_block_size = key.shape[1]
238+
assert key.shape[1] % 128 == 0
239+
kv_max_block_size = key.shape[1]
242240
else:
243241
kv_max_block_size = q_max_block_size
244242
# ensure that for cross attention we override the block sizes.

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 35 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -677,20 +677,20 @@ def load_transformer(
677677

678678
@staticmethod
679679
def _pack_text_embeds(
680-
text_hidden_states: torch.Tensor,
681-
sequence_lengths: torch.Tensor,
680+
text_hidden_states: jax.Array,
681+
sequence_lengths: jax.Array,
682682
padding_side: str = "left",
683683
scale_factor: int = 8,
684684
eps: float = 1e-6,
685-
) -> torch.Tensor:
685+
) -> jax.Array:
686686
"""
687-
Packs and normalizes text encoder hidden states using PyTorch to save device HBM.
687+
Packs and normalizes text encoder hidden states using JAX natively.
688688
"""
689689
batch_size, seq_len, hidden_dim, num_layers = text_hidden_states.shape
690690
original_dtype = text_hidden_states.dtype
691691

692692
# Create padding mask
693-
token_indices = torch.arange(seq_len, device=text_hidden_states.device).unsqueeze(0)
693+
token_indices = jnp.arange(seq_len)[None, :]
694694
if padding_side == "right":
695695
mask = token_indices < sequence_lengths[:, None]
696696
elif padding_side == "left":
@@ -700,20 +700,20 @@ def _pack_text_embeds(
700700
raise ValueError(f"padding_side must be 'left' or 'right', got {padding_side}")
701701
mask = mask[:, :, None, None]
702702

703-
masked_text_hidden_states = text_hidden_states.masked_fill(~mask, 0.0)
704-
num_valid_positions = (sequence_lengths * hidden_dim).view(batch_size, 1, 1, 1)
705-
masked_mean = masked_text_hidden_states.sum(dim=(1, 2), keepdim=True) / (num_valid_positions + eps)
703+
masked_text_hidden_states = jnp.where(mask, text_hidden_states, 0.0)
704+
num_valid_positions = (sequence_lengths * hidden_dim).reshape(batch_size, 1, 1, 1)
705+
masked_mean = jnp.sum(masked_text_hidden_states, axis=(1, 2), keepdims=True) / (num_valid_positions + eps)
706706

707-
x_min = text_hidden_states.masked_fill(~mask, float("inf")).amin(dim=(1, 2), keepdim=True)
708-
x_max = text_hidden_states.masked_fill(~mask, float("-inf")).amax(dim=(1, 2), keepdim=True)
707+
x_min = jnp.min(jnp.where(mask, text_hidden_states, jnp.inf), axis=(1, 2), keepdims=True)
708+
x_max = jnp.max(jnp.where(mask, text_hidden_states, -jnp.inf), axis=(1, 2), keepdims=True)
709709

710710
normalized_hidden_states = (text_hidden_states - masked_mean) / (x_max - x_min + eps)
711711
normalized_hidden_states = normalized_hidden_states * scale_factor
712712

713-
normalized_hidden_states = normalized_hidden_states.flatten(2)
714-
mask_flat = mask.squeeze(-1).expand(-1, -1, hidden_dim * num_layers)
715-
normalized_hidden_states = normalized_hidden_states.masked_fill(~mask_flat, 0.0)
716-
normalized_hidden_states = normalized_hidden_states.to(dtype=original_dtype)
713+
normalized_hidden_states = normalized_hidden_states.reshape(batch_size, seq_len, -1)
714+
mask_flat = jnp.broadcast_to(mask.squeeze(-1), (batch_size, seq_len, hidden_dim * num_layers))
715+
normalized_hidden_states = jnp.where(mask_flat, normalized_hidden_states, 0.0)
716+
normalized_hidden_states = normalized_hidden_states.astype(original_dtype)
717717
return normalized_hidden_states
718718

719719
def _get_gemma_prompt_embeds(
@@ -733,7 +733,6 @@ def _get_gemma_prompt_embeds(
733733
self.tokenizer.pad_token = self.tokenizer.eos_token
734734

735735
prompt = [p.strip() for p in prompt]
736-
# Return Numpy tensors to be compatible with JAX if no text encoder, else PyTorch
737736

738737
if self.text_encoder is not None:
739738
# PyTorch Text Encoder
@@ -748,49 +747,41 @@ def _get_gemma_prompt_embeds(
748747
text_input_ids = text_inputs.input_ids
749748
prompt_attention_mask = text_inputs.attention_mask
750749

751-
# Move to device if needed (assuming text_encoder is on correct device or CPU)
752-
# For now, keep on CPU or same device as model
753750
text_input_ids = text_input_ids.to(self.text_encoder.device)
754751
prompt_attention_mask = prompt_attention_mask.to(self.text_encoder.device)
755752

756-
max_logging.log(f"DEBUG: text_encoder is on {self.text_encoder.device}")
757-
max_logging.log(f"DEBUG: text_input_ids is on {text_input_ids.device}")
758-
759753
with torch.no_grad():
760754
text_encoder_outputs = self.text_encoder(
761755
input_ids=text_input_ids, attention_mask=prompt_attention_mask, output_hidden_states=True
762756
)
763757

764-
text_encoder_hidden_states = text_encoder_outputs.hidden_states
765-
del text_encoder_outputs # Free memory
758+
text_encoder_hidden_states = torch.stack(text_encoder_outputs.hidden_states, dim=-1)
759+
sequence_lengths = prompt_attention_mask.sum(dim=-1)
766760

767-
prompt_embeds_list = []
768-
for state in text_encoder_hidden_states:
769-
state_np = state.cpu().to(torch.float32).numpy()
770-
prompt_embeds_list.append(jnp.array(state_np, dtype=jnp.bfloat16))
771-
772-
prompt_embeds = prompt_embeds_list
761+
# Convert to JAX arrays to do native JAX math
762+
hidden_states_jax = jnp.array(text_encoder_hidden_states.cpu().to(torch.float32).numpy())
763+
sequence_lengths_jax = jnp.array(sequence_lengths.cpu().numpy())
764+
prompt_attention_mask_jax = jnp.array(prompt_attention_mask.cpu().numpy())
765+
766+
del text_encoder_outputs # Free memory
773767
del text_encoder_hidden_states # Free PyTorch tensor memory
774768

775-
prompt_attention_mask = jnp.array(prompt_attention_mask.cpu().to(torch.float32).numpy(), dtype=jnp.bfloat16)
769+
prompt_embeds = self._pack_text_embeds(
770+
hidden_states_jax,
771+
sequence_lengths_jax,
772+
padding_side=self.tokenizer.padding_side,
773+
scale_factor=scale_factor,
774+
)
775+
prompt_attention_mask = prompt_attention_mask_jax
776776
else:
777777
raise ValueError("`text_encoder` is required to encode prompts.")
778+
778779
if dtype is not None:
779-
if isinstance(prompt_embeds, list):
780-
prompt_embeds = [state.astype(dtype) for state in prompt_embeds]
781-
else:
782-
prompt_embeds = prompt_embeds.astype(dtype)
783-
784-
if isinstance(prompt_embeds, list):
785-
_, seq_len, _ = prompt_embeds[0].shape
786-
prompt_embeds = [
787-
jnp.repeat(state, num_videos_per_prompt, axis=0).reshape(batch_size * num_videos_per_prompt, seq_len, -1)
788-
for state in prompt_embeds
789-
]
790-
else:
791-
_, seq_len, _ = prompt_embeds.shape
792-
prompt_embeds = jnp.repeat(prompt_embeds, num_videos_per_prompt, axis=0)
793-
prompt_embeds = prompt_embeds.reshape(batch_size * num_videos_per_prompt, seq_len, -1)
780+
prompt_embeds = prompt_embeds.astype(dtype)
781+
782+
_, seq_len, _ = prompt_embeds.shape
783+
prompt_embeds = jnp.repeat(prompt_embeds, num_videos_per_prompt, axis=0)
784+
prompt_embeds = prompt_embeds.reshape(batch_size * num_videos_per_prompt, seq_len, -1)
794785

795786
prompt_attention_mask = prompt_attention_mask.reshape(batch_size, -1)
796787
prompt_attention_mask = jnp.repeat(prompt_attention_mask, num_videos_per_prompt, axis=0)

0 commit comments

Comments
 (0)