@@ -456,25 +456,25 @@ def __init__(
456456 precision : jax .lax .Precision = None ,
457457 ):
458458 out_channels = out_channels or in_channels
459-
460- self . resnets = nnx . List (
461- [
462- LTX2VideoResnetBlock3d (
463- in_channels = in_channels ,
464- out_channels = in_channels ,
465- dropout = dropout ,
466- eps = resnet_eps ,
467- non_linearity = resnet_act_fn ,
468- spatial_padding_mode = spatial_padding_mode ,
469- rngs = rngs ,
470- mesh = mesh ,
471- dtype = dtype ,
472- weights_dtype = weights_dtype ,
473- precision = precision ,
474- )
475- for _ in range ( num_layers )
476- ]
477- )
459+ self . num_layers = num_layers
460+
461+ @ nnx . split_rngs ( splits = num_layers )
462+ @ nnx . vmap ( in_axes = 0 , out_axes = 0 , axis_size = num_layers )
463+ def create_resnets ( rngs ):
464+ return LTX2VideoResnetBlock3d (
465+ in_channels = in_channels ,
466+ out_channels = in_channels ,
467+ dropout = dropout ,
468+ eps = resnet_eps ,
469+ non_linearity = resnet_act_fn ,
470+ spatial_padding_mode = spatial_padding_mode ,
471+ rngs = rngs ,
472+ mesh = mesh ,
473+ dtype = dtype ,
474+ weights_dtype = weights_dtype ,
475+ precision = precision ,
476+ )
477+ self . resnets = create_resnets ( rngs )
478478
479479 self .downsamplers = nnx .List ([])
480480 if spatio_temporal_scale :
@@ -544,11 +544,22 @@ def __call__(
544544 causal : bool = True ,
545545 deterministic : bool = True ,
546546 ) -> jax .Array :
547- for resnet in self .resnets :
548- subkey = None
549- if key is not None :
550- key , subkey = jax .random .split (key )
551- hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
547+
548+ subkeys = None
549+ if key is not None :
550+ subkeys = jax .random .split (key , self .num_layers )
551+
552+ def resnet_scan_fn (hidden_states , args ):
553+ resnet , subkey = args
554+ hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
555+ return hidden_states , None
556+
557+ hidden_states , _ = nnx .scan (
558+ resnet_scan_fn ,
559+ length = self .num_layers ,
560+ in_axes = (nnx .Carry , 0 ), # Scan over 0-th dim of input tuple
561+ out_axes = (nnx .Carry , 0 ),
562+ )(hidden_states , (self .resnets , subkeys ))
552563
553564 for downsampler in self .downsamplers :
554565 hidden_states = downsampler (hidden_states , causal = causal )
@@ -588,26 +599,27 @@ def __init__(
588599 else :
589600 self .time_embedder = None
590601
591- self .resnets = nnx .List (
592- [
593- LTX2VideoResnetBlock3d (
594- in_channels = in_channels ,
595- out_channels = in_channels ,
596- dropout = dropout ,
597- eps = resnet_eps ,
598- non_linearity = resnet_act_fn ,
599- inject_noise = inject_noise ,
600- timestep_conditioning = timestep_conditioning ,
601- spatial_padding_mode = spatial_padding_mode ,
602- rngs = rngs ,
603- mesh = mesh ,
604- dtype = dtype ,
605- weights_dtype = weights_dtype ,
606- precision = precision ,
607- )
608- for _ in range (num_layers )
609- ]
610- )
602+ self .num_layers = num_layers
603+
604+ @nnx .split_rngs (splits = num_layers )
605+ @nnx .vmap (in_axes = 0 , out_axes = 0 , axis_size = num_layers )
606+ def create_resnets (rngs ):
607+ return LTX2VideoResnetBlock3d (
608+ in_channels = in_channels ,
609+ out_channels = in_channels ,
610+ dropout = dropout ,
611+ eps = resnet_eps ,
612+ non_linearity = resnet_act_fn ,
613+ inject_noise = inject_noise ,
614+ timestep_conditioning = timestep_conditioning ,
615+ spatial_padding_mode = spatial_padding_mode ,
616+ rngs = rngs ,
617+ mesh = mesh ,
618+ dtype = dtype ,
619+ weights_dtype = weights_dtype ,
620+ precision = precision ,
621+ )
622+ self .resnets = create_resnets (rngs )
611623
612624 def __call__ (
613625 self ,
@@ -621,12 +633,21 @@ def __call__(
621633 temb = self .time_embedder (timestep = temb .flatten (), hidden_dtype = hidden_states .dtype )
622634 temb = temb .reshape (temb .shape [0 ], 1 , 1 , 1 , - 1 )
623635
624- for resnet in self .resnets :
625- subkey = None
626- if key is not None :
627- key , subkey = jax .random .split (key )
628-
629- hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
636+ subkeys = None
637+ if key is not None :
638+ subkeys = jax .random .split (key , self .num_layers )
639+
640+ def resnet_scan_fn (hidden_states , args ):
641+ resnet , subkey = args
642+ hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
643+ return hidden_states , None
644+
645+ hidden_states , _ = nnx .scan (
646+ resnet_scan_fn ,
647+ length = self .num_layers ,
648+ in_axes = (nnx .Carry , 0 ),
649+ out_axes = (nnx .Carry , 0 ),
650+ )(hidden_states , (self .resnets , subkeys ))
630651
631652 return hidden_states
632653
@@ -688,43 +709,43 @@ def __init__(
688709 )
689710 )
690711
691- self .upsamplers = nnx .List ([])
692712 if spatio_temporal_scale :
693- self .upsamplers .append (
694- LTXVideoUpsampler3d (
695- in_channels = out_channels * upscale_factor ,
696- stride = (2 , 2 , 2 ),
697- residual = upsample_residual ,
698- upscale_factor = upscale_factor ,
699- spatial_padding_mode = spatial_padding_mode ,
700- rngs = rngs ,
701- mesh = mesh ,
702- dtype = dtype ,
703- weights_dtype = weights_dtype ,
704- precision = precision ,
705- )
713+ self .upsampler = LTXVideoUpsampler3d (
714+ in_channels = out_channels * upscale_factor ,
715+ stride = (2 , 2 , 2 ),
716+ residual = upsample_residual ,
717+ upscale_factor = upscale_factor ,
718+ spatial_padding_mode = spatial_padding_mode ,
719+ rngs = rngs ,
720+ mesh = mesh ,
721+ dtype = dtype ,
722+ weights_dtype = weights_dtype ,
723+ precision = precision ,
706724 )
707-
708- self .resnets = nnx .List (
709- [
710- LTX2VideoResnetBlock3d (
711- in_channels = out_channels ,
712- out_channels = out_channels ,
713- dropout = dropout ,
714- eps = resnet_eps ,
715- non_linearity = resnet_act_fn ,
716- inject_noise = inject_noise ,
717- timestep_conditioning = timestep_conditioning ,
718- spatial_padding_mode = spatial_padding_mode ,
719- rngs = rngs ,
720- mesh = mesh ,
721- dtype = dtype ,
722- weights_dtype = weights_dtype ,
723- precision = precision ,
724- )
725- for _ in range (num_layers )
726- ]
727- )
725+ else :
726+ self .upsampler = None
727+
728+ self .num_layers = num_layers
729+
730+ @nnx .split_rngs (splits = num_layers )
731+ @nnx .vmap (in_axes = 0 , out_axes = 0 , axis_size = num_layers )
732+ def create_resnets (rngs ):
733+ return LTX2VideoResnetBlock3d (
734+ in_channels = out_channels ,
735+ out_channels = out_channels ,
736+ dropout = dropout ,
737+ eps = resnet_eps ,
738+ non_linearity = resnet_act_fn ,
739+ inject_noise = inject_noise ,
740+ timestep_conditioning = timestep_conditioning ,
741+ spatial_padding_mode = spatial_padding_mode ,
742+ rngs = rngs ,
743+ mesh = mesh ,
744+ dtype = dtype ,
745+ weights_dtype = weights_dtype ,
746+ precision = precision ,
747+ )
748+ self .resnets = create_resnets (rngs )
728749
729750 def __call__ (
730751 self ,
@@ -744,15 +765,24 @@ def __call__(
744765 temb = self .time_embedder (timestep = temb .flatten (), hidden_dtype = hidden_states .dtype )
745766 temb = temb .reshape (temb .shape [0 ], 1 , 1 , 1 , - 1 )
746767
747- for upsampler in self .upsamplers :
748- hidden_states = upsampler (hidden_states , causal = causal )
749-
750- for resnet in self .resnets :
751- subkey = None
752- if key is not None :
753- key , subkey = jax .random .split (key )
768+ if self .upsampler is not None :
769+ hidden_states = self .upsampler (hidden_states , causal = causal )
754770
755- hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
771+ subkeys = None
772+ if key is not None :
773+ subkeys = jax .random .split (key , self .num_layers )
774+
775+ def resnet_scan_fn (hidden_states , args ):
776+ resnet , subkey = args
777+ hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
778+ return hidden_states , None
779+
780+ hidden_states , _ = nnx .scan (
781+ resnet_scan_fn ,
782+ length = self .num_layers ,
783+ in_axes = (nnx .Carry , 0 ),
784+ out_axes = (nnx .Carry , 0 ),
785+ )(hidden_states , (self .resnets , subkeys ))
756786
757787 return hidden_states
758788
0 commit comments