Skip to content

Commit 36534a2

Browse files
committed
fix
1 parent 1423e0a commit 36534a2

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def prepare_video_coords(
195195
# pixel_coords[:, 0, ...] selects Frame dimension.
196196
# pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
197197
frame_coords = pixel_coords[:, 0, ...]
198-
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], a_min=0)
198+
frame_coords = jnp.clip(frame_coords + self.causal_offset - self.scale_factors[0], min=0)
199199
pixel_coords = pixel_coords.at[:, 0, ...].set(frame_coords / fps)
200200

201201
return pixel_coords
@@ -212,12 +212,12 @@ def prepare_audio_coords(
212212
# 2. Start timestamps
213213
audio_scale_factor = self.scale_factors[0]
214214
grid_start_mel = grid_f * audio_scale_factor
215-
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, a_min=0)
215+
grid_start_mel = jnp.clip(grid_start_mel + self.causal_offset - audio_scale_factor, min=0)
216216
grid_start_s = grid_start_mel * self.hop_length / self.sampling_rate
217217

218218
# 3. End timestamps
219219
grid_end_mel = (grid_f + self.patch_size_t) * audio_scale_factor
220-
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, a_min=0)
220+
grid_end_mel = jnp.clip(grid_end_mel + self.causal_offset - audio_scale_factor, min=0)
221221
grid_end_s = grid_end_mel * self.hop_length / self.sampling_rate
222222

223223
# Stack [num_patches, 2]

0 commit comments

Comments
 (0)