Skip to content

Commit 41cb3d0

Browse files
committed
Added randn_tensors
1 parent 814327d commit 41cb3d0

1 file changed

Lines changed: 1 addition & 1 deletion

File tree

src/maxdiffusion/max_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,7 @@ def maybe_initialize_jax_distributed_system(raw_keys):
645645

646646
def randn_tensor(
647647
shape: Union[Tuple, List],
648-
generator: Optional[Union[List[jax.random.KeyArray], jax.random.KeyArray]] = None,
648+
generator: Optional[Union[List[jax.Array], jax.Array]] = None,
649649
config=None,
650650
dtype: Optional[jnp.dtype] = None,
651651
):

0 commit comments

Comments
 (0)