Skip to content

Commit e4e281c

Browse files
committed
feat(ltx2): pass sharding specs to text projections
1 parent aa526b1 commit e4e281c

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

src/maxdiffusion/models/ltx2/transformer_ltx2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,13 +696,15 @@ def __init__(
696696
hidden_size=inner_dim,
697697
dtype=self.dtype,
698698
weights_dtype=self.weights_dtype,
699+
sharding_specs=self.sharding_specs,
699700
)
700701
self.audio_caption_projection = NNXPixArtAlphaTextProjection(
701702
rngs=rngs,
702703
in_features=self.caption_channels,
703704
hidden_size=audio_inner_dim,
704705
dtype=self.dtype,
705706
weights_dtype=self.weights_dtype,
707+
sharding_specs=self.sharding_specs,
706708
)
707709
# 3. Timestep Modulation Params and Embedding
708710
self.time_embed = LTX2AdaLayerNormSingle(

0 commit comments

Comments
 (0)