diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index fcdb7cf65..25788fb69 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -113,7 +113,7 @@ def _unflatten_heads(tensor, heads): return tensor -def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): +def _reshape_data_for_flash(tensor, heads): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of @@ -121,6 +121,16 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1 """ if tensor.ndim != 4: tensor = _unflatten_heads(tensor, heads) + return tensor + + +def _pad_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): + """ + Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. + Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of + blocks is divisible by the number of shards. + """ + tensor = _reshape_data_for_flash(tensor, heads) # Pad head_dim to 128 if less than that. kv_size = tensor.shape[-1] @@ -148,8 +158,7 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1 if kv_size < 128 or seq_len_pad != 0: npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) - padded_tensor = jnp.pad(tensor, npad) - tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "tensor", "fsdp", None)) + tensor = jnp.pad(tensor, npad) return tensor, kv_size, seq_len @@ -164,12 +173,14 @@ def _tpu_flash_attention( axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dtype: jnp.dtype = jnp.float32, + attention_kernel: str = "flash", ) -> jax.Array: """TPU Flash Attention""" + q_max_block_size = 1024 if dtype == jnp.bfloat16 else 512 - # Cross-attention where kv dims are much smaller due to encoder_hidden_states. - # If kv seq_len is padded too much, it causes issues in attention calculations. + # This is the case for cross-attn. if key.shape[1] != query.shape[1]: + assert key.shape[1] % 128 == 0 kv_max_block_size = key.shape[1] else: kv_max_block_size = q_max_block_size @@ -186,28 +197,38 @@ def _tpu_flash_attention( block_q_dq=min(q_max_block_size, query.shape[2]), block_kv_dq=min(kv_max_block_size, query.shape[2]), ) - num_fsdp_shards = mesh.shape["fsdp"] - query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards) - key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards) - value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) + query = _reshape_data_for_flash(query, heads) + key = _reshape_data_for_flash(key, heads) + value = _reshape_data_for_flash(value, heads) q_axis_names = nn.logical_to_mesh_axes(axis_names_q) kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) @functools.partial( shard_map.shard_map, mesh=mesh, - in_specs=( - q_axis_names, - kv_axis_names, - kv_axis_names, - ), + in_specs=(q_axis_names, kv_axis_names, kv_axis_names), out_specs=q_axis_names, check_rep=False, ) def wrap_flash_attention(query, key, value): + + query, kv_size, query_seq_len = _pad_data_for_flash(query, heads, block_sizes.block_q) + key, _, key_seq_len = _pad_data_for_flash(key, heads, block_sizes.block_kv_compute) + value, _, _ = _pad_data_for_flash(value, heads, block_sizes.block_kv_compute) + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + + q_padded_len = query.shape[2] + q_indices = jax.lax.broadcasted_iota(jnp.int32, (q_padded_len,), 0) + q_segment_ids = (q_indices < query_seq_len).astype(jnp.int32) + + kv_padded_len = key.shape[2] + kv_indices = jax.lax.broadcasted_iota(jnp.int32, (kv_padded_len,), 0) + kv_segment_ids = (kv_indices < key_seq_len).astype(jnp.int32) + segment_ids = splash_attention_kernel.SegmentIds(q=q_segment_ids, kv=kv_segment_ids) + # make_splash_mha is wrapped around shardmap and seq and head is already # sharded based on in_specs, therefore setting head_shards=1 and q_seq_shards=1. splash_kernel = splash_attention_kernel.make_splash_mha( @@ -215,9 +236,51 @@ def wrap_flash_attention(query, key, value): head_shards=1, # the sizes of the axis is sharding over heads q_seq_shards=1, # the sizes of the axis is sharding over seq_len block_sizes=block_sizes, + save_residuals=True if attention_kernel == "ring" else False, ) - attention_output = jax.vmap(splash_kernel)(query, key, value) - return attention_output + vmapped_splash = jax.vmap(splash_kernel, in_axes=(0, 0, 0, None)) + + if attention_kernel == "flash": + attention_output = vmapped_splash(query, key, value, segment_ids) + else: + if num_fsdp_shards > 1: + out, (lse,) = vmapped_splash(query, key, value, segment_ids) + m = lse.astype(jnp.float32) + l = jnp.exp(lse - m) + o = out.astype(jnp.float32) * l[..., None] + + perm = [(j, (j + 1) % num_fsdp_shards) for j in range(num_fsdp_shards)] + + k1 = jax.lax.ppermute(key, axis_name="fsdp", perm=perm) + v1 = jax.lax.ppermute(value, axis_name="fsdp", perm=perm) + + def ring_scan_body(carry, _): + m, l, o, k_current, v_current = carry + k_next = jax.lax.ppermute(k_current, axis_name="fsdp", perm=perm) + v_next = jax.lax.ppermute(v_current, axis_name="fsdp", perm=perm) + + out_chunk, (lse_chunk,) = vmapped_splash(query, k_current, v_current, segment_ids) + + m_chunk = lse_chunk.astype(jnp.float32) + m_old = m + m = jnp.maximum(m_old, m_chunk) + + exp_m_diff = jnp.exp(m_old - m) + exp_m_chunk_diff = jnp.exp(m_chunk - m) + + l = l * exp_m_diff + jnp.exp(lse_chunk - m) + o = o * exp_m_diff[..., None] + o += exp_m_chunk_diff[..., None] * out_chunk.astype(jnp.float32) + + # Return the updated state for the next iteration + return (m, l, o, k_next, v_next), None + + initial_carry = (m, l, o, k1, v1) + (m_final, l_final, o_final, _, _), _ = jax.lax.scan(ring_scan_body, initial_carry, None, length=num_fsdp_shards - 1) + + attention_output = o_final / l_final[..., None] + + return attention_output[:, :, :query_seq_len, :kv_size].astype(query.dtype) devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] # This warning might show up when doing model eval for example, when calculating model flops @@ -228,7 +291,6 @@ def wrap_flash_attention(query, key, value): f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" ) x = wrap_flash_attention(query, key, value) - x = x[:, :, :query_seq_len, :kv_size] x = _reshape_heads_to_head_dim(x) return x @@ -379,6 +441,10 @@ def _apply_attention( return _tpu_flash_attention( query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype ) + elif attention_kernel == "ring": + return _tpu_flash_attention( + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype, attention_kernel + ) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) else: diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 182a427bb..33fc62f83 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -27,6 +27,7 @@ from . import max_logging from . import max_utils from .models.wan.wan_utils import CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH, WAN_21_FUSION_X_MODEL_NAME_OR_PATH +from maxdiffusion.common_types import LENGTH, KV_LENGTH def string_to_bool(s: str) -> bool: @@ -175,6 +176,17 @@ def user_init(raw_keys): max_utils.write_config_raw_keys_for_gcs(raw_keys) raw_keys["logical_axis_rules"] = _lists_to_tuples(raw_keys["logical_axis_rules"]) + # Verify qkv is sharded across sequence. + if raw_keys["attention"] == "ring": + logical_axis_rules = list(raw_keys["logical_axis_rules"]) + q_seq_sharding = (LENGTH, "fsdp") + kv_seq_sharding = (KV_LENGTH, "fsdp") + if q_seq_sharding not in logical_axis_rules: + logical_axis_rules.append(q_seq_sharding) + if kv_seq_sharding not in logical_axis_rules: + logical_axis_rules.append(kv_seq_sharding) + raw_keys["logical_axis_rules"] = tuple(logical_axis_rules) + raw_keys["data_sharding"] = _lists_to_tuples(raw_keys["data_sharding"]) if raw_keys["learning_rate_schedule_steps"] == -1: diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index a267e0653..37615b076 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -255,7 +255,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, train_data eval_data_iterator = self.load_dataset(mesh, is_training=False) eval_rng = jax.random.key(self.config.seed + step) eval_metrics = [] - # Loop indefinitely until the iterator is exhausted + # Loop indefinitely until the iterator is exhausted while True: try: with mesh: @@ -329,6 +329,7 @@ def loss_fn(params): metrics = {"scalar": {"learning/loss": loss}, "scalars": {}} return new_state, scheduler_state, metrics, new_rng + def eval_step(state, data, rng, scheduler_state, scheduler, config): """ Computes the evaluation loss for a single batch without updating model weights. @@ -338,44 +339,44 @@ def eval_step(state, data, rng, scheduler_state, scheduler, config): # This ensures the batch size is consistent, though it might be redundant # if the evaluation dataloader is already configured correctly. for k, v in data.items(): - data[k] = v[: config.global_batch_size_to_train_on, :] + data[k] = v[: config.global_batch_size_to_train_on, :] # The loss function logic is identical to training. We are evaluating the model's # ability to perform its core training objective (e.g., denoising). def loss_fn(params): - # Reconstruct the model from its definition and parameters - model = nnx.merge(state.graphdef, params, state.rest_of_state) - - # Prepare inputs - latents = data["latents"].astype(config.weights_dtype) - encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) - bsz = latents.shape[0] - - # Sample random timesteps and noise, just as in a training step - timesteps = jax.random.randint( - timestep_rng, - (bsz,), - 0, - scheduler.config.num_train_timesteps, - ) - noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) - noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - - # Get the model's prediction - model_pred = model( - hidden_states=noisy_latents, - timestep=timesteps, - encoder_hidden_states=encoder_hidden_states, - ) + # Reconstruct the model from its definition and parameters + model = nnx.merge(state.graphdef, params, state.rest_of_state) + + # Prepare inputs + latents = data["latents"].astype(config.weights_dtype) + encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype) + bsz = latents.shape[0] + + # Sample random timesteps and noise, just as in a training step + timesteps = jax.random.randint( + timestep_rng, + (bsz,), + 0, + scheduler.config.num_train_timesteps, + ) + noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype) + noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps) - # Calculate the loss against the target - training_target = scheduler.training_target(latents, noise, timesteps) - training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) - loss = (training_target - model_pred) ** 2 - loss = loss * training_weight - loss = jnp.mean(loss) + # Get the model's prediction + model_pred = model( + hidden_states=noisy_latents, + timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + ) - return loss + # Calculate the loss against the target + training_target = scheduler.training_target(latents, noise, timesteps) + training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4)) + loss = (training_target - model_pred) ** 2 + loss = loss * training_weight + loss = jnp.mean(loss) + + return loss # --- Key Difference from train_step --- # Directly compute the loss without calculating gradients.