Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions src/maxdiffusion/models/attention_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,27 @@ def _unflatten_heads(tensor, heads):
return tensor


def _reshape_data_for_flash(tensor, heads):
def _reshape_data_for_flash(tensor, heads, num_context_shards=1):
"""
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
blocks is divisible by the number of shards.
"""
if tensor.ndim != 4:
tensor = _unflatten_heads(tensor, heads)
return tensor

org_seq_len = tensor.shape[2]

# Pad sequence dimension so it is evenly divisible by the context mesh axis,
# which shard_map requires.
if num_context_shards <= 1:
return tensor, org_seq_len
rem = org_seq_len % num_context_shards
if rem == 0:
return tensor, org_seq_len
pad_width = [(0, 0)] * tensor.ndim
pad_width[2] = (0, num_context_shards - rem)
return jnp.pad(tensor, pad_width), org_seq_len


def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
Expand All @@ -145,7 +157,7 @@ def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
blocks is divisible by the number of shards.
"""
tensor = _reshape_data_for_flash(tensor, heads)
tensor, _ = _reshape_data_for_flash(tensor, heads)

# Pad head_dim to 128 if less than that.
kv_size = tensor.shape[-1]
Expand Down Expand Up @@ -255,9 +267,10 @@ def _tpu_flash_attention(
use_fused_bwd_kernel=True if attention_kernel == "tokamax_flash" else False,
)
num_context_shards = mesh.shape["context"]
query = _reshape_data_for_flash(query, heads)
key = _reshape_data_for_flash(key, heads)
value = _reshape_data_for_flash(value, heads)
query, orig_q_seq_len = _reshape_data_for_flash(query, heads, num_context_shards)
key, _ = _reshape_data_for_flash(key, heads, num_context_shards)
value, _ = _reshape_data_for_flash(value, heads, num_context_shards)

q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)

Expand Down Expand Up @@ -401,6 +414,8 @@ def ring_scan_body(carry, _):
f" axis, batch dimension: {query.shape[0]}, devices_in_data_context: {devices_in_data_context}"
)
x = wrap_flash_attention(query, key, value)
# Trim back to original sequence length after context-axis padding.
x = x[:, :, :orig_q_seq_len, :]
x = _reshape_heads_to_head_dim(x)

return x
Expand Down
12 changes: 3 additions & 9 deletions src/maxdiffusion/models/ltx2/attention_ltx2.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,7 @@ def prepare_video_coords(
# pixel_coords[:, 0, ...] selects Frame dimension.
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
frame_coords = pixel_coords[:, 0, ...]
frame_coords = jnp.clip(
frame_coords + self.causal_offset - self.scale_factors[0], min=0
)
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0)
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)

return pixel_coords
Expand All @@ -212,16 +210,12 @@ def prepare_audio_coords(
# 2. Start timestamps
audio_scale_factor = self.scale_factors[0]
grid_start_mel = grid_f * audio_scale_factor
grid_start_mel = jnp.clip(
grid_start_mel + self.causal_offset - audio_scale_factor, min=0
)
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0)
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate

# 3. End timestamps
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
grid_end_mel = jnp.clip(
grid_end_mel + self.causal_offset - audio_scale_factor, min=0
)
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0)
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate

# Stack [num_patches, 2]
Expand Down
Loading