Skip to content

Commit 935b457

Browse files
committed
add support for wan vae 2.2 & fix hacky wan vae 2.1
1 parent 18f6f0f commit 935b457

4 files changed

Lines changed: 1747 additions & 45 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,32 @@ def get_sinusoidal_embeddings(
3535
"""Returns the positional encoding (same as Tensor2Tensor).
3636
3737
Args:
38-
timesteps: a 1-D Tensor of N indices, one per batch element.
38+
timesteps: a 1-D or 2-D Tensor of indices.
3939
These may be fractional.
4040
embedding_dim: The number of output channels.
4141
min_timescale: The smallest time unit (should probably be 0.0).
4242
max_timescale: The largest time unit.
4343
Returns:
44-
a Tensor of timing signals [N, num_channels]
44+
a Tensor of timing signals [B, num_channels] or [B, N, num_channels]
4545
"""
46-
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
46+
assert timesteps.ndim <= 2, "Timesteps should be a 1d or 2d-array"
4747
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
4848
num_timescales = float(embedding_dim // 2)
4949
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
5050
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
51-
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
51+
emb = jnp.expand_dims(timesteps, -1) * inv_timescales
5252

5353
# scale embeddings
5454
scaled_time = scale * emb
5555

5656
if flip_sin_to_cos:
57-
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
57+
signal = jnp.concatenate(
58+
[jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=-1
59+
)
5860
else:
59-
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
60-
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
61+
signal = jnp.concatenate(
62+
[jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1
63+
)
6164
return signal
6265

6366

@@ -84,7 +87,7 @@ def __init__(
8487
sample_proj_bias=True,
8588
dtype: jnp.dtype = jnp.float32,
8689
weights_dtype: jnp.dtype = jnp.float32,
87-
precision: jax.lax.Precision = None,
90+
precision: jax.lax.Precision | None = None,
8891
):
8992
self.linear_1 = nnx.Linear(
9093
rngs=rngs,
@@ -221,7 +224,7 @@ def __call__(self, timesteps):
221224

222225
def get_1d_rotary_pos_embed(
223226
dim: int,
224-
pos: Union[jnp.array, int],
227+
pos: Union[jnp.ndarray, int],
225228
theta: float = 10000.0,
226229
linear_factor=1.0,
227230
ntk_factor=1.0,
@@ -332,11 +335,11 @@ def __init__(
332335
rngs: nnx.Rngs,
333336
in_features: int,
334337
hidden_size: int,
335-
out_features: int = None,
338+
out_features: int | None = None,
336339
act_fn: str = "gelu_tanh",
337340
dtype: jnp.dtype = jnp.float32,
338341
weights_dtype: jnp.dtype = jnp.float32,
339-
precision: jax.lax.Precision = None,
342+
precision: jax.lax.Precision | None = None,
340343
):
341344
if out_features is None:
342345
out_features = hidden_size
@@ -392,11 +395,11 @@ class PixArtAlphaTextProjection(nn.Module):
392395
"""
393396

394397
hidden_size: int
395-
out_features: int = None
398+
out_features: int | None = None
396399
act_fn: str = "gelu_tanh"
397400
dtype: jnp.dtype = jnp.float32
398401
weights_dtype: jnp.dtype = jnp.float32
399-
precision: jax.lax.Precision = None
402+
precision: jax.lax.Precision | None = None
400403

401404
@nn.compact
402405
def __call__(self, caption):
@@ -455,7 +458,7 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
455458
pooled_projection_dim: int
456459
dtype: jnp.dtype = jnp.float32
457460
weights_dtype: jnp.dtype = jnp.float32
458-
precision: jax.lax.Precision = None
461+
precision: jax.lax.Precision | None = None
459462

460463
@nn.compact
461464
def __call__(self, timestep, pooled_projection):
@@ -479,7 +482,7 @@ class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
479482
pooled_projection_dim: int
480483
dtype: jnp.dtype = jnp.float32
481484
weights_dtype: jnp.dtype = jnp.float32
482-
precision: jax.lax.Precision = None
485+
precision: jax.lax.Precision | None = None
483486

484487
@nn.compact
485488
def __call__(self, timestep, guidance, pooled_projection):

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
360360
feat_cache = _update_cache(feat_cache, idx, cache_x)
361361
feat_idx += 1
362362
x = x.reshape(b, t, h, w, 2, c)
363-
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
363+
# x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
364+
x = x.transpose(0, 1, 4, 2, 3, 5)
364365
x = x.reshape(b, t * 2, h, w, c)
365366
t = x.shape[1]
366367
x = x.reshape(b * t, h, w, c)
@@ -1160,23 +1161,7 @@ def _decode(
11601161
out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
11611162
else:
11621163
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1163-
1164-
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1165-
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1166-
# Most likely due to an incorrect reshaping in the decoder.
1167-
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1168-
# When batch_size is 0, expand batch dim for concatenation
1169-
# else, expand frame dim for concatenation so that batch dim stays intact.
1170-
axis = 0
1171-
if fm1.shape[0] > 1:
1172-
axis = 1
1173-
1174-
if len(fm1.shape) == 4:
1175-
fm1 = jnp.expand_dims(fm1, axis=axis)
1176-
fm2 = jnp.expand_dims(fm2, axis=axis)
1177-
fm3 = jnp.expand_dims(fm3, axis=axis)
1178-
fm4 = jnp.expand_dims(fm4, axis=axis)
1179-
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1164+
out = jnp.concatenate([out, out_], axis=1)
11801165

11811166
feat_cache._feat_map = dec_feat_map
11821167

0 commit comments

Comments
 (0)