Skip to content

Commit 866c3d0

Browse files
committed
transformer fix
1 parent 2374f14 commit 866c3d0

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
limitations under the License.
1515
"""
1616

17-
from typing import Optional, Tuple, Any, Dict
17+
from typing import Optional, Tuple, Any, Dict, List
1818
import jax
1919
import jax.numpy as jnp
2020
from flax import nnx
@@ -998,6 +998,7 @@ def __call__(
998998
use_cross_timestep: bool = False,
999999
modality_mask: Optional[jax.Array] = None,
10001000
isolate_modalities: bool = False,
1001+
spatio_temporal_guidance_blocks: Optional[List[int]] = None,
10011002
return_dict: bool = True,
10021003
perturbation_mask: Optional[jax.Array] = None,
10031004
) -> Any:

0 commit comments

Comments
 (0)