Skip to content

Commit 5213ea5

Browse files
committed
fix for dtypes
1 parent e3b9eec commit 5213ea5

1 file changed

Lines changed: 35 additions & 31 deletions

File tree

src/maxdiffusion/models/wan/autoencoder_kl_wan.py

Lines changed: 35 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ def __init__(
6161
padding: Union[int, Tuple[int, int, int]] = 0,
6262
use_bias: bool = True,
6363
mesh: jax.sharding.Mesh = None,
64-
dtype: jnp.dtype = jnp.float32,
65-
weights_dtype: jnp.dtype = jnp.float32,
64+
dtype: jnp.dtype = jnp.bfloat16,
65+
weights_dtype: jnp.dtype = jnp.bfloat16,
6666
precision: jax.lax.Precision = None,
6767
):
6868
self.kernel_size = _canonicalize_tuple(kernel_size, 3, "kernel_size")
@@ -270,8 +270,8 @@ def __init__(
270270
mode: str,
271271
rngs: nnx.Rngs,
272272
mesh: jax.sharding.Mesh = None,
273-
dtype: jnp.dtype = jnp.float32,
274-
weights_dtype: jnp.dtype = jnp.float32,
273+
dtype: jnp.dtype = jnp.bfloat16,
274+
weights_dtype: jnp.dtype = jnp.bfloat16,
275275
precision: jax.lax.Precision = None,
276276
):
277277
self.dtype = dtype
@@ -443,8 +443,8 @@ def __init__(
443443
dropout: float = 0.0,
444444
non_linearity: str = "silu",
445445
mesh: jax.sharding.Mesh = None,
446-
dtype: jnp.dtype = jnp.float32,
447-
weights_dtype: jnp.dtype = jnp.float32,
446+
dtype: jnp.dtype = jnp.bfloat16,
447+
weights_dtype: jnp.dtype = jnp.bfloat16,
448448
precision: jax.lax.Precision = None,
449449
):
450450
self.dtype = dtype
@@ -511,19 +511,19 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
511511
input_dtype = x.dtype
512512

513513
h, sc_cache = self.conv_shortcut(x, cache.get("shortcut"))
514-
new_cache["shortcut"] = sc_cache
514+
new_cache["shortcut"] = sc_cache.astype(self.dtype)
515515

516516
x = self.norm1(x)
517517
x = self.nonlinearity(x)
518518

519519
x, c1 = self.conv1(x, cache.get("conv1"))
520-
new_cache["conv1"] = c1
520+
new_cache["conv1"] = c1.astype(self.dtype)
521521

522522
x = self.norm2(x)
523523
x = self.nonlinearity(x)
524524

525525
x, c2 = self.conv2(x, cache.get("conv2"))
526-
new_cache["conv2"] = c2
526+
new_cache["conv2"] = c2.astype(self.dtype)
527527

