@@ -85,97 +85,18 @@ def load_ltx2_configs_from_orbax(self, step: Optional[int]) -> Tuple[Optional[di
8585 max_logging .log (f"optimizer found in checkpoint { 'opt_state' in restored_checkpoint .ltx2_state .keys ()} " )
8686 return restored_checkpoint , step
8787
88- def load_diffusers_checkpoint (self ):
89- config = self .config
90- max_logging .log ("Loading LTX2 components from Hugging Face base models." )
91-
92- # 1. Tokenizer
93- max_logging .log ("Loading Gemma Tokenizer..." )
94- tokenizer = AutoTokenizer .from_pretrained (
95- config .pretrained_model_name_or_path ,
96- subfolder = "tokenizer" ,
97- )
98- # 3. Connectors
99- max_logging .log ("Loading Connectors..." )
100- connectors = LTX2AudioVideoGemmaTextEncoder .from_pretrained (
101- config .pretrained_model_name_or_path ,
102- subfolder = "connectors" ,
103- )
104-
105- # 4. Video VAE
106- max_logging .log ("Loading Video VAE..." )
107- vae = LTX2VideoAutoencoderKL .from_pretrained (
108- config .pretrained_model_name_or_path ,
109- subfolder = "vae" ,
110- )
111-
112- # 5. Audio VAE
113- max_logging .log ("Loading Audio VAE..." )
114- audio_vae = FlaxAutoencoderKLLTX2Audio .from_pretrained (
115- config .pretrained_model_name_or_path ,
116- subfolder = "audio_vae" ,
117- )
118-
119- # 6. Transformer
120- max_logging .log ("Loading Transformer..." )
121- # NOTE: Transformer weights are usually sharded and loaded separately in generation scripts
122- # This just instantiates the architecture wrapper or loads full weights.
123- # In MaxDiffusion we typically let the pipeline or generation script handle sharded loading
124- # but we load the raw config/eval shape here.
125- transformer = LTX2VideoTransformer3DModel .from_pretrained (
126- config .pretrained_model_name_or_path ,
127- subfolder = "transformer" ,
128- )
129-
130- # 7. Vocoder
131- max_logging .log ("Loading Vocoder..." )
132- vocoder = LTX2Vocoder .from_pretrained (
133- config .pretrained_model_name_or_path ,
134- subfolder = "vocoder" ,
135- )
136-
137- # 8. Scheduler
138- max_logging .log ("Loading Scheduler..." )
139- scheduler = FlaxFlowMatchScheduler .from_pretrained (
140- config .pretrained_model_name_or_path ,
141- subfolder = "scheduler" ,
142- )
143- # 2. Text Encoder (PyTorch)
144- max_logging .log ("Loading Gemma3 Text Encoder..." )
145- text_encoder = Gemma3ForConditionalGeneration .from_pretrained (
146- config .pretrained_model_name_or_path ,
147- subfolder = "text_encoder" ,
148- torch_dtype = torch .bfloat16 ,
149- )
150- text_encoder .eval ()
151-
152-
153-
154- pipeline = LTX2Pipeline (
155- scheduler = scheduler ,
156- vae = vae ,
157- audio_vae = audio_vae ,
158- text_encoder = text_encoder ,
159- tokenizer = tokenizer ,
160- connectors = connectors ,
161- transformer = transformer ,
162- vocoder = vocoder ,
163- )
164-
165- return pipeline
166-
167- def load_checkpoint (self , step = None ) -> Tuple [LTX2Pipeline , Optional [dict ], Optional [int ]]:
88+ def load_checkpoint (self , step = None , vae_only = False , load_transformer = True ) -> Tuple [LTX2Pipeline , Optional [dict ], Optional [int ]]:
16889 restored_checkpoint , step = self .load_ltx2_configs_from_orbax (step )
16990 opt_state = None
91+
17092 if restored_checkpoint :
171- max_logging .log ("Loading LTX2 pipeline from checkpoint (TODO: implement fully if needed)" )
172- # pipeline = LTX2Pipeline.from_checkpoint(self.config, restored_checkpoint)
173- # if "opt_state" in restored_checkpoint.ltx2_state.keys():
174- # opt_state = restored_checkpoint.ltx2_state["opt_state"]
175- pipeline = self .load_diffusers_checkpoint () # Fallback for now
93+ max_logging .log ("Loading LTX2 pipeline from checkpoint" )
94+ pipeline = LTX2Pipeline .from_checkpoint (self .config , restored_checkpoint , vae_only , load_transformer )
95+ if "opt_state" in restored_checkpoint .ltx2_state .keys ():
96+ opt_state = restored_checkpoint .ltx2_state ["opt_state" ]
17697 else :
177- max_logging .log ("No checkpoint found, loading default pipeline. " )
178- pipeline = self . load_diffusers_checkpoint ( )
98+ max_logging .log ("No checkpoint found, loading pipeline from pretrained hub " )
99+ pipeline = LTX2Pipeline . from_pretrained ( self . config , vae_only , load_transformer )
179100
180101 return pipeline , opt_state , step
181102
0 commit comments