Skip to content

Commit 0df5659

Browse files
committed
fix ag from vmap/scan.
1 parent 34968e0 commit 0df5659

4 files changed

Lines changed: 33 additions & 30 deletions

File tree

src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,11 @@
2020
"""
2121

2222
import os
23-
import functools
2423
from absl import app
25-
from typing import Sequence, Union, List
26-
from datasets import load_dataset
24+
from typing import Sequence
2725
import csv
28-
import numpy as np
29-
import jax
3026
import jax.numpy as jnp
31-
from jax.sharding import Mesh
32-
from maxdiffusion import pyconfig, max_utils
33-
from maxdiffusion.pipelines.wan.wan_pipeline import WanPipeline
34-
from maxdiffusion.video_processor import VideoProcessor
27+
from maxdiffusion import pyconfig
3528

3629
import torch
3730
import tensorflow as tf
@@ -72,6 +65,7 @@ def create_example(latent, hidden_states):
7265
example = tf.train.Example(features=tf.train.Features(feature=feature))
7366
return example.SerializeToString()
7467

68+
7569
def generate_dataset(config):
7670

7771
tfrecords_dir = config.tfrecords_dir
@@ -88,7 +82,7 @@ def generate_dataset(config):
8882

8983
# Load dataset
9084
metadata_path = os.path.join(config.train_data_dir, "metadata.csv")
91-
with open(metadata_path, 'r', newline='') as file:
85+
with open(metadata_path, "r", newline="") as file:
9286
# Create a csv.reader object
9387
csv_reader = csv.reader(file)
9488
next(csv_reader)
@@ -99,11 +93,11 @@ def generate_dataset(config):
9993
# Iterate over each row in the CSV file
10094
for row in csv_reader:
10195
video_name = row[0]
102-
pth_path = os.path.join(config.train_data_dir,"train", f"{video_name}.tensors.pth")
103-
loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu'))
96+
pth_path = os.path.join(config.train_data_dir, "train", f"{video_name}.tensors.pth")
97+
loaded_state_dict = torch.load(pth_path, map_location=torch.device("cpu"))
10498
prompt_embeds = loaded_state_dict["prompt_emb"]["context"].squeeze()
10599
latent = loaded_state_dict["latents"]
106-
100+
107101
# Format we want(Batch, channels, Frames, Height, Width)
108102
# Save them as float32 because numpy cannot read bfloat16.
109103
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
@@ -120,6 +114,7 @@ def generate_dataset(config):
120114
)
121115
shard_record_count = 0
122116

117+
123118
def run(config):
124119
generate_dataset(config)
125120

src/maxdiffusion/models/attention_flax.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -674,8 +674,10 @@ def __init__(
674674
dtype=dtype,
675675
quant=quant,
676676
)
677-
678-
kernel_axes = ("embed", "heads")
677+
# None axes corresponds to the stacked weights across all blocks
678+
# because of the use of nnx.vmap and nnx.scan.
679+
# Dims are [num_blocks, embed, heads]
680+
kernel_axes = (None, "embed", "heads")
679681
qkv_init_kernel = nnx.with_partitioning(nnx.initializers.lecun_normal(), kernel_axes)
680682

681683
self.query = nnx.Linear(
@@ -686,7 +688,7 @@ def __init__(
686688
dtype=dtype,
687689
param_dtype=weights_dtype,
688690
precision=precision,
689-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
691+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)),
690692
)
691693

692694
self.key = nnx.Linear(
@@ -697,7 +699,7 @@ def __init__(
697699
dtype=dtype,
698700
param_dtype=weights_dtype,
699701
precision=precision,
700-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
702+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)),
701703
)
702704

703705
self.value = nnx.Linear(
@@ -708,14 +710,14 @@ def __init__(
708710
dtype=dtype,
709711
param_dtype=weights_dtype,
710712
precision=precision,
711-
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
713+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed",)),
712714
)
713715

714716
self.proj_attn = nnx.Linear(
715717
rngs=rngs,
716718
in_features=self.inner_dim,
717719
out_features=self.inner_dim,
718-
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), ("heads", "embed")),
720+
kernel_init=nnx.with_partitioning(nnx.initializers.lecun_normal(), (None, "heads", "embed")),
719721
dtype=dtype,
720722
param_dtype=weights_dtype,
721723
precision=precision,
@@ -729,15 +731,15 @@ def __init__(
729731
rngs=rngs,
730732
epsilon=eps,
731733
dtype=dtype,
732-
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)),
734+
scale_init=nnx.with_partitioning(nnx.initializers.ones, (None, "norm",)),
733735
param_dtype=weights_dtype,
734736
)
735737

736738
self.norm_k = nnx.RMSNorm(
737739
num_features=self.inner_dim,
738740
rngs=rngs,
739741
dtype=dtype,
740-
scale_init=nnx.with_partitioning(nnx.initializers.ones, ("norm",)),
742+
scale_init=nnx.with_partitioning(nnx.initializers.ones, (None, "norm",)),
741743
param_dtype=weights_dtype,
742744
)
743745

src/maxdiffusion/models/wan/transformers/transformer_wan.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,7 @@ def __init__(
398398

399399
# 3. Transformer blocks
400400
@nnx.split_rngs(splits=num_layers)
401-
@nnx.vmap
401+
@nnx.vmap(in_axes=0, out_axes=0)
402402
def init_block(rngs):
403403
return WanTransformerBlock(
404404
rngs=rngs,
@@ -416,6 +416,7 @@ def init_block(rngs):
416416
precision=precision,
417417
attention=attention,
418418
)
419+
419420
self.blocks = init_block(rngs)
420421

421422
self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False)
@@ -471,10 +472,10 @@ def scan_fn(carry, block):
471472

472473
initial_carry = (hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
473474
final_carry = nnx.scan(
474-
scan_fn,
475-
length=self.num_layers,
476-
in_axes=(nnx.Carry, 0),
477-
out_axes=nnx.Carry,
475+
scan_fn,
476+
length=self.num_layers,
477+
in_axes=(nnx.Carry, 0),
478+
out_axes=nnx.Carry,
478479
)(initial_carry, self.blocks)
479480

480481
hidden_states = final_carry[0]

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,9 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
8787
new_key = ("blocks",) + pt_tuple_key[2:]
8888
block_index = int(pt_tuple_key[1])
8989
pt_tuple_key = new_key
90-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL)
90+
flax_key, flax_tensor = rename_key_and_reshape_tensor(
91+
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
92+
)
9193
flax_key = rename_for_nnx(flax_key)
9294
flax_key = _tuple_str_to_int(flax_key)
9395

@@ -133,11 +135,12 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
133135
new_key = ("blocks",) + pt_tuple_key[2:]
134136
block_index = int(pt_tuple_key[1])
135137
pt_tuple_key = new_key
136-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL)
138+
flax_key, flax_tensor = rename_key_and_reshape_tensor(
139+
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
140+
)
137141
flax_key = rename_for_nnx(flax_key)
138142
flax_key = _tuple_str_to_int(flax_key)
139143

140-
141144
if "blocks" in flax_key:
142145
if flax_key in flax_state_dict:
143146
new_tensor = flax_state_dict[flax_key]
@@ -224,7 +227,9 @@ def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: d
224227
new_key = ("blocks",) + pt_tuple_key[2:]
225228
block_index = int(pt_tuple_key[1])
226229
pt_tuple_key = new_key
227-
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL)
230+
flax_key, flax_tensor = rename_key_and_reshape_tensor(
231+
pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL
232+
)
228233
flax_key = rename_for_nnx(flax_key)
229234
flax_key = _tuple_str_to_int(flax_key)
230235

0 commit comments

Comments
 (0)