Skip to content

Commit db4caf0

Browse files
authored
fixes flux training. (#206)
1 parent 12469fa commit db4caf0

4 files changed

Lines changed: 148 additions & 148 deletions

File tree

src/maxdiffusion/configs/base_flux_dev.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ enable_profiler: False
228228
# the iteration time a chance to stabilize.
229229
skip_first_n_steps_for_profiler: 5
230230
profiler_steps: 10
231+
profiler: ""
231232

232233
# Generation parameters
233234
prompt: "A magical castle in the middle of a forest, artistic drawing"

src/maxdiffusion/generate_flux_pipeline.py

Lines changed: 80 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -33,85 +33,86 @@ def run(config):
3333
from maxdiffusion.checkpointing.flux_checkpointer import FluxCheckpointer
3434

3535
checkpoint_loader = FluxCheckpointer(config, "FLUX_CHECKPOINT")
36-
pipeline, params = checkpoint_loader.load_checkpoint()
37-
38-
if not params:
39-
## VAE
40-
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
41-
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
42-
pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False
43-
)
44-
# load unet params from orbax checkpoint
45-
vae_params = load_params_from_path(
46-
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state"
47-
)
48-
49-
vae_state = {"params": vae_params}
50-
51-
## Flux
52-
weights_init_fn = functools.partial(
53-
pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length
54-
)
55-
56-
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
57-
pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False
58-
)
59-
# load unet params from orbax checkpoint
60-
flux_params = load_params_from_path(
61-
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state"
62-
)
63-
flux_state = {"params": flux_params}
64-
else:
65-
weights_init_fn = functools.partial(
66-
pipeline.flux.init_weights,
67-
rngs=checkpoint_loader.rng,
68-
max_sequence_length=config.max_sequence_length,
69-
eval_only=False,
70-
)
71-
transformer_state, flux_state_shardings = setup_initial_state(
72-
model=pipeline.flux,
73-
tx=None,
74-
config=config,
75-
mesh=checkpoint_loader.mesh,
76-
weights_init_fn=weights_init_fn,
77-
model_params=None,
78-
training=False,
79-
)
80-
transformer_state = transformer_state.replace(params=params["flux_transformer_params"])
81-
transformer_state = jax.device_put(transformer_state, flux_state_shardings)
82-
83-
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
84-
vae_state, _ = setup_initial_state(
85-
model=pipeline.vae,
86-
tx=None,
87-
config=config,
88-
mesh=checkpoint_loader.mesh,
89-
weights_init_fn=weights_init_fn,
90-
model_params=params["flux_vae"],
91-
training=False,
92-
)
93-
94-
vae_state = {"params": vae_state.params}
95-
flux_state = {"params": transformer_state.params}
96-
97-
t0 = time.perf_counter()
98-
with ExitStack():
99-
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
100-
t1 = time.perf_counter()
101-
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
102-
103-
t0 = time.perf_counter()
104-
with ExitStack():
105-
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
106-
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
107-
t1 = time.perf_counter()
108-
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
109-
imgs = np.array(imgs)
110-
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
111-
imgs = np.transpose(imgs, (0, 2, 3, 1))
112-
imgs = np.uint8(imgs * 255)
113-
for i, image in enumerate(imgs):
114-
Image.fromarray(image).save(f"flux_{i}.png")
36+
mesh = checkpoint_loader.mesh
37+
with mesh:
38+
pipeline, params = checkpoint_loader.load_checkpoint()
39+
40+
if not params:
41+
## VAE
42+
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
43+
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
44+
pipeline.vae, None, config, checkpoint_loader.mesh, weights_init_fn, False
45+
)
46+
# load unet params from orbax checkpoint
47+
vae_params = load_params_from_path(
48+
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "vae_state"
49+
)
50+
51+
vae_state = {"params": vae_params}
52+
53+
## Flux
54+
weights_init_fn = functools.partial(
55+
pipeline.flux.init_weights, rngs=checkpoint_loader.rng, max_sequence_length=config.max_sequence_length
56+
)
57+
unboxed_abstract_state, _, _ = max_utils.get_abstract_state(
58+
pipeline.flux, None, config, checkpoint_loader.mesh, weights_init_fn, False
59+
)
60+
# load unet params from orbax checkpoint
61+
flux_params = load_params_from_path(
62+
config, checkpoint_loader.checkpoint_manager, unboxed_abstract_state.params, "flux_state"
63+
)
64+
flux_state = {"params": flux_params}
65+
else:
66+
weights_init_fn = functools.partial(
67+
pipeline.flux.init_weights,
68+
rngs=checkpoint_loader.rng,
69+
max_sequence_length=config.max_sequence_length,
70+
eval_only=False,
71+
)
72+
transformer_state, flux_state_shardings = setup_initial_state(
73+
model=pipeline.flux,
74+
tx=None,
75+
config=config,
76+
mesh=checkpoint_loader.mesh,
77+
weights_init_fn=weights_init_fn,
78+
model_params=None,
79+
training=False,
80+
)
81+
transformer_state = transformer_state.replace(params=params["flux_transformer_params"])
82+
transformer_state = jax.device_put(transformer_state, flux_state_shardings)
83+
84+
weights_init_fn = functools.partial(pipeline.vae.init_weights, rng=checkpoint_loader.rng)
85+
vae_state, _ = setup_initial_state(
86+
model=pipeline.vae,
87+
tx=None,
88+
config=config,
89+
mesh=checkpoint_loader.mesh,
90+
weights_init_fn=weights_init_fn,
91+
model_params=params["flux_vae"],
92+
training=False,
93+
)
94+
95+
vae_state = {"params": vae_state.params}
96+
flux_state = {"params": transformer_state.params}
97+
98+
t0 = time.perf_counter()
99+
with ExitStack():
100+
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
101+
t1 = time.perf_counter()
102+
max_logging.log(f"Compile time: {t1 - t0:.1f}s.")
103+
104+
t0 = time.perf_counter()
105+
with ExitStack():
106+
imgs = pipeline(flux_params=flux_state, timesteps=config.num_inference_steps, vae_params=vae_state).block_until_ready()
107+
imgs = jax.experimental.multihost_utils.process_allgather(imgs, tiled=True)
108+
t1 = time.perf_counter()
109+
max_logging.log(f"Inference time: {t1 - t0:.1f}s.")
110+
imgs = np.array(imgs)
111+
imgs = (imgs * 0.5 + 0.5).clip(0, 1)
112+
imgs = np.transpose(imgs, (0, 2, 3, 1))
113+
imgs = np.uint8(imgs * 255)
114+
for i, image in enumerate(imgs):
115+
Image.fromarray(image).save(f"flux_{i}.png")
115116

