Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/maxdiffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@
],
}

if is_flax_available():
from flax import config as flax_config
flax_config.update('flax_always_shard_variable', False)

try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down
6 changes: 3 additions & 3 deletions src/maxdiffusion/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,11 +376,11 @@ def load_config(
if os.path.isfile(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
if subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
# Load from a PyTorch checkpoint
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
elif subfolder is not None and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)):
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
else:
raise EnvironmentError(f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}.")
else:
Expand Down
26 changes: 14 additions & 12 deletions src/maxdiffusion/trainers/wan_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,19 +444,21 @@ def loss_fn(params):
noise = jax.random.normal(key=new_rng, shape=latents.shape, dtype=latents.dtype)
noisy_latents = scheduler.add_noise(scheduler_state, latents, noise, timesteps)

model_pred = model(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
deterministic=False,
rngs=nnx.Rngs(dropout_rng),
)
with jax.named_scope('forward_pass'):
model_pred = model(
hidden_states=noisy_latents,
timestep=timesteps,
encoder_hidden_states=encoder_hidden_states,
deterministic=False,
rngs=nnx.Rngs(dropout_rng),
)

training_target = scheduler.training_target(latents, noise, timesteps)
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
loss = (training_target - model_pred) ** 2
loss = loss * training_weight
loss = jnp.mean(loss)
with jax.named_scope('loss'):
training_target = scheduler.training_target(latents, noise, timesteps)
training_weight = jnp.expand_dims(scheduler.training_weight(scheduler_state, timesteps), axis=(1, 2, 3, 4))
loss = (training_target - model_pred) ** 2
loss = loss * training_weight
loss = jnp.mean(loss)

return loss

Expand Down
Loading