Skip to content

Commit f05c97b

Browse files
committed
reverted
1 parent 7ebd392 commit f05c97b

1 file changed

Lines changed: 34 additions & 76 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 34 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,26 @@ def load_causvid_transformer(
177177
return flax_state_dict
178178

179179

180+
def load_wan_transformer(
181+
pretrained_model_name_or_path: str,
182+
eval_shapes: dict,
183+
device: str,
184+
hf_download: bool = True,
185+
num_layers: int = 40,
186+
scan_layers: bool = True,
187+
subfolder: str = "",
188+
):
189+
190+
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
191+
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
192+
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
193+
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
194+
else:
195+
return load_base_wan_transformer(
196+
pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder
197+
)
198+
199+
180200
def load_base_wan_transformer(
181201
pretrained_model_name_or_path: str,
182202
eval_shapes: dict,
@@ -186,7 +206,6 @@ def load_base_wan_transformer(
186206
scan_layers: bool = True,
187207
subfolder: str = "",
188208
):
189-
print(f"\n=== DEBUG START: Loading Transformer from {pretrained_model_name_or_path} ===")
190209
device = jax.local_devices(backend=device)[0]
191210
filename = "diffusion_pytorch_model.safetensors.index.json"
192211
local_files = False
@@ -196,13 +215,14 @@ def load_base_wan_transformer(
196215
raise FileNotFoundError(f"File {index_file_path} not found for local directory.")
197216
local_files = True
198217
elif hf_download:
218+
# download the index file for sharded models.
199219
index_file_path = hf_hub_download(
200220
pretrained_model_name_or_path,
201221
subfolder=subfolder,
202222
filename=filename,
203223
)
204-
205224
with jax.default_device(device):
225+
# open the index file.
206226
with open(index_file_path, "r") as f:
207227
index_dict = json.load(f)
208228
model_files = set()
@@ -211,68 +231,37 @@ def load_base_wan_transformer(
211231

212232
model_files = list(model_files)
213233
tensors = {}
214-
print(f"=== DEBUG: Loading {len(model_files)} shard files... ===")
215-
216234
for model_file in model_files:
217235
if local_files:
218236
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
219237
else:
220238
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
221-
239+
# now get all the filenames for the model that need downloading
222240
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
223241

224242
if ckpt_shard_path is not None:
225243
with safe_open(ckpt_shard_path, framework="pt") as f:
226244
for k in f.keys():
227245
tensors[k] = torch2jax(f.get_tensor(k))
228-
229-
print(f"=== DEBUG: Total tensors loaded: {len(tensors)} ===")
230-
231246
flax_state_dict = {}
232247
cpu = jax.local_devices(backend="cpu")[0]
233248
flattened_dict = flatten_dict(eval_shapes)
249+
# turn all block numbers to strings just for matching weights.
250+
# Later they will be turned back to ints.
234251
random_flax_state_dict = {}
235252
for key in flattened_dict:
236253
string_tuple = tuple([str(item) for item in key])
237254
random_flax_state_dict[string_tuple] = flattened_dict[key]
238255
del flattened_dict
239-
240-
# --- 1. Initialize Buffer ---
241256
norm_added_q_buffer = {}
242-
norm_added_q_debug_count = 0
243-
244-
print("=== DEBUG: Starting Tensor Loop ===")
245257
for pt_key, tensor in tensors.items():
246258
renamed_pt_key = rename_key(pt_key)
247-
248-
# --- 2. Buffer 'norm_added_q' Logic ---
249259
if "norm_added_q" in pt_key:
250-
norm_added_q_debug_count += 1
251-
# Print the first 3 to confirm we are hitting this block
252-
if norm_added_q_debug_count <= 3:
253-
print(f"DEBUG: Catching norm_added_q key: {pt_key}")
254-
255-
try:
256-
parts = pt_key.split(".")
257-
# Robust parsing: Find the part that is a digit
258-
block_idx = -1
259-
for part in parts:
260-
if part.isdigit():
261-
block_idx = int(part)
262-
break
263-
264-
if block_idx == -1:
265-
print(f"DEBUG ERROR: Could not find block index in {pt_key}")
266-
continue
267-
268-
tensor = tensor.T
269-
norm_added_q_buffer[block_idx] = tensor
270-
except Exception as e:
271-
print(f"DEBUG EXCEPTION parsing {pt_key}: {e}")
272-
273-
continue # SKIP rest of loop
274-
275-
# --- 3. Image Embedder Logic ---
260+
parts = pt_key.split(".")
261+
block_idx = int(parts[1])
262+
tensor = tensor.T
263+
norm_added_q_buffer[block_idx] = tensor
264+
continue
276265
if "image_embedder" in renamed_pt_key:
277266
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
278267
"net.2" in renamed_pt_key or "net_2" in renamed_pt_key:
@@ -292,53 +281,22 @@ def load_base_wan_transformer(
292281
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
293282
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
294283
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
295-
296-
# --- 4. Global Replacements ---
297284
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
298285
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
299286
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
300287
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
301288
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
302-
303289
if "norm2.layer_norm" not in renamed_pt_key:
304-
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
305-
290+
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
306291
pt_tuple_key = tuple(renamed_pt_key.split("."))
307292
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
308293
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
309-
310-
# --- 5. Final Stack & Insert ---
311-
print(f"=== DEBUG: Loop Finished. norm_added_q count: {len(norm_added_q_buffer)} ===")
312-
313294
if norm_added_q_buffer:
314-
sorted_indices = sorted(norm_added_q_buffer.keys())
315-
print(f"DEBUG: Stacking indices from {sorted_indices[0]} to {sorted_indices[-1]}")
316-
317-
sorted_tensors = [norm_added_q_buffer[i] for i in sorted_indices]
295+
sorted_tensors = [norm_added_q_buffer[i] for i in sorted(norm_added_q_buffer.keys())]
318296
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
319-
print(f"DEBUG: Final Stacked Shape: {stacked_tensor.shape}")
320-
321-
# INSERT AND VERIFY
322297
final_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
323298
flax_state_dict[final_key] = jax.device_put(stacked_tensor, device=cpu)
324-
325-
if final_key in flax_state_dict:
326-
print(f"DEBUG: SUCCESS - Key {final_key} exists in dict.")
327-
else:
328-
print(f"DEBUG: CRITICAL FAILURE - Key insertion failed.")
329-
else:
330-
print("DEBUG: FAILURE - Buffer is EMPTY. No norm_added_q found!")
331-
332-
print("=== DEBUG: Starting Validation ===")
333-
# Print what the validator EXPECTS vs what we HAVE
334-
expected_keys = set(flatten_dict(eval_shapes).keys())
335-
actual_keys = set(flax_state_dict.keys())
336-
337-
missing = expected_keys - actual_keys
338-
print(f"DEBUG: Missing Keys Count: {len(missing)}")
339-
if len(missing) > 0 and len(missing) < 10:
340-
print(f"DEBUG: Missing Keys: {missing}")
341-
299+
342300
validate_flax_state_dict(eval_shapes, flax_state_dict)
343301
flax_state_dict = unflatten_dict(flax_state_dict)
344302
del tensors
@@ -404,4 +362,4 @@ def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device:
404362
else:
405363
raise FileNotFoundError(f"Path {ckpt_path} was not found")
406364

407-
return flax_state_dict
365+
return flax_state_dict

0 commit comments

Comments
 (0)