Skip to content

Commit 82b719e

Browse files
fix rope calculations.
1 parent 0ef8c71 commit 82b719e

4 files changed

Lines changed: 81 additions & 23 deletions

File tree

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,8 +139,8 @@ data_sharding: [['data', 'fsdp', 'tensor']]
139139
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
140140
dcn_fsdp_parallelism: -1
141141
dcn_tensor_parallelism: 1
142-
ici_data_parallelism: -1
143-
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
142+
ici_data_parallelism: 1
143+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
144144
ici_tensor_parallelism: 1
145145

146146
# Dataset

src/maxdiffusion/models/attention_flax.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import functools
1616
import math
1717
from typing import Optional, Callable, Tuple
18+
import numpy as np
1819
import flax.linen as nn
1920
from flax import nnx
2021
import jax
@@ -318,7 +319,7 @@ def _apply_attention(
318319
):
319320
"""Routes to different attention kernels."""
320321
_check_attention_inputs(query, key, value)
321-
322+
322323
if attention_kernel == "flash":
323324
can_use_flash_attention = (
324325
query.shape[1] >= flash_min_seq_length
@@ -578,8 +579,7 @@ def __init__(
578579
qkv_bias: bool = False,
579580
quant: Quant = None,
580581
):
581-
582-
if attention_kernel == "cudnn_flash_te" or attention_kernel == "dot_product":
582+
if attention_kernel == "cudnn_flash_te":
583583
raise NotImplementedError(f"Wan 2.1 has not been tested with {attention_kernel}")
584584

585585
if attention_kernel in {"flash", "cudnn_flash_te"} and mesh is None:
@@ -676,7 +676,7 @@ def __init__(
676676
def _apply_rope(self, xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array) -> Tuple[jax.Array, jax.Array]:
677677
dtype = xq.dtype
678678
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
679-
reshape_xk = xq.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
679+
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
680680

681681
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
682682
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
@@ -696,13 +696,19 @@ def __call__(
696696
encoder_hidden_states: jax.Array = None,
697697
rotary_emb: Optional[jax.Array] = None
698698
) -> jax.Array:
699-
699+
print(" -- -- WanAttention -- ")
700700
dtype = hidden_states.dtype
701701
if encoder_hidden_states is None:
702702
encoder_hidden_states = hidden_states
703703
query_proj = self.query(hidden_states)
704+
print("query_proj min: ", np.min(query_proj))
705+
print("query_proj max: ", np.max(query_proj))
704706
key_proj = self.key(encoder_hidden_states)
707+
print("key_proj min: ", np.min(key_proj))
708+
print("key_proj max: ", np.max(key_proj))
705709
value_proj = self.value(encoder_hidden_states)
710+
print("value_proj min: ", np.min(value_proj))
711+
print("value_proj max: ", np.max(value_proj))
706712

707713
query_proj = nn.with_logical_constraint(query_proj, self.query_axis_names)
708714
key_proj = nn.with_logical_constraint(key_proj, self.key_axis_names)
@@ -711,18 +717,37 @@ def __call__(
711717
if self.qk_norm:
712718
query_proj = self.norm_q(query_proj)
713719
key_proj = self.norm_k(key_proj)
720+
print("query_proj min: ", np.min(query_proj))
721+
print("query_proj max: ", np.max(query_proj))
722+
print("key_proj min: ", np.min(key_proj))
723+
print("key_proj max: ", np.max(key_proj))
714724

715725
if rotary_emb is not None:
716726
query_proj = _unflatten_heads(query_proj, self.heads)
717727
key_proj = _unflatten_heads(key_proj, self.heads)
728+
# value_proj = _unflatten_heads(value_proj, self.heads)
718729
query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb)
730+
print("Rope query_proj min: ", np.min(query_proj))
731+
print("Rope query_proj max: ", np.max(query_proj))
732+
print("Rope key_proj min: ", np.min(key_proj))
733+
print("Rope key_proj max: ", np.max(key_proj))
734+
#breakpoint()
719735
query_proj = _reshape_heads_to_head_dim(query_proj)
720736
key_proj = _reshape_heads_to_head_dim(key_proj)
721737

722738
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
739+
try:
740+
print("attn_output min: ", np.min(attn_output))
741+
print("attn_output_for_print max: ", np.max(attn_output))
742+
except:
743+
pass
723744
attn_output = attn_output.astype(dtype=dtype)
724745

725746
hidden_states = self.proj_attn(hidden_states)
747+
print("hidden_states min: ", np.min(hidden_states))
748+
print("hidden_states max: ", np.max(hidden_states))
749+
print(" -- -- WanAttention DONE -- ")
750+
#breakpoint()
726751
return hidden_states
727752

728753

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,7 @@ def get_1d_rotary_pos_embed(
227227
out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1)
228228
else:
229229
# Wan 2.1
230-
out = jax.lax.complex(jnp.ones_like(freqs), freqs)
230+
out = jax.lax.complex(jnp.cos(freqs), jnp.sin(freqs))
231231
return out
232232

233233
class NNXPixArtAlphaTextProjection(nnx.Module):

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 48 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import jax
2020
import jax.numpy as jnp
2121
from flax import nnx
22+
import numpy as np
2223
from .... import common_types
2324
from ...modeling_flax_utils import FlaxModelMixin, get_activation
2425
from ....configuration_utils import ConfigMixin, register_to_config
@@ -58,12 +59,7 @@ def __init__(
5859
use_real=False
5960
)
6061
freqs.append(freq)
61-
self.freqs = jnp.concatenate(freqs, axis=1)
62-
63-
def __call__(self, hidden_states: jax.Array) -> jax.Array:
64-
_, num_frames, height, width, _ = hidden_states.shape
65-
p_t, p_h, p_w = self.patch_size
66-
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
62+
freqs = jnp.concatenate(freqs, axis=1)
6763

6864
sizes = [
6965
self.attention_head_dim // 2 - 2 * (self.attention_head_dim // 6),
@@ -72,16 +68,21 @@ def __call__(self, hidden_states: jax.Array) -> jax.Array:
7268
]
7369
cumulative_sizes = jnp.cumsum(jnp.array(sizes))
7470
split_indices = cumulative_sizes[:-1]
75-
freqs_split = jnp.split(self.freqs, split_indices, axis=1)
71+
self.freqs_split = jnp.split(freqs, split_indices, axis=1)
72+
73+
def __call__(self, hidden_states: jax.Array) -> jax.Array:
74+
_, num_frames, height, width, _ = hidden_states.shape
75+
p_t, p_h, p_w = self.patch_size
76+
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
7677

77-
freqs_f = jnp.expand_dims(jnp.expand_dims(freqs_split[0][:ppf], axis=1), axis=1)
78-
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, freqs_split[0].shape[-1]))
78+
freqs_f = jnp.expand_dims(jnp.expand_dims(self.freqs_split[0][:ppf], axis=1), axis=1)
79+
freqs_f = jnp.broadcast_to(freqs_f, (ppf, pph, ppw, self.freqs_split[0].shape[-1]))
7980

80-
freqs_h = jnp.expand_dims(jnp.expand_dims(freqs_split[1][:pph], axis=0), axis=2)
81-
freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, freqs_split[1].shape[-1]))
81+
freqs_h = jnp.expand_dims(jnp.expand_dims(self.freqs_split[1][:pph], axis=0), axis=2)
82+
freqs_h = jnp.broadcast_to(freqs_h, (ppf, pph, ppw, self.freqs_split[1].shape[-1]))
8283

83-
freqs_w = jnp.expand_dims(jnp.expand_dims(freqs_split[2][:ppw], axis=0), axis=1)
84-
freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, freqs_split[2].shape[-1]))
84+
freqs_w = jnp.expand_dims(jnp.expand_dims(self.freqs_split[2][:ppw], axis=0), axis=1)
85+
freqs_w = jnp.broadcast_to(freqs_w, (ppf, pph, ppw, self.freqs_split[2].shape[-1]))
8586

8687
freqs_concat = jnp.concatenate([freqs_f, freqs_h, freqs_w], axis=-1)
8788
freqs_final = jnp.reshape(freqs_concat, (1, 1, ppf * pph * ppw, -1))
@@ -361,22 +362,41 @@ def __call__(
361362
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
362363
(self.scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
363364
)
365+
print("Wan Block -- START -- ")
364366

365367
# 1. Self-attention
366368
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
369+
print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states))
370+
print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states))
367371
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
372+
print("Wan Block -- Self Attn. attn_output, min: ", np.min(attn_output))
373+
print("Wan Block -- Self Attn. attn_output, max: ", np.max(attn_output))
368374
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
375+
print("Wan Block -- hidden_states, min: ", np.min(hidden_states))
376+
print("Wan Block -- hidden_states, max: ", np.max(hidden_states))
369377

