Skip to content

Commit 814327d

Browse files
committed
Added randn_tensors
1 parent 4d44417 commit 814327d

1 file changed

Lines changed: 34 additions & 2 deletions

File tree

src/maxdiffusion/max_utils.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import functools
2121
from functools import partial, reduce
2222
from contextlib import nullcontext
23-
from typing import Dict, Callable
23+
from typing import Dict, Callable, List, Optional, Tuple, Union
2424
import json
2525
import yaml
2626
import os
@@ -641,4 +641,36 @@ def maybe_initialize_jax_distributed_system(raw_keys):
641641
initialize_jax_for_gpu()
642642
max_logging.log("Jax distributed system initialized on GPU!")
643643
else:
644-
jax.distributed.initialize()
644+
jax.distributed.initialize()
645+
646+
def randn_tensor(
647+
shape: Union[Tuple, List],
648+
generator: Optional[Union[List[jax.random.KeyArray], jax.random.KeyArray]] = None,
649+
config=None,
650+
dtype: Optional[jnp.dtype] = None,
651+
):
652+
"""A helper function to create random tensors on the desired `device` with the desired `dtype`.
653+
When passing a list of generators, you can seed each batch size individually.
654+
"""
655+
batch_size = shape[0]
656+
if generator is None:
657+
if config is None:
658+
raise ValueError("config must be provided if generator is None.")
659+
generator = jax.random.key(config.seed)
660+
661+
if isinstance(generator, list) and len(generator) == 1:
662+
generator = generator[0]
663+
664+
if isinstance(generator, list):
665+
if len(generator) != batch_size:
666+
raise ValueError(
667+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
668+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
669+
)
670+
shape = (1,) + shape[1:]
671+
latents = [jax.random.normal(generator[i], shape=shape, dtype=dtype) for i in range(batch_size)]
672+
latents = jnp.concatenate(latents, axis=0)
673+
else:
674+
latents = jax.random.normal(generator, shape=shape, dtype=dtype)
675+
676+
return latents

0 commit comments

Comments
 (0)