Skip to content

Commit 4bcffd1

Browse files
committed
later
1 parent a272d08 commit 4bcffd1

3 files changed

Lines changed: 8 additions & 13 deletions

File tree

src/maxdiffusion/models/ltx_video/transformers/transformer3d.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,7 @@ def scale_shift_table_init(key):
153153
weight_dtype=self.weight_dtype,
154154
matmul_precision=self.matmul_precision,
155155
)
156-
def init_weights(self, key, in_channels, caption_channels, eval_only=True):
157-
import pdb; pdb.set_trace()
156+
def init_weights(self, in_channels, key, caption_channels, eval_only=True):
158157
example_inputs = {}
159158
batch_size, num_tokens = 4, 256
160159
input_shapes = {
@@ -169,16 +168,15 @@ def init_weights(self, key, in_channels, caption_channels, eval_only=True):
169168
example_inputs[name] = jnp.ones(
170169
shape, dtype=jnp.float32 if name not in ["attention_mask", "encoder_attention_mask"] else jnp.bool
171170
)
172-
171+
173172
if eval_only:
174173
return jax.eval_shape(
175174
self.init,
176-
key, ##need to change?
175+
key,
177176
**example_inputs,
178177
)["params"]
179178
else:
180-
return self.init(key, **example_inputs)['params']
181-
179+
return self.init(key, **example_inputs)["params"]
182180
def create_skip_layer_mask(
183181
self,
184182
batch_size: int,

src/maxdiffusion/models/ltx_video/xora_v1.2-13B-balanced-128.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
{
2-
"ckpt_path": "",
2+
"ckpt_path": "/mnt/disks/diffusionproj/jax_weights",
33
"activation_fn": "gelu-approximate",
44
"attention_bias": true,
55
"attention_head_dim": 128,

src/maxdiffusion/pipelines/ltx_video/ltx_video_pipeline.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -226,12 +226,9 @@ def load_transformer(cls, config):
226226
**model_config, dtype=jnp.float32, gradient_checkpointing="matmul_without_batch", sharding_mesh=mesh)
227227

228228
weights_init_fn = functools.partial(
229-
transformer.init_weights,
230-
jax.random.PRNGKey(42),
231-
in_channels,
232-
model_config['caption_channels'],
233-
eval_only=True
229+
transformer.init_weights, in_channels, jax.random.PRNGKey(42), model_config["caption_channels"], eval_only=True
234230
)
231+
235232
absolute_ckpt_path = os.path.abspath(relative_ckpt_path)
236233

237234
checkpoint_manager = ocp.CheckpointManager(absolute_ckpt_path)
@@ -240,7 +237,7 @@ def load_transformer(cls, config):
240237
tx=None,
241238
config=config,
242239
mesh=mesh,
243-
weights_init_fn=None,
240+
weights_init_fn=weights_init_fn,
244241
checkpoint_manager=checkpoint_manager,
245242
checkpoint_item=" ",
246243
model_params=None,

0 commit comments

Comments
 (0)