Skip to content

Commit f3d3234

Browse files
committed
trying fix for NaN
1 parent d05ca47 commit f3d3234

7 files changed

Lines changed: 13 additions & 10 deletions

File tree

src/maxdiffusion/configs/base_wan_i2v_14b.yml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,10 +113,11 @@ timestep_bias: {
113113

114114
# Override parameters from checkpoints's scheduler.
115115
diffusion_scheduler_config: {
116-
_class_name: 'FlaxEulerDiscreteScheduler',
116+
_class_name: 'FlaxUniPCMultistepScheduler',
117117
prediction_type: 'epsilon',
118118
rescale_zero_terminal_snr: False,
119-
timestep_spacing: 'trailing'
119+
timestep_spacing: 'trailing',
120+
final_sigmas_type: 'sigma_min'
120121
}
121122

122123
# Output directory

src/maxdiffusion/models/embeddings_flax.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,6 @@ def __call__(self, encoder_hidden_states_image: jax.Array) -> jax.Array:
287287
padding_size = target_seq_len - current_seq_len
288288
padding = jnp.zeros((B, padding_size, D_out), dtype=hidden_states.dtype)
289289
hidden_states = jnp.concatenate([hidden_states, padding], axis=1)
290-
print(f"[DEBUG EMB] Padded image embeds from {current_seq_len} to {target_seq_len}. New shape: {hidden_states.shape}")
291290

292291
return hidden_states
293292

src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6363
@classmethod
6464
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
6565
pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer)
66-
transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh)
66+
pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh)
6767
return pipeline
6868

6969
@classmethod

src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
7272
@classmethod
7373
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
7474
pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer)
75-
low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh)
76-
high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh)
75+
pipeline.low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh)
76+
pipeline.high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh)
7777
return pipeline
7878

7979
@classmethod

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6868

6969
@classmethod
7070
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
71-
pipeline , _ = cls._load_and_init(config, None, vae_only, load_transformer)
71+
pipeline , transformer = cls._load_and_init(config, None, vae_only, load_transformer)
72+
pipeline.transformer = cls.quantize_transformer(config, transformer, pipeline, pipeline.mesh)
7273
return pipeline
7374

7475
@classmethod

src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@ def _load_and_init(cls, config, restored_checkpoint=None, vae_only=False, load_t
6565

6666
@classmethod
6767
def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transformer=True):
68-
pipeline, _, _ = cls._load_and_init(config, None, vae_only, load_transformer)
68+
pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init(config, None, vae_only, load_transformer)
69+
pipeline.low_noise_transformer = cls.quantize_transformer(config, low_noise_transformer, pipeline, pipeline.mesh)
70+
pipeline.high_noise_transformer = cls.quantize_transformer(config, high_noise_transformer, pipeline, pipeline.mesh)
6971
return pipeline
7072

7173
@classmethod

src/maxdiffusion/schedulers/scheduling_unipc_multistep_flax.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -518,13 +518,13 @@ def solve_for_rhos_p(R_mat, b_vec, current_order):
518518
check_nan_jit(pred_res, "P pred_res", step)
519519

520520
if self.config.predict_x0:
521-
x_t_ = sigma_t / (sigma_s0 + 1e-8) * x - alpha_t * h_phi_1 * m0
521+
x_t_ = sigma_t / (sigma_s0) * x - alpha_t * h_phi_1 * m0
522522
check_nan_jit(x_t_, "P x_t_ term", step)
523523
term2 = alpha_t * B_h * pred_res
524524
check_nan_jit(term2, "P term2", step)
525525
x_t = x_t_ - term2
526526
else: # Predict epsilon
527-
x_t_ = alpha_t / (alpha_s0 + 1e-8) * x - sigma_t * h_phi_1 * m0
527+
x_t_ = alpha_t / (alpha_s0) * x - sigma_t * h_phi_1 * m0
528528
check_nan_jit(x_t_, "P x_t_ term eps", step)
529529
term2 = sigma_t * B_h * pred_res
530530
check_nan_jit(term2, "P term2 eps", step)

0 commit comments

Comments
 (0)