Skip to content

Commit 00517d9

Browse files
committed
reformatted
1 parent eb1f7b8 commit 00517d9

1 file changed

Lines changed: 4 additions & 9 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
DType = common_types.DType
2626
BlockSizes = common_types.BlockSizes
2727

28+
2829
def apply_rotary_emb(x: Array, freqs: Tuple[Array, Array]) -> Array:
2930
"""
3031
Applies Interleaved RoPE to input x.
@@ -193,9 +194,7 @@ def prepare_video_coords(
193194
# pixel_coords[:, 0, ...] selects Frame dimension.
194195
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
195196
frame_coords = pixel_coords[:, 0, ...]
196-
frame_coords = jnp.clip(
197-
frame_coords + self.causal_offset - self.scale_factors[0], min=0
198-
)
197+
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0)
199198
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)
200199

201200
return pixel_coords
@@ -212,16 +211,12 @@ def prepare_audio_coords(
212211
# 2. Start timestamps
213212
audio_scale_factor = self.scale_factors[0]
214213
grid_start_mel = grid_f * audio_scale_factor
215-
grid_start_mel = jnp.clip(
216-
grid_start_mel + self.causal_offset - audio_scale_factor, min=0
217-
)
214+
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0)
218215
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
219216

220217
# 3. End timestamps
221218
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
222-
grid_end_mel = jnp.clip(
223-
grid_end_mel + self.causal_offset - audio_scale_factor, min=0
224-
)
219+
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0)
225220
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
226221

227222
# Stack [num_patches, 2]

0 commit comments

Comments
 (0)