|
12 | 12 | from huggingface_hub import hf_hub_download |
13 | 13 | from flax import nnx |
14 | 14 | from flax import traverse_util |
| 15 | + |
15 | 16 | def convert_ltx2_vae(hf_repo, output_path): |
16 | 17 | # Load weights directly from Safetensors |
17 | 18 | print(f"Downloading/Loading weights from {hf_repo}...") |
@@ -48,40 +49,77 @@ def convert_ltx2_vae(hf_repo, output_path): |
48 | 49 | dtype=jnp.float32, |
49 | 50 | rngs=nnx.Rngs(0) |
50 | 51 | ) |
51 | | - |
52 | | - # Get PyTorch state dict |
53 | | - pt_state_dict = load_file(ckpt_path) |
54 | 52 |
|
55 | | - # Define mapping |
56 | | - # We will need to map PT keys to Flax keys |
57 | | - # Helper to print PT keys |
58 | | - print("PyTorch Keys:") |
59 | | - sorted_pt_keys = sorted(pt_state_dict.keys()) |
60 | | - for k in sorted_pt_keys: |
61 | | - v = pt_state_dict[k] |
62 | | - print(f"{k}: {v.shape}") |
| 53 | + print("\nMapping weights...") |
| 54 | + graphdef, state = nnx.split(model) |
| 55 | + params = state.filter(nnx.Param) |
63 | 56 |
|
64 | | - print("\nMaxDiffusion Keys (initialization):") |
65 | | - # Get MaxDiffusion keys from initialized model |
66 | | - # We need to run a dummy forward or init to get parameters if they are lazy, |
67 | | - # but nnx.Module usually has them after init if shape is provided? |
68 | | - # Wait, nnx modules need to be split to see params. |
69 | | - graphdef, state = nnx.split(model); params = state.filter(nnx.Param) |
70 | | - flat_params = traverse_util.flatten_dict(params.to_pure_dict()) |
71 | | - sorted_flat_keys = sorted(flat_params.keys()) |
72 | | - for k in sorted_flat_keys: |
73 | | - v = flat_params[k] |
74 | | - print(f"{k}: {v.shape}") |
75 | | - |
76 | | - params = {} |
| 57 | + params_dict = params.to_pure_dict() |
| 58 | + flat_params = traverse_util.flatten_dict(params_dict) |
77 | 59 |
|
78 | | - # TODO: Implement the mapping logic here |
79 | | - # This acts as a template for now |
| 60 | + new_params = {} |
| 61 | + |
| 62 | + for key_tuple, value in flat_params.items(): |
| 63 | + # Skip Rngs if any leak through |
| 64 | + if "rngs" in key_tuple or "count" in key_tuple or "key" in key_tuple: |
| 65 | + continue |
| 66 | + |
| 67 | + # Construct PyTorch key |
| 68 | + pt_key_parts = [] |
| 69 | + for p in key_tuple: |
| 70 | + if isinstance(p, int): |
| 71 | + pt_key_parts.append(str(p)) |
| 72 | + else: |
| 73 | + pt_key_parts.append(p) |
| 74 | + |
| 75 | + # Adjust property names from MaxDiffusion to Diffusers |
| 76 | + if pt_key_parts[-1] == "kernel": |
| 77 | + pt_key_parts[-1] = "weight" |
| 78 | + elif pt_key_parts[-1] == "scale": |
| 79 | + pt_key_parts[-1] = "weight" |
| 80 | + |
| 81 | + pt_key = ".".join(pt_key_parts) |
| 82 | + |
| 83 | + # Check if key exists in PT dict |
| 84 | + if pt_key not in pt_state_dict: |
| 85 | + print(f"Warning: {pt_key} not found in PyTorch state dict. Checking alternatives...") |
| 86 | + continue |
| 87 | + |
| 88 | + pt_tensor = pt_state_dict[pt_key] |
| 89 | + np_array = pt_tensor.numpy() |
| 90 | + |
| 91 | + # Handle shape mismatch (Transpose Conv3d weights) |
| 92 | + is_conv_weight = "conv" in pt_key and "weight" in pt_key and len(np_array.shape) == 5 |
| 93 | + |
| 94 | + if is_conv_weight: |
| 95 | + # PyTorch Conv3d: (Out, In, T, H, W) |
| 96 | + # JAX Conv: (T, H, W, In, Out) |
| 97 | + # Permutation: 0, 1, 2, 3, 4 -> 2, 3, 4, 1, 0 |
| 98 | + np_array = np_array.transpose(2, 3, 4, 1, 0) |
| 99 | + |
| 100 | + # Verify shape |
| 101 | + if np_array.shape != value.shape: |
| 102 | + # Handle singleton dimensions if they match in total size or one dim |
| 103 | + if np_array.shape == (1,) + value.shape: |
| 104 | + np_array = np_array.squeeze(0) |
| 105 | + elif value.shape == (1,) + np_array.shape: |
| 106 | + np_array = np_array[None] |
| 107 | + else: |
| 108 | + print(f"Shape mismatch for {pt_key}: PT {np_array.shape} vs Max {value.shape}") |
| 109 | + continue |
| 110 | + |
| 111 | + new_params[key_tuple] = jnp.array(np_array) |
| 112 | + |
| 113 | + print(f"Mapped {len(new_params)} out of {len(flat_params)} parameters.") |
| 114 | + |
| 115 | + # Reconstruct nested dictionary |
| 116 | + params_nested = traverse_util.unflatten_dict(new_params) |
80 | 117 |
|
| 118 | + # Save checkpoint |
81 | 119 | print(f"Saving converted weights to {output_path}...") |
82 | | - checkpointer = orbax.checkpoint.PyTreeCheckpointer() |
83 | | - save_args = orbax_utils.save_args_from_target(params) |
84 | | - checkpointer.save(output_path, params, save_args=save_args) |
| 120 | + checkpointer = orbax.checkpoint.Checkpointer(orbax.checkpoint.PyTreeCheckpointHandler()) |
| 121 | + save_args = orbax_utils.save_args_from_target(params_nested) |
| 122 | + checkpointer.save(output_path, params_nested, save_args=save_args) |
85 | 123 | print("Done!") |
86 | 124 |
|
87 | 125 | if __name__ == "__main__": |
|
0 commit comments