Training added on top of flux_impl#147
Conversation
|
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
ba2d028 to
f56234e
Compare
|
Background in #148 @entrpn, I've rebased on flux_lora now and aligned the pipeline with the changes you made to generate_flux.py. Inference is working as expected, but I am a bit suspicious about the training. Please try it out and lets discuss on how to move forward with this. In the meantime, I will prepare another PR where I will add FA for GPUs, similar to how it is done in maxtext. |
|
Hi @entrpn, have you had a chance to test this PR? I think we can try to merge this soon if you think it looks alright |
I started to take a look at it. The pipeline fails for me during the data pipeline due to memory restraints in my environment. During the text encoding, I get OOM. The code will need to be refactored in a way that this can run on on CPU or at least 32 GB of accelerator memory (preferably 16) since the t5 encoder cannot be sharded atm. I remember doing something similar before by batching the captions in the data pipeline. I can take a look at it next week and try to get that part working. |
f56234e to
bffc7dc
Compare
|
@entrpn I rebased on main and cleaned up the commit logs. Will keep this PR up to date with main to make the merge easier when we are ready. |
|
@entrpn I have merged your branch into mine, added the orbax checkpointing, and a new file for inference that utilizes the flux pipeline. It was easier getting the orbax loading working that way, plus I think it is a bit cleaner. Tell me what you think. I did a small training run of 100 steps, and the resulting images does look ok. Please verify on your side and let me know if and when we can merge this. Thanks. |
|
@ksikiric apologies for the late response and thank you for adding this functionality. Let me take a look at this the week after next as I will be traveling until then. Will get back to you right after I'm back. |
|
@ksikiric can you share the commands you use to run the training job and save the checkpoint and then the command you use to run inference on the saved checkpoint? |
|
python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml save_final_checkpoint=True run_name=flux should do the trick :) @entrpn |
|
@ksikiric please take a look at the comments. Afterwards can you rebase with main and run the linter. If all passes, we can merge it. |
58b1bdc to
5453b3c
Compare
Linked to #146
I've added the training code on top of https://github.com/AI-Hypercomputer/maxdiffusion/tree/flux_impl. This PR is meant to be merged after #146.
With the training code, I have also added a pipeline for flux, which can be used for inference as well.