2525DType = common_types .DType
2626BlockSizes = common_types .BlockSizes
2727
28+
2829def apply_rotary_emb (x : Array , freqs : Tuple [Array , Array ]) -> Array :
2930 """
3031 Applies Interleaved RoPE to input x.
@@ -193,9 +194,7 @@ def prepare_video_coords(
193194 # pixel_coords[:, 0, ...] selects Frame dimension.
194195 # pixel_coords shape: [B, 3, num_patches, 2] -> dim 1 is (F, H, W)
195196 frame_coords = pixel_coords [:, 0 , ...]
196- frame_coords = jnp .clip (
197- frame_coords + self .causal_offset - self .scale_factors [0 ], min = 0
198- )
197+ frame_coords = jnp .clip (frame_coords + self .causal_offset - self .scale_factors [0 ], min = 0 )
199198 pixel_coords = pixel_coords .at [:, 0 , ...].set (frame_coords / fps )
200199
201200 return pixel_coords
@@ -212,16 +211,12 @@ def prepare_audio_coords(
212211 # 2. Start timestamps
213212 audio_scale_factor = self .scale_factors [0 ]
214213 grid_start_mel = grid_f * audio_scale_factor
215- grid_start_mel = jnp .clip (
216- grid_start_mel + self .causal_offset - audio_scale_factor , min = 0
217- )
214+ grid_start_mel = jnp .clip (grid_start_mel + self .causal_offset - audio_scale_factor , min = 0 )
218215 grid_start_s = grid_start_mel * self .hop_length / self .sampling_rate
219216
220217 # 3. End timestamps
221218 grid_end_mel = (grid_f + self .patch_size_t ) * audio_scale_factor
222- grid_end_mel = jnp .clip (
223- grid_end_mel + self .causal_offset - audio_scale_factor , min = 0
224- )
219+ grid_end_mel = jnp .clip (grid_end_mel + self .causal_offset - audio_scale_factor , min = 0 )
225220 grid_end_s = grid_end_mel * self .hop_length / self .sampling_rate
226221
227222 # Stack [num_patches, 2]
0 commit comments