Skip to content

Commit 365605c

Browse files
committed
fix
1 parent dcee8e8 commit 365605c

2 files changed

Lines changed: 47 additions & 9 deletions

File tree

src/maxdiffusion/pipelines/ltx2/ltx2_pipeline.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -247,9 +247,19 @@ def __init__(
247247

248248
# Initialize video processor
249249
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_spatial_compression_ratio)
250-
251250
self.tokenizer_max_length = getattr(self.tokenizer, "model_max_length", 1024)
252251

252+
@staticmethod
253+
def _init_dummy_shape(node):
254+
if isinstance(node, jax.ShapeDtypeStruct):
255+
if jax.dtypes.issubdtype(node.dtype, jax.dtypes.prng_key):
256+
dummy_key = jax.random.key(0)
257+
if node.shape == ():
258+
return dummy_key
259+
return jax.random.split(dummy_key, node.shape[0])
260+
return jnp.zeros(node.shape, dtype=node.dtype)
261+
return node
262+
253263
@classmethod
254264
def load_tokenizer(cls, config: HyperParameters):
255265
max_logging.log("Loading Gemma Tokenizer...")
@@ -286,7 +296,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
286296

287297
p_model_factory = partial(create_model, config=config)
288298
connectors = nnx.eval_shape(p_model_factory, rngs=rngs)
289-
graphdef, state = nnx.split(connectors, nnx.Param)
299+
graphdef, state, rest_of_state = nnx.split(connectors, nnx.Param, ...)
300+
rest_of_state = jax.tree_util.tree_map(cls._init_dummy_shape, rest_of_state)
290301

291302
logical_state_spec = nnx.get_partition_spec(state)
292303
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -307,7 +318,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
307318
state[path].value = jax.device_put(val)
308319

309320
state = nnx.from_flat_state(state)
310-
connectors = nnx.merge(graphdef, state)
321+
connectors = nnx.merge(graphdef, state, rest_of_state)
311322
return connectors
312323

313324
@classmethod
@@ -326,7 +337,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
326337

327338
p_model_factory = partial(create_model, config=config)
328339
vae = nnx.eval_shape(p_model_factory, rngs=rngs)
329-
graphdef, state = nnx.split(vae, nnx.Param)
340+
graphdef, state, rest_of_state = nnx.split(vae, nnx.Param, ...)
341+
rest_of_state = jax.tree_util.tree_map(cls._init_dummy_shape, rest_of_state)
330342

331343
logical_state_spec = nnx.get_partition_spec(state)
332344
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -353,7 +365,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
353365
state[path].value = jax.device_put(val)
354366

355367
state = nnx.from_flat_state(state)
356-
vae = nnx.merge(graphdef, state)
368+
vae = nnx.merge(graphdef, state, rest_of_state)
357369
return vae
358370

359371
@classmethod
@@ -372,7 +384,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
372384

373385
p_model_factory = partial(create_model, config=config)
374386
audio_vae = nnx.eval_shape(p_model_factory, rngs=rngs)
375-
graphdef, state = nnx.split(audio_vae, nnx.Param)
387+
graphdef, state, rest_of_state = nnx.split(audio_vae, nnx.Param, ...)
388+
rest_of_state = jax.tree_util.tree_map(cls._init_dummy_shape, rest_of_state)
376389

377390
logical_state_spec = nnx.get_partition_spec(state)
378391
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -399,7 +412,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
399412
state[path].value = jax.device_put(val)
400413

401414
state = nnx.from_flat_state(state)
402-
audio_vae = nnx.merge(graphdef, state)
415+
audio_vae = nnx.merge(graphdef, state, rest_of_state)
403416
return audio_vae
404417

405418
@classmethod
@@ -439,7 +452,8 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
439452

440453
p_model_factory = partial(create_model, config=config)
441454
vocoder = nnx.eval_shape(p_model_factory, rngs=rngs)
442-
graphdef, state = nnx.split(vocoder, nnx.Param)
455+
graphdef, state, rest_of_state = nnx.split(vocoder, nnx.Param, ...)
456+
rest_of_state = jax.tree_util.tree_map(cls._init_dummy_shape, rest_of_state)
443457

444458
logical_state_spec = nnx.get_partition_spec(state)
445459
logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules)
@@ -460,7 +474,7 @@ def create_model(rngs: nnx.Rngs, config: HyperParameters):
460474
state[path].value = jax.device_put(val)
461475

462476
state = nnx.from_flat_state(state)
463-
vocoder = nnx.merge(graphdef, state)
477+
vocoder = nnx.merge(graphdef, state, rest_of_state)
464478
return vocoder
465479

466480
@classmethod

test_keys.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import jax
2+
import jax.numpy as jnp
3+
from flax import nnx
4+
5+
def fix_struct(x):
6+
if isinstance(x, jax.ShapeDtypeStruct):
7+
if jax.dtypes.issubdtype(x.dtype, jax.dtypes.prng_key):
8+
key = jax.random.key(0)
9+
if x.shape == ():
10+
return key
11+
else:
12+
return jax.random.split(key, x.shape[0])
13+
else:
14+
return jnp.zeros(x.shape, x.dtype)
15+
return x
16+
17+
struct_key = jax.ShapeDtypeStruct((5,), jax.dtypes.prng_key)
18+
struct_count = jax.ShapeDtypeStruct((5,), jnp.uint32)
19+
20+
fixed_key = fix_struct(struct_key)
21+
fixed_count = fix_struct(struct_count)
22+
23+
print("Key dtype:", fixed_key.dtype, "shape:", fixed_key.shape)
24+
print("Count dtype:", fixed_count.dtype, "shape:", fixed_count.shape)

0 commit comments

Comments
 (0)