@@ -195,7 +195,7 @@ def prepare_video_coords(
195195 # pixel_coords[:, 0, ...] selects Frame dimension.
196196 # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
197197 frame_coords = pixel_coords [:, 0 , ...]
198- frame_coords = jnp .clip (frame_coords + self .causal_offset - self .scale_factors [0 ], a_min = 0 )
198+ frame_coords = jnp .clip (frame_coords + self .causal_offset - self .scale_factors [0 ], min = 0 )
199199 pixel_coords = pixel_coords .at [:, 0 , ...].set (frame_coords / fps )
200200
201201 return pixel_coords
@@ -212,12 +212,12 @@ def prepare_audio_coords(
212212 # 2. Start timestamps
213213 audio_scale_factor = self .scale_factors [0 ]
214214 grid_start_mel = grid_f * audio_scale_factor
215- grid_start_mel = jnp .clip (grid_start_mel + self .causal_offset - audio_scale_factor , a_min = 0 )
215+ grid_start_mel = jnp .clip (grid_start_mel + self .causal_offset - audio_scale_factor , min = 0 )
216216 grid_start_s = grid_start_mel * self .hop_length / self .sampling_rate
217217
218218 # 3. End timestamps
219219 grid_end_mel = (grid_f + self .patch_size_t ) * audio_scale_factor
220- grid_end_mel = jnp .clip (grid_end_mel + self .causal_offset - audio_scale_factor , a_min = 0 )
220+ grid_end_mel = jnp .clip (grid_end_mel + self .causal_offset - audio_scale_factor , min = 0 )
221221 grid_end_s = grid_end_mel * self .hop_length / self .sampling_rate
222222
223223 # Stack [num_patches, 2]
0 commit comments