Skip to content

Training added on top of flux_impl#147

Merged
entrpn merged 9 commits intoAI-Hypercomputer:mainfrom
ksikiric:kris/flux-impl-training
Apr 16, 2025
Merged

Training added on top of flux_impl#147
entrpn merged 9 commits intoAI-Hypercomputer:mainfrom
ksikiric:kris/flux-impl-training

Conversation

@ksikiric
Copy link
Copy Markdown
Contributor

@ksikiric ksikiric commented Feb 12, 2025

Linked to #146

I've added the training code on top of https://github.com/AI-Hypercomputer/maxdiffusion/tree/flux_impl. This PR is meant to be merged after #146.

With the training code, I have also added a pipeline for flux, which can be used for inference as well.

@google-cla
Copy link
Copy Markdown

google-cla Bot commented Feb 12, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@ksikiric ksikiric closed this Feb 12, 2025
@ksikiric ksikiric reopened this Feb 12, 2025
@entrpn entrpn mentioned this pull request Feb 13, 2025
@ksikiric ksikiric force-pushed the kris/flux-impl-training branch from ba2d028 to f56234e Compare February 13, 2025 12:11
@ksikiric
Copy link
Copy Markdown
Contributor Author

ksikiric commented Feb 13, 2025

Background in #148

@entrpn, I've rebased on flux_lora now and aligned the pipeline with the changes you made to generate_flux.py. Inference is working as expected, but I am a bit suspicious about the training. Please try it out and lets discuss on how to move forward with this.

In the meantime, I will prepare another PR where I will add FA for GPUs, similar to how it is done in maxtext.

@ksikiric
Copy link
Copy Markdown
Contributor Author

Hi @entrpn, have you had a chance to test this PR? I think we can try to merge this soon if you think it looks alright

@entrpn
Copy link
Copy Markdown
Collaborator

entrpn commented Feb 19, 2025

Hi @entrpn, have you had a chance to test this PR? I think we can try to merge this soon if you think it looks alright

I started to take a look at it. The pipeline fails for me during the data pipeline due to memory restraints in my environment. During the text encoding, I get OOM. The code will need to be refactored in a way that this can run on on CPU or at least 32 GB of accelerator memory (preferably 16) since the t5 encoder cannot be sharded atm. I remember doing something similar before by batching the captions in the data pipeline. I can take a look at it next week and try to get that part working.

@ksikiric ksikiric force-pushed the kris/flux-impl-training branch from f56234e to bffc7dc Compare March 5, 2025 12:10
@ksikiric
Copy link
Copy Markdown
Contributor Author

ksikiric commented Mar 5, 2025

@entrpn I rebased on main and cleaned up the commit logs. Will keep this PR up to date with main to make the merge easier when we are ready.

@ksikiric
Copy link
Copy Markdown
Contributor Author

@entrpn I have merged your branch into mine, added the orbax checkpointing, and a new file for inference that utilizes the flux pipeline. It was easier getting the orbax loading working that way, plus I think it is a bit cleaner. Tell me what you think.

I did a small training run of 100 steps, and the resulting images does look ok. Please verify on your side and let me know if and when we can merge this. Thanks.

@entrpn
Copy link
Copy Markdown
Collaborator

entrpn commented Apr 3, 2025

@ksikiric apologies for the late response and thank you for adding this functionality. Let me take a look at this the week after next as I will be traveling until then. Will get back to you right after I'm back.

Comment thread src/maxdiffusion/schedulers/scheduling_euler_discrete_flax.py Outdated
Comment thread src/maxdiffusion/configs/base_flux_dev.yml Outdated
@entrpn
Copy link
Copy Markdown
Collaborator

entrpn commented Apr 15, 2025

@ksikiric can you share the commands you use to run the training job and save the checkpoint and then the command you use to run inference on the saved checkpoint?

@ksikiric
Copy link
Copy Markdown
Contributor Author

python src/maxdiffusion/train_flux.py src/maxdiffusion/configs/base_flux_dev.yml save_final_checkpoint=True run_name=flux
followed by
python src/maxdiffusion/generate_flux_pipeline.py src/maxdiffusion/configs/base_flux_dev.yml run_name=flux

should do the trick :) @entrpn

Comment thread src/maxdiffusion/configs/base_flux_dev.yml Outdated
Comment thread src/maxdiffusion/configs/base_flux_schnell.yml Outdated
Comment thread src/maxdiffusion/configs/base_flux_schnell.yml Outdated
Comment thread src/maxdiffusion/checkpointing/flux_checkpointer.py Outdated
Comment thread src/maxdiffusion/checkpointing/flux_checkpointer.py Outdated
Comment thread src/maxdiffusion/pipelines/flux/flux_pipeline.py Outdated
@entrpn
Copy link
Copy Markdown
Collaborator

entrpn commented Apr 15, 2025

@ksikiric please take a look at the comments. Afterwards can you rebase with main and run the linter. If all passes, we can merge it.

@ksikiric ksikiric force-pushed the kris/flux-impl-training branch from 58b1bdc to 5453b3c Compare April 16, 2025 08:58
@ksikiric
Copy link
Copy Markdown
Contributor Author

ksikiric commented Apr 16, 2025

@entrpn @coolkp all comments have now been addressed as well as ruff + code_tyle.sh have been applied.

@ksikiric ksikiric marked this pull request as ready for review April 16, 2025 10:29
@entrpn entrpn merged commit b951454 into AI-Hypercomputer:main Apr 16, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants