Skip to content

Commit faad9b5

Browse files
committed
adding parity test file
1 parent 10d4f73 commit faad9b5

1 file changed

Lines changed: 7 additions & 1 deletion

File tree

parity_ltx2_maxdiffusion.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,18 @@ def custom_normal(key, shape, dtype=None, **kwargs):
4949
# Patch pipeline module's random normal if needed
5050
pipe_module.jax.random.normal = custom_normal
5151

52-
def print_stat(name, t):
52+
def _print_stat_impl(name, t):
5353
if hasattr(t, "cpu"):
5454
t = t.detach().cpu().float().numpy()
5555
t_np = np.array(t, dtype=np.float32)
5656
print(f"[{name}] min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}")
5757

58+
def print_stat(name, t):
59+
if isinstance(t, jax.core.Tracer):
60+
jax.debug.callback(_print_stat_impl, name, t)
61+
else:
62+
_print_stat_impl(name, t)
63+
5864
# Patch transformer forward pass
5965
orig_transformer_forward_pass = pipe_module.transformer_forward_pass
6066
def patched_transformer_forward_pass(*args, **kwargs):

0 commit comments

Comments
 (0)