Skip to content

Commit 34968e0

Browse files
committed
support wan transformers for nnx.scan.
1 parent 3d2edcc commit 34968e0

3 files changed

Lines changed: 24 additions & 3 deletions

File tree

src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def generate_dataset(config):
101101
video_name = row[0]
102102
pth_path = os.path.join(config.train_data_dir,"train", f"{video_name}.tensors.pth")
103103
loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu'))
104-
prompt_embeds = loaded_state_dict["prompt_emb"]["context"]
104+
prompt_embeds = loaded_state_dict["prompt_emb"]["context"].squeeze()
105105
latent = loaded_state_dict["latents"]
106106

107107
# Format we want(Batch, channels, Frames, Height, Width)

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,20 @@ def load_fusionx_transformer(pretrained_model_name_or_path: str, eval_shapes: di
8383

8484
pt_tuple_key = tuple(renamed_pt_key.split("."))
8585

86+
if "blocks" in pt_tuple_key:
87+
new_key = ("blocks",) + pt_tuple_key[2:]
88+
block_index = int(pt_tuple_key[1])
89+
pt_tuple_key = new_key
8690
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL)
8791
flax_key = rename_for_nnx(flax_key)
8892
flax_key = _tuple_str_to_int(flax_key)
93+
94+
if "blocks" in flax_key:
95+
if flax_key in flax_state_dict:
96+
new_tensor = flax_state_dict[flax_key]
97+
else:
98+
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
99+
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
89100
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
90101
validate_flax_state_dict(eval_shapes, flax_state_dict)
91102
flax_state_dict = unflatten_dict(flax_state_dict)
@@ -118,9 +129,21 @@ def load_causvid_transformer(pretrained_model_name_or_path: str, eval_shapes: di
118129

119130
pt_tuple_key = tuple(renamed_pt_key.split("."))
120131

132+
if "blocks" in pt_tuple_key:
133+
new_key = ("blocks",) + pt_tuple_key[2:]
134+
block_index = int(pt_tuple_key[1])
135+
pt_tuple_key = new_key
121136
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL)
122137
flax_key = rename_for_nnx(flax_key)
123138
flax_key = _tuple_str_to_int(flax_key)
139+
140+
141+
if "blocks" in flax_key:
142+
if flax_key in flax_state_dict:
143+
new_tensor = flax_state_dict[flax_key]
144+
else:
145+
new_tensor = jnp.zeros((40,) + flax_tensor.shape)
146+
flax_tensor = new_tensor.at[block_index].set(flax_tensor)
124147
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
125148
validate_flax_state_dict(eval_shapes, flax_state_dict)
126149
flax_state_dict = unflatten_dict(flax_state_dict)

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,8 +229,6 @@ def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, confi
229229
def loss_fn(model):
230230
latents = data["latents"].astype(config.weights_dtype)
231231
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
232-
# TODO - fix tf record conversion.
233-
encoder_hidden_states = jax.numpy.squeeze(encoder_hidden_states, axis=1)
234232
bsz = latents.shape[0]
235233
timesteps = jax.random.randint(
236234
timestep_rng,

0 commit comments

Comments
 (0)