Skip to content

Commit ad6391a

Browse files
Merge pull request #381 from AI-Hypercomputer:wan2p2vae
PiperOrigin-RevId: 901340603
2 parents 2965670 + 1da5d2f commit ad6391a

4 files changed

Lines changed: 1485 additions & 37 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,28 @@ def get_sinusoidal_embeddings(
3535
"""Returns the positional encoding (same as Tensor2Tensor).
3636
3737
Args:
38-
timesteps: a 1-D Tensor of N indices, one per batch element.
38+
timesteps: a 1-D or 2-D Tensor of indices.
3939
These may be fractional.
4040
embedding_dim: The number of output channels.
4141
min_timescale: The smallest time unit (should probably be 0.0).
4242
max_timescale: The largest time unit.
4343
Returns:
44-
a Tensor of timing signals [N, num_channels]
44+
a Tensor of timing signals [B, num_channels] or [B, N, num_channels]
4545
"""
46-
assert timesteps.ndim == 1, "Timesteps should be a 1d-array"
46+
assert timesteps.ndim <= 2, "Timesteps should be a 1d or 2d-array"
4747
assert embedding_dim % 2 == 0, f"Embedding dimension {embedding_dim} should be even"
4848
num_timescales = float(embedding_dim // 2)
4949
log_timescale_increment = math.log(max_timescale / min_timescale) / (num_timescales - freq_shift)
5050
inv_timescales = min_timescale * jnp.exp(jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment)
51-
emb = jnp.expand_dims(timesteps, 1) * jnp.expand_dims(inv_timescales, 0)
51+
emb = jnp.expand_dims(timesteps, -1) * inv_timescales
5252

5353
# scale embeddings
5454
scaled_time = scale * emb
5555

5656
if flip_sin_to_cos:
57-
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=1)
57+
signal = jnp.concatenate([jnp.cos(scaled_time), jnp.sin(scaled_time)], axis=-1)
5858
else:
59-
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1)
60-
signal = jnp.reshape(signal, [jnp.shape(timesteps)[0], embedding_dim])
59+
signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=-1)
6160
return signal
6261

6362

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,8 @@ def __call__(self, x: jax.Array, feat_cache=None, feat_idx=0):
360360
feat_cache = _update_cache(feat_cache, idx, cache_x)
361361
feat_idx += 1
362362
x = x.reshape(b, t, h, w, 2, c)
363-
x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
363+
# x = jnp.stack([x[:, :, :, :, 0, :], x[:, :, :, :, 1, :]], axis=1)
364+
x = x.transpose(0, 1, 4, 2, 3, 5)
364365
x = x.reshape(b, t * 2, h, w, c)
365366
t = x.shape[1]
366367
x = x.reshape(b * t, h, w, c)
@@ -1160,23 +1161,7 @@ def _decode(
11601161
out, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
11611162
else:
11621163
out_, dec_feat_map, conv_idx = self.decoder(x[:, i : i + 1, :, :, :], feat_cache=dec_feat_map, feat_idx=conv_idx)
1163-
1164-
# This is to bypass an issue where frame[1] should be frame[2] and vise versa.
1165-
# Ideally shouldn't need to do this however, can't find where the frame is going out of sync.
1166-
# Most likely due to an incorrect reshaping in the decoder.
1167-
fm1, fm2, fm3, fm4 = out_[:, 0, :, :, :], out_[:, 1, :, :, :], out_[:, 2, :, :, :], out_[:, 3, :, :, :]
1168-
# When batch_size is 0, expand batch dim for concatenation
1169-
# else, expand frame dim for concatenation so that batch dim stays intact.
1170-
axis = 0
1171-
if fm1.shape[0] > 1:
1172-
axis = 1
1173-
1174-
if len(fm1.shape) == 4:
1175-
fm1 = jnp.expand_dims(fm1, axis=axis)
1176-
fm2 = jnp.expand_dims(fm2, axis=axis)
1177-
fm3 = jnp.expand_dims(fm3, axis=axis)
1178-
fm4 = jnp.expand_dims(fm4, axis=axis)
1179-
out = jnp.concatenate([out, fm1, fm3, fm2, fm4], axis=1)
1164+
out = jnp.concatenate([out, out_], axis=1)
11801165

11811166
feat_cache._feat_map = dec_feat_map
11821167

0 commit comments

Comments
 (0)