Skip to content

Commit 76dec0e

Browse files
committed
weight dtype fix in pipeline
1 parent df7e8dc commit 76dec0e

2 files changed

Lines changed: 6 additions & 36 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,7 @@
2020
import flax.linen as nn
2121
import numpy as np
2222

23-
printed_count = 0
24-
def print_shape(name, tensor):
25-
global printed_count
26-
if printed_count > 1000:
27-
return
28-
if tensor is not None:
29-
def _print_fn(n, t):
30-
t_np = np.array(t, dtype=np.float32)
31-
print(f"[{n}] min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}")
32-
if isinstance(tensor, jax.core.Tracer):
33-
jax.debug.callback(_print_fn, name, tensor)
34-
else:
35-
_print_fn(name, tensor)
36-
printed_count += 1
23+
3724

3825
from maxdiffusion.models.ltx2.attention_ltx2 import LTX2Attention, LTX2RotaryPosEmbed
3926
from maxdiffusion.models.attention_flax import NNXSimpleFeedForward
@@ -358,16 +345,7 @@ def __call__(
358345
) -> Tuple[jax.Array, jax.Array]:
359346
batch_size = hidden_states.shape[0]
360347

361-
print_shape("Block Input hidden_states", hidden_states)
362-
print_shape("Block Input audio_hidden_states", audio_hidden_states)
363-
print_shape("Block Input encoder_hidden_states", encoder_hidden_states)
364-
print_shape("Block Input audio_encoder_hidden_states", audio_encoder_hidden_states)
365-
print_shape("Block Input temb", temb)
366-
print_shape("Block Input temb_audio", temb_audio)
367-
print_shape("Block Input temb_ca_scale_shift", temb_ca_scale_shift)
368-
print_shape("Block Input temb_ca_audio_scale_shift", temb_ca_audio_scale_shift)
369-
print_shape("Block Input temb_ca_gate", temb_ca_gate)
370-
print_shape("Block Input temb_ca_audio_gate", temb_ca_audio_gate)
348+
371349

372350
axis_names = nn.logical_to_mesh_axes(("activation_batch", "activation_length", "activation_embed"))
373351
hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names)
@@ -397,12 +375,8 @@ def __call__(
397375
scale_mlp = ada_values[:, :, 4, :]
398376
gate_mlp = ada_values[:, :, 5, :]
399377

400-
print_shape("shift_msa", shift_msa)
401-
print_shape("scale_msa", scale_msa)
402-
print_shape("gate_msa", gate_msa)
403-
print_shape("shift_mlp", shift_mlp)
404-
print_shape("scale_mlp", scale_mlp)
405-
print_shape("gate_mlp", gate_mlp)
378+
379+
406380

407381
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
408382

@@ -923,11 +897,7 @@ def __call__(
923897
audio_encoder_attention_mask = jnp.expand_dims(audio_encoder_attention_mask, axis=1)
924898

925899
batch_size = hidden_states.shape[0]
926-
print_shape("Model Input hidden_states", hidden_states)
927-
print_shape("Model Input audio_hidden_states", audio_hidden_states)
928-
print_shape("Model Input encoder_hidden_states", encoder_hidden_states)
929-
print_shape("Model Input audio_encoder_hidden_states", audio_encoder_hidden_states)
930-
print_shape("Model Input timestep", timestep)
900+
931901

932902
# 1. Prepare RoPE positional embeddings
933903
if video_coords is None:

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
303303
subfolder="connectors",
304304
rngs=rngs,
305305
mesh=mesh,
306-
dtype=jnp.float32,
306+
dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
307307
weights_dtype=config.weights_dtype if hasattr(config, "weights_dtype") else jnp.float32,
308308
)
309309
return connectors

0 commit comments

Comments
 (0)