@@ -44,6 +44,7 @@ activations_dtype: 'bfloat16'
4444
4545# Replicates vae across devices instead of using the model's sharding annotations for sharding.
4646replicate_vae : False
47+ vae_spatial : -1 # default to total_device * 2 // (dp)
4748
4849# matmul and conv precision from https://jax.readthedocs.io/en/latest/jax.lax.html#jax.lax.Precision
4950# Options are "DEFAULT", "HIGH", "HIGHEST"
@@ -60,7 +61,7 @@ jit_initializers: True
6061# Set true to load weights from pytorch
6162from_pt : True
6263split_head_dim : True
63- attention : ' flash' # Supported attention: dot_product, flash, cudnn_flash_te, ring, ulysses
64+ attention : ' flash' # Supported attention: dot_product, flash, tokamax_flash, cudnn_flash_te, ring, tokamax_ring , ulysses
6465flash_min_seq_length : 0
6566
6667# If mask_padding_tokens is True, we pass in segment ids to splash attention to avoid attending to padding tokens.
@@ -180,6 +181,19 @@ logical_axis_rules: [
180181 ['out_channels', 'tensor'],
181182 ['conv_out', 'context'],
182183 ]
184+ vae_logical_axis_rules : [
185+ ['activation_batch', 'redundant'],
186+ ['activation_length', 'vae_spatial'],
187+ ['activation_heads', null],
188+ ['activation_kv_length', null],
189+ ['embed', null],
190+ ['heads', null],
191+ ['norm', null],
192+ ['conv_batch', 'redundant'],
193+ ['out_channels', 'vae_spatial'],
194+ ['conv_out', 'vae_spatial'],
195+ ['conv_in', 'vae_spatial'],
196+ ]
183197data_sharding : [['data', 'fsdp', 'context', 'tensor']]
184198
185199# One axis for each parallelism type may hold a placeholder (-1)
0 commit comments