370378
# 2. Cross-attention
371379
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32))
380+
print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states))
381+
print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states))
372382
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
383+
print("Wan Block -- Cross Attn. attn_output, min: ", np.min(attn_output))
384+
print("Wan Block -- Cross Attn. attn_output, max: ", np.max(attn_output))
373385
hidden_states = hidden_states + attn_output
386+
print("Wan Block -- hidden_states, min: ", np.min(hidden_states))
387+
print("Wan Block -- hidden_states, max: ", np.max(hidden_states))
374388

375389
# 3. Feed-forward
376390
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype)
377-
391+
print("Wan Block -- norm_hidden_states, min: ", np.min(norm_hidden_states))
392+
print("Wan Block -- norm_hidden_states, max: ", np.max(norm_hidden_states))
378393
ff_output = self.ffn(norm_hidden_states)
394+
print("Wan Block -- ff_output, min: ", np.min(ff_output))
395+
print("Wan Block -- ff_output, max: ", np.max(ff_output))
379396
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(hidden_states.dtype)
397+
print("Wan Block -- hidden_states, min: ", np.min(hidden_states))
398+
print("Wan Block -- hidden_states, max: ", np.max(hidden_states))
399+
print("Wan Block -- COMPLETE -- ")
380400
return hidden_states
381401

382402

@@ -495,19 +515,32 @@ def __call__(
495515

496516
rotary_emb = self.rope(hidden_states)
497517
hidden_states = self.patch_embedding(hidden_states)
518+
print("***** After patch embedding")
519+
print("hidden_states, min: ", np.min(hidden_states))
520+
print("hidden_states, max: ", np.max(hidden_states))
498521
hidden_states = jax.lax.collapse(hidden_states, 1, -1)
499522

500523
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
501524
timestep, encoder_hidden_states, encoder_hidden_states_image
502525
)
526+
print("***** After condition embedder")
527+
print("temb, min: ", np.min(temb))
528+
print("temb, max: ", np.max(temb))
529+
print("timestep_proj, min: ", np.min(timestep_proj))
530+
print("timestep_proj, max: ", np.max(timestep_proj))
531+
print("encoder_hidden_states min: ", np.min(encoder_hidden_states))
532+
print("encoder_hidden_states max: ", np.max(encoder_hidden_states))
533+
503534
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
504535

505536
if encoder_hidden_states_image is not None:
506537
raise NotImplementedError("img2vid is not yet implemented.")
507538

508539
for block in self.blocks:
509540
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
510-
541+
print("After block, hidden_states min:", np.min(hidden_states))
542+
print("After block, hidden_states max:", np.max(hidden_states))
543+
#breakpoint()
511544
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
512545

513546
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)

0 commit comments

Comments
 (0)