Skip to content

Commit b83a1ab

Browse files
committed
fix and ltx2 backward compatibility
1 parent c159b84 commit b83a1ab

3 files changed

Lines changed: 133 additions & 55 deletions

File tree

src/maxdiffusion/models/ltx2/ltx2_3_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@
3939
"audio_linear.bias": "audio_text_proj_in.bias",
4040
"video_linear.weight": "video_text_proj_in.kernel",
4141
"video_linear.bias": "video_text_proj_in.bias",
42+
}
43+
44+
LTX_2_3_ONLY_RENAME_DICT = {
4245
"video_embeddings_connector": "video_connector",
4346
"audio_embeddings_connector": "audio_connector",
4447
}
@@ -50,6 +53,7 @@ def load_connectors_weights(
5053
hf_download: bool = True,
5154
subfolder: str = "",
5255
filename: str = None,
56+
is_ltx2_3: bool = False,
5357
):
5458
device = jax.local_devices(backend=device)[0]
5559

@@ -69,6 +73,10 @@ def load_connectors_weights(
6973
for replace_key, rename_to in LTX_2_3_CONNECTORS_KEYS_RENAME_DICT.items():
7074
flax_key_str = flax_key_str.replace(replace_key, rename_to)
7175

76+
if is_ltx2_3:
77+
for replace_key, rename_to in LTX_2_3_ONLY_RENAME_DICT.items():
78+
flax_key_str = flax_key_str.replace(replace_key, rename_to)
79+
7280
segments = flax_key_str.split(".")
7381

7482
# Only extract digit if it immediately follows 'stacked_blocks'

src/maxdiffusion/models/ltx2/text_encoders/text_encoders_ltx2.py

Lines changed: 117 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -73,69 +73,135 @@ def __init__(
7373

7474
self.per_modality_projections = per_modality_projections
7575

76-
self.feature_extractor = LTX2GemmaFeatureExtractor(
77-
input_dim=input_dim,
78-
output_dim=caption_channels,
79-
dtype=dtype,
80-
rngs=rngs,
81-
per_modality_projections=per_modality_projections,
82-
use_bias=proj_bias,
83-
video_output_dim=v_dim,
84-
audio_output_dim=a_dim,
85-
)
86-
87-
# Two independent connectors
88-
self.video_embeddings_connector = Embeddings1DConnector(
89-
input_dim=v_dim,
90-
heads=video_connector_num_attention_heads,
91-
head_dim=video_connector_attention_head_dim,
92-
layers=video_connector_num_layers,
93-
num_learnable_registers=video_connector_num_learnable_registers,
94-
rope_type=rope_type,
95-
theta=rope_theta,
96-
base_seq_len=connector_rope_base_seq_len,
97-
double_precision=rope_double_precision,
98-
attention_kernel=attention_kernel,
99-
mesh=mesh,
100-
rngs=rngs,
101-
gated_attn=video_gated_attn,
102-
)
103-
104-
self.audio_embeddings_connector = Embeddings1DConnector(
105-
input_dim=a_dim,
106-
heads=audio_connector_num_attention_heads,
107-
head_dim=audio_connector_attention_head_dim,
108-
layers=audio_connector_num_layers,
109-
num_learnable_registers=audio_connector_num_learnable_registers,
110-
rope_type=rope_type,
111-
theta=rope_theta,
112-
base_seq_len=connector_rope_base_seq_len,
113-
double_precision=rope_double_precision,
114-
attention_kernel=attention_kernel,
115-
mesh=mesh,
116-
rngs=rngs,
117-
gated_attn=audio_gated_attn,
118-
)
76+
if per_modality_projections:
77+
self.video_text_proj_in = nnx.Linear(
78+
in_features=input_dim, out_features=v_dim, use_bias=proj_bias, rngs=rngs
79+
)
80+
self.audio_text_proj_in = nnx.Linear(
81+
in_features=input_dim, out_features=a_dim, use_bias=proj_bias, rngs=rngs
82+
)
83+
84+
self.video_connector = Embeddings1DConnector(
85+
input_dim=v_dim,
86+
heads=video_connector_num_attention_heads,
87+
head_dim=video_connector_attention_head_dim,
88+
layers=video_connector_num_layers,
89+
num_learnable_registers=video_connector_num_learnable_registers,
90+
rope_type=rope_type,
91+
theta=rope_theta,
92+
base_seq_len=connector_rope_base_seq_len,
93+
double_precision=rope_double_precision,
94+
attention_kernel=attention_kernel,
95+
mesh=mesh,
96+
rngs=rngs,
97+
gated_attn=video_gated_attn,
98+
)
99+
self.audio_connector = Embeddings1DConnector(
100+
input_dim=a_dim,
101+
heads=audio_connector_num_attention_heads,
102+
head_dim=audio_connector_attention_head_dim,
103+
layers=audio_connector_num_layers,
104+
num_learnable_registers=audio_connector_num_learnable_registers,
105+
rope_type=rope_type,
106+
theta=rope_theta,
107+
base_seq_len=connector_rope_base_seq_len,
108+
double_precision=rope_double_precision,
109+
attention_kernel=attention_kernel,
110+
mesh=mesh,
111+
rngs=rngs,
112+
gated_attn=audio_gated_attn,
113+
)
114+
else:
115+
self.feature_extractor = LTX2GemmaFeatureExtractor(
116+
input_dim=input_dim,
117+
output_dim=caption_channels,
118+
dtype=dtype,
119+
rngs=rngs,
120+
per_modality_projections=per_modality_projections,
121+
use_bias=proj_bias,
122+
video_output_dim=v_dim,
123+
audio_output_dim=a_dim,
124+
)
125+
126+
# Two independent connectors
127+
self.video_embeddings_connector = Embeddings1DConnector(
128+
input_dim=v_dim,
129+
heads=video_connector_num_attention_heads,
130+
head_dim=video_connector_attention_head_dim,
131+
layers=video_connector_num_layers,
132+
num_learnable_registers=video_connector_num_learnable_registers,
133+
rope_type=rope_type,
134+
theta=rope_theta,
135+
base_seq_len=connector_rope_base_seq_len,
136+
double_precision=rope_double_precision,
137+
attention_kernel=attention_kernel,
138+
mesh=mesh,
139+
rngs=rngs,
140+
gated_attn=video_gated_attn,
141+
)
142+
self.audio_embeddings_connector = Embeddings1DConnector(
143+
input_dim=a_dim,
144+
heads=audio_connector_num_attention_heads,
145+
head_dim=audio_connector_attention_head_dim,
146+
layers=audio_connector_num_layers,
147+
num_learnable_registers=audio_connector_num_learnable_registers,
148+
rope_type=rope_type,
149+
theta=rope_theta,
150+
base_seq_len=connector_rope_base_seq_len,
151+
double_precision=rope_double_precision,
152+
attention_kernel=attention_kernel,
153+
mesh=mesh,
154+
rngs=rngs,
155+
gated_attn=audio_gated_attn,
156+
)
119157

120158
def __call__(
121159
self,
122160
hidden_states: Union[Tuple[Array, ...], List[Array]],
123161
attention_mask: Array,
124-
) -> Tuple[Array, Array]:
162+
) -> Tuple[Array, Array, Array]:
125163
"""
126164
Returns:
127165
(video_embeds, audio_embeds, new_attention_mask)
128166
"""
129167
with jax.named_scope("Text Encoder Forward"):
130-
# 1. Shared Feature Extraction
131-
features = self.feature_extractor(hidden_states, attention_mask)
132-
133-
# 2. Parallel Connection
134168
if self.per_modality_projections:
135-
video_features, audio_features = features
136-
video_embeds, new_attention_mask = self.video_embeddings_connector(video_features, attention_mask)
137-
audio_embeds, _ = self.audio_embeddings_connector(audio_features, attention_mask)
169+
# 1. Stack Hidden States if needed
170+
if isinstance(hidden_states, (tuple, list)):
171+
x = jnp.stack(hidden_states, axis=-1)
172+
else:
173+
x = hidden_states
174+
175+
b, l, d, k = x.shape
176+
177+
# 2. Per-token RMS norm
178+
variance = jnp.mean(x**2, axis=2, keepdims=True)
179+
norm_text_encoder_hidden_states = x * jax.lax.rsqrt(variance + 1e-6)
180+
181+
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states.reshape(b, l, -1)
182+
183+
bool_mask = (attention_mask > 0.5).astype(jnp.float32)[..., None]
184+
norm_text_encoder_hidden_states = norm_text_encoder_hidden_states * bool_mask
185+
186+
# 3. Rescale norms
187+
# Using self.caption_channels if available, or fallback to config or 3840
188+
cap_channels = getattr(self, "caption_channels", getattr(self.config, "caption_channels", 3840))
189+
190+
video_scale_factor = jnp.sqrt(self.video_connector.dim / cap_channels)
191+
video_norm_text_emb = norm_text_encoder_hidden_states * video_scale_factor
192+
audio_scale_factor = jnp.sqrt(self.audio_connector.dim / cap_channels)
193+
audio_norm_text_emb = norm_text_encoder_hidden_states * audio_scale_factor
194+
195+
video_text_emb_proj = self.video_text_proj_in(video_norm_text_emb)
196+
audio_text_emb_proj = self.audio_text_proj_in(audio_norm_text_emb)
197+
198+
video_embeds, new_attention_mask = self.video_connector(video_text_emb_proj, attention_mask)
199+
audio_embeds, _ = self.audio_connector(audio_text_emb_proj, attention_mask)
138200
else:
201+
# 1. Shared Feature Extraction
202+
features = self.feature_extractor(hidden_states, attention_mask)
203+
204+
# 2. Parallel Connection
139205
video_embeds, new_attention_mask = self.video_embeddings_connector(features, attention_mask)
140206
audio_embeds, _ = self.audio_embeddings_connector(features, attention_mask)
141207

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -379,11 +379,15 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
379379
logical_state_spec = nnx.get_partition_spec(state)
380380
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
381381
logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding))
382-
params = state.to_pure_dict()
383-
state = dict(nnx.to_flat_state(state))
384-
385382
filename = "ltx-2.3-22b-dev.safetensors" if getattr(config, "model_name", "") == "ltx2.3" else None
386-
params = load_connectors_weights(config.pretrained_model_name_or_path, params, "cpu", subfolder="", filename=filename)
383+
params = load_connectors_weights(
384+
config.pretrained_model_name_or_path,
385+
params,
386+
"cpu",
387+
subfolder="",
388+
filename=filename,
389+
is_ltx2_3=(getattr(config, "model_name", "") == "ltx2.3"),
390+
)
387391
if hasattr(config, "weights_dtype"):
388392
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
389393

0 commit comments

Comments
 (0)