@@ -383,48 +383,43 @@ def setup(self):
383383 precision = self .precision ,
384384 )
385385
386- self .double_blocks = nn .Sequential (
387- [
388- * [
389- FluxTransformerBlock (
390- dim = self .inner_dim ,
391- num_attention_heads = self .num_attention_heads ,
392- attention_head_dim = self .attention_head_dim ,
393- attention_kernel = self .attention_kernel ,
394- flash_min_seq_length = self .flash_min_seq_length ,
395- flash_block_sizes = self .flash_block_sizes ,
396- mesh = self .mesh ,
397- dtype = self .dtype ,
398- weights_dtype = self .weights_dtype ,
399- precision = self .precision ,
400- mlp_ratio = self .mlp_ratio ,
401- qkv_bias = self .qkv_bias ,
402- )
403- for _ in range (self .num_layers )
404- ]
405- ]
406- )
407-
408- self .single_blocks = nn .Sequential (
409- [
410- * [
411- FluxSingleTransformerBlock (
412- dim = self .inner_dim ,
413- num_attention_heads = self .num_attention_heads ,
414- attention_head_dim = self .attention_head_dim ,
415- attention_kernel = self .attention_kernel ,
416- flash_min_seq_length = self .flash_min_seq_length ,
417- flash_block_sizes = self .flash_block_sizes ,
418- mesh = self .mesh ,
419- dtype = self .dtype ,
420- weights_dtype = self .weights_dtype ,
421- precision = self .precision ,
422- mlp_ratio = self .mlp_ratio ,
423- )
424- for _ in range (self .num_single_layers )
425- ]
426- ]
427- )
386+ double_blocks = []
387+ for _ in range (self .num_layers ):
388+ double_block = FluxTransformerBlock (
389+ dim = self .inner_dim ,
390+ num_attention_heads = self .num_attention_heads ,
391+ attention_head_dim = self .attention_head_dim ,
392+ attention_kernel = self .attention_kernel ,
393+ flash_min_seq_length = self .flash_min_seq_length ,
394+ flash_block_sizes = self .flash_block_sizes ,
395+ mesh = self .mesh ,
396+ dtype = self .dtype ,
397+ weights_dtype = self .weights_dtype ,
398+ precision = self .precision ,
399+ mlp_ratio = self .mlp_ratio ,
400+ qkv_bias = self .qkv_bias ,
401+ )
402+ double_blocks .append (double_block )
403+ self .double_blocks = double_blocks
404+
405+ single_blocks = []
406+ for _ in range (self .num_single_layers ):
407+ single_block = FluxSingleTransformerBlock (
408+ dim = self .inner_dim ,
409+ num_attention_heads = self .num_attention_heads ,
410+ attention_head_dim = self .attention_head_dim ,
411+ attention_kernel = self .attention_kernel ,
412+ flash_min_seq_length = self .flash_min_seq_length ,
413+ flash_block_sizes = self .flash_block_sizes ,
414+ mesh = self .mesh ,
415+ dtype = self .dtype ,
416+ weights_dtype = self .weights_dtype ,
417+ precision = self .precision ,
418+ mlp_ratio = self .mlp_ratio ,
419+ )
420+ single_blocks .append (single_block )
421+
422+ self .single_blocks = single_blocks
428423
429424 self .norm_out = AdaLayerNormContinuous (
430425 self .inner_dim ,
@@ -509,18 +504,19 @@ def __call__(
509504 image_rotary_emb = self .pe_embedder (ids )
510505 image_rotary_emb = nn .with_logical_constraint (image_rotary_emb , ("activation_batch" , "activation_embed" ))
511506
512- hidden_states , encoder_hidden_states , temb , image_rotary_emb = self .double_blocks (
513- hidden_states = hidden_states ,
514- encoder_hidden_states = encoder_hidden_states ,
515- temb = temb ,
516- image_rotary_emb = image_rotary_emb ,
517- )
507+ for double_block in self .double_blocks :
508+ hidden_states , encoder_hidden_states , temb , image_rotary_emb = double_block (
509+ hidden_states = hidden_states ,
510+ encoder_hidden_states = encoder_hidden_states ,
511+ temb = temb ,
512+ image_rotary_emb = image_rotary_emb ,
513+ )
518514 hidden_states = jnp .concatenate ([encoder_hidden_states , hidden_states ], axis = 1 )
519515 hidden_states = nn .with_logical_constraint (hidden_states , ("activation_batch" , "activation_length" , "activation_embed" ))
520-
521- hidden_states , temb , image_rotary_emb = self . single_blocks (
522- hidden_states = hidden_states , temb = temb , image_rotary_emb = image_rotary_emb
523- )
516+ for single_block in self . single_blocks :
517+ hidden_states , temb , image_rotary_emb = single_block (
518+ hidden_states = hidden_states , temb = temb , image_rotary_emb = image_rotary_emb
519+ )
524520 hidden_states = hidden_states [:, encoder_hidden_states .shape [1 ] :, ...]
525521
526522 hidden_states = self .norm_out (hidden_states , temb )
0 commit comments