Skip to content

Commit 8a18686

Browse files
authored
Merge branch 'main' into elisatsai_disable_unsafe_rng
2 parents f68c7b0 + 7d25dc9 commit 8a18686

9 files changed

Lines changed: 56 additions & 41 deletions

File tree

.github/workflows/UploadDockerImages.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ on:
2121
schedule:
2222
# Run the job daily at 12AM UTC
2323
- cron: '0 0 * * *'
24+
25+
workflow_dispatch:
2426

2527
jobs:
2628
build-image:

maxdiffusion_jax_ai_image_tpu.Dockerfile

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ ARG JAX_AI_IMAGE_BASEIMAGE
33
# JAX AI Base Image
44
FROM $JAX_AI_IMAGE_BASEIMAGE
55

6+
ARG JAX_AI_IMAGE_BASEIMAGE
7+
68
ARG COMMIT_HASH
79

810
ENV COMMIT_HASH=$COMMIT_HASH
@@ -18,5 +20,12 @@ COPY . .
1820
# Install Maxdiffusion Jax AI Image requirements
1921
RUN pip install -r /deps/requirements_with_jax_ai_image.txt
2022

23+
# TODO: Remove the flax pin and fsspec overrides once flax stable version releases
24+
RUN if echo "$JAX_AI_IMAGE_BASEIMAGE" | grep -q "nightly"; then \
25+
echo "Nightly build detected: Installing specific Flax commit and fsspec." && \
26+
pip install --upgrade --force-reinstall git+https://github.com/google/flax.git@ef78d6584623511746be4824965cdef42b464583 && \
27+
pip install "fsspec==2025.10.0"; \
28+
fi
29+
2130
# Run the script available in JAX-AI-Image base image to generate the manifest file
2231
RUN bash /jax-ai-image/generate_manifest.sh PREFIX=maxdiffusion COMMIT_HASH=$COMMIT_HASH

src/maxdiffusion/configs/base_wan_14b.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,10 @@ quantization: ''
341341
quantization_local_shard_count: -1
342342
compile_topology_num_slices: -1 # Number of target slices, set to a positive integer.
343343
use_qwix_quantization: False # Whether to use qwix for quantization. If set to True, the transformer of WAN will be quantized using qwix.
344-
# Quantization calibration method used for weights and activations. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
345-
quantization_calibration_method: "absmax"
344+
# Quantization calibration method used for weights, activations and bwd. Supported methods can be found in https://github.com/google/qwix/blob/dc2a0770351c740e5ab3cce7c0efe9f7beacce9e/qwix/qconfig.py#L70-L80
345+
weight_quantization_calibration_method: "absmax"
346+
act_quantization_calibration_method: "absmax"
347+
bwd_quantization_calibration_method: "absmax"
346348
qwix_module_path: ".*"
347349

348350
# Eval model on per eval_every steps. -1 means don't eval.

