Skip to content

Commit 8f05f7c

Browse files
committed
Cleaned up ltx2 transformer tests and implementations
1 parent 218f16f commit 8f05f7c

3 files changed

Lines changed: 8 additions & 4 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
"""
2+
This is a test file used for ensuring numerical parity between pytorch and jax implementation of LTX2.
3+
This is to be ignored and will not be pushed when commiting to main branch.
4+
"""
15
from typing import Optional, Tuple, Any, Dict, Union
26
import jax
37
import jax.numpy as jnp
@@ -243,7 +247,7 @@ def __init__(
243247
rngs=rngs,
244248
dim=dim,
245249
dim_out=dim,
246-
activation_fn=activation_fn, # Diffusers uses gelu
250+
activation_fn=activation_fn,
247251
dtype=dtype,
248252
weights_dtype=weights_dtype,
249253
)
@@ -633,7 +637,7 @@ def __init__(
633637
weights_dtype=self.weights_dtype,
634638
)
635639

636-
# 3.3. Output Layer Scale/Shift Modulation parameters
640+
# 3. Output Layer Scale/Shift Modulation parameters
637641
param_rng = rngs.params()
638642
self.scale_shift_table = nnx.Param(
639643
jax.random.normal(param_rng, (2, inner_dim), dtype=self.weights_dtype) / jnp.sqrt(inner_dim),
@@ -816,7 +820,7 @@ def __call__(
816820
audio_coords: Optional[jax.Array] = None,
817821
attention_kwargs: Optional[Dict[str, Any]] = None,
818822
return_dict: bool = True,
819-
) -> Any: # Should be AudioVisualModelOutput or Tuple
823+
) -> Any:
820824
# Determine timestep for audio.
821825
audio_timestep = audio_timestep if audio_timestep is not None else timestep
822826

@@ -974,4 +978,4 @@ def scan_fn(carry, block):
974978

975979
if not return_dict:
976980
return (output, audio_output)
977-
return {"sample": output, "audio_sample": audio_output} # Placeholder
981+
return {"sample": output, "audio_sample": audio_output}
File renamed without changes.
File renamed without changes.

0 commit comments

Comments
 (0)