Skip to content

Commit dcee8e8

Browse files
committed
fix
1 parent bf69bdf commit dcee8e8

1 file changed

Lines changed: 8 additions & 8 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
286286

287287
p_model_factory = partial(create_model, config=config)
288288
connectors = nnx.eval_shape(p_model_factory, rngs=rngs)
289-
graphdef, state, rest_of_state = nnx.split(connectors, nnx.Param, ...)
289+
graphdef, state = nnx.split(connectors, nnx.Param)
290290

291291
logical_state_spec = nnx.get_partition_spec(state)
292292
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -307,7 +307,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
307307
state[path].value = jax.device_put(val)
308308

309309
state = nnx.from_flat_state(state)
310-
connectors = nnx.merge(graphdef, state, rest_of_state)
310+
connectors = nnx.merge(graphdef, state)
311311
return connectors
312312

313313
@classmethod
@@ -326,7 +326,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
326326

327327
p_model_factory = partial(create_model, config=config)
328328
vae = nnx.eval_shape(p_model_factory, rngs=rngs)
329-
graphdef, state, rest_of_state = nnx.split(vae, nnx.Param, ...)
329+
graphdef, state = nnx.split(vae, nnx.Param)
330330

331331
logical_state_spec = nnx.get_partition_spec(state)
332332
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -353,7 +353,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
353353
state[path].value = jax.device_put(val)
354354

355355
state = nnx.from_flat_state(state)
356-
vae = nnx.merge(graphdef, state, rest_of_state)
356+
vae = nnx.merge(graphdef, state)
357357
return vae
358358

359359
@classmethod
@@ -372,7 +372,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
372372

373373
p_model_factory = partial(create_model, config=config)
374374
audio_vae = nnx.eval_shape(p_model_factory, rngs=rngs)
375-
graphdef, state, rest_of_state = nnx.split(audio_vae, nnx.Param, ...)
375+
graphdef, state = nnx.split(audio_vae, nnx.Param)
376376

377377
logical_state_spec = nnx.get_partition_spec(state)
378378
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -399,7 +399,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
399399
state[path].value = jax.device_put(val)
400400

401401
state = nnx.from_flat_state(state)
402-
audio_vae = nnx.merge(graphdef, state, rest_of_state)
402+
audio_vae = nnx.merge(graphdef, state)
403403
return audio_vae
404404

405405
@classmethod
@@ -439,7 +439,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
439439

440440
p_model_factory = partial(create_model, config=config)
441441
vocoder = nnx.eval_shape(p_model_factory, rngs=rngs)
442-
graphdef, state, rest_of_state = nnx.split(vocoder, nnx.Param, ...)
442+
graphdef, state = nnx.split(vocoder, nnx.Param)
443443

444444
logical_state_spec = nnx.get_partition_spec(state)
445445
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -460,7 +460,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
460460
state[path].value = jax.device_put(val)
461461

462462
state = nnx.from_flat_state(state)
463-
vocoder = nnx.merge(graphdef, state, rest_of_state)
463+
vocoder = nnx.merge(graphdef, state)
464464
return vocoder
465465

466466
@classmethod

0 commit comments

Comments
 (0)