Skip to content

Commit cac3fb5

Browse files
committed
conditioning activations in fp32, inputs in fp32. vae in fp32
1 parent c95fc1a commit cac3fb5

3 files changed

Lines changed: 11 additions & 11 deletions

File tree

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(
8989
in_features=in_channels,
9090
out_features=time_embed_dim,
9191
use_bias=sample_proj_bias,
92-
dtype=dtype,
92+
dtype=jnp.float32,
9393
param_dtype=weights_dtype,
9494
precision=precision,
9595
kernel_init=nnx.with_partitioning(
@@ -121,7 +121,7 @@ def __init__(
121121
in_features=time_embed_dim,
122122
out_features=time_embed_dim_out,
123123
use_bias=sample_proj_bias,
124-
dtype=dtype,
124+
dtype=jnp.float32,
125125
param_dtype=weights_dtype,
126126
precision=precision,
127127
kernel_init=nnx.with_partitioning(
@@ -269,7 +269,7 @@ def __init__(
269269
in_features=in_features,
270270
out_features=hidden_size,
271271
use_bias=True,
272-
dtype=dtype,
272+
dtype=jnp.float32,
273273
param_dtype=weights_dtype,
274274
precision=precision,
275275
kernel_init=nnx.with_partitioning(
@@ -288,7 +288,7 @@ def __init__(
288288
in_features=hidden_size,
289289
out_features=out_features,
290290
use_bias=True,
291-
dtype=dtype,
291+
dtype=jnp.float32,
292292
param_dtype=weights_dtype,
293293
precision=precision,
294294
kernel_init=nnx.with_partitioning(

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def __init__(
116116
rngs=rngs,
117117
in_features=dim,
118118
out_features=time_proj_dim,
119-
dtype=dtype,
119+
dtype=jnp.float32,
120120
param_dtype=weights_dtype,
121121
precision=precision,
122122
kernel_init=nnx.with_partitioning(

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -234,8 +234,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
234234
subfolder="vae",
235235
rngs=rngs,
236236
mesh=mesh,
237-
dtype=config.activations_dtype,
238-
weights_dtype=config.weights_dtype,
237+
dtype=jnp.float32,
238+
weights_dtype=jnp.float32,
239239
)
240240
return wan_vae
241241

@@ -494,7 +494,7 @@ def encode_prompt(
494494
num_videos_per_prompt=num_videos_per_prompt,
495495
max_sequence_length=max_sequence_length,
496496
)
497-
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype)
497+
prompt_embeds = jnp.array(prompt_embeds.detach().numpy(), dtype=jnp.float32)
498498

499499
if negative_prompt_embeds is None:
500500
negative_prompt = negative_prompt or ""
@@ -504,7 +504,7 @@ def encode_prompt(
504504
num_videos_per_prompt=num_videos_per_prompt,
505505
max_sequence_length=max_sequence_length,
506506
)
507-
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=self.config.weights_dtype)
507+
negative_prompt_embeds = jnp.array(negative_prompt_embeds.detach().numpy(), dtype=jnp.float32)
508508

509509
return prompt_embeds, negative_prompt_embeds
510510

@@ -527,7 +527,7 @@ def prepare_latents(
527527
int(height) // vae_scale_factor_spatial,
528528
int(width) // vae_scale_factor_spatial,
529529
)
530-
latents = jax.random.normal(rng, shape=shape, dtype=self.config.weights_dtype)
530+
latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32)
531531

532532
return latents
533533

@@ -617,7 +617,7 @@ def __call__(
617617
latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1)
618618
latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1)
619619
latents = latents / latents_std + latents_mean
620-
latents = latents.astype(self.config.weights_dtype)
620+
latents = latents.astype(jnp.float32)
621621

622622
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
623623
video = self.vae.decode(latents, self.vae_cache)[0]

0 commit comments

Comments
 (0)