Skip to content

Commit 7ebd392

Browse files
committed
fix1 try
1 parent 4c954a5 commit 7ebd392

1 file changed

Lines changed: 75 additions & 33 deletions

File tree

src/maxdiffusion/models/wan/wan_utils.py

Lines changed: 75 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -177,26 +177,6 @@ def load_causvid_transformer(
177177
return flax_state_dict
178178

179179

180-
def load_wan_transformer(
181-
pretrained_model_name_or_path: str,
182-
eval_shapes: dict,
183-
device: str,
184-
hf_download: bool = True,
185-
num_layers: int = 40,
186-
scan_layers: bool = True,
187-
subfolder: str = "",
188-
):
189-
190-
if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH:
191-
return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
192-
elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH:
193-
return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers)
194-
else:
195-
return load_base_wan_transformer(
196-
pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers, subfolder
197-
)
198-
199-
200180
def load_base_wan_transformer(
201181
pretrained_model_name_or_path: str,
202182
eval_shapes: dict,
@@ -206,6 +186,7 @@ def load_base_wan_transformer(
206186
scan_layers: bool = True,
207187
subfolder: str = "",
208188
):
189+
print(f"\n=== DEBUG START: Loading Transformer from {pretrained_model_name_or_path} ===")
209190
device = jax.local_devices(backend=device)[0]
210191
filename = "diffusion_pytorch_model.safetensors.index.json"
211192
local_files = False
@@ -215,14 +196,13 @@ def load_base_wan_transformer(
215196
raise FileNotFoundError(f"File {index_file_path} not found for local directory.")
216197
local_files = True
217198
elif hf_download:
218-
# download the index file for sharded models.
219199
index_file_path = hf_hub_download(
220200
pretrained_model_name_or_path,
221201
subfolder=subfolder,
222202
filename=filename,
223203
)
204+
224205
with jax.default_device(device):
225-
# open the index file.
226206
with open(index_file_path, "r") as f:
227207
index_dict = json.load(f)
228208
model_files = set()
@@ -231,37 +211,68 @@ def load_base_wan_transformer(
231211

232212
model_files = list(model_files)
233213
tensors = {}
214+
print(f"=== DEBUG: Loading {len(model_files)} shard files... ===")
215+
234216
for model_file in model_files:
235217
if local_files:
236218
ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file)
237219
else:
238220
ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file)
239-
# now get all the filenames for the model that need downloading
221+
240222
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
241223

242224
if ckpt_shard_path is not None:
243225
with safe_open(ckpt_shard_path, framework="pt") as f:
244226
for k in f.keys():
245227
tensors[k] = torch2jax(f.get_tensor(k))
228+
229+
print(f"=== DEBUG: Total tensors loaded: {len(tensors)} ===")
230+
246231
flax_state_dict = {}
247232
cpu = jax.local_devices(backend="cpu")[0]
248233
flattened_dict = flatten_dict(eval_shapes)
249-
# turn all block numbers to strings just for matching weights.
250-
# Later they will be turned back to ints.
251234
random_flax_state_dict = {}
252235
for key in flattened_dict:
253236
string_tuple = tuple([str(item) for item in key])
254237
random_flax_state_dict[string_tuple] = flattened_dict[key]
255238
del flattened_dict
239+
240+
# --- 1. Initialize Buffer ---
256241
norm_added_q_buffer = {}
242+
norm_added_q_debug_count = 0
243+
244+
print("=== DEBUG: Starting Tensor Loop ===")
257245
for pt_key, tensor in tensors.items():
258246
renamed_pt_key = rename_key(pt_key)
247+
248+
# --- 2. Buffer 'norm_added_q' Logic ---
259249
if "norm_added_q" in pt_key:
260-
parts = pt_key.split(".")
261-
block_idx = int(parts[1])
262-
tensor = tensor.T
263-
norm_added_q_buffer[block_idx] = tensor
264-
continue
250+
norm_added_q_debug_count += 1
251+
# Print the first 3 to confirm we are hitting this block
252+
if norm_added_q_debug_count <= 3:
253+
print(f"DEBUG: Catching norm_added_q key: {pt_key}")
254+
255+
try:
256+
parts = pt_key.split(".")
257+
# Robust parsing: Find the part that is a digit
258+
block_idx = -1
259+
for part in parts:
260+
if part.isdigit():
261+
block_idx = int(part)
262+
break
263+
264+
if block_idx == -1:
265+
print(f"DEBUG ERROR: Could not find block index in {pt_key}")
266+
continue
267+
268+
tensor = tensor.T
269+
norm_added_q_buffer[block_idx] = tensor
270+
except Exception as e:
271+
print(f"DEBUG EXCEPTION parsing {pt_key}: {e}")
272+
273+
continue # SKIP rest of loop
274+
275+
# --- 3. Image Embedder Logic ---
265276
if "image_embedder" in renamed_pt_key:
266277
if "net.0" in renamed_pt_key or "net_0" in renamed_pt_key or \
267278
"net.2" in renamed_pt_key or "net_2" in renamed_pt_key:
@@ -281,22 +292,53 @@ def load_base_wan_transformer(
281292
if "norm1" in renamed_pt_key or "norm2" in renamed_pt_key:
282293
renamed_pt_key = renamed_pt_key.replace("weight", "scale")
283294
renamed_pt_key = renamed_pt_key.replace("kernel", "scale")
295+
296+
# --- 4. Global Replacements ---
284297
renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.")
285298
renamed_pt_key = renamed_pt_key.replace(".scale_shift_table", ".adaln_scale_shift_table")
286299
renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn")
287300
renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out")
288301
renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn")
302+
289303
if "norm2.layer_norm" not in renamed_pt_key:
290-
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
304+
renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm")
305+
291306
pt_tuple_key = tuple(renamed_pt_key.split("."))
292307
flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers)
293308
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
309+
310+
# --- 5. Final Stack & Insert ---
311+
print(f"=== DEBUG: Loop Finished. norm_added_q count: {len(norm_added_q_buffer)} ===")
312+
294313
if norm_added_q_buffer:
295-
sorted_tensors = [norm_added_q_buffer[i] for i in sorted(norm_added_q_buffer.keys())]
314+
sorted_indices = sorted(norm_added_q_buffer.keys())
315+
print(f"DEBUG: Stacking indices from {sorted_indices[0]} to {sorted_indices[-1]}")
316+
317+
sorted_tensors = [norm_added_q_buffer[i] for i in sorted_indices]
296318
stacked_tensor = jnp.stack(sorted_tensors, axis=0)
319+
print(f"DEBUG: Final Stacked Shape: {stacked_tensor.shape}")
320+
321+
# INSERT AND VERIFY
297322
final_key = ('blocks', 'attn2', 'norm_added_q', 'kernel')
298323
flax_state_dict[final_key] = jax.device_put(stacked_tensor, device=cpu)
299-
324+
325+
if final_key in flax_state_dict:
326+
print(f"DEBUG: SUCCESS - Key {final_key} exists in dict.")
327+
else:
328+
print(f"DEBUG: CRITICAL FAILURE - Key insertion failed.")
329+
else:
330+
print("DEBUG: FAILURE - Buffer is EMPTY. No norm_added_q found!")
331+
332+
print("=== DEBUG: Starting Validation ===")
333+
# Print what the validator EXPECTS vs what we HAVE
334+
expected_keys = set(flatten_dict(eval_shapes).keys())
335+
actual_keys = set(flax_state_dict.keys())
336+
337+
missing = expected_keys - actual_keys
338+
print(f"DEBUG: Missing Keys Count: {len(missing)}")
339+
if len(missing) > 0 and len(missing) < 10:
340+
print(f"DEBUG: Missing Keys: {missing}")
341+
300342
validate_flax_state_dict(eval_shapes, flax_state_dict)
301343
flax_state_dict = unflatten_dict(flax_state_dict)
302344
del tensors

0 commit comments

Comments
 (0)