Skip to content

Commit a34529f

Browse files
committed
fix
1 parent 5757481 commit a34529f

1 file changed

Lines changed: 65 additions & 162 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_utils.py

Lines changed: 65 additions & 162 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
21
import os
32
import json
43
import torch
@@ -28,79 +27,73 @@ def rename_for_ltx2_transformer(key):
2827
"""
2928
Renames Diffusers LTX-2 keys to MaxDiffusion Flax LTX-2 keys.
3029
"""
31-
# General replacements
3230
key = key.replace("patchify_proj", "proj_in")
3331
key = key.replace("audio_patchify_proj", "audio_proj_in")
34-
key = key.replace("transformer_blocks", "transformer_blocks") # kept same
35-
36-
# AdaLN / Timestep Embed handling
37-
# Diffusers uses: time_embed, audio_time_embed, av_cross_attn_...
38-
# These match Flax implementation names mostly.
39-
40-
# Attention QK Norms -> Flax uses "norm_q", "norm_k" (Diffusers often uses q_norm, k_norm but conversion script mapped them to norm_q/norm_k already?
41-
# Wait, the conversion script maps *from* original *to* Diffusers.
42-
# If loading Diffusers checkpoint, we should expect "norm_q", "norm_k" if that's what Diffusers uses.
43-
# Checking conversion script: LTX_2_0_TRANSFORMER_KEYS_RENAME_DICT maps "q_norm" -> "norm_q".
44-
# So Diffusers likely uses "norm_q".
45-
46-
# Handle "weight" -> "kernel" for Linear/Conv layers is done in rename_key_and_reshape_tensor
47-
# checking rename_key_and_reshape_tensor: it handles "weight" -> "kernel" for linear/conv.
48-
49-
# Specific LTX-2 nested renaming
50-
# Diffusers: transformer_blocks.0.attn1.to_q.weight
51-
# Flax: transformer_blocks.layers.0.attn1.query.kernel (if scanned)
52-
53-
# rename_key_and_reshape_tensor handles:
54-
# to_q -> query
55-
# to_k -> key
56-
# to_v -> value
57-
# to_out.0 -> proj_attn
58-
59-
# We might need to handle specific mismatches if any.
60-
61-
# The "scale" vs "weight" for LayerNorm is also handled in rename_key_and_reshape_tensor
62-
# BUT only if it detects "norm" in key.
63-
64-
# LTX2AdaLayerNormSingle usually has "linear" which is a Linear layer.
65-
32+
key = key.replace("transformer_blocks", "transformer_blocks")
6633
return key
6734

6835

6936
def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers, num_layers=48):
7037
if scan_layers:
7138
if "transformer_blocks" in pt_tuple_key:
72-
# transformer_blocks.0.attn1... -> transformer_blocks.layers.attn1...
73-
# We need to extract the block index
74-
new_key = ("transformer_blocks",) + pt_tuple_key[2:] # removing index
39+
new_key = ("transformer_blocks",) + pt_tuple_key[2:]
7540
block_index = int(pt_tuple_key[1])
7641
pt_tuple_key = new_key
7742

78-
# For scanned layers, we need to locate the param in the huge stacked tensor
79-
# But wait, rename_key_and_reshape_tensor takes the *modified* pt_tuple_key?
80-
# No, it takes the original one usually to check against random_flax_state_dict.
81-
# But here we are constructing it.
82-
8343
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers)
84-
85-
# Custom cleaning after generic rename
86-
# e.g. converting "weight" to "value" for Params if needed, though they usually just take array.
87-
8844
flax_key = _tuple_str_to_int(flax_key)
8945

9046
if scan_layers:
9147
if "transformer_blocks" in flax_key:
92-
# We need to stack correct index
9348
if flax_key in flax_state_dict:
9449
new_tensor = flax_state_dict[flax_key]
9550
else:
96-
# Initialize with zeros of shape (num_layers, ...) + tensor.shape
9751
new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape, dtype=flax_tensor.dtype)
98-
9952
new_tensor = new_tensor.at[block_index].set(flax_tensor)
10053
flax_tensor = new_tensor
10154

10255
return flax_key, flax_tensor
10356

