Skip to content

Commit e08d5af

Browse files
committed
more annotations added
1 parent 3391a5e commit e08d5af

1 file changed

Lines changed: 15 additions & 14 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1032,23 +1032,24 @@ def scan_fn(carry, block):
10321032
)
10331033

10341034
# 6. Output layers
1035-
scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2)
1036-
shift = scale_shift_values[:, :, 0, :]
1037-
scale = scale_shift_values[:, :, 1, :]
1035+
with jax.named_scope("Output Projection & Norm"):
1036+
scale_shift_values = jnp.expand_dims(self.scale_shift_table, axis=(0, 1)) + jnp.expand_dims(embedded_timestep, axis=2)
1037+
shift = scale_shift_values[:, :, 0, :]
1038+
scale = scale_shift_values[:, :, 1, :]
10381039

1039-
hidden_states = self.norm_out(hidden_states)
1040-
hidden_states = hidden_states * (1 + scale) + shift
1041-
output = self.proj_out(hidden_states)
1040+
hidden_states = self.norm_out(hidden_states)
1041+
hidden_states = hidden_states * (1 + scale) + shift
1042+
output = self.proj_out(hidden_states)
10421043

1043-
audio_scale_shift_values = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1)) + jnp.expand_dims(
1044-
audio_embedded_timestep, axis=2
1045-
)
1046-
audio_shift = audio_scale_shift_values[:, :, 0, :]
1047-
audio_scale = audio_scale_shift_values[:, :, 1, :]
1044+
audio_scale_shift_values = jnp.expand_dims(self.audio_scale_shift_table, axis=(0, 1)) + jnp.expand_dims(
1045+
audio_embedded_timestep, axis=2
1046+
)
1047+
audio_shift = audio_scale_shift_values[:, :, 0, :]
1048+
audio_scale = audio_scale_shift_values[:, :, 1, :]
10481049

1049-
audio_hidden_states = self.audio_norm_out(audio_hidden_states)
1050-
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
1051-
audio_output = self.audio_proj_out(audio_hidden_states)
1050+
audio_hidden_states = self.audio_norm_out(audio_hidden_states)
1051+
audio_hidden_states = audio_hidden_states * (1 + audio_scale) + audio_shift
1052+
audio_output = self.audio_proj_out(audio_hidden_states)
10521053

10531054
if not return_dict:
10541055
return (output, audio_output)

0 commit comments

Comments
 (0)