Skip to content

Commit 955bd86

Browse files
authored
adds segment ids for masking. (#236)
* adds segment ids for masking. * reduce padding by computing it inside sharded qkvs. * scanned ring attn. * add ring attention - inference only.
1 parent aad9839 commit 955bd86

3 files changed

Lines changed: 129 additions & 50 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -113,14 +113,24 @@ def _unflatten_heads(tensor, heads):
113113
return tensor
114114

115115

116-
def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
116+
def _reshape_data_for_flash(tensor, heads):
117117
"""
118118
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
119119
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
120120
blocks is divisible by the number of shards.
121121
"""
122122
if tensor.ndim != 4:
123123
tensor = _unflatten_heads(tensor, heads)
124+
return tensor
125+
126+
127+
def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1):
128+
"""
129+
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
130+
Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of
131+
blocks is divisible by the number of shards.
132+
"""
133+
tensor = _reshape_data_for_flash(tensor, heads)
124134

125135
# Pad head_dim to 128 if less than that.
126136
kv_size = tensor.shape[-1]
@@ -148,8 +158,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1
148158

149159
if kv_size < 128 or seq_len_pad != 0:
150160
npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad))
151-
padded_tensor = jnp.pad(tensor, npad)
152-
tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "tensor", "fsdp", None))
161+
tensor = jnp.pad(tensor, npad)
153162

154163
return tensor, kv_size, seq_len
155164

@@ -164,12 +173,14 @@ def _tpu_flash_attention(
164173
axis_names_kv: AxisNames,
165174
flash_block_sizes: BlockSizes,
166175
dtype: jnp.dtype = jnp.float32,
176+
attention_kernel: str = "flash",
167177
) -> jax.Array:
168178
"""TPU Flash Attention"""
179+
169180
q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512
170-
# Cross-attention where kv dims are much smaller due to encoder_hidden_states.
171-
# If kv seq_len is padded too much, it causes issues in attention calculations.
181+
# This is the case for cross-attn.
172182
if key.shape[1] != query.shape[1]:
183+
assert key.shape[1] % 128 == 0
173184
kv_max_block_size = key.shape[1]
174185
else:
175186
kv_max_block_size = q_max_block_size
@@ -186,38 +197,90 @@ def _tpu_flash_attention(
186197
block_q_dq=min(q_max_block_size, query.shape[2]),
187198
block_kv_dq=min(kv_max_block_size, query.shape[2]),
188199
)
189-
190200
num_fsdp_shards = mesh.shape["fsdp"]
191-
query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards)
192-
key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards)
193-
value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards)
201+
query = _reshape_data_for_flash(query, heads)
202+
key = _reshape_data_for_flash(key, heads)
203+
value = _reshape_data_for_flash(value, heads)
194204
q_axis_names = nn.logical_to_mesh_axes(axis_names_q)
195205
kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv)
196206

197207
@functools.partial(
198208
shard_map.shard_map,
199209
mesh=mesh,
200-
in_specs=(
201-
q_axis_names,
202-
kv_axis_names,
203-
kv_axis_names,
204-
),
210+
in_specs=(q_axis_names, kv_axis_names, kv_axis_names),
205211
out_specs=q_axis_names,
206212
check_rep=False,
207213
)
208214
def wrap_flash_attention(query, key, value):
215+
216+
query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q)
217+
key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv_compute)
218+
value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv_compute)
219+
209220
mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]))
210221
multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1])
222+
223+
q_padded_len = query.shape[2]
224+
q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0)
225+
q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32)
226+
227+
kv_padded_len = key.shape[2]
228+
kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0)
229+
kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32)
230+
segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids)
231+
211232
# make_splash_mha is wrapped around shardmap and seq and head is already
212233
# sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1.
213234
splash_kernel = splash_attention_kernel.make_splash_mha(
214235
mask=multi_head_mask,
215236
head_shards=1, # the sizes of the axis is sharding over heads
216237
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
217238
block_sizes=block_sizes,
239+
save_residuals=True if attention_kernel == "ring" else False,
218240
)
219-
attention_output = jax.vmap(splash_kernel)(query, key, value)
220-
return attention_output
241+
vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None))
242+
243+
if attention_kernel == "flash":
244+
attention_output = vmapped_splash(query, key, value, segment_ids)
245+
else:
246+
if num_fsdp_shards > 1:
247+
out, (lse,) = vmapped_splash(query, key, value, segment_ids)
248+
m = lse.astype(jnp.float32)
249+
l = jnp.exp(lse - m)
250+
o = out.astype(jnp.float32) * l[..., None]
251+
252+
perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)]
253+
254+
k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm)
255+
v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm)
256+
257+
def ring_scan_body(carry, _):
258+
m, l, o, k_current, v_current = carry
259+
k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm)
260+
v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm)
261+
262+
out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids)
263+
264+
m_chunk = lse_chunk.astype(jnp.float32)
265+
m_old = m
266+
m = jnp.maximum(m_old, m_chunk)
267+
268+
exp_m_diff = jnp.exp(m_old - m)
269+
exp_m_chunk_diff = jnp.exp(m_chunk - m)
270+
271+
l = l * exp_m_diff + jnp.exp(lse_chunk - m)
272+
o = o * exp_m_diff[..., None]
273+
o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32)
274+
275+
# Return the updated state for the next iteration
276+
return (m, l, o, k_next, v_next), None
277+
278+
initial_carry = (m, l, o, k1, v1)
279+
(m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1)
280+
281+
attention_output = o_final / l_final[..., None]
282+
283+
return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype)
221284

