Skip to content

Commit 085774a

Browse files
committed
use_real set to False in get_frequencies
1 parent a0d6dc4 commit 085774a

2 files changed

Lines changed: 1 addition & 3 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -244,10 +244,8 @@ def get_1d_rotary_pos_embed(
244244
freqs_cos = jnp.cos(freqs)
245245
freqs_sin = jnp.sin(freqs)
246246
out = jnp.stack([freqs_cos, -freqs_sin, freqs_sin, freqs_cos], axis=-1)
247-
print("Using real rotary embeddings (Flux-style)")
248247
else:
249248
# Wan 2.1
250-
print("Using complex rotary embeddings (Wan-style)")
251249
out = jnp.exp(1j * freqs)
252250
return out
253251

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int, use_r
4646
t_dim = attention_head_dim - h_dim - w_dim
4747
freqs = []
4848
for dim in [t_dim, h_dim, w_dim]:
49-
freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float32, use_real=use_real)
49+
freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float32, use_real=False)
5050
freqs.append(freq)
5151
freqs = jnp.concatenate(freqs, axis=1)
5252
t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6)

0 commit comments

Comments
 (0)