@@ -193,9 +193,7 @@ def prepare_video_coords(
193193 # pixel_coords[:, 0, ...] selects Frame dimension.
194194 # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
195195 frame_coords = pixel_coords [:, 0 , ...]
196- frame_coords = jnp .clip (
197- frame_coords + self .causal_offset - self .scale_factors [0 ], min = 0
198- )
196+ frame_coords = jnp .clip (frame_coords + self .causal_offset - self .scale_factors [0 ], min = 0 )
199197 pixel_coords = pixel_coords .at [:, 0 , ...].set (frame_coords / fps )
200198
201199 return pixel_coords
@@ -212,16 +210,12 @@ def prepare_audio_coords(
212210 # 2. Start timestamps
213211 audio_scale_factor = self .scale_factors [0 ]
214212 grid_start_mel = grid_f * audio_scale_factor
215- grid_start_mel = jnp .clip (
216- grid_start_mel + self .causal_offset - audio_scale_factor , min = 0
217- )
213+ grid_start_mel = jnp .clip (grid_start_mel + self .causal_offset - audio_scale_factor , min = 0 )
218214 grid_start_s = grid_start_mel * self .hop_length / self .sampling_rate
219215
220216 # 3. End timestamps
221217 grid_end_mel = (grid_f + self .patch_size_t ) * audio_scale_factor
222- grid_end_mel = jnp .clip (
223- grid_end_mel + self .causal_offset - audio_scale_factor , min = 0
224- )
218+ grid_end_mel = jnp .clip (grid_end_mel + self .causal_offset - audio_scale_factor , min = 0 )
225219 grid_end_s = grid_end_mel * self .hop_length / self .sampling_rate
226220
227221 # Stack [num_patches, 2]
0 commit comments