Skip to content

Commit 0731a49

Browse files
add sharding annotations for vae. Verified transformer correctness for one step.
1 parent 6973222 commit 0731a49

7 files changed

Lines changed: 552 additions & 186 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ logical_axis_rules: [
130130
['conv_batch', ['data','fsdp']],
131131
['out_channels', 'tensor'],
132132
['conv_out', 'fsdp'],
133+
['conv_in', 'fsdp']
133134
]
134135
data_sharding: [['data', 'fsdp', 'tensor']]
135136

src/maxdiffusion/models/attention_flax.py

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ def _reshape_data_for_flash(tensor, heads, flash_block_size):
101101
"""
102102
Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim.
103103
"""
104-
tensor = _unflatten_heads(tensor, heads)
104+
if tensor.ndim != 4:
105+
tensor = _unflatten_heads(tensor, heads)
105106

106107
# pad head_dim to 128 if less than that.
107108
kv_size = tensor.shape[-1]
@@ -319,12 +320,14 @@ def _apply_attention(
319320
):
320321
"""Routes to different attention kernels."""
321322
_check_attention_inputs(query, key, value)
322-
323+
seq_len_idx = 1
324+
if query.ndim == 4:
325+
seq_len_idx = 2
323326
if attention_kernel == "flash":
324327
can_use_flash_attention = (
325-
query.shape[1] >= flash_min_seq_length
326-
and key.shape[1] >= flash_min_seq_length
327-
and value.shape[1] >= flash_min_seq_length
328+
query.shape[seq_len_idx] >= flash_min_seq_length
329+
and key.shape[seq_len_idx] >= flash_min_seq_length
330+
and value.shape[seq_len_idx] >= flash_min_seq_length
328331
)
329332
else:
330333
can_use_flash_attention = True
@@ -584,7 +587,6 @@ def __init__(
584587

585588
if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
586589
raise ValueError(f"The flash attention kernel requires a value for mesh, but mesh is {self.mesh}")
587-
588590
self.dim_head = dim_head
589591
self.heads = heads
590592
self.inner_dim = dim_head * heads
@@ -681,7 +683,6 @@ def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tup
681683
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
682684
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
683685

684-
freqs_cis = freqs_cis[None, None, ...]
685686
xq_out_complex = xq_ * freqs_cis
686687
xk_out_complex = xk_ * freqs_cis
687688

@@ -696,58 +697,26 @@ def __call__(
696697
encoder_hidden_states: jax.Array = None,
697698
rotary_emb: Optional[jax.Array] = None
698699
) -> jax.Array:
699-
print(" -- -- WanAttention -- ")
700700
dtype = hidden_states.dtype
701701
if encoder_hidden_states is None:
702702
encoder_hidden_states = hidden_states
703703
query_proj = self.query(hidden_states)
704-
print("query_proj min: ", np.min(query_proj))
705-
print("query_proj max: ", np.max(query_proj))
706704
key_proj = self.key(encoder_hidden_states)
707-
print("key_proj min: ", np.min(key_proj))
708-
print("key_proj max: ", np.max(key_proj))
709705
value_proj = self.value(encoder_hidden_states)
710-
print("value_proj min: ", np.min(value_proj))
711-
print("value_proj max: ", np.max(value_proj))
712-
713-
query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
714-
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
715-
value_proj = nn.with_logical_constraint(value_proj, self.value_axis_names)
716706

717707
if self.qk_norm:
718708
query_proj = self.norm_q(query_proj)
719709
key_proj = self.norm_k(key_proj)
720-
print("query_proj min: ", np.min(query_proj))
721-
print("query_proj max: ", np.max(query_proj))
722-
print("key_proj min: ", np.min(key_proj))
723-
print("key_proj max: ", np.max(key_proj))
724-
725710
if rotary_emb is not None:
726711
query_proj = _unflatten_heads(query_proj, self.heads)
727712
key_proj = _unflatten_heads(key_proj, self.heads)
728-
# value_proj = _unflatten_heads(value_proj, self.heads)
713+
value_proj = _unflatten_heads(value_proj, self.heads)
729714
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
730-
print("Rope query_proj min: ", np.min(query_proj))
731-
print("Rope query_proj max: ", np.max(query_proj))
732-
print("Rope key_proj min: ", np.min(key_proj))
733-
print("Rope key_proj max: ", np.max(key_proj))
734-
#breakpoint()
735-
query_proj = _reshape_heads_to_head_dim(query_proj)
736-
key_proj = _reshape_heads_to_head_dim(key_proj)
737715

738716
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
739-
try:
740-
print("attn_output min: ", np.min(attn_output))
741-
print("attn_output_for_print max: ", np.max(attn_output))
742-
except:
743-
pass
744717
attn_output = attn_output.astype(dtype=dtype)
745718

746-
hidden_states = self.proj_attn(hidden_states)
747-
print("hidden_states min: ", np.min(hidden_states))
748-
print("hidden_states max: ", np.max(hidden_states))
749-
print(" -- -- WanAttention DONE -- ")
750-
#breakpoint()
719+
hidden_states = self.proj_attn(attn_output)
751720
return hidden_states
752721

753722

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def get_1d_rotary_pos_embed(
227227
out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1)
228228
else:
229229
# Wan 2.1
230-
out = jax.lax.complex(jnp.cos(freqs), jnp.sin(freqs))
230+
out = jnp.exp(1j * freqs)
231231
return out
232232

233233
class NNXPixArtAlphaTextProjection(nnx.Module):

0 commit comments

Comments
 (0)