Skip to content

Commit c95fc1a

Browse files
committed
adding mixed precision for better video generation.
1 parent 09b01d8 commit c95fc1a

3 files changed

Lines changed: 38 additions & 9 deletions

File tree

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -332,33 +332,33 @@ def __call__(
332332
rngs: nnx.Rngs = None,
333333
):
334334
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
335-
(self.adaln_scale_shift_table + temb), 6, axis=1
335+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
336336
)
337337
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
338338
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
339339

340340
# 1. Self-attention
341-
norm_hidden_states = (self.norm1(hidden_states) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
341+
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(hidden_states.dtype)
342342
attn_output = self.attn1(
343343
hidden_states=norm_hidden_states,
344344
encoder_hidden_states=norm_hidden_states,
345345
rotary_emb=rotary_emb,
346346
deterministic=deterministic,
347347
rngs=rngs,
348348
)
349-
hidden_states = (hidden_states + attn_output * gate_msa).astype(hidden_states.dtype)
349+
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
350350

351351
# 2. Cross-attention
352-
norm_hidden_states = self.norm2(hidden_states)
352+
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
353353
attn_output = self.attn2(
354354
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
355355
)
356356
hidden_states = hidden_states + attn_output
357357

358358
# 3. Feed-forward
359-
norm_hidden_states = (self.norm3(hidden_states) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype)
359+
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(hidden_states.dtype)
360360
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
361-
hidden_states = (hidden_states + ff_output * c_gate_msa).astype(hidden_states.dtype)
361+
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(hidden_states.dtype)
362362
return hidden_states
363363

364364

@@ -526,7 +526,7 @@ def scan_fn(carry, block):
526526

527527
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
528528

529-
hidden_states = (self.norm_out(hidden_states) * (1 + scale) + shift).astype(hidden_states.dtype)
529+
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
530530
hidden_states = self.proj_out(hidden_states)
531531

532532
hidden_states = hidden_states.reshape(

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,28 @@
3939
import torch
4040
import qwix
4141

42-
42+
def cast_with_exclusion(path, x, dtype_to_cast):
43+
"""
44+
Casts arrays to dtype_to_cast, but keeps params from any 'norm' layer in float32.
45+
"""
46+
is_norm_param = any('norm' in str(key).lower() for key in path)
47+
48+
exclusion_keywords = [
49+
"norm", # For all LayerNorm/GroupNorm layers
50+
"condition_embedder", # The entire time/text conditioning module
51+
"scale_shift_table", # Catches both the final and the AdaLN tables
52+
]
53+
54+
path_str = ".".join(str(k.key) if isinstance(k, jax.tree_util.DictKey) else str(k) for k in path)
55+
56+
if any(keyword in path_str.lower() for keyword in exclusion_keywords):
57+
print("is_norm_path: ", path)
58+
# Keep LayerNorm/GroupNorm weights and biases in full precision
59+
return x.astype(jnp.float32)
60+
else:
61+
# Cast everything else to dtype_to_cast
62+
return x.astype(dtype_to_cast)
63+
4364
def basic_clean(text):
4465
if is_ftfy_available():
4566
import ftfy
@@ -113,7 +134,11 @@ def create_model(rngs: nnx.Rngs, wan_config: dict):
113134
params = load_wan_transformer(
114135
config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"]
115136
)
116-
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
137+
138+
params = jax.tree_util.tree_map_with_path(
139+
lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype),
140+
params
141+
)
117142
for path, val in flax.traverse_util.flatten_dict(params).items():
118143
if restored_checkpoint:
119144
path = path[:-1]

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -674,6 +674,10 @@ def step(
674674
Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with
675675
the multistep UniPC.
676676
"""
677+
678+
original_dtype = sample.dtype
679+
sample = sample.astype(jnp.float32)
680+
677681
if state.timesteps is None:
678682
raise ValueError("Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler")
679683

0 commit comments

Comments
 (0)