@@ -461,19 +461,20 @@ def __init__(
461461 @nnx .split_rngs (splits = num_layers )
462462 @nnx .vmap (in_axes = 0 , out_axes = 0 , axis_size = num_layers )
463463 def create_resnets (rngs ):
464- return LTX2VideoResnetBlock3d (
465- in_channels = in_channels ,
466- out_channels = in_channels ,
467- dropout = dropout ,
468- eps = resnet_eps ,
469- non_linearity = resnet_act_fn ,
470- spatial_padding_mode = spatial_padding_mode ,
471- rngs = rngs ,
472- mesh = mesh ,
473- dtype = dtype ,
474- weights_dtype = weights_dtype ,
475- precision = precision ,
476- )
464+ return LTX2VideoResnetBlock3d (
465+ in_channels = in_channels ,
466+ out_channels = in_channels ,
467+ dropout = dropout ,
468+ eps = resnet_eps ,
469+ non_linearity = resnet_act_fn ,
470+ spatial_padding_mode = spatial_padding_mode ,
471+ rngs = rngs ,
472+ mesh = mesh ,
473+ dtype = dtype ,
474+ weights_dtype = weights_dtype ,
475+ precision = precision ,
476+ )
477+
477478 self .resnets = create_resnets (rngs )
478479
479480 self .downsamplers = nnx .List ([])
@@ -544,20 +545,19 @@ def __call__(
544545 causal : bool = True ,
545546 deterministic : bool = True ,
546547 ) -> jax .Array :
547-
548548 subkeys = None
549549 if key is not None :
550- subkeys = jax .random .split (key , self .num_layers )
551-
550+ subkeys = jax .random .split (key , self .num_layers )
551+
552552 def resnet_scan_fn (hidden_states , args ):
553- resnet , subkey = args
554- hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
555- return hidden_states , None
556-
553+ resnet , subkey = args
554+ hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
555+ return hidden_states , None
556+
557557 hidden_states , _ = nnx .scan (
558558 resnet_scan_fn ,
559559 length = self .num_layers ,
560- in_axes = (nnx .Carry , 0 ), # Scan over 0-th dim of input tuple
560+ in_axes = (nnx .Carry , 0 ), # Scan over 0-th dim of input tuple
561561 out_axes = (nnx .Carry , 0 ),
562562 )(hidden_states , (self .resnets , subkeys ))
563563
@@ -604,21 +604,22 @@ def __init__(
604604 @nnx .split_rngs (splits = num_layers )
605605 @nnx .vmap (in_axes = 0 , out_axes = 0 , axis_size = num_layers )
606606 def create_resnets (rngs ):
607- return LTX2VideoResnetBlock3d (
608- in_channels = in_channels ,
609- out_channels = in_channels ,
610- dropout = dropout ,
611- eps = resnet_eps ,
612- non_linearity = resnet_act_fn ,
613- inject_noise = inject_noise ,
614- timestep_conditioning = timestep_conditioning ,
615- spatial_padding_mode = spatial_padding_mode ,
616- rngs = rngs ,
617- mesh = mesh ,
618- dtype = dtype ,
619- weights_dtype = weights_dtype ,
620- precision = precision ,
621- )
607+ return LTX2VideoResnetBlock3d (
608+ in_channels = in_channels ,
609+ out_channels = in_channels ,
610+ dropout = dropout ,
611+ eps = resnet_eps ,
612+ non_linearity = resnet_act_fn ,
613+ inject_noise = inject_noise ,
614+ timestep_conditioning = timestep_conditioning ,
615+ spatial_padding_mode = spatial_padding_mode ,
616+ rngs = rngs ,
617+ mesh = mesh ,
618+ dtype = dtype ,
619+ weights_dtype = weights_dtype ,
620+ precision = precision ,
621+ )
622+
622623 self .resnets = create_resnets (rngs )
623624
624625 def __call__ (
@@ -635,12 +636,12 @@ def __call__(
635636
636637 subkeys = None
637638 if key is not None :
638- subkeys = jax .random .split (key , self .num_layers )
639-
639+ subkeys = jax .random .split (key , self .num_layers )
640+
640641 def resnet_scan_fn (hidden_states , args ):
641- resnet , subkey = args
642- hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
643- return hidden_states , None
642+ resnet , subkey = args
643+ hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
644+ return hidden_states , None
644645
645646 hidden_states , _ = nnx .scan (
646647 resnet_scan_fn ,
@@ -730,21 +731,22 @@ def __init__(
730731 @nnx .split_rngs (splits = num_layers )
731732 @nnx .vmap (in_axes = 0 , out_axes = 0 , axis_size = num_layers )
732733 def create_resnets (rngs ):
733- return LTX2VideoResnetBlock3d (
734- in_channels = out_channels ,
735- out_channels = out_channels ,
736- dropout = dropout ,
737- eps = resnet_eps ,
738- non_linearity = resnet_act_fn ,
739- inject_noise = inject_noise ,
740- timestep_conditioning = timestep_conditioning ,
741- spatial_padding_mode = spatial_padding_mode ,
742- rngs = rngs ,
743- mesh = mesh ,
744- dtype = dtype ,
745- weights_dtype = weights_dtype ,
746- precision = precision ,
747- )
734+ return LTX2VideoResnetBlock3d (
735+ in_channels = out_channels ,
736+ out_channels = out_channels ,
737+ dropout = dropout ,
738+ eps = resnet_eps ,
739+ non_linearity = resnet_act_fn ,
740+ inject_noise = inject_noise ,
741+ timestep_conditioning = timestep_conditioning ,
742+ spatial_padding_mode = spatial_padding_mode ,
743+ rngs = rngs ,
744+ mesh = mesh ,
745+ dtype = dtype ,
746+ weights_dtype = weights_dtype ,
747+ precision = precision ,
748+ )
749+
748750 self .resnets = create_resnets (rngs )
749751
750752 def __call__ (
@@ -770,12 +772,12 @@ def __call__(
770772
771773 subkeys = None
772774 if key is not None :
773- subkeys = jax .random .split (key , self .num_layers )
774-
775+ subkeys = jax .random .split (key , self .num_layers )
776+
775777 def resnet_scan_fn (hidden_states , args ):
776- resnet , subkey = args
777- hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
778- return hidden_states , None
778+ resnet , subkey = args
779+ hidden_states = resnet (hidden_states , temb = temb , key = subkey , causal = causal , deterministic = deterministic )
780+ return hidden_states , None
779781
780782 hidden_states , _ = nnx .scan (
781783 resnet_scan_fn ,
0 commit comments