11"""
2- Copyright 2025 Google LLC
2+ Copyright 2026 Google LLC
33
44Licensed under the Apache License, Version 2.0 (the "License");
55you may not use this file except in compliance with the License.
@@ -237,14 +237,21 @@ def prepare_coords(self, *args, **kwargs):
237237 return None
238238
239239 def __call__ (self , coords : Array ) -> Tuple [Array , Array ]:
240- # coords: [B, num_pos_dims, num_patches, 2]
241- num_pos_dims = coords .shape [1 ]
242-
243- # 1. Midpoint
240+ # Handle both [B, num_pos_dims, num_patches, 2] (from prepare_coords)
241+ # and [B, num_patches, num_pos_dims] (raw grid coordinates)
244242 if coords .ndim == 4 :
243+ num_pos_dims = coords .shape [1 ]
244+ # 1. Midpoint
245245 coords_start = coords [..., 0 ]
246246 coords_end = coords [..., 1 ]
247247 coords = (coords_start + coords_end ) / 2.0 # [B, num_pos_dims, num_patches]
248+ # Transpose to standardize layout: [B, num_patches, num_pos_dims]
249+ grid = coords .transpose (0 , 2 , 1 )
250+ elif coords .ndim == 3 :
251+ num_pos_dims = coords .shape [- 1 ]
252+ grid = coords # Already [B, num_patches, num_pos_dims]
253+ else :
254+ raise ValueError (f"coords must be 3D or 4D, got { coords .ndim } D" )
248255
249256 # 2. Fractions
250257 if self .modality == "video" :
@@ -253,10 +260,11 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
253260 max_positions = jnp .array ((self .base_num_frames ,), dtype = coords .dtype )
254261
255262 max_positions = max_positions [:num_pos_dims ]
256- max_positions = max_positions .reshape (1 , num_pos_dims , 1 )
257- grid = coords / max_positions
258-
259- grid = grid .transpose (0 , 2 , 1 )
263+ # Reshape to broadcast with [B, num_patches, num_pos_dims]
264+ max_positions = max_positions .reshape (1 , 1 , num_pos_dims )
265+
266+ # Scale to [0, 1]
267+ grid = grid / max_positions
260268
261269 num_rope_elems = num_pos_dims * 2
262270
@@ -265,12 +273,19 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
265273 # linspace 0..1
266274 steps = self .dim // num_rope_elems
267275 pow_indices = jnp .power (self .theta , jnp .linspace (0.0 , 1.0 , steps , dtype = freqs_dtype ))
268- freqs = (pow_indices * jnp .pi / 2.0 ).astype (jnp .float32 ) # [D//2K ]
276+ base_freqs = (pow_indices * jnp .pi / 2.0 ).astype (jnp .float32 ) # [steps ]
269277
270278 # 4. Outer product
271- freqs = (jnp .expand_dims (grid , - 1 ) * 2 - 1 ) * freqs
272-
273- # Flatten last two dims: K, S -> K*S = dim//2
279+ # Map grid [0, 1] -> [-1, 1]
280+ scaled_grid = grid * 2.0 - 1.0 # [B, num_patches, num_pos_dims]
281+
282+ # [B, num_patches, num_pos_dims, 1] * [steps] -> [B, num_patches, num_pos_dims, steps]
283+ freqs = jnp .expand_dims (scaled_grid , - 1 ) * base_freqs
284+
285+ # CRITICAL: Transpose the last two dimensions to exactly match Diffusers flattening order!
286+ freqs = jnp .swapaxes (freqs , - 1 , - 2 ) # [B, num_patches, steps, num_pos_dims]
287+
288+ # Flatten last two dims -> [B, num_patches, dim // 2]
274289 freqs = freqs .reshape (* freqs .shape [:2 ], - 1 )
275290
276291 # 5. Cos/Sin
@@ -294,25 +309,22 @@ def __call__(self, coords: Array) -> Tuple[Array, Array]:
294309
295310 elif self .rope_type == "split" :
296311 # Cos/Sin
297- cos_freq = jnp .cos (freqs )
298- sin_freq = jnp .sin (freqs )
299-
300- curr_dim = cos_freq .shape [- 1 ]
312+ curr_dim = cos_freqs .shape [- 1 ]
301313 expected_dim = self .dim // 2
302314 pad_size = expected_dim - curr_dim
303315
304316 if pad_size > 0 :
305- cos_padding = jnp .ones ((* cos_freq .shape [:- 1 ], pad_size ), dtype = cos_freq .dtype )
306- sin_padding = jnp .zeros ((* sin_freq .shape [:- 1 ], pad_size ), dtype = sin_freq .dtype )
307- cos_freq = jnp .concatenate ([cos_padding , cos_freq ], axis = - 1 )
308- sin_freq = jnp .concatenate ([sin_padding , sin_freq ], axis = - 1 )
317+ cos_padding = jnp .ones ((* cos_freqs .shape [:- 1 ], pad_size ), dtype = cos_freqs .dtype )
318+ sin_padding = jnp .zeros ((* sin_freqs .shape [:- 1 ], pad_size ), dtype = sin_freqs .dtype )
319+ cos_freqs = jnp .concatenate ([cos_padding , cos_freqs ], axis = - 1 )
320+ sin_freqs = jnp .concatenate ([sin_padding , sin_freqs ], axis = - 1 )
309321
310- b = cos_freq .shape [0 ]
311- s = cos_freq .shape [1 ]
322+ b = cos_freqs .shape [0 ]
323+ s = cos_freqs .shape [1 ]
312324 h = self .num_attention_heads
313325
314- cos_freqs = cos_freq .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
315- sin_freqs = sin_freq .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
326+ cos_freqs = cos_freqs .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
327+ sin_freqs = sin_freqs .reshape (b , s , h , - 1 ).transpose (0 , 2 , 1 , 3 )
316328
317329 return cos_freqs , sin_freqs
318330
@@ -341,24 +353,39 @@ def __init__(
341353 self .inner_dim = dim_head * heads
342354 self .dropout_rate = dropout
343355
344- # 1. Projections
345- self .to_q = nnx .Linear (query_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
356+
357+ # 1. Define Partitioned Initializers (Logical Axes)
358+ # Q, K, V kernels: [in_features (embed), out_features (heads)]
359+ qkv_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("embed" , "heads" ))
360+ # Q, K, V biases: [out_features (heads)]
361+ qkv_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), ("heads" ,))
362+
363+ # Out kernel: [in_features (heads), out_features (embed)]
364+ out_kernel_init = nnx .with_partitioning (nnx .initializers .lecun_normal (), ("heads" , "embed" ))
365+ # Out bias: [out_features (embed)]
366+ out_bias_init = nnx .with_partitioning (nnx .initializers .zeros_init (), ("embed" ,))
367+
368+ # Norm scales
369+ norm_scale_init = nnx .with_partitioning (nnx .initializers .ones_init (), ("norm" ,))
370+
371+ # 2. Projections
372+ self .to_q = nnx .Linear (query_dim , self .inner_dim , use_bias = bias , kernel_init = qkv_kernel_init , bias_init = qkv_bias_init , rngs = rngs , dtype = dtype )
346373
347374 # Handle Self vs Cross Attention input dims
348375 kv_dim = context_dim if context_dim is not None else query_dim
349- self .to_k = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
350- self .to_v = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , rngs = rngs , dtype = dtype )
376+ self .to_k = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , kernel_init = qkv_kernel_init , bias_init = qkv_bias_init , rngs = rngs , dtype = dtype )
377+ self .to_v = nnx .Linear (kv_dim , self .inner_dim , use_bias = bias , kernel_init = qkv_kernel_init , bias_init = qkv_bias_init , rngs = rngs , dtype = dtype )
351378
352- # 2 . Normalization (Applied to full inner_dim, NOT per-head)
379+ # 3 . Normalization (Applied to full inner_dim, NOT per-head)
353380 self .norm_q = nnx .RMSNorm (
354- self .inner_dim , epsilon = eps , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , rngs = rngs
381+ self .inner_dim , epsilon = eps , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , scale_init = norm_scale_init , rngs = rngs
355382 )
356383 self .norm_k = nnx .RMSNorm (
357- self .inner_dim , epsilon = eps , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , rngs = rngs
384+ self .inner_dim , epsilon = eps , dtype = jnp .float32 , param_dtype = jnp .float32 , use_scale = True , scale_init = norm_scale_init , rngs = rngs
358385 )
359386
360- # 3 . Output
361- self .to_out = nnx .Linear (self .inner_dim , query_dim , use_bias = out_bias , rngs = rngs , dtype = dtype )
387+ # 4 . Output
388+ self .to_out = nnx .Linear (self .inner_dim , query_dim , use_bias = out_bias , kernel_init = out_kernel_init , bias_init = out_bias_init , rngs = rngs , dtype = dtype )
362389
363390 if self .dropout_rate > 0 :
364391 self .dropout_layer = nnx .Dropout (self .dropout_rate , rngs = rngs )
0 commit comments