Skip to content

Commit 7375d6e

Browse files
committed
fix: reformat attention_ltx2.py jnp.clip lines to pass pyink formatter
1 parent 5823603 commit 7375d6e

1 file changed

Lines changed: 3 additions & 9 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -193,9 +193,7 @@ def prepare_video_coords(
193193
# pixel_coords[:, 0, ...] selects Frame dimension.
194194
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
195195
frame_coords = pixel_coords[:, 0, ...]
196-
frame_coords = jnp.clip(
197-
frame_coords + self.causal_offset - self.scale_factors[0], min=0
198-
)
196+
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0)
199197
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)
200198

201199
return pixel_coords
@@ -212,16 +210,12 @@ def prepare_audio_coords(
212210
# 2. Start timestamps
213211
audio_scale_factor = self.scale_factors[0]
214212
grid_start_mel = grid_f * audio_scale_factor
215-
grid_start_mel = jnp.clip(
216-
grid_start_mel + self.causal_offset - audio_scale_factor, min=0
217-
)
213+
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0)
218214
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
219215

220216
# 3. End timestamps
221217
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
222-
grid_end_mel = jnp.clip(
223-
grid_end_mel + self.causal_offset - audio_scale_factor, min=0
224-
)
218+
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0)
225219
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
226220

227221
# Stack [num_patches, 2]

0 commit comments

Comments
 (0)