|
26 | 26 | from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) |
27 | 27 |
|
28 | 28 |
|
| 29 | +LTX_2_0_VIDEO_VAE_RENAME_DICT = { |
| 30 | + # Encoder |
| 31 | + "down_blocks.0": "down_blocks.0", |
| 32 | + "down_blocks.1": "down_blocks.0.downsamplers.0", |
| 33 | + "down_blocks.2": "down_blocks.1", |
| 34 | + "down_blocks.3": "down_blocks.1.downsamplers.0", |
| 35 | + "down_blocks.4": "down_blocks.2", |
| 36 | + "down_blocks.5": "down_blocks.2.downsamplers.0", |
| 37 | + "down_blocks.6": "down_blocks.3", |
| 38 | + "down_blocks.7": "down_blocks.3.downsamplers.0", |
| 39 | + "down_blocks.8": "mid_block", |
| 40 | + # Decoder |
| 41 | + "up_blocks.0": "mid_block", |
| 42 | + "up_blocks.1": "up_blocks.0.upsamplers.0", |
| 43 | + "up_blocks.2": "up_blocks.0", |
| 44 | + "up_blocks.3": "up_blocks.1.upsamplers.0", |
| 45 | + "up_blocks.4": "up_blocks.1", |
| 46 | + "up_blocks.5": "up_blocks.2.upsamplers.0", |
| 47 | + "up_blocks.6": "up_blocks.2", |
| 48 | + "last_time_embedder": "time_embedder", |
| 49 | + "last_scale_shift_table": "scale_shift_table", |
| 50 | + # Common |
| 51 | + # For all 3D ResNets |
| 52 | + "res_blocks": "resnets", |
| 53 | + "per_channel_statistics.mean-of-means": "latents_mean", |
| 54 | + "per_channel_statistics.std-of-means": "latents_std", |
| 55 | +} |
| 56 | + |
| 57 | +LTX_2_3_VIDEO_VAE_RENAME_DICT = { |
| 58 | + **LTX_2_0_VIDEO_VAE_RENAME_DICT, |
| 59 | + # Decoder extra blocks |
| 60 | + "up_blocks.7": "up_blocks.3.upsamplers.0", |
| 61 | + "up_blocks.8": "up_blocks.3", |
| 62 | +} |
| 63 | + |
| 64 | + |
29 | 65 | def _tuple_str_to_int(in_tuple): |
30 | 66 | out_list = [] |
31 | 67 | for item in in_tuple: |
@@ -225,7 +261,12 @@ def load_vae_weights( |
225 | 261 |
|
226 | 262 | for pt_key, tensor in tensors.items(): |
227 | 263 | # latents_mean and latents_std are nnx.Params and will be loaded correctly. |
228 | | - renamed_pt_key = rename_key(pt_key) |
| 264 | + new_key = pt_key |
| 265 | + if filename == "ltx-2.3-22b-dev.safetensors": |
| 266 | + for replace_key, rename_to in LTX_2_3_VIDEO_VAE_RENAME_DICT.items(): |
| 267 | + new_key = new_key.replace(replace_key, rename_to) |
| 268 | + |
| 269 | + renamed_pt_key = rename_key(new_key) |
229 | 270 | renamed_pt_key = renamed_pt_key.replace("nin_shortcut", "conv_shortcut") |
230 | 271 |
|
231 | 272 | pt_tuple_key = tuple(renamed_pt_key.split(".")) |
|
0 commit comments