Skip to content

Commit 7e098c5

Browse files
committed
format fixed
1 parent 7bed4f9 commit 7e098c5

5 files changed

Lines changed: 33 additions & 52 deletions

File tree

src/maxdiffusion/configs/ltx_video.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,4 @@ cache_latents_text_encoder_outputs: True
6262
per_device_batch_size: 1
6363
compile_topology_num_slices: -1
6464
quantization_local_shard_count: -1
65-
jit_initializers: True
65+
jit_initializers: True

src/maxdiffusion/generate_ltx_video.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,13 @@
2020
import json
2121
from maxdiffusion.models.ltx_video.transformers.transformer3d import Transformer3DModel
2222
import os
23+
import functools
2324
import jax.numpy as jnp
2425
from maxdiffusion import pyconfig
2526
from maxdiffusion.max_utils import (
2627
create_device_mesh,
2728
)
29+
from jax.sharding import Mesh
2830

2931

3032
def validate_transformer_inputs(prompt_embeds, fractional_coords, latents, noise_cond):
@@ -38,7 +40,7 @@ def run(config):
3840
key = jax.random.PRNGKey(0)
3941

4042
devices_array = create_device_mesh(config)
41-
mesh = Mesh(devices_array, config.mesh_axes)
43+
mesh = Mesh(devices_array, config.mesh_axes) # noqa F841
4244

4345
batch_size, text_tokens, num_tokens, features = 4, 256, 2048, 128
4446
base_dir = os.path.dirname(__file__)
@@ -49,12 +51,10 @@ def run(config):
4951
model_config = json.load(f)
5052

5153
transformer = Transformer3DModel(**model_config, dtype=jnp.bfloat16, gradient_checkpointing="matmul_without_batch")
52-
transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False)
54+
transformer_param_shapes = transformer.init_weights(key, batch_size, text_tokens, num_tokens, features, eval_only=False) # noqa F841
5355

5456
key, split_key = jax.random.split(key)
55-
56-
57-
weights_init_fn = functools.partial(
57+
weights_init_fn = functools.partial( # noqa F841
5858
transformer.init_weights, split_key, batch_size, text_tokens, num_tokens, features, eval_only=True
5959
)
6060

src/maxdiffusion/models/ltx_video/transformers/attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -438,7 +438,7 @@ def __call__(
438438
deterministic: bool = True,
439439
**cross_attention_kwargs,
440440
) -> jnp.ndarray:
441-
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
441+
cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters} # noqa F821
442442
assert cross_attention_kwargs.get("scale", None) is None, "Not supported"
443443

444444
input_axis_names = ("activation_batch", "activation_length", "activation_embed")

src/maxdiffusion/models/ltx_video/transformers/transformer3d.py

