@@ -53,11 +53,13 @@ class QKNorm(nn.Module):
5353 def __call__ (self , q : Array , k : Array , v : Array ) -> tuple [Array , Array ]:
5454 q = nn .RMSNorm (
5555 dtype = self .dtype ,
56- param_dtype = self .weights_dtype
56+ param_dtype = self .weights_dtype ,
57+ name = "query_norm"
5758 )(q )
5859 k = nn .RMSNorm (
5960 dtype = self .dtype ,
60- param_dtype = self .weights_dtype
61+ param_dtype = self .weights_dtype ,
62+ name = "key_norm"
6163 )(k )
6264 return q , k
6365
@@ -173,7 +175,8 @@ def __call__(self, vec: Array) -> tuple[ModulationOut, ModulationOut | None]:
173175 kernel_init = nn .with_logical_partitioning (
174176 nn .initializers .lecun_normal (),
175177 ("embed" , "heads" )
176- )
178+ ),
179+ name = "lin"
177180 )(nn .silu (vec ))
178181
179182 out = jnp .split (lin [:, None , :], multiplier , axis = - 1 )
@@ -205,14 +208,16 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array:
205208 double = False ,
206209 dtype = self .dtype ,
207210 weights_dtype = self .weights_dtype ,
208- precision = self .precision
211+ precision = self .precision ,
212+ name = "modulation"
209213 )(vec )
210214 x_mod = (1 + mod .scale ) * nn .LayerNorm (
211215 use_scale = False ,
212216 use_bias = False ,
213217 epsilon = 1e-6 ,
214218 dtype = self .dtype ,
215- param_dtype = self .weights_dtype
219+ param_dtype = self .weights_dtype ,
220+ name = "pre_norm"
216221 )(x ) + mod .shift
217222
218223 x_mod = nn .Dense (
@@ -231,7 +236,8 @@ def __call__(self, x: Array, vec: Array, pe: Array) -> Array:
231236 q , k , v = rearrange (qkv , "B L (K H D) -> K B H L D" , K = 3 , H = self .num_heads )
232237 q , k = QKNorm (
233238 dtype = self .dtype ,
234- weights_dtype = self .weights_dtype
239+ weights_dtype = self .weights_dtype ,
240+ name = "norm"
235241 )(q , k , v )
236242
237243 q , k = apply_rope (q , k , pe )
@@ -286,15 +292,17 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
286292 double = True ,
287293 dtype = self .dtype ,
288294 weights_dtype = self .weights_dtype ,
289- precision = self .precision
295+ precision = self .precision ,
296+ name = "img_mod"
290297 )(vec )
291298
292299 txt_mod1 , txt_mod2 = Modulation (
293300 self .hidden_size ,
294301 double = True ,
295302 dtype = self .dtype ,
296303 weights_dtype = self .weights_dtype ,
297- precision = self .precision
304+ precision = self .precision ,
305+ name = "txt_mod"
298306 )(vec )
299307
300308 # prepare image for attention
@@ -303,7 +311,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
303311 use_bias = False ,
304312 epsilon = 1e-6 ,
305313 dtype = self .dtype ,
306- param_dtype = self .weights_dtype
314+ param_dtype = self .weights_dtype ,
315+ name = "img_norm1"
307316 )(img )
308317 img_modulated = (1 + img_mod1 .scale ) * img_modulated + img_mod1 .shift
309318 img_qkv = nn .Dense (
@@ -315,14 +324,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
315324 kernel_init = nn .with_logical_partitioning (
316325 nn .initializers .lecun_normal (),
317326 ("embed" , "heads" )
318- )
327+ ),
328+ name = "img_attn_qkv"
319329 )(img_modulated )
320330 img_q , img_k , img_v = rearrange (
321331 img_qkv , "B L (K H D) -> K B H L D" , K = 3 , H = self .num_heads
322332 )
323333 img_q , img_k = QKNorm (
324334 dtype = self .dtype ,
325- weights_dtype = self .weights_dtype
335+ weights_dtype = self .weights_dtype ,
336+ name = "img_attn_norm"
326337 )(img_q , img_k , img_v )
327338
328339 # prepare text for attention
@@ -331,7 +342,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
331342 use_bias = False ,
332343 epsilon = 1e-6 ,
333344 dtype = self .dtype ,
334- param_dtype = self .weights_dtype
345+ param_dtype = self .weights_dtype ,
346+ name = "txt_norm1"
335347 )(txt )
336348 txt_modulated = (1 + txt_mod1 .scale ) * txt_modulated + txt_mod1 .shift
337349 txt_qkv = nn .Dense (
@@ -343,14 +355,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
343355 kernel_init = nn .with_logical_partitioning (
344356 nn .initializers .lecun_normal (),
345357 ("embed" , "heads" )
346- )
358+ ),
359+ name = "txt_attn_qkv"
347360 )(txt_modulated )
348361 txt_q , txt_k , txt_v = rearrange (
349362 txt_qkv , "B L (K H D) -> K B H L D" , K = 3 , H = self .num_heads
350363 )
351364 txt_q , txt_k = QKNorm (
352365 dtype = self .dtype ,
353- weights_dtype = self .weights_dtype
366+ weights_dtype = self .weights_dtype ,
367+ name = "txt_attn_norm"
354368 )(txt_q , txt_k , txt_v )
355369
356370 # run actual attention
@@ -385,6 +399,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
385399 nn .initializers .lecun_normal (),
386400 ("heads" , "embed" )
387401 ),
402+ name = "img_attn_proj"
388403 )(img_attn )
389404 img = img + img_mod2 .gate * nn .Sequential (
390405 [
@@ -397,7 +412,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
397412 kernel_init = nn .with_logical_partitioning (
398413 nn .initializers .lecun_normal (),
399414 ("embed" , "heads" )
400- )
415+ ),
416+ name = "img_mlp_0"
401417 ),
402418 nn .gelu ,
403419 nn .Dense (
@@ -408,14 +424,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
408424 kernel_init = nn .with_logical_partitioning (
409425 nn .initializers .lecun_normal (),
410426 ("heads" , "embed" )
411- )
412- )
413- ]
427+ ),
428+ name = "img_mlp_2"
429+ ),
430+ ],
414431 )(
415432 (1 + img_mod2 .scale ) * nn .LayerNorm (
416433 use_scale = False ,
417434 use_bias = False ,
418- param_dtype = self .weights_dtype
435+ param_dtype = self .weights_dtype ,
436+ name = "img_norm2"
419437 )(img ) + img_mod2 .shift
420438 )
421439
@@ -430,6 +448,7 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
430448 nn .initializers .lecun_normal (),
431449 ("heads" , "embed" )
432450 ),
451+ name = "txt_attn_proj"
433452 )(txt_attn )
434453 txt = txt + txt_mod1 .gate * txt_proj
435454
@@ -444,7 +463,8 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
444463 kernel_init = nn .with_logical_partitioning (
445464 nn .initializers .lecun_normal (),
446465 ("embed" , "heads" )
447- )
466+ ),
467+ name = "txt_mlp_0"
448468 ),
449469 nn .gelu ,
450470 nn .Dense (
@@ -455,14 +475,16 @@ def __call__(self, img: Array, txt: Array, vec: Array, pe: Array) -> tuple[Array
455475 kernel_init = nn .with_logical_partitioning (
456476 nn .initializers .lecun_normal (),
457477 ("heads" , "embed" )
458- )
459- )
460- ]
478+ ),
479+ name = "txt_mlp_2"
480+ ),
481+ ],
461482 )(
462483 (1 + txt_mod2 .scale ) * nn .LayerNorm (
463484 use_scale = False ,
464485 use_bias = False ,
465- param_dtype = self .weights_dtype
486+ param_dtype = self .weights_dtype ,
487+ name = "txt_norm2"
466488 )(txt ) + txt_mod2 .shift
467489 )
468490
@@ -491,8 +513,9 @@ def __call__(self, x: Array, vec: Array) -> Array:
491513 kernel_init = nn .with_logical_partitioning (
492514 nn .initializers .lecun_normal (),
493515 ("embed" , "heads" )
494- )
495- )
516+ ),
517+ name = "adaLN_modulation_1"
518+ ),
496519 ]
497520 )(vec ), 2 , axis = 1
498521 )
@@ -515,5 +538,5 @@ def __call__(self, x: Array, vec: Array) -> Array:
515538 ("heads" , "embed" )
516539 ),
517540 name = "linear"
518- )
541+ )( x )
519542 return x
0 commit comments