src/maxdiffusion/models/attention_flax.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1005,13 +1005,12 @@ def __call__(
10051005
if encoder_hidden_states is None:
10061006
encoder_hidden_states = hidden_states
10071007

1008-
with self.conditional_named_scope("attn_qkv_proj"):
1009-
with self.conditional_named_scope("proj_query"):
1010-
query_proj = self.query(hidden_states)
1011-
with self.conditional_named_scope("proj_key"):
1012-
key_proj = self.key(encoder_hidden_states)
1013-
with self.conditional_named_scope("proj_value"):
1014-
value_proj = self.value(encoder_hidden_states)
1008+
with jax.named_scope("query_proj"):
1009+
query_proj = self.query(hidden_states)
1010+
with jax.named_scope("key_proj"):
1011+
key_proj = self.key(encoder_hidden_states)
1012+
with jax.named_scope("value_proj"):
1013+
value_proj = self.value(encoder_hidden_states)
10151014

10161015
if self.qk_norm:
10171016
with self.conditional_named_scope("attn_q_norm"):
@@ -1031,13 +1030,13 @@ def __call__(
10311030
key_proj = checkpoint_name(key_proj, "key_proj")
10321031
value_proj = checkpoint_name(value_proj, "value_proj")
10331032

1034-
with self.conditional_named_scope("attn_compute"):
1033+
with jax.named_scope("apply_attention"):
10351034
attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj)
10361035

10371036
attn_output = attn_output.astype(dtype=dtype)
10381037
attn_output = checkpoint_name(attn_output, "attn_output")
10391038

1040-
with self.conditional_named_scope("attn_out_proj"):
1039+
with jax.named_scope("proj_attn"):
10411040
hidden_states = self.proj_attn(attn_output)
10421041
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
10431042
return hidden_states

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -142,8 +142,8 @@ def __call__(
142142
):
143143
timestep = self.timesteps_proj(timestep)
144144
temb = self.time_embedder(timestep)
145-
146-
timestep_proj = self.time_proj(self.act_fn(temb))
145+
with jax.named_scope("time_proj"):
146+
timestep_proj = self.time_proj(self.act_fn(temb))
147147

148148
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
149149
if encoder_hidden_states_image is not None:
@@ -186,7 +186,8 @@ def __init__(
186186
)
187187

188188
def __call__(self, x: jax.Array) -> jax.Array:
189-
x = self.proj(x)
189+
with jax.named_scope("gelu"):
190+
x = self.proj(x)
190191
return nnx.gelu(x)
191192

192193

@@ -244,12 +245,11 @@ def conditional_named_scope(self, name: str):
244245
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
245246

246247
def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array:
247-
with self.conditional_named_scope("mlp_up_proj_and_gelu"):
248248
hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824)
249249
hidden_states = checkpoint_name(hidden_states, "ffn_activation")
250250
hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs)
251-
with self.conditional_named_scope("mlp_down_proj"):
252-
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
251+
with jax.named_scope("proj_out"):
252+
return self.proj_out(hidden_states) # output is (4, 75600, 5120)
253253

254254

255255
class WanTransformerBlock(nnx.Module):
@@ -359,10 +359,9 @@ def __call__(
359359
rngs: nnx.Rngs = None,
360360
):
361361
with self.conditional_named_scope("transformer_block"):
362-
with self.conditional_named_scope("adaln"):
363-
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
364-
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
365-
)
362+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split(
363+
(self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1
364+
)
366365
hidden_states = jax.lax.with_sharding_constraint(hidden_states, PartitionSpec("data", "fsdp", "tensor"))
367366
hidden_states = checkpoint_name(hidden_states, "hidden_states")
368367
encoder_hidden_states = jax.lax.with_sharding_constraint(encoder_hidden_states, PartitionSpec("data", "fsdp", None))
@@ -558,6 +557,7 @@ def conditional_named_scope(self, name: str):
558557
"""Return a JAX named scope if enabled, otherwise a null context."""
559558
return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext()
560559

560+
@jax.named_scope('WanModel')
561561
def __call__(
562562
self,
563563
hidden_states: jax.Array,
@@ -625,9 +625,8 @@ def layer_forward(hidden_states):
625625
hidden_states = rematted_layer_forward(hidden_states)
626626

627627
shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1)
628-
with self.conditional_named_scope("output_norm"):
629-
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
630-
with self.conditional_named_scope("output_proj"):
628+
hidden_states = (self.norm_out(hidden_states.astype(jnp.float32)) * (1 + scale) + shift).astype(hidden_states.dtype)
629+
with jax.named_scope("proj_out"):
631630
hidden_states = self.proj_out(hidden_states)
632631

633632
hidden_states = hidden_states.reshape(

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,9 @@ def get_fp8_config(cls, config: HyperParameters):
303303
act_qtype=jnp.float8_e4m3fn,
304304
bwd_qtype=jnp.float8_e5m2,
305305
disable_channelwise_axes=True, # per_tensor calibration
306-
weight_calibration_method=config.quantization_calibration_method,
307-
act_calibration_method=config.quantization_calibration_method,
308-
bwd_calibration_method=config.quantization_calibration_method,
306+
weight_calibration_method=config.weight_quantization_calibration_method,
307+
act_calibration_method=config.act_quantization_calibration_method,
308+
bwd_calibration_method=config.bwd_quantization_calibration_method,
309309
op_names=("dot_general", "einsum"),
310310
),
311311
qwix.QtRule(
@@ -314,9 +314,9 @@ def get_fp8_config(cls, config: HyperParameters):
314314
act_qtype=jnp.float8_e4m3fn,
315315
bwd_qtype=jnp.float8_e4m3fn,
316316
disable_channelwise_axes=True, # per_tensor calibration
317-
weight_calibration_method=config.quantization_calibration_method,
318-
act_calibration_method=config.quantization_calibration_method,
319-
bwd_calibration_method=config.quantization_calibration_method,
317+
weight_calibration_method=config.weight_quantization_calibration_method,
318+
act_calibration_method=config.act_quantization_calibration_method,
319+
bwd_calibration_method=config.bwd_quantization_calibration_method,
320320
op_names=("conv_general_dilated"),
321321
),
322322
]

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6262
@classmethod
6363
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
6464
pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer)
65-
transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh)
65+
pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh)
6666
return pipeline
6767