Lines changed: 24 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,13 @@ class Transformer3DModel(nn.Module):
2525
only_cross_attention: bool = False
2626
double_self_attention: bool = False
2727
upcast_attention: bool = False
28-
# 'single_scale_shift' or 'single_scale'
29-
adaptive_norm: str = "single_scale_shift"
28+
adaptive_norm: str = "single_scale_shift" # 'single_scale_shift' or 'single_scale'
3029
standardization_norm: str = "layer_norm" # 'layer_norm' or 'rms_norm'
3130
norm_elementwise_affine: bool = True
3231
norm_eps: float = 1e-5
3332
attention_type: str = "default"
3433
caption_channels: int = None
35-
# if True uses the TPU attention offload ('flash attention')
36-
use_tpu_flash_attention: bool = True
34+
use_tpu_flash_attention: bool = True # if True uses the TPU attention offload ('flash attention')
3735
qk_norm: Optional[str] = None
3836
positional_embedding_type: str = "rope"
3937
positional_embedding_theta: Optional[float] = None
@@ -98,7 +96,7 @@ def scale_shift_table_init(key):
9896
self.transformer_blocks = RepeatableLayer(
9997
RemattedBasicTransformerBlock,
10098
num_layers=self.num_layers,
101-
module_init_kwargs=dict(
99+
module_init_kwargs=dict( # noqa C408
102100
dim=self.inner_dim,
103101
num_attention_heads=self.num_attention_heads,
104102
attention_head_dim=self.attention_head_dim,
@@ -139,46 +137,30 @@ def scale_shift_table_init(key):
139137
matmul_precision=self.matmul_precision,
140138
)
141139

142-
def init_weights(self, key, batch_size, text_tokens, num_tokens, features, eval_only=True):
143-
144-
# bookkeeping, for convenient changes later
145-
latents_shape = (batch_size, num_tokens, features)
146-
fractional_cords_shape = (batch_size, 3, num_tokens)
147-
prompt_embeds_shape = (batch_size, text_tokens, features)
148-
noise_cond_shape = (batch_size, 1)
149-
latents_dtype = jnp.bfloat16
150-
fractional_coords_dtype = jnp.bfloat16
151-
prompt_embeds_dtype = jnp.bfloat16
152-
noise_cond_dtype = jnp.bfloat16
153-
154-
# initialize to random
155-
key, split_key = jax.random.split(key)
156-
prompt_embeds = jax.random.normal(split_key, shape=prompt_embeds_shape, dtype=latents_dtype)
157-
key, split_key = jax.random.split(key)
158-
fractional_coords = jax.random.normal(split_key, shape=fractional_cords_shape, dtype=fractional_coords_dtype)
159-
key, split_key = jax.random.split(key)
160-
latents = jax.random.normal(split_key, shape=latents_shape, dtype=prompt_embeds_dtype)
161-
key, split_key = jax.random.split(key)
162-
noise_cond = jax.random.normal(split_key, shape=noise_cond_shape, dtype=noise_cond_dtype)
163-
164-
key, split_key = jax.random.split(key)
140+
def init_weights(self, in_channels, key, caption_channels, eval_only=True):
141+
example_inputs = {}
142+
batch_size, num_tokens = 4, 256
143+
input_shapes = {
144+
"hidden_states": (batch_size, num_tokens, in_channels),
145+
"indices_grid": (batch_size, 3, num_tokens),
146+
"encoder_hidden_states": (batch_size, 128, caption_channels),
147+
"timestep": (batch_size, 256),
148+
"segment_ids": (batch_size, 256),
149+
"encoder_attention_segment_ids": (batch_size, 128),
150+
}
151+
for name, shape in input_shapes.items():
152+
example_inputs[name] = jnp.ones(
153+
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
154+
)
155+
165156
if eval_only:
166157
return jax.eval_shape(
167158
self.init,
168-
rngs={"params": split_key},
169-
hidden_states=latents,
170-
indices_grid=fractional_coords,
171-
encoder_hidden_states=prompt_embeds,
172-
timestep=noise_cond,
159+
key,
160+
**example_inputs,
173161
)["params"]
174162
else:
175-
return self.init(
176-
rngs={"params": split_key},
177-
hidden_states=latents,
178-
indices_grid=fractional_coords,
179-
encoder_hidden_states=prompt_embeds,
180-
timestep=noise_cond,
181-
)["params"]
163+
return self.init(key, **example_inputs)["params"]
182164

183165
def __call__(
184166
self,
@@ -271,8 +253,7 @@ def get_fractional_positions(self, indices_grid: jax.Array) -> jax.Array:
271253
@nn.compact
272254
def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]:
273255
source_dtype = indices_grid.dtype
274-
# We need full precision in the freqs_cis computation.
275-
dtype = jnp.float32
256+
dtype = jnp.float32 # We need full precision in the freqs_cis computation.
276257
dim = self.inner_dim
277258
theta = self.positional_embedding_theta
278259

@@ -294,8 +275,7 @@ def __call__(self, indices_grid: jax.Array) -> Tuple[jax.Array, jax.Array]:
294275
indices = indices * jnp.pi / 2
295276

296277
freqs = (indices * (jnp.expand_dims(fractional_positions, axis=-1) * 2 - 1)).swapaxes(-1, -2)
297-
# Flatten along axis 2
298-
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1)
278+
freqs = freqs.reshape(freqs.shape[0], freqs.shape[1], -1) # Flatten along axis 2
299279

300280
cos_freq = jnp.cos(freqs).repeat(2, axis=-1)
301281
sin_freq = jnp.sin(freqs).repeat(2, axis=-1)

src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,5 +20,6 @@
2020
"positional_embedding_type": "rope",
2121
"positional_embedding_theta": 10000.0,
2222
"positional_embedding_max_pos": [20, 2048, 2048],
23-
"timestep_scale_multiplier": 1000
23+
"timestep_scale_multiplier": 1000,
24+
"in_channels": 128
2425
}

0 commit comments

Comments
 (0)