@@ -140,9 +140,6 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
140140 else :
141141 ltx2_config = LTX2VideoTransformer3DModel .load_config (config .pretrained_model_name_or_path , subfolder = subfolder )
142142
143- # Align RoPE type with connectors
144- ltx2_config ["rope_type" ] = "split"
145-
146143 if ltx2_config .get ("activation_fn" ) == "gelu-approximate" :
147144 ltx2_config ["activation_fn" ] = "gelu"
148145
@@ -157,13 +154,6 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
157154 ltx2_config ["remat_policy" ] = config .remat_policy
158155 ltx2_config ["names_which_can_be_saved" ] = config .names_which_can_be_saved
159156 ltx2_config ["names_which_can_be_offloaded" ] = config .names_which_can_be_offloaded
160- ltx2_config ["use_prompt_embeddings" ] = True
161-
162- if getattr (config , "model_name" , "" ) == "ltx2.3" :
163- ltx2_config ["gated_attn" ] = True
164- ltx2_config ["cross_attn_mod" ] = True
165- ltx2_config ["perturbed_attn" ] = True
166- ltx2_config ["use_prompt_embeddings" ] = False
167157
168158 # 2. eval_shape
169159 p_model_factory = partial (create_model , ltx2_config = ltx2_config )
@@ -184,25 +174,13 @@ def create_model(rngs: nnx.Rngs, ltx2_config: dict):
184174 else :
185175 params = restored_checkpoint ["ltx2_state" ]
186176 else :
187- filename = "ltx-2.3-22b-dev.safetensors" if getattr (config , "model_name" , "" ) == "ltx2.3" else None
188- subfolder = "" if getattr (config , "model_name" , "" ) == "ltx2.3" else subfolder
189-
190- if tensors is not None and getattr (config , "model_name" , "" ) == "ltx2.3" :
191- from maxdiffusion .models .ltx2 .ltx2_3_utils import load_transformer_weights_2_3
192- params = load_transformer_weights_2_3 (
193- params , # eval_shapes
194- "cpu" ,
195- tensors ,
196- scan_layers = getattr (config , "scan_layers" , True ),
197- )
198- else :
199- params = load_transformer_weights (
200- config .pretrained_model_name_or_path ,
201- params , # eval_shapes
202- "cpu" ,
203- scan_layers = getattr (config , "scan_layers" , True ),
204- subfolder = subfolder ,
205- )
177+ params = load_transformer_weights (
178+ config .pretrained_model_name_or_path ,
179+ params , # eval_shapes
180+ "cpu" ,
181+ scan_layers = getattr (config , "scan_layers" , True ),
182+ subfolder = subfolder ,
183+ )
206184
207185 params = jax .tree_util .tree_map_with_path (
208186 lambda path , x : cast_with_exclusion (path , x , dtype_to_cast = config .weights_dtype ), params
0 commit comments