Skip to content

Commit 95e8f37

Browse files
committed
Format with pyink
1 parent 935b457 commit 95e8f37

3 files changed

Lines changed: 41 additions & 124 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,9 @@ def get_sinusoidal_embeddings(
5454
scaled_time = scale * emb
5555

5656
if flip_sin_to_cos:
57-
signal = jnp.concatenate(
58-
[jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=-1
59-
)
57+
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=-1)
6058
else:
61-
signal = jnp.concatenate(
62-
[jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1
63-
)
59+
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1)
6460
return signal
6561

6662

src/maxdiffusion/models/wan/autoencoder_kl_wan_2p2.py

Lines changed: 27 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -48,19 +48,14 @@ def _update_cache(cache, idx, value):
4848

4949

5050
# Helper to ensure kernel_size, stride, padding are tuples of 3 integers
51-
def _canonicalize_tuple(
52-
x: Union[int, Sequence[int]], rank: int, name: str
53-
) -> Tuple[int, ...]:
51+
def _canonicalize_tuple(x: Union[int, Sequence[int]], rank: int, name: str) -> Tuple[int, ...]:
5452
"""Canonicalizes a value to a tuple of integers."""
5553
if isinstance(x, int):
5654
return (x,) * rank
5755
elif isinstance(x, Sequence) and len(x) == rank:
5856
return tuple(x)
5957
else:
60-
raise ValueError(
61-
f"Argument '{name}' must be an integer or a sequence of {rank}"
62-
f" integers. Got {x}"
63-
)
58+
raise ValueError(f"Argument '{name}' must be an integer or a sequence of {rank}" f" integers. Got {x}")
6459

6560

6661
class RepSentinel:
@@ -69,9 +64,7 @@ def __eq__(self, other):
6964
return isinstance(other, RepSentinel)
7065

7166

72-
tree_util.register_pytree_node(
73-
RepSentinel, lambda x: ((), None), lambda _, __: RepSentinel()
74-
)
67+
tree_util.register_pytree_node(RepSentinel, lambda x: ((), None), lambda _, __: RepSentinel())
7568

7669

7770
class WanPatchify(nnx.Module):
@@ -217,9 +210,7 @@ def __init__(
217210
self.bias = 0
218211

219212
def __call__(self, x: jax.Array) -> jax.Array:
220-
normalized = jnp.linalg.norm(
221-
x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True
222-
)
213+
normalized = jnp.linalg.norm(x, ord=2, axis=(1 if self.channel_first else -1), keepdims=True)
223214
normalized = x / jnp.maximum(normalized, self.eps)
224215
normalized = normalized * self.scale * self.gamma
225216
if self.bias:
@@ -229,9 +220,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
229220

230221
class WanUpsample(nnx.Module):
231222

232-
def __init__(
233-
self, scale_factor: Tuple[float, float], method: str = "nearest"
234-
):
223+
def __init__(self, scale_factor: Tuple[float, float], method: str = "nearest"):
235224
# scale_factor for (H, W)
236225
# JAX resize works on spatial dims, H, W assuming (N, D, H, W, C) or (N, H, W, C)
237226
self.scale_factor = scale_factor
@@ -244,9 +233,7 @@ def __call__(self, x: jax.Array) -> jax.Array:
244233
n, h, w, c = in_shape
245234
target_h = int(h * self.scale_factor[0])
246235
target_w = int(w * self.scale_factor[1])
247-
out = jax.image.resize(
248-
x.astype(jnp.float32), (n, target_h, target_w, c), method=self.method
249-
)
236+
out = jax.image.resize(x.astype(jnp.float32), (n, target_h, target_w, c), method=self.method)
250237
return out.astype(input_dtype)
251238

252239

@@ -282,9 +269,7 @@ def __init__(
282269
use_bias=True,
283270
padding=[(0, 1), (0, 1)],
284271
rngs=rngs,
285-
kernel_init=nnx.with_partitioning(
286-
nnx.initializers.xavier_uniform(), (None, None, None, None)
287-
),
272+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, None)),
288273
dtype=dtype,
289274
param_dtype=weights_dtype,
290275
precision=precision,
@@ -409,11 +394,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
409394
feat_idx += 1
410395
else:
411396
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
412-
if (
413-
cache_x.shape[1] < 2
414-
and feat_cache[idx] is not None
415-
and not isinstance(feat_cache[idx], RepSentinel)
416-
):
397+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and not isinstance(feat_cache[idx], RepSentinel):
417398
# cache last frame of last two chunk
418399
cache_x = jnp.concatenate(
419400
[
@@ -422,14 +403,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
422403
],
423404
axis=1,
424405
)
425-
if (
426-
cache_x.shape[1] < 2
427-
and feat_cache[idx] is not None
428-
and isinstance(feat_cache[idx], RepSentinel)
429-
):
430-
cache_x = jnp.concatenate(
431-
[jnp.zeros(cache_x.shape), cache_x], axis=1
432-
)
406+
if cache_x.shape[1] < 2 and feat_cache[idx] is not None and isinstance(feat_cache[idx], RepSentinel):
407+
cache_x = jnp.concatenate([jnp.zeros(cache_x.shape), cache_x], axis=1)
433408
if isinstance(feat_cache[idx], RepSentinel):
434409
x = self.time_conv(x)
435410
else:
@@ -453,9 +428,7 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
453428
feat_idx += 1
454429
else:
455430
cache_x = jnp.copy(x[:, -1:, :, :, :])
456-
x = self.time_conv(
457-
jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1)
458-
)
431+
x = self.time_conv(jnp.concatenate([feat_cache[idx][:, -1:, :, :, :], x], axis=1))
459432
feat_cache = _update_cache(feat_cache, idx, cache_x)
460433
feat_idx += 1
461434