528528
x = (x + h).astype(self.dtype)
529529
return x, new_cache
@@ -535,8 +535,8 @@ def __init__(
535535
dim: int,
536536
rngs: nnx.Rngs,
537537
mesh: jax.sharding.Mesh = None,
538-
dtype: jnp.dtype = jnp.float32,
539-
weights_dtype: jnp.dtype = jnp.float32,
538+
dtype: jnp.dtype = jnp.bfloat16,
539+
weights_dtype: jnp.dtype = jnp.bfloat16,
540540
precision: jax.lax.Precision = None,
541541
):
542542
self.dim = dim
@@ -597,10 +597,11 @@ def __init__(
597597
non_linearity: str = "silu",
598598
num_layers: int = 1,
599599
mesh: jax.sharding.Mesh = None,
600-
dtype: jnp.dtype = jnp.float32,
601-
weights_dtype: jnp.dtype = jnp.float32,
600+
dtype: jnp.dtype = jnp.bfloat16,
601+
weights_dtype: jnp.dtype = jnp.bfloat16,
602602
precision: jax.lax.Precision = None,
603603
):
604+
self.dtype = dtype
604605
self.dim = dim
605606
self.resnets = nnx.List(
606607
[
@@ -657,13 +658,13 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
657658
new_cache = {"resnets": []}
658659

659660
x, c = self.resnets[0](x, cache.get("resnets", [None])[0])
660-
new_cache["resnets"].append(c)
661+
new_cache["resnets"].append(c.astype(self.dtype))
661662

662663
for i, (attn, resnet) in enumerate(zip(self.attentions, self.resnets[1:])):
663664
if attn is not None:
664665
x = attn(x)
665666
x, c = resnet(x, cache.get("resnets", [None] * len(self.resnets))[i + 1])
666-
new_cache["resnets"].append(c)
667+
new_cache["resnets"].append(c.astype(self.dtype))
667668

668669
return x, new_cache
669670

@@ -679,10 +680,11 @@ def __init__(
679680
upsample_mode: Optional[str] = None,
680681
non_linearity: str = "silu",
681682
mesh: jax.sharding.Mesh = None,
682-
dtype: jnp.dtype = jnp.float32,
683-
weights_dtype: jnp.dtype = jnp.float32,
683+
dtype: jnp.dtype = jnp.bfloat16,
684+
weights_dtype: jnp.dtype = jnp.bfloat16,
684685
precision: jax.lax.Precision = None,
685686
):
687+
self.dtype = dtype
686688
self.resnets = nnx.List([])
687689
current_dim = in_dim
688690
for _ in range(num_res_blocks + 1):
@@ -736,11 +738,11 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
736738

737739
for i, resnet in enumerate(self.resnets):
738740
x, c = resnet(x, cache.get("resnets", [None] * len(self.resnets))[i])
739-
new_cache["resnets"].append(c)
741+
new_cache["resnets"].append(c.astype(self.dtype))
740742

741743
if self.upsamplers:
742744
x, c = self.upsamplers[0](x, cache.get("upsamplers", [None])[0])
743-
new_cache["upsamplers"].append(c)
745+
new_cache["upsamplers"].append(c.astype(self.dtype))
744746
return x, new_cache
745747

746748

@@ -757,10 +759,11 @@ def __init__(
757759
dropout=0.0,
758760
non_linearity: str = "silu",
759761
mesh: jax.sharding.Mesh = None,
760-
dtype: jnp.dtype = jnp.float32,
761-
weights_dtype: jnp.dtype = jnp.float32,
762+
dtype: jnp.dtype = jnp.bfloat16,
763+
weights_dtype: jnp.dtype = jnp.bfloat16,
762764
precision: jax.lax.Precision = None,
763765
):
766+
self.dtype = dtype
764767
self.dim = dim
765768
self.z_dim = z_dim
766769
self.dim_mult = dim_mult
@@ -885,27 +888,27 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
885888
new_cache = {}
886889

887890
x, c = self.conv_in(x, cache.get("conv_in"))
888-
new_cache["conv_in"] = c
891+
new_cache["conv_in"] = c.astype(self.dtype)
889892

890893
new_cache["down_blocks"] = []
891894
current_down_caches = cache.get("down_blocks", [None] * len(self.down_blocks))
892895

893896
for i, layer in enumerate(self.down_blocks):
894897
if isinstance(layer, (WanResidualBlock, WanResample)):
895898
x, c = layer(x, current_down_caches[i])
896-
new_cache["down_blocks"].append(c)
899+
new_cache["down_blocks"].append(c.astype(self.dtype))
897900
else:
898901
x = layer(x)
899902
new_cache["down_blocks"].append(None)
900903

901904
x, c = self.mid_block(x, cache.get("mid_block"))
902-
new_cache["mid_block"] = c
905+
new_cache["mid_block"] = c.astype(self.dtype)
903906

904907
x = self.norm_out(x)
905908
x = self.nonlinearity(x)
906909

907910
x, c = self.conv_out(x, cache.get("conv_out"))
908-
new_cache["conv_out"] = c
911+
new_cache["conv_out"] = c.astype(self.dtype)
909912

910913
return x, new_cache
911914

@@ -923,10 +926,11 @@ def __init__(
923926
dropout=0.0,
924927
non_linearity: str = "silu",
925928
mesh: jax.sharding.Mesh = None,
926-
dtype: jnp.dtype = jnp.float32,
927-
weights_dtype: jnp.dtype = jnp.float32,
929+
dtype: jnp.dtype = jnp.bfloat16,
930+
weights_dtype: jnp.dtype = jnp.bfloat16,
928931
precision: jax.lax.Precision = None,
929932
):
933+
self.dtype = dtype
930934
self.dim = dim
931935
self.dim_mult = dim_mult
932936
self.nonlinearity = get_activation(non_linearity)
@@ -1022,21 +1026,21 @@ def __call__(self, x: jax.Array, cache: Dict[str, Any] = None):
10221026
new_cache = {}
10231027

10241028
x, c = self.conv_in(x, cache.get("conv_in"))
1025-
new_cache["conv_in"] = c
1029+
new_cache["conv_in"] = c.astype(self.dtype)
10261030

10271031
x, c = self.mid_block(x, cache.get("mid_block"))
1028-
new_cache["mid_block"] = c
1032+
new_cache["mid_block"] = c.astype(self.dtype)
10291033

10301034
new_cache["up_blocks"] = []
10311035
current_up_caches = cache.get("up_blocks", [None] * len(self.up_blocks))
10321036
for i, up_block in enumerate(self.up_blocks):
10331037
x, c = up_block(x, current_up_caches[i])
1034-
new_cache["up_blocks"].append(c)
1038+
new_cache["up_blocks"].append(c.astype(self.dtype))
10351039

10361040
x = self.norm_out(x)
10371041
x = self.nonlinearity(x)
10381042
x, c = self.conv_out(x, cache.get("conv_out"))
1039-
new_cache["conv_out"] = c
1043+
new_cache["conv_out"] = c.astype(self.dtype)
10401044

10411045
return x, new_cache
10421046

0 commit comments

Comments
 (0)