Skip to content

Commit a0c377f

Browse files
committed
jitting and sharding vae, refactored for loops in jitted VAE, 132 sec on 16 TPUs
1 parent 65e7f93 commit a0c377f

3 files changed

Lines changed: 365 additions & 218 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,15 +334,15 @@ def wrap_flash_attention(query, key, value):
334334
mask=mask,
335335
q_seq_shards=1, # the sizes of the axis is sharding over seq_len
336336
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
337-
save_residuals=True if "ring" in attention_kernel else False,
337+
save_residuals=False,
338338
)
339339
elif attention_kernel == "tokamax_ring":
340340
mask = tokamax_splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2]),)
341341
splash_kernel = tokamax_ring_attention_kernel.make_ring_attention(
342342
mask=mask,
343343
is_mqa=False,
344344
config=convert_to_tokamax_splash_config(block_sizes, residual_checkpoint_name=residual_checkpoint_name),
345-
save_residuals=True,
345+
save_residuals=False,
346346
ring_axis="fsdp",
347347
)
348348
else:

src/maxdiffusion/models/vae_flax.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,13 @@
2222
import flax.linen as nn
2323
import jax
2424
import jax.numpy as jnp
25+
from jax import tree_util
2526
from flax.core.frozen_dict import FrozenDict
2627

2728
from ..configuration_utils import ConfigMixin, flax_register_to_config
2829
from ..utils import BaseOutput
2930
from .modeling_flax_utils import FlaxModelMixin
31+
3032

3133

3234
@flax.struct.dataclass
@@ -931,3 +933,29 @@ def __call__(self, sample, sample_posterior=False, deterministic: bool = True, r
931933
return (sample,)
932934

933935
return FlaxDecoderOutput(sample=sample)
936+
937+
class WanDiagonalGaussianDistribution(FlaxDiagonalGaussianDistribution):
938+
pass
939+
940+
941+
def _wan_diag_gauss_dist_flatten(dist):
942+
return (dist.mean, dist.logvar, dist.std, dist.var), (dist.deterministic,)
943+
944+
945+
def _wan_diag_gauss_dist_unflatten(aux, children):
946+
mean, logvar, std, var = children
947+
deterministic = aux[0]
948+
obj = WanDiagonalGaussianDistribution.__new__(WanDiagonalGaussianDistribution)
949+
obj.mean = mean
950+
obj.logvar = logvar
951+
obj.std = std
952+
obj.var = var
953+
obj.deterministic = deterministic
954+
return obj
955+
956+
957+
tree_util.register_pytree_node(
958+
WanDiagonalGaussianDistribution,
959+
_wan_diag_gauss_dist_flatten,
960+
_wan_diag_gauss_dist_unflatten,
961+
)

0 commit comments

Comments
 (0)