Skip to content

Commit 7cbb714

Browse files
authored
fix name scopes to be picked up by quantization config (#294)
* fix name scopes * Fixed named_scope * Update name scopes * Fix name scopes * Update transformer_wan.py
1 parent 26a6ac3 commit 7cbb714

2 files changed

Lines changed: 30 additions & 37 deletions

File tree

src/maxdiffusion/models/attention_flax.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -895,13 +895,12 @@ def __call__(
895895
if encoder_hidden_states is None:
896896
encoder_hidden_states = hidden_states
897897

898-
with self.conditional_named_scope("attn_qkv_proj"):
899-
with self.conditional_named_scope("proj_query"):
900-
query_proj = self.query(hidden_states)
901-
with self.conditional_named_scope("proj_key"):
902-
key_proj = self.key(encoder_hidden_states)
903-
with self.conditional_named_scope("proj_value"):
904-
value_proj = self.value(encoder_hidden_states)
898+
with jax.named_scope("query_proj"):
899+
query_proj = self.query(hidden_states)
900+
with jax.named_scope("key_proj"):
901+
key_proj = self.key(encoder_hidden_states)
902+
with jax.named_scope("value_proj"):
903+
value_proj = self.value(encoder_hidden_states)
905904

906905
if self.qk_norm:
907906
with self.conditional_named_scope("attn_q_norm"):
@@ -921,13 +920,13 @@ def __call__(
921920
key_proj = checkpoint_name(key_proj, "key_proj")
922921
value_proj = checkpoint_name(value_proj, "value_proj")
923922

924-
with self.conditional_named_scope("attn_compute"):
923+
with jax.named_scope("apply_attention"):
925924
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
926925

927926
attn_output = attn_output.astype(dtype=dtype)
928927
attn_output = checkpoint_name(attn_output, "attn_output")
929928

930-
with self.conditional_named_scope("attn_out_proj"):
929+
with jax.named_scope("proj_attn"):
931930
hidden_states = self.proj_attn(attn_output)
932931
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
933932
return hidden_states

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def __call__(
142142
):
143143
timestep = self.timesteps_proj(timestep)
144144
temb = self.time_embedder(timestep)
145-
146-
timestep_proj = self.time_proj(self.act_fn(temb))
145+
with jax.named_scope("time_proj"):
146+
timestep_proj = self.time_proj(self.act_fn(temb))
147147

148148
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
149149
if encoder_hidden_states_image is not None:
@@ -186,7 +186,8 @@ def __init__(
186186
)
187187

188188
def __call__(self, x: jax.Array) -> jax.Array:
189-
x = self.proj(x)
189+
with jax.named_scope("gelu"):
190+
x = self.proj(x)
190191
return nnx.gelu(x)
191192

192193

@@ -244,12 +245,11 @@ def conditional_named_scope(self, name: str):
244245
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
245246

246247
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
247-
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
248248
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
249249
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
250250
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
251-
with self.conditional_named_scope("mlp_down_proj"):
252-
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
251+
with jax.named_scope("proj_out"):
252+
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
253253

254254

255255
class WanTransformerBlock(nnx.Module):
@@ -354,48 +354,42 @@ def __call__(
354354
rngs: nnx.Rngs = None,
355355
):
356356
with self.conditional_named_scope("transformer_block"):
357-
with self.conditional_named_scope("adaln"):
358-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
359-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
360-
)
357+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
358+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
359+
)
361360
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
362361
hidden_states = checkpoint_name(hidden_states, "hidden_states")
363362
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
364363

365364
# 1. Self-attention
366-
with self.conditional_named_scope("self_attn"):
367-
with self.conditional_named_scope("self_attn_norm"):
365+
with jax.named_scope("attn1"):
368366
norm_hidden_states = (self.norm1(hidden_states.astype(jnp.float32)) * (1 + scale_msa) + shift_msa).astype(
369367
hidden_states.dtype
370368
)
371-
with self.conditional_named_scope("self_attn_attn"):
372369
attn_output = self.attn1(
373370
hidden_states=norm_hidden_states,
374371
encoder_hidden_states=norm_hidden_states,
375372
rotary_emb=rotary_emb,
376373
deterministic=deterministic,
377374
rngs=rngs,
378375
)
379-
with self.conditional_named_scope("self_attn_residual"):
380376
hidden_states = (hidden_states.astype(jnp.float32) + attn_output * gate_msa).astype(hidden_states.dtype)
381377

382378
# 2. Cross-attention
383-
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
384-
attn_output = self.attn2(
385-
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
386-
)
387-
hidden_states = hidden_states + attn_output
379+
with jax.named_scope('attn2'):
380+
norm_hidden_states = self.norm2(hidden_states.astype(jnp.float32)).astype(hidden_states.dtype)
381+
attn_output = self.attn2(
382+
hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states, deterministic=deterministic, rngs=rngs
383+
)
384+
hidden_states = hidden_states + attn_output
388385

389386
# 3. Feed-forward
390-
with self.conditional_named_scope("mlp"):
391-
with self.conditional_named_scope("mlp_norm"):
387+
with jax.named_scope("ffn"):
392388
norm_hidden_states = (self.norm3(hidden_states.astype(jnp.float32)) * (1 + c_scale_msa) + c_shift_msa).astype(
393389
hidden_states.dtype
394390
)
395-
with self.conditional_named_scope("mlp_ffn"):
396-
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
397-
with self.conditional_named_scope("mlp_residual"):
398-
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
391+
ff_output = self.ffn(norm_hidden_states, deterministic=deterministic, rngs=rngs)
392+
hidden_states = (hidden_states.astype(jnp.float32) + ff_output.astype(jnp.float32) * c_gate_msa).astype(
399393
hidden_states.dtype
400394
)
401395
return hidden_states
@@ -543,6 +537,7 @@ def conditional_named_scope(self, name: str):
543537
"""Return a JAX named scope if enabled, otherwise a null context."""
544538
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
545539

540+
@jax.named_scope('WanModel')
546541
def __call__(
547542
self,
548543
hidden_states: jax.Array,
@@ -609,9 +604,8 @@ def layer_forward(hidden_states):
609604
hidden_states = rematted_layer_forward(hidden_states)
610605

611606
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
612-
with self.conditional_named_scope("output_norm"):
613-
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
614-
with self.conditional_named_scope("output_proj"):
607+
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
608+
with jax.named_scope("proj_out"):
615609
hidden_states = self.proj_out(hidden_states)
616610

617611
hidden_states = hidden_states.reshape(

0 commit comments

Comments
 (0)