Skip to content

Commit e9eb4ca

Browse files
committed
wip - adds trainer and attn changes.
1 parent 220f24b commit e9eb4ca

3 files changed

Lines changed: 15 additions & 11 deletions

File tree

src/maxdiffusion/data_preprocessing/wan_pusav1_to_tfrecords.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,9 +103,11 @@ def generate_dataset(config):
103103
loaded_state_dict = torch.load(pth_path, map_location=torch.device('cpu'))
104104
prompt_embeds = loaded_state_dict["prompt_emb"]["context"]
105105
latent = loaded_state_dict["latents"]
106-
# Format we want(4, 16, 1, 64, 64)
107-
latent = jnp.array(latent.float().numpy(), dtype=config.weights_dtype)
108-
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=config.weights_dtype)
106+
107+
# Format we want(Batch, channels, Frames, Height, Width)
108+
# Save them as float32 because numpy cannot read bfloat16.
109+
latent = jnp.array(latent.float().numpy(), dtype=jnp.float32)
110+
prompt_embeds = jnp.array(prompt_embeds.float().numpy(), dtype=jnp.float32)
109111
writer.write(create_example(latent, prompt_embeds))
110112
shard_record_count += 1
111113
global_record_count += 1

src/maxdiffusion/models/attention_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -380,7 +380,7 @@ def _apply_attention(
380380
)
381381
else:
382382
can_use_flash_attention = True
383-
383+
can_use_flash_attention=True
384384
if attention_kernel == "dot_product" or use_memory_efficient_attention or not can_use_flash_attention:
385385
return _apply_attention_dot(
386386
query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention
@@ -509,7 +509,7 @@ def __init__(
509509
heads: int,
510510
dim_head: int,
511511
use_memory_efficient_attention: bool = False,
512-
split_head_dim: bool = False,
512+
split_head_dim: bool = True,
513513
float32_qk_product: bool = True,
514514
axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV),
515515
axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV),

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
165165

166166
state = state.to_pure_dict()
167167
p_train_step = jax.jit(
168-
functools.partial(train_step, scheduler=pipeline.scheduler),
168+
functools.partial(train_step, scheduler=pipeline.scheduler, config=self.config),
169169
donate_argnums=(0,),
170170
)
171171
rng = jax.random.key(self.config.seed)
@@ -219,16 +219,18 @@ def training_loop(self, pipeline, optimizer, learning_rate_scheduler, data_itera
219219
return pipeline
220220

221221

222-
def train_step(state, graphdef, scheduler_state, data, rng, scheduler):
223-
return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng)
222+
def train_step(state, graphdef, scheduler_state, data, rng, scheduler, config):
223+
return step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config)
224224

225225

226-
def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng):
226+
def step_optimizer(graphdef, state, scheduler, scheduler_state, data, rng, config):
227227
_, new_rng, timestep_rng = jax.random.split(rng, num=3)
228228

229229
def loss_fn(model):
230-
latents = data["latents"]
231-
encoder_hidden_states = data["encoder_hidden_states"]
230+
latents = data["latents"].astype(config.weights_dtype)
231+
encoder_hidden_states = data["encoder_hidden_states"].astype(config.weights_dtype)
232+
# TODO - fix tf record conversion.
233+
encoder_hidden_states = jax.numpy.squeeze(encoder_hidden_states, axis=1)
232234
bsz = latents.shape[0]
233235
timesteps = jax.random.randint(
234236
timestep_rng,

0 commit comments

Comments
 (0)