Skip to content

Commit 2b0a783

Browse files
committed
sharding attempt
1 parent cd3616a commit 2b0a783

3 files changed

Lines changed: 36 additions & 11 deletions

File tree

src/maxdiffusion/configs/ltx2_video.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,6 @@ logical_axis_rules: [
6666
['conv_out', 'fsdp'],
6767
['conv_in', 'fsdp']
6868
]
69-
data_sharding: [['data', 'fsdp', 'context', 'tensor']]
7069
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
7170
dcn_fsdp_parallelism: -1
7271
dcn_context_parallelism: 1

src/maxdiffusion/models/ltx2/attention_ltx2.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -342,12 +342,24 @@ def __init__(
342342
self.dropout_rate = dropout
343343

344344
# 1. Projections
345-
self.to_q = nnx.Linear(query_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
345+
self.to_q = nnx.Linear(
346+
query_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype,
347+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "heads")),
348+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)),
349+
)
346350

347351
# Handle Self vs Cross Attention input dims
348352
kv_dim = context_dim if context_dim is not None else query_dim
349-
self.to_k = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
350-
self.to_v = nnx.Linear(kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype)
353+
self.to_k = nnx.Linear(
354+
kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype,
355+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "heads")),
356+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)),
357+
)
358+
self.to_v = nnx.Linear(
359+
kv_dim, self.inner_dim, use_bias=bias, rngs=rngs, dtype=dtype,
360+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("embed", "heads")),
361+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("heads",)),
362+
)
351363

352364
# 2. Normalization (Applied to full inner_dim, NOT per-head)
353365
self.norm_q = nnx.RMSNorm(
@@ -358,7 +370,11 @@ def __init__(
358370
)
359371

360372
# 3. Output
361-
self.to_out = nnx.Linear(self.inner_dim, query_dim, use_bias=out_bias, rngs=rngs, dtype=dtype)
373+
self.to_out = nnx.Linear(
374+
self.inner_dim, query_dim, use_bias=out_bias, rngs=rngs, dtype=dtype,
375+
kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), ("heads", "embed")),
376+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
377+
)
362378

363379
if self.dropout_rate > 0:
364380
self.dropout_layer = nnx.Dropout(self.dropout_rate, rngs=rngs)

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,8 @@ def __init__(
5757
use_bias=True,
5858
dtype=dtype,
5959
param_dtype=weights_dtype,
60-
kernel_init=nnx.initializers.zeros,
61-
bias_init=nnx.initializers.zeros,
60+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed", "embed")),
61+
bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)),
6262
)
6363

6464
def __call__(
@@ -291,12 +291,22 @@ def __init__(
291291
key = rngs.params()
292292
k1, k2, k3, k4 = jax.random.split(key, 4)
293293

294-
self.scale_shift_table = nnx.Param(jax.random.normal(k1, (6, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim))
294+
self.scale_shift_table = nnx.Param(
295+
jax.random.normal(k1, (6, self.dim), dtype=weights_dtype) / jnp.sqrt(self.dim),
296+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
297+
)
295298
self.audio_scale_shift_table = nnx.Param(
296-
jax.random.normal(k2, (6, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim)
299+
jax.random.normal(k2, (6, audio_dim), dtype=weights_dtype) / jnp.sqrt(audio_dim),
300+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
301+
)
302+
self.video_a2v_cross_attn_scale_shift_table = nnx.Param(
303+
jax.random.normal(k3, (5, self.dim), dtype=weights_dtype),
304+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
305+
)
306+
self.audio_a2v_cross_attn_scale_shift_table = nnx.Param(
307+
jax.random.normal(k4, (5, audio_dim), dtype=weights_dtype),
308+
kernel_init=nnx.with_partitioning(nnx.initializers.zeros, (None, "embed")),
297309
)
298-
self.video_a2v_cross_attn_scale_shift_table = nnx.Param(jax.random.normal(k3, (5, self.dim), dtype=weights_dtype))
299-
self.audio_a2v_cross_attn_scale_shift_table = nnx.Param(jax.random.normal(k4, (5, audio_dim), dtype=weights_dtype))
300310

301311
def __call__(
302312
self,

0 commit comments

Comments
 (0)