2020import functools
2121from functools import partial , reduce
2222from contextlib import nullcontext
23- from typing import Dict , Callable
23+ from typing import Dict , Callable , List , Optional , Tuple , Union
2424import json
2525import yaml
2626import os
@@ -641,4 +641,36 @@ def maybe_initialize_jax_distributed_system(raw_keys):
641641 initialize_jax_for_gpu ()
642642 max_logging .log ("Jax distributed system initialized on GPU!" )
643643 else :
644- jax .distributed .initialize ()
644+ jax .distributed .initialize ()
645+
646+ def randn_tensor (
647+ shape : Union [Tuple , List ],
648+ generator : Optional [Union [List [jax .random .KeyArray ], jax .random .KeyArray ]] = None ,
649+ config = None ,
650+ dtype : Optional [jnp .dtype ] = None ,
651+ ):
652+ """A helper function to create random tensors on the desired `device` with the desired `dtype`.
653+ When passing a list of generators, you can seed each batch size individually.
654+ """
655+ batch_size = shape [0 ]
656+ if generator is None :
657+ if config is None :
658+ raise ValueError ("config must be provided if generator is None." )
659+ generator = jax .random .key (config .seed )
660+
661+ if isinstance (generator , list ) and len (generator ) == 1 :
662+ generator = generator [0 ]
663+
664+ if isinstance (generator , list ):
665+ if len (generator ) != batch_size :
666+ raise ValueError (
667+ f"You have passed a list of generators of length { len (generator )} , but requested an effective batch"
668+ f" size of { batch_size } . Make sure the batch size matches the length of the generators."
669+ )
670+ shape = (1 ,) + shape [1 :]
671+ latents = [jax .random .normal (generator [i ], shape = shape , dtype = dtype ) for i in range (batch_size )]
672+ latents = jnp .concatenate (latents , axis = 0 )
673+ else :
674+ latents = jax .random .normal (generator , shape = shape , dtype = dtype )
675+
676+ return latents
0 commit comments