@@ -479,9 +452,7 @@ def __init__(
479452
self.nonlinearity = get_activation(non_linearity)
480453

481454
# layers
482-
self.norm1 = WanRMS_norm(
483-
dim=in_dim, rngs=rngs, images=False, channel_first=False
484-
)
455+
self.norm1 = WanRMS_norm(dim=in_dim, rngs=rngs, images=False, channel_first=False)
485456
self.conv1 = WanCausalConv3d(
486457
rngs=rngs,
487458
in_channels=in_dim,
@@ -493,9 +464,7 @@ def __init__(
493464
weights_dtype=weights_dtype,
494465
precision=precision,
495466
)
496-
self.norm2 = WanRMS_norm(
497-
dim=out_dim, rngs=rngs, images=False, channel_first=False
498-
)
467+
self.norm2 = WanRMS_norm(dim=out_dim, rngs=rngs, images=False, channel_first=False)
499468
self.conv2 = WanCausalConv3d(
500469
rngs=rngs,
501470
in_channels=out_dim,
@@ -581,9 +550,7 @@ def __init__(
581550
out_features=dim * 3,
582551
kernel_size=(1, 1),
583552
rngs=rngs,
584-
kernel_init=nnx.with_partitioning(
585-
nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")
586-
),
553+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, None, "conv_out")),
587554
dtype=dtype,
588555
param_dtype=weights_dtype,
589556
precision=precision,
@@ -593,9 +560,7 @@ def __init__(
593560
out_features=dim,
594561
kernel_size=(1, 1),
595562
rngs=rngs,
596-
kernel_init=nnx.with_partitioning(
597-
nnx.initializers.xavier_uniform(), (None, None, "conv_in", None)
598-
),
563+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), (None, None, "conv_in", None)),
599564
dtype=dtype,
600565
param_dtype=weights_dtype,
601566
precision=precision,
@@ -709,9 +674,7 @@ def __init__(
709674
self.factor = self.factor_t * self.factor_s * self.factor_s
710675
self.group_size = in_channels * self.factor // out_channels
711676

712-
def __call__(
713-
self, x: jax.Array, feat_cache=None, feat_idx=0
714-
) -> Tuple[jax.Array, Any, int]:
677+
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0) -> Tuple[jax.Array, Any, int]:
715678
if self.factor_t > 1 or self.factor_s > 1:
716679
n, d, h, w, c = x.shape
717680
pad_d = (self.factor_t - d % self.factor_t) % self.factor_t
@@ -769,9 +732,7 @@ def __init__(
769732
self.out_channels = out_channels
770733
self.repeats = out_channels * self.factor // in_channels
771734

772-
def __call__(
773-
self, x: jax.Array, feat_cache=None, feat_idx=0, first_chunk: bool = False
774-
) -> Tuple[jax.Array, Any, int]:
735+
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0, first_chunk: bool = False) -> Tuple[jax.Array, Any, int]:
775736
# Duplicate channels to match the expected total channels for upsampling.
776737
# x: (N, D, H, W, in_channels) -> (N, D, H, W, in_channels * self.repeats)
777738
x = jnp.repeat(x, repeats=self.repeats, axis=4)
@@ -891,9 +852,7 @@ def __call__(
891852

892853
x_shortcut = None
893854
if self.avg_shortcut is not None:
894-
x_shortcut, feat_cache, feat_idx = self.avg_shortcut(
895-
x_main, feat_cache, feat_idx
896-
)
855+
x_shortcut, feat_cache, feat_idx = self.avg_shortcut(x_main, feat_cache, feat_idx)
897856
x = x + x_shortcut
898857

899858
if return_shortcut:
@@ -994,9 +953,7 @@ def __call__(
994953

995954
x_shortcut = None
996955
if self.avg_shortcut is not None:
997-
x_shortcut, feat_cache, feat_idx = self.avg_shortcut(
998-
x_main, feat_cache, feat_idx, first_chunk
999-
)
956+
x_shortcut, feat_cache, feat_idx = self.avg_shortcut(x_main, feat_cache, feat_idx, first_chunk)
1000957
x = x + x_shortcut
1001958

1002959
if return_shortcut:
@@ -1052,9 +1009,7 @@ def __init__(
10521009
self.down_blocks = []
10531010
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
10541011
if i != len(dim_mult) - 1:
1055-
downsample_mode = (
1056-
"downsample3d" if temperal_downsample[i] else "downsample2d"
1057-
)
1012+
downsample_mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
10581013
else:
10591014
downsample_mode = None
10601015
self.down_blocks.append(
@@ -1120,9 +1075,7 @@ def __init__(
11201075
)
11211076

11221077
# output blocks
1123-
self.norm_out = WanRMS_norm(
1124-
out_dim, channel_first=False, images=False, rngs=rngs
1125-
)
1078+
self.norm_out = WanRMS_norm(out_dim, channel_first=False, images=False, rngs=rngs)
11261079
self.conv_out = WanCausalConv3d(
11271080
rngs=rngs,
11281081
in_channels=out_dim,
@@ -1281,9 +1234,7 @@ def __init__(
12811234
self.up_blocks = nnx.data(self.up_blocks)
12821235

12831236
# output blocks
1284-
self.norm_out = WanRMS_norm(
1285-
dim=out_dim, images=False, rngs=rngs, channel_first=False
1286-
)
1237+
self.norm_out = WanRMS_norm(dim=out_dim, images=False, rngs=rngs, channel_first=False)
12871238
self.conv_out = WanCausalConv3d(
12881239
rngs=rngs,
12891240
in_channels=out_dim,
@@ -1297,9 +1248,7 @@ def __init__(
12971248
)
12981249

12991250
@nnx.jit(static_argnames=("feat_idx", "first_chunk"))
1300-
def __call__(
1301-
self, x: jax.Array, feat_cache=None, feat_idx=0, first_chunk: bool = False
1302-
):
1251+
def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0, first_chunk: bool = False):
13031252
if feat_cache is not None:
13041253
idx = feat_idx
13051254
cache_x = jnp.copy(x[:, -CACHE_T:, :, :, :])
@@ -1553,9 +1502,7 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
15531502
if x.shape[-1] != 3:
15541503
# reshape channel last for JAX
15551504
x = jnp.transpose(x, (0, 2, 3, 4, 1))
1556-
assert (
1557-
x.shape[-1] == 3
1558-
), f"Expected input shape (N, D, H, W, 3), got {x.shape}"
1505+
assert x.shape[-1] == 3, f"Expected input shape (N, D, H, W, 3), got {x.shape}"
15591506

15601507
x = self.patchify(x)
15611508

@@ -1566,9 +1513,7 @@ def _encode(self, x: jax.Array, feat_cache: AutoencoderKLWanCache):
15661513
for i in range(iter_):
15671514
enc_conv_idx = 0
15681515
if i == 0:
1569-
out, enc_feat_map, enc_conv_idx = self.encoder(
1570-
x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx
1571-
)
1516+
out, enc_feat_map, enc_conv_idx = self.encoder(x[:, :1, :, :, :], feat_cache=enc_feat_map, feat_idx=enc_conv_idx)
15721517
else:
15731518
out_, enc_feat_map, enc_conv_idx = self.encoder(
15741519
x[:, 1 + 4 * (i - 1) : 1 + 4 * i, :, :, :],
@@ -1621,9 +1566,7 @@ def _decode(
16211566
first_chunk=True,
16221567
)
16231568
else:
1624-
out_, dec_feat_map, conv_idx = self.decoder(
1625-
x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx
1626-
)
1569+
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
16271570
out = jnp.concatenate([out, out_], axis=1)
16281571

16291572
feat_cache._feat_map = dec_feat_map
@@ -1645,9 +1588,7 @@ def decode(
16451588
if z.shape[-1] != self.z_dim:
16461589
# reshape channel last for JAX
16471590
z = jnp.transpose(z, (0, 2, 3, 4, 1))
1648-
assert (
1649-
z.shape[-1] == self.z_dim
1650-
), f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}"
1591+
assert z.shape[-1] == self.z_dim, f"Expected input shape (N, D, H, W, {self.z_dim}, got {z.shape}"
16511592
decoded = self._decode(z, feat_cache).sample
16521593
if not return_dict:
16531594
return (decoded,)

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 12 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -380,9 +380,7 @@ def __call__(
380380
# Support both global [B, 6, dim] and per-token [B, seq_len, 6, dim] temb.
381381
# Per-token temb is used by TI2V where first-frame tokens have timestep=0.
382382
if temb.ndim == 4: # Per-token: [B, seq_len, 6, dim]
383-
adaln = jnp.expand_dims(
384-
self.adaln_scale_shift_table, 0
385-
) # [1, 1, 6, dim]
383+
adaln = jnp.expand_dims(self.adaln_scale_shift_table, 0) # [1, 1, 6, dim]
386384
combined = adaln + temb.astype(jnp.float32) # [B, seq_len, 6, dim]
387385
parts = jnp.split(combined, 6, axis=2)
388386
shift_msa = parts[0].squeeze(2)
@@ -392,12 +390,10 @@ def __call__(
392390
c_scale_msa = parts[4].squeeze(2)
393391
c_gate_msa = parts[5].squeeze(2)
394392
else: # Global: [B, 6, dim]
395-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
396-
jnp.split(
397-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)),
398-
6,
399-
axis=1,
400-
)
393+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
394+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)),
395+
6,
396+
axis=1,
401397
)
402398
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_heads"))
403399
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
@@ -639,22 +635,14 @@ def __call__(
639635
# tokens receive timestep=0 (clean) and other tokens receive timestep=t.
640636
bt, sl = timestep.shape
641637
t_flat = timestep.reshape(-1) # [B*seq_len]
642-
t_sinusoidal = self.condition_embedder.timesteps_proj(
643-
t_flat
644-
) # [B*sl, freq_dim]
638+
t_sinusoidal = self.condition_embedder.timesteps_proj(t_flat) # [B*sl, freq_dim]
645639
t_sinusoidal = t_sinusoidal.reshape(bt, sl, -1) # [B, sl, freq_dim]
646-
temb = self.condition_embedder.time_embedder(
647-
t_sinusoidal
648-
) # [B, sl, dim]
640+
temb = self.condition_embedder.time_embedder(t_sinusoidal) # [B, sl, dim]
649641
with jax.named_scope("time_proj"):
650-
timestep_proj = self.condition_embedder.time_proj(
651-
self.condition_embedder.act_fn(temb)
652-
) # [B, sl, dim*6]
642+
timestep_proj = self.condition_embedder.time_proj(self.condition_embedder.act_fn(temb)) # [B, sl, dim*6]
653643
timestep_proj = timestep_proj.reshape(bt, sl, 6, -1) # [B, sl, 6, dim]
654644
# Text processing
655-
encoder_hidden_states = self.condition_embedder.text_embedder(
656-
encoder_hidden_states
657-
)
645+
encoder_hidden_states = self.condition_embedder.text_embedder(encoder_hidden_states)
658646
encoder_hidden_states_image = None
659647
encoder_attention_mask = None
660648
else:
@@ -664,9 +652,7 @@ def __call__(
664652
encoder_hidden_states,
665653
encoder_hidden_states_image,
666654
encoder_attention_mask,
667-
) = self.condition_embedder(
668-
timestep, encoder_hidden_states, encoder_hidden_states_image
669-
)
655+
) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image)
670656
timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1)
671657

672658
if encoder_hidden_states_image is not None:
@@ -745,18 +731,12 @@ def layer_forward(hidden_states):
745731

746732
if per_token_t:
747733
# temb: [B, seq_len, dim] — per-token modulation for final head
748-
combined_head = jnp.expand_dims(
749-
self.scale_shift_table, 0
750-
) + jnp.expand_dims(
751-
temb, 2
752-
) # [B, sl, 2, dim]
734+
combined_head = jnp.expand_dims(self.scale_shift_table, 0) + jnp.expand_dims(temb, 2) # [B, sl, 2, dim]
753735
shift, scale = jnp.split(combined_head, 2, axis=2)
754736
shift = shift.squeeze(2) # [B, sl, dim]
755737
scale = scale.squeeze(2) # [B, sl, dim]
756738
else:
757-
shift, scale = jnp.split(
758-
self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1
759-
)
739+
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
760740
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
761741
with jax.named_scope("proj_out"):
762742
hidden_states = self.proj_out(hidden_states)

0 commit comments

Comments
 (0)