57+
def load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device):
58+
"""
59+
Loads weights from a sharded safetensors checkpoint.
60+
"""
61+
index_file = "diffusion_pytorch_model.safetensors.index.json"
62+
tensors = {}
63+
64+
# Try to download index file
65+
try:
66+
index_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=index_file)
67+
with open(index_path, "r") as f:
68+
index_data = json.load(f)
69+
weight_map = index_data["weight_map"]
70+
shards = set(weight_map.values())
71+
72+
for shard_file in shards:
73+
shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=shard_file)
74+
with safe_open(shard_path, framework="pt") as f:
75+
for k in f.keys():
76+
tensors[k] = torch2jax(f.get_tensor(k))
77+
except Exception:
78+
# Fallback to single file
79+
filename = "diffusion_pytorch_model.safetensors"
80+
try:
81+
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
82+
except Exception:
83+
filename = "diffusion_pytorch_model.bin"
84+
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
85+
86+
if filename.endswith(".safetensors"):
87+
with safe_open(ckpt_path, framework="pt") as f:
88+
for k in f.keys():
89+
tensors[k] = torch2jax(f.get_tensor(k))
90+
else:
91+
loaded_state_dict = torch.load(ckpt_path, map_location="cpu")
92+
for k, v in loaded_state_dict.items():
93+
tensors[k] = torch2jax(v)
94+
95+
return tensors
96+
10497
def load_transformer_weights(
10598
pretrained_model_name_or_path: str,
10699
eval_shapes: dict,
@@ -111,57 +104,25 @@ def load_transformer_weights(
111104
subfolder: str = "transformer",
112105
):
113106
device = jax.local_devices(backend=device)[0]
114-
115-
# Determine if local or hub
116-
filename = "diffusion_pytorch_model.safetensors"
117-
local_files = False
118-
if os.path.isdir(pretrained_model_name_or_path):
119-
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
120-
if not os.path.isfile(ckpt_path):
121-
# Try .bin just in case
122-
filename = "diffusion_pytorch_model.bin"
123-
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
124-
if not os.path.isfile(ckpt_path):
125-
raise FileNotFoundError(f"File {ckpt_path} not found for local directory.")
126-
local_files = True
127-
elif hf_download:
128-
try:
129-
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
130-
except Exception:
131-
filename = "diffusion_pytorch_model.bin"
132-
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
133-
134107
max_logging.log(f"Load and port {pretrained_model_name_or_path} {subfolder} on {device}")
135108

136109
with jax.default_device(device):
137-
tensors = {}
138-
if filename.endswith(".safetensors"):
139-
with safe_open(ckpt_path, framework="pt") as f:
140-
for k in f.keys():
141-
tensors[k] = torch2jax(f.get_tensor(k))
142-
else: # bin/pt
143-
loaded_state_dict = torch.load(ckpt_path, map_location="cpu")
144-
for k, v in loaded_state_dict.items():
145-
tensors[k] = torch2jax(v)
110+
# Support sharded loading
111+
tensors = load_sharded_checkpoint(pretrained_model_name_or_path, subfolder, device)
146112

147113
flax_state_dict = {}
148114
cpu = jax.local_devices(backend="cpu")[0]
149115
flattened_dict = flatten_dict(eval_shapes)
150116

151-
# Create random state dict with string keys for matching
152117
random_flax_state_dict = {}
153118
for key in flattened_dict:
154-
# Convert all ints to strings in key tuple
155119
string_tuple = tuple([str(item) for item in key])
156120
random_flax_state_dict[string_tuple] = flattened_dict[key]
157121

158122
for pt_key, tensor in tensors.items():
159123
renamed_pt_key = rename_key(pt_key)
160124
renamed_pt_key = rename_for_ltx2_transformer(renamed_pt_key)
161125

162-
# Handling specific replacements that `rename_key` might miss or `rename_for_ltx2` specifically targets
163-
# The `scan_layers` handling requires splitting the key differently if needed.
164-
165126
pt_tuple_key = tuple(renamed_pt_key.split("."))
166127

167128
flax_key, flax_tensor = get_key_and_value(
@@ -176,7 +137,6 @@ def load_transformer_weights(
176137
jax.clear_caches()
177138
return flax_state_dict
178139

179-
180140
def load_vae_weights(
181141
pretrained_model_name_or_path: str,
182142
eval_shapes: dict,
@@ -185,22 +145,16 @@ def load_vae_weights(
185145
subfolder: str = "vae"
186146
):
187147
device = jax.local_devices(backend=device)[0]
188-
filename = "diffusion_pytorch_model.safetensors"
148+
# VAE for LTX-2 is likely single file, but safe to use the helper if we wanted general robustness.
149+
# But `lightricks/LTX-2` VAE is single file.
189150

190-
if os.path.isdir(pretrained_model_name_or_path):
191-
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
192-
if not os.path.isfile(ckpt_path):
193-
filename = "diffusion_pytorch_model.bin"
194-
ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename)
195-
if not os.path.isfile(ckpt_path):
196-
raise FileNotFoundError(f"File {ckpt_path} not found for local directory.")
197-
elif hf_download:
198-
try:
151+
filename = "diffusion_pytorch_model.safetensors"
152+
try:
199153
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
200-
except Exception:
154+
except Exception:
201155
filename = "diffusion_pytorch_model.bin"
202156
ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename)
203-
157+
204158
max_logging.log(f"Load and port {pretrained_model_name_or_path} VAE on {device}")
205159

206160
with jax.default_device(device):
@@ -218,83 +172,32 @@ def load_vae_weights(
218172
cpu = jax.local_devices(backend="cpu")[0]
219173
flattened_eval = flatten_dict(eval_shapes)
220174

221-
# Build random state dict for shape checking/key matching help
222-
# VAE usually doesn't need scan layers logic for mapping (unless we implement scanned VAE similar to Transformer, but autoencoder_kl_ltx2.py uses scan but keys seem compatible with standard diffusers structure if mapped correctly)
223-
# Wait, `autoencoder_kl_ltx2.py` DOES use scan for `resnets`!
224-
# See `create_resnets` and `resnet_scan_fn`.
225-
# So we DO need scan layer handling for VAE if we want to load it into that structure.
226-
# The VAE resnets are scanned over `num_layers`.
227-
228-
# Mapping Diffusers VAE to Scanned VAE:
229-
# Diffusers: down_blocks.0.resnets.0 ...
230-
# Flax Scanned: down_blocks.0.resnets.layers.0 ... (if mapped that way)
231-
# OR: down_blocks.0.resnets -> (num_layers, ...) tensor if we stack them.
232-
233-
# Let's check `autoencoder_kl_ltx2.py` again.
234-
# `self.resnets = create_resnets(rngs)` where `create_resnets` is vmapped.
235-
# This creates params with a leading dimension = num_layers.
236-
# So we need to stack Diffusers resnets weights.
237-
238-
# We need a custom `get_key_and_value` for VAE or modify the existing one to handle VAE blocks too.
239-
pass
240-
241-
# For now, let's just write the loading logic and we might need to iterate and fix VAE scanning logic if it fails validation.
242-
# Ideally we use `rename_key_and_reshape_tensor` heavily.
243-
244175
random_flax_state_dict = {}
245176
for key in flattened_eval:
246177
string_tuple = tuple([str(item) for item in key])
247178
random_flax_state_dict[string_tuple] = flattened_eval[key]
248-
179+
249180
for pt_key, tensor in tensors.items():
250181
renamed_pt_key = rename_key(pt_key)
251-
252-
# VAE specific renames
253-
renamed_pt_key = renamed_pt_key.replace("mid_block.resnets.", "mid_block.resnets.layers.")
254-
renamed_pt_key = renamed_pt_key.replace("down_blocks.", "down_blocks.") # keeping same
255-
# Need to handle resnets.0 -> resnets.layers.0 etc if we want to be explicit, or rely on scanning logic.
256-
257-
# If we use scan, we need to stack "resnets.0", "resnets.1" etc into "resnets" tensor.
258-
# The logic in `get_key_and_value` handles `transformer_blocks` scanning. We should extend it for VAE `resnets`.
259-
260-
# Actually, `autoencoder_kl_ltx2.py` VAE scanning is slightly different.
261-
# It scans over `resnets`.
262-
# Diffusers has `down_blocks.0.resnets.0`, `down_blocks.0.resnets.1`.
263-
# We need to stack these.
264-
182+
if ".resnets." in renamed_pt_key:
183+
# pattern: resnets.0 -> resnets_0
184+
parts = renamed_pt_key.split(".")
185+
new_parts = []
186+
i = 0
187+
while i < len(parts):
188+
if parts[i] == "resnets" and i+1 < len(parts) and parts[i+1].isdigit():
189+
new_parts.append(f"resnets_{parts[i+1]}")
190+
i += 2
191+
else:
192+
new_parts.append(parts[i])
193+
i += 1
194+
renamed_pt_key = ".".join(new_parts)
195+
265196
pt_tuple_key = tuple(renamed_pt_key.split("."))
266197

267-
# Let's add VAE scanning logic here or in a helper
268-
# Identifying keys to stack: keys containing `resnets._`
269-
270-
# Simplified VAE Loading (non-scanned or manual stacking):
271-
# If `rename_key_and_reshape_tensor` expects exact matching, we might have trouble if keys are "resnets.0" but flax expects "resnets" (stacked).
272-
273-
# I will implement a check: if key has `resnets.N`, we try to stack it.
274-
275198
flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict)
276-
277-
# If it didn't match immediately, check if it's a resnet layer that needs stacking
278-
# This part is tricky without strictly knowing num_layers per block.
279-
# But we can infer or just load individually if Flax wasn't scanned?
280-
# The Flax code definitely uses scan.
281-
282-
# HACK: For VAE, let's assume we might need to manually stack or map to specific indices if `rename_key_and_reshape_tensor` didn't catch it.
283-
# But for now, let's just use `rename_key_and_reshape_tensor` and `validate_flax_state_dict` will tell us what failed.
284-
285199
flax_key = _tuple_str_to_int(flax_key)
286200

287-
# Manual VAE Stacking logic if needed:
288-
# if "resnets" in flax_key and generic match failed...
289-
290-
# Let's rely on `validate_flax_state_dict` to debug VAE mapping in the test phase if it's complex.
291-
# But I should probably add the `resnets` -> `resnets.layers` replacement to be safe?
292-
# Wait, if I replace `resnets.0` with `resnets.layers.0`, and Flax expects `resnets` (stacked), it still won't match.
293-
# Flax `nnx.vmap` with `transform_metadata={nnx.PARTITION_NAME: "layers"}` usually expects a stacked axis.
294-
# The parameter key in `state_dict` for a vmapped layer often depends on how it's stored.
295-
# In NNX/Flax, it might be stored as `resnets.layers`? No, usually just `resnets` with an extra dim?
296-
# Or `resnets.layers.kernel` if it kept the name.
297-
298201
flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu)
299202

300203
validate_flax_state_dict(eval_shapes, flax_state_dict)

0 commit comments

Comments
 (0)