Skip to content

Commit 0d54fc3

Browse files
committed
adding parity test file
1 parent 80e8923 commit 0d54fc3

1 file changed

Lines changed: 2 additions & 0 deletions

File tree

parity_ltx2_maxdiffusion.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def custom_normal(key, shape, dtype=None, **kwargs):
5050
pipe_module.jax.random.normal = custom_normal
5151

5252
def print_stat(name, t):
53+
if hasattr(t, "cpu"):
54+
t = t.detach().cpu().float().numpy()
5355
t_np = np.array(t, dtype=np.float32)
5456
print(f"[{name}] min: {t_np.min():.5f}, max: {t_np.max():.5f}, mean: {t_np.mean():.5f}, std: {t_np.std():.5f}")
5557

0 commit comments

Comments
 (0)