We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 814327d commit 41cb3d0Copy full SHA for 41cb3d0
1 file changed
src/maxdiffusion/max_utils.py
@@ -645,7 +645,7 @@ def maybe_initialize_jax_distributed_system(raw_keys):
645
646
def randn_tensor(
647
shape: Union[Tuple, List],
648
- generator: Optional[Union[List[jax.random.KeyArray], jax.random.KeyArray]] = None,
+ generator: Optional[Union[List[jax.Array], jax.Array]] = None,
649
config=None,
650
dtype: Optional[jnp.dtype] = None,
651
):
0 commit comments