1- # Copyright 2025 Lightricks Ltd.
2- #
3- # Licensed under the Apache License, Version 2.0 (the "License");
4- # you may not use this file except in compliance with the License.
5- # You may obtain a copy of the License at
6- #
7- # https://github.com/Lightricks/LTX-Video/blob/main/LICENSE
8- #
9- # Unless required by applicable law or agreed to in writing, software
10- # distributed under the License is distributed on an "AS IS" BASIS,
11- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12- # See the License for the specific language governing permissions and
13- # limitations under the License.
14- #
15- # This implementation is based on the Torch version available at:
16- # https://github.com/Lightricks/LTX-Video/tree/main
171from dataclasses import field
182from typing import Any , Callable , Dict , List , Tuple , Optional
193
204import jax
215from flax import linen as nn
6+ import jax .numpy as jnp
227from flax .linen import partitioning
238
249
2510class RepeatableCarryBlock (nn .Module ):
26- """
27- Integrates an input module in a jax carry format
11+ """
12+ Integrates an input module in a jax carry format
2813
29- ergo, the module assumes the role of a building block
30- and returns both input and output across all blocks
31- """
14+ ergo, the module assumes the role of a building block
15+ and returns both input and output across all blocks
16+ """
3217
33- module : Callable [[Any ], nn .Module ]
34- module_init_args : List [Any ]
35- module_init_kwargs : Dict [str , Any ]
18+ module : Callable [[Any ], nn .Module ]
19+ module_init_args : List [Any ]
20+ module_init_kwargs : Dict [str , Any ]
3621
37- @nn .compact
38- def __call__ (self , * args ) -> Tuple [jax .Array , None ]:
39- """
40- jax carry-op format of block
41- assumes the input contains an input tensor to the block along with kwargs that might be send to the block
42- kwargs are assumed to have static role, while the input changes between cycles
22+ @nn .compact
23+ def __call__ (self , carry : Tuple [jax .Array , jax .Array ], * block_args ) -> Tuple [Tuple [jax .Array , jax .Array ], None ]:
24+ data_input , index_input = carry
4325
44- Returns:
45- Tuple[jax.Array, None]: Output tensor from the block
46- """
47- mod = self .module (* self .module_init_args , ** self .module_init_kwargs )
48- output = mod (* args )
49- return output , None
26+ mod = self .module (* self .module_init_args , ** self .module_init_kwargs )
5027
28+ # block_args are the static arguments passed to each individual block
29+ output_data = mod (index_input , data_input , * block_args ) # Pass block_args to the module
30+
31+ next_index = index_input + 1
32+ new_carry = (output_data , next_index )
33+
34+
35+ return new_carry , None
5136
5237class RepeatableLayer (nn .Module ):
53- """
54- RepeatableLayer will assume a similar role to torch.nn.ModuleList
55- with the condition that each block has the same graph, and only the parameters differ
38+ """
39+ RepeatableLayer will assume a similar role to torch.nn.ModuleList
40+ with the condition that each block has the same graph, and only the parameters differ
5641
57- The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation
58- """
42+ The compilation in RepeatableLayer will happen only once, in contrast to repeat-graph compilation
43+ """
5944
60- module : Callable [[Any ], nn .Module ]
61- """
45+ module : Callable [[Any ], nn .Module ]
46+ """
6247 A Callable function for single block construction
6348 """
6449
65- num_layers : int
66- """
50+ num_layers : int
51+ """
6752 The amount of blocks to build
6853 """
6954
70- module_init_args : List [Any ] = field (default_factory = list )
71- """
55+ module_init_args : List [Any ] = field (default_factory = list )
56+ """
7257 args passed to RepeatableLayer.module callable, to support block construction
7358 """
7459
75- module_init_kwargs : Dict [str , Any ] = field (default_factory = dict )
76- """
60+ module_init_kwargs : Dict [str , Any ] = field (default_factory = dict )
61+ """
7762 kwargs passed to RepeatableLayer.module callable, to support block construction
7863 """
7964
80- pspec_name : Optional [str ] = None
81- """
65+ pspec_name : Optional [str ] = None
66+ """
8267 Partition spec metadata
8368 """
8469
85- param_scan_axis : int = 0
86- """
70+ param_scan_axis : int = 0
71+ """
8772 The axis that the "layers" will be aggragated on
8873 eg: if a kernel is shaped (8, 16)
8974 N layers will be (N, 8, 16) if param_scan_axis=0
9075 and (8, N, 16) if param_scan_axis=1
9176 """
9277
93- @nn .compact
94- def __call__ (self , * args ):
95-
96- scan_kwargs = {}
97- if self .pspec_name is not None :
98- scan_kwargs ["metadata_params" ] = {nn .PARTITION_NAME : self .pspec_name }
99-
100- initializing = self .is_mutable_collection ("params" )
101- params_spec = self .param_scan_axis if initializing else partitioning .ScanIn (self .param_scan_axis )
102- scan_fn = nn .scan (
103- RepeatableCarryBlock ,
104- variable_axes = {
105- "params" : params_spec ,
106- "cache" : 0 ,
107- "intermediates" : 0 ,
108- "aqt" : 0 ,
109- "_overwrite_with_gradient" : 0 ,
110- }, # Separate params per timestep
111- split_rngs = {"params" : True },
112- in_axes = (nn .broadcast ,) * (len (args ) - 1 ),
113- length = self .num_layers ,
114- ** scan_kwargs ,
115- )
116- wrapped_function = scan_fn (self .module , self .module_init_args , self .module_init_kwargs )
117- x , _ = wrapped_function (* args )
118- return x
78+ @nn .compact
79+ def __call__ (self , * args ): # args is now the full input to RepeatableLayer
80+ if not args :
81+ raise ValueError ("RepeatableLayer expects at least one argument for initial data input." )
82+
83+ initial_data_input = args [0 ] # The first element is your main data input
84+ static_block_args = args [1 :] # Any subsequent elements are static args for each block
85+
86+ initial_index = jnp .array (0 , dtype = jnp .int32 )
87+
88+ scan_kwargs = {}
89+ if self .pspec_name is not None :
90+ scan_kwargs ["metadata_params" ] = {nn .PARTITION_NAME : self .pspec_name }
91+
92+ initializing = self .is_mutable_collection ("params" )
93+ params_spec = self .param_scan_axis if initializing else partitioning .ScanIn (self .param_scan_axis )
94+
95+ # in_axes for the scanned function (RepeatableCarryBlock.__call__):
96+ # 1. The 'carry' tuple ((0, 0))
97+ # 2. Then, nn.broadcast for each of the `static_block_args`
98+ in_axes_for_scan = (nn .broadcast ,) * (len (args )- 1 )
99+
100+ scan_fn = nn .scan (
101+ RepeatableCarryBlock ,
102+ variable_axes = {
103+ "params" : params_spec ,
104+ "cache" : 0 ,
105+ "intermediates" : 0 ,
106+ "aqt" : 0 ,
107+ "_overwrite_with_gradient" : 0 ,
108+ },
109+ split_rngs = {"params" : True },
110+ in_axes = in_axes_for_scan ,
111+ length = self .num_layers ,
112+ ** scan_kwargs ,
113+ )
114+
115+ wrapped_function = scan_fn (self .module , self .module_init_args , self .module_init_kwargs )
116+
117+ # Call wrapped_function with the initial carry tuple and the static_block_args
118+ (final_data , final_index ), _ = wrapped_function ((initial_data_input , initial_index ), * static_block_args )
119+
120+ # Typically, you only want the final data output from the sequence of layers
121+ return final_data
0 commit comments