|
| 1 | +""" |
| 2 | +This is a test file used for ensuring numerical parity between pytorch and jax implementation of LTX2. |
| 3 | +This is to be ignored and will not be pushed when commiting to main branch. |
| 4 | +""" |
1 | 5 | from typing import Optional, Tuple, Any, Dict, Union |
2 | 6 | import jax |
3 | 7 | import jax.numpy as jnp |
@@ -243,7 +247,7 @@ def __init__( |
243 | 247 | rngs=rngs, |
244 | 248 | dim=dim, |
245 | 249 | dim_out=dim, |
246 | | - activation_fn=activation_fn, # Diffusers uses gelu |
| 250 | + activation_fn=activation_fn, |
247 | 251 | dtype=dtype, |
248 | 252 | weights_dtype=weights_dtype, |
249 | 253 | ) |
@@ -633,7 +637,7 @@ def __init__( |
633 | 637 | weights_dtype=self.weights_dtype, |
634 | 638 | ) |
635 | 639 |
|
636 | | - # 3.3. Output Layer Scale/Shift Modulation parameters |
| 640 | + # 3. Output Layer Scale/Shift Modulation parameters |
637 | 641 | param_rng = rngs.params() |
638 | 642 | self.scale_shift_table = nnx.Param( |
639 | 643 | jax.random.normal(param_rng, (2, inner_dim), dtype=self.weights_dtype) / jnp.sqrt(inner_dim), |
@@ -816,7 +820,7 @@ def __call__( |
816 | 820 | audio_coords: Optional[jax.Array] = None, |
817 | 821 | attention_kwargs: Optional[Dict[str, Any]] = None, |
818 | 822 | return_dict: bool = True, |
819 | | - ) -> Any: # Should be AudioVisualModelOutput or Tuple |
| 823 | + ) -> Any: |
820 | 824 | # Determine timestep for audio. |
821 | 825 | audio_timestep = audio_timestep if audio_timestep is not None else timestep |
822 | 826 |
|
@@ -974,4 +978,4 @@ def scan_fn(carry, block): |
974 | 978 |
|
975 | 979 | if not return_dict: |
976 | 980 | return (output, audio_output) |
977 | | - return {"sample": output, "audio_sample": audio_output} # Placeholder |
| 981 | + return {"sample": output, "audio_sample": audio_output} |
0 commit comments