116117
return imgs
117118

src/maxdiffusion/train_flux.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@
1818

1919
import jax
2020
from absl import app
21-
from maxdiffusion import (
22-
max_logging,
23-
pyconfig,
24-
)
21+
from maxdiffusion import (max_logging, pyconfig)
2522

2623
from maxdiffusion.train_utils import (
2724
validate_train_config,

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 66 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -80,70 +80,71 @@ def start_training(self):
8080
# Hook
8181
# self.pre_training_steps()
8282
# Load checkpoint - will load or create states
83-
pipeline, params = self.load_checkpoint()
84-
85-
# create train states
86-
train_states = {}
87-
state_shardings = {}
88-
89-
# move params to accelerator
90-
encoders_sharding = NamedSharding(self.mesh, P(None))
91-
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
92-
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
93-
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
94-
pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params)
95-
pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params)
96-
97-
vae_state, vae_state_mesh_shardings = self.create_vae_state(
98-
pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False
99-
)
100-
train_states[VAE_STATE_KEY] = vae_state
101-
state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings
102-
103-
# Load dataset
104-
data_iterator = self.load_dataset(pipeline, params, train_states)
105-
if self.config.dataset_type == "grain":
106-
data_iterator = self.restore_data_iterator_state(data_iterator)
107-
108-
# don't need this anymore, clear some memory.
109-
del pipeline.t5_encoder
110-
111-
# evaluate shapes
112-
113-
flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
114-
# ambiguous here, but if params=None
115-
# Then its 1 of 2 scenarios:
116-
# 1. flux state will be loaded directly from orbax
117-
# 2. a new flux is being trained from scratch.
118-
pipeline=pipeline,
119-
params=None, # Params are loaded inside create_flux_state
120-
checkpoint_item_name=FLUX_STATE_KEY,
121-
is_training=True,
122-
)
123-
flux_state = jax.device_put(flux_state, flux_state_mesh_shardings)
124-
train_states[FLUX_STATE_KEY] = flux_state
125-
state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings
126-
# self.post_training_steps(pipeline, params, train_states, msg="before_training")
127-
128-
# Create scheduler
129-
noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params)
130-
pipeline.scheduler = noise_scheduler
131-
train_states["scheduler"] = noise_scheduler_state
132-
133-
# Calculate tflops
134-
per_device_tflops = self.calculate_tflops(pipeline)
135-
self.per_device_tflops = per_device_tflops
136-
137-
data_shardings = self.get_data_shardings()
138-
# Compile train_step
139-
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
140-
# Start training
141-
train_states = self.training_loop(
142-
p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler
143-
)
144-
# 6. save final checkpoint
145-
# Hook
146-
self.post_training_steps(pipeline, params, train_states, "after_training")
83+
with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules):
84+
pipeline, params = self.load_checkpoint()
85+
86+
# create train states
87+
train_states = {}
88+
state_shardings = {}
89+
90+
# move params to accelerator
91+
encoders_sharding = NamedSharding(self.mesh, P(None))
92+
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
93+
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
94+
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)
95+
pipeline.t5_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.t5_encoder.params)
96+
pipeline.t5_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.t5_encoder.params)
97+
98+
vae_state, vae_state_mesh_shardings = self.create_vae_state(
99+
pipeline=pipeline, params=params, checkpoint_item_name=VAE_STATE_KEY, is_training=False
100+
)
101+
train_states[VAE_STATE_KEY] = vae_state
102+
state_shardings[VAE_STATE_SHARDINGS_KEY] = vae_state_mesh_shardings
103+
104+
# Load dataset
105+
data_iterator = self.load_dataset(pipeline, params, train_states)
106+
if self.config.dataset_type == "grain":
107+
data_iterator = self.restore_data_iterator_state(data_iterator)
108+
109+
# don't need this anymore, clear some memory.
110+
del pipeline.t5_encoder
111+
112+
# evaluate shapes
113+
114+
flux_state, flux_state_mesh_shardings, flux_learning_rate_scheduler = self.create_flux_state(
115+
# ambiguous here, but if params=None
116+
# Then its 1 of 2 scenarios:
117+
# 1. flux state will be loaded directly from orbax
118+
# 2. a new flux is being trained from scratch.
119+
pipeline=pipeline,
120+
params=None, # Params are loaded inside create_flux_state
121+
checkpoint_item_name=FLUX_STATE_KEY,
122+
is_training=True,
123+
)
124+
flux_state = jax.device_put(flux_state, flux_state_mesh_shardings)
125+
train_states[FLUX_STATE_KEY] = flux_state
126+
state_shardings[FLUX_STATE_SHARDINGS_KEY] = flux_state_mesh_shardings
127+
# self.post_training_steps(pipeline, params, train_states, msg="before_training")
128+
129+
# Create scheduler
130+
noise_scheduler, noise_scheduler_state = self.create_scheduler(pipeline, params)
131+
pipeline.scheduler = noise_scheduler
132+
train_states["scheduler"] = noise_scheduler_state
133+
134+
# Calculate tflops
135+
per_device_tflops = self.calculate_tflops(pipeline)
136+
self.per_device_tflops = per_device_tflops
137+
138+
data_shardings = self.get_data_shardings()
139+
# Compile train_step
140+
p_train_step = self.compile_train_step(pipeline, params, train_states, state_shardings, data_shardings)
141+
# Start training
142+
train_states = self.training_loop(
143+
p_train_step, pipeline, params, train_states, data_iterator, flux_learning_rate_scheduler
144+
)
145+
# 6. save final checkpoint
146+
# Hook
147+
self.post_training_steps(pipeline, params, train_states, "after_training")
147148

148149
def get_shaped_batch(self, config, pipeline=None):
149150
"""Return the shape of the batch - this is what eval_shape would return for the
@@ -349,7 +350,7 @@ def training_loop(self, p_train_step, pipeline, params, train_states, data_itera
349350
example_batch = load_next_batch(data_iterator, example_batch, self.config)
350351
example_batch = {key: jnp.asarray(value, dtype=self.config.activations_dtype) for key, value in example_batch.items()}
351352

352-
if self.config.profiler == 'nsys':
353+
if self.config.profiler == "nsys":
353354
with self.mesh:
354355
flux_state, train_metric, train_rngs = p_train_step(flux_state, example_batch, train_rngs)
355356
else:

0 commit comments

Comments
 (0)