Skip to content

Commit efc2681

Browse files
committed
moving run_connectors to outside of __call__
1 parent f5608b7 commit efc2681

1 file changed

Lines changed: 7 additions & 7 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,12 @@ def _unpack_audio_latents(
976976
latents = latents.transpose(0, 2, 1, 3)
977977
return latents
978978

979+
@staticmethod
980+
@jax.jit
981+
def run_connectors(graphdef, state, hidden_states, attention_mask):
982+
model = nnx.merge(graphdef, state)
983+
return model(hidden_states, attention_mask)
984+
979985
def prepare_latents(
980986
self,
981987
batch_size: int = 1,
@@ -1223,13 +1229,7 @@ def __call__(
12231229
with context_manager, axis_rules_context:
12241230
connectors_graphdef, connectors_state = nnx.split(self.connectors)
12251231

1226-
@staticmethod
1227-
@jax.jit
1228-
def run_connectors(graphdef, state, hidden_states, attention_mask):
1229-
model = nnx.merge(graphdef, state)
1230-
return model(hidden_states, attention_mask)
1231-
1232-
video_embeds, audio_embeds, new_attention_mask = run_connectors(
1232+
video_embeds, audio_embeds, new_attention_mask = self.run_connectors(
12331233
connectors_graphdef, connectors_state, prompt_embeds_jax, prompt_attention_mask_jax.astype(jnp.bool_)
12341234
)
12351235

0 commit comments

Comments
 (0)