You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
* tf processing on cpu
* improve attention for gpus by using shard_map.
* update readme to include training on GPUs. Revert max_utils jitting of state.
* remove commented out line.
* lint
* flag to jit initializers
---------
Co-authored-by: Juan Acevedo <jfacevedo@google.com>
To generate images with a trained checkpoint, run:
87
95
88
96
```bash
@@ -176,8 +184,8 @@ To generate images, run the following command:
176
184
python src/maxdiffusion/generate_flux.py src/maxdiffusion/configs/base_flux_schnell.yml jax_cache_dir=/tmp/cache_dir run_name=flux_test output_dir=/tmp/ prompt="photograph of an electronics chip in the shape of a race car with trillium written on its side" per_device_batch_size=1 ici_data_parallelism=1 ici_fsdp_parallelism=-1 offload_encoders=False
177
185
```
178
186
179
-
## Flash Attention for GPU:
180
-
Flash Attention for GPU is supported via TransformerEngine. Installation instructions:
187
+
## Fused Attention for GPU:
188
+
Fused Attention for GPU is supported via TransformerEngine. Installation instructions:
0 commit comments