6868
@classmethod

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
7070
@classmethod
7171
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
7272
pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer)
73-
low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh)
74-
high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh)
73+
pipeline.low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh)
74+
pipeline.high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh)
7575
return pipeline
7676

7777
@classmethod

src/maxdiffusion/tests/wan_transformer_test.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,9 @@ def create_real_rule_instance(*args, **kwargs):
346346
config_fp8_full = Mock(spec=HyperParameters)
347347
config_fp8_full.use_qwix_quantization = True
348348
config_fp8_full.quantization = "fp8_full"
349-
config_fp8_full.quantization_calibration_method = "absmax"
349+
config_fp8_full.weight_quantization_calibration_method = "fixed,-224,224"
350+
config_fp8_full.act_quantization_calibration_method = "fixed,-224,224"
351+
config_fp8_full.bwd_quantization_calibration_method = "absmax"
350352
config_fp8_full.qwix_module_path = ".*"
351353
provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full)
352354
self.assertIsNotNone(provider_fp8_full)
@@ -357,9 +359,9 @@ def create_real_rule_instance(*args, **kwargs):
357359
act_qtype=jnp.float8_e4m3fn,
358360
bwd_qtype=jnp.float8_e5m2,
359361
disable_channelwise_axes=True, # per_tensor calibration
360-
weight_calibration_method=config_fp8_full.quantization_calibration_method,
361-
act_calibration_method=config_fp8_full.quantization_calibration_method,
362-
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
362+
weight_calibration_method=config_fp8_full.weight_quantization_calibration_method,
363+
act_calibration_method=config_fp8_full.act_quantization_calibration_method,
364+
bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method,
363365
op_names=("dot_general", "einsum"),
364366
),
365367
call(
@@ -368,9 +370,9 @@ def create_real_rule_instance(*args, **kwargs):
368370
act_qtype=jnp.float8_e4m3fn,
369371
bwd_qtype=jnp.float8_e4m3fn,
370372
disable_channelwise_axes=True, # per_tensor calibration
371-
weight_calibration_method=config_fp8_full.quantization_calibration_method,
372-
act_calibration_method=config_fp8_full.quantization_calibration_method,
373-
bwd_calibration_method=config_fp8_full.quantization_calibration_method,
373+
weight_calibration_method=config_fp8_full.weight_quantization_calibration_method,
374+
act_calibration_method=config_fp8_full.act_quantization_calibration_method,
375+
bwd_calibration_method=config_fp8_full.bwd_quantization_calibration_method,
374376
op_names=("conv_general_dilated"),
375377
),
376378
]
@@ -395,7 +397,9 @@ def test_quantize_transformer_enabled(self, mock_get_dummy_inputs, mock_quantize
395397
mock_config.quantization = "fp8_full"
396398
mock_config.qwix_module_path = ".*"
397399
mock_config.per_device_batch_size = 1
398-
mock_config.quantization_calibration_method = "absmax"
400+
mock_config.weight_quantization_calibration_method = "fixed,-224,224"
401+
mock_config.act_quantization_calibration_method = "fixed,-224,224"
402+
mock_config.bwd_quantization_calibration_method = "absmax"
399403

400404
mock_model = Mock(spec=WanModel)
401405
mock_pipeline = Mock()

0 commit comments

Comments
 (0)