222285
devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"]
223286
# This warning might show up when doing model eval for example, when calculating model flops
@@ -228,7 +291,6 @@ def wrap_flash_attention(query, key, value):
228291
f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}"
229292
)
230293
x = wrap_flash_attention(query, key, value)
231-
x = x[:, :, :query_seq_len, :kv_size]
232294
x = _reshape_heads_to_head_dim(x)
233295

234296
return x
@@ -379,6 +441,10 @@ def _apply_attention(
379441
return _tpu_flash_attention(
380442
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype
381443
)
444+
elif attention_kernel == "ring":
445+
return _tpu_flash_attention(
446+
query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel
447+
)
382448
elif attention_kernel == "cudnn_flash_te":
383449
return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer)
384450
else:

src/maxdiffusion/pyconfig.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from . import max_logging
2828
from . import max_utils
2929
from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH
30+
from maxdiffusion.common_types import LENGTH, KV_LENGTH
3031

3132

3233
def string_to_bool(s: str) -> bool:
@@ -175,6 +176,17 @@ def user_init(raw_keys):
175176
max_utils.write_config_raw_keys_for_gcs(raw_keys)
176177

177178
raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"])
179+
# Verify qkv is sharded across sequence.
180+
if raw_keys["attention"] == "ring":
181+
logical_axis_rules = list(raw_keys["logical_axis_rules"])
182+
q_seq_sharding = (LENGTH, "fsdp")
183+
kv_seq_sharding = (KV_LENGTH, "fsdp")
184+
if q_seq_sharding not in logical_axis_rules:
185+
logical_axis_rules.append(q_seq_sharding)
186+
if kv_seq_sharding not in logical_axis_rules:
187+
logical_axis_rules.append(kv_seq_sharding)
188+
raw_keys["logical_axis_rules"] = tuple(logical_axis_rules)
189+
178190
raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"])
179191

180192
if raw_keys["learning_rate_schedule_steps"] == -1:

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data
255255
eval_data_iterator = self.load_dataset(mesh, is_training=False)
256256
eval_rng = jax.random.key(self.config.seed + step)
257257
eval_metrics = []
258-
# Loop indefinitely until the iterator is exhausted
258+
# Loop indefinitely until the iterator is exhausted
259259
while True:
260260
try:
261261
with mesh:
@@ -329,6 +329,7 @@ def loss_fn(params):
329329
metrics = {"scalar": {"learning/loss": loss}, "scalars": {}}
330330
return new_state, scheduler_state, metrics, new_rng
331331

332+
332333
def eval_step(state, data, rng, scheduler_state, scheduler, config):
333334
"""
334335
Computes the evaluation loss for a single batch without updating model weights.
@@ -338,44 +339,44 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config):
338339
# This ensures the batch size is consistent, though it might be redundant
339340
# if the evaluation dataloader is already configured correctly.
340341
for k, v in data.items():
341-
data[k] = v[: config.global_batch_size_to_train_on, :]
342+
data[k] = v[: config.global_batch_size_to_train_on, :]
342343

343344
# The loss function logic is identical to training. We are evaluating the model's
344345
# ability to perform its core training objective (e.g., denoising).
345346
def loss_fn(params):
346-
# Reconstruct the model from its definition and parameters
347-
model = nnx.merge(state.graphdef, params, state.rest_of_state)
348-
349-
# Prepare inputs
350-
latents = data["latents"].astype(config.weights_dtype)
351-
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
352-
bsz = latents.shape[0]
353-
354-
# Sample random timesteps and noise, just as in a training step
355-
timesteps = jax.random.randint(
356-
timestep_rng,
357-
(bsz,),
358-
0,
359-
scheduler.config.num_train_timesteps,
360-
)
361-
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
362-
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
363-
364-
# Get the model's prediction
365-
model_pred = model(
366-
hidden_states=noisy_latents,
367-
timestep=timesteps,
368-
encoder_hidden_states=encoder_hidden_states,
369-
)
347+
# Reconstruct the model from its definition and parameters
348+
model = nnx.merge(state.graphdef, params, state.rest_of_state)
349+
350+
# Prepare inputs
351+
latents = data["latents"].astype(config.weights_dtype)
352+
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
353+
bsz = latents.shape[0]
354+
355+
# Sample random timesteps and noise, just as in a training step
356+
timesteps = jax.random.randint(
357+
timestep_rng,
358+
(bsz,),
359+
0,
360+
scheduler.config.num_train_timesteps,
361+
)
362+
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
363+
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)
370364

371-
# Calculate the loss against the target
372-
training_target = scheduler.training_target(latents, noise, timesteps)
373-
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
374-
loss = (training_target - model_pred) ** 2
375-
loss = loss * training_weight
376-
loss = jnp.mean(loss)
365+
# Get the model's prediction
366+
model_pred = model(
367+
hidden_states=noisy_latents,
368+
timestep=timesteps,
369+
encoder_hidden_states=encoder_hidden_states,
370+
)
377371

378-
return loss
372+
# Calculate the loss against the target
373+
training_target = scheduler.training_target(latents, noise, timesteps)
374+
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
375+
loss = (training_target - model_pred) ** 2
376+
loss = loss * training_weight
377+
loss = jnp.mean(loss)
378+
379+
return loss
379380

380381
# --- Key Difference from train_step ---
381382
# Directly compute the loss without calculating gradients.

0 commit comments

Comments
 (0)