Skip to content

Commit 956341e

Browse files
apply fsdp sharding, do one forward pass in the transformer.
1 parent 3eb5729 commit 956341e

2 files changed

Lines changed: 49 additions & 25 deletions

File tree

src/maxdiffusion/configs/base_flux.yml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,11 +123,11 @@ data_sharding: [['data', 'fsdp', 'tensor']]
123123
# value to auto-shard based on available slices and devices.
124124
# By default, product of the DCN axes should equal number of slices
125125
# and product of the ICI axes should equal number of devices per slice.
126-
dcn_data_parallelism: -1 # recommended DCN axis to be auto-sharded
127-
dcn_fsdp_parallelism: 1
126+
dcn_data_parallelism: 1 # recommended DCN axis to be auto-sharded
127+
dcn_fsdp_parallelism: -1
128128
dcn_tensor_parallelism: 1
129-
ici_data_parallelism: -1
130-
ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded
129+
ici_data_parallelism: 1
130+
ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded
131131
ici_tensor_parallelism: 1
132132

133133
# Dataset

src/maxdiffusion/generate_flux.py

Lines changed: 45 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import functools
2020
import numpy as np
2121
import jax
22-
from jax.sharding import Mesh, PositionalSharding
22+
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
2323
import jax.numpy as jnp
2424
from chex import Array
2525
from transformers import (
@@ -196,7 +196,7 @@ def run(config):
196196
devices_array = create_device_mesh(config)
197197
mesh = Mesh(devices_array, config.mesh_axes)
198198

199-
per_host_number_of_images = config.per_device_batch_size * jax.local_device_count()
199+
global_batch_size = config.per_device_batch_size * jax.local_device_count()
200200

201201
# LOAD VAE
202202

@@ -225,7 +225,7 @@ def run(config):
225225

226226
num_channels_latents = transformer.in_channels // 4
227227
latents, latent_image_ids = prepare_latents(
228-
batch_size=per_host_number_of_images,
228+
batch_size=global_batch_size,
229229
num_channels_latents=num_channels_latents,
230230
height=config.resolution,
231231
width=config.resolution,
@@ -270,7 +270,7 @@ def run(config):
270270
clip_text_encoder=clip_text_encoder,
271271
t5_tokenizer=t5_tokenizer,
272272
t5_text_encoder=t5_encoder,
273-
num_images_per_prompt=per_host_number_of_images
273+
num_images_per_prompt=global_batch_size
274274
)
275275

276276
def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timesteps, guidance, pooled_prompt_embeds):
@@ -282,8 +282,8 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
282282
print("guidance.shape: ", guidance.shape, guidance.dtype)
283283
print("pooled_prompt_embeds.shape: ", pooled_prompt_embeds.shape, pooled_prompt_embeds.dtype)
284284

285-
timesteps = jnp.asarray([1.0], dtype=jnp.bfloat16)
286-
guidance = jnp.asarray([3.5], dtype=jnp.bfloat16)
285+
timesteps = jnp.asarray([1.0] * global_batch_size, dtype=jnp.bfloat16)
286+
guidance = jnp.asarray([3.5] * global_batch_size, dtype=jnp.bfloat16)
287287
validate_inputs(
288288
latents,
289289
latent_image_ids,
@@ -293,13 +293,26 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
293293
guidance,
294294
pooled_prompt_embeds
295295
)
296+
297+
# TODO - remove this later and figure out why t5x is returning wrong shape
298+
prompt_embeds = jnp.ones((global_batch_size, 512, 4096))
299+
300+
# move inputs to device and shard
301+
data_sharding = jax.sharding.NamedSharding(mesh, P(*config.data_sharding))
302+
latents = jax.device_put(latents, data_sharding)
303+
latent_image_ids = jax.device_put(latent_image_ids, data_sharding)
304+
prompt_embeds = jax.device_put(prompt_embeds, data_sharding)
305+
text_ids = jax.device_put(text_ids, data_sharding)
306+
timesteps = jax.device_put(timesteps, data_sharding)
307+
guidance = jax.device_put(guidance, data_sharding)
308+
pooled_prompt_embeds = jax.device_put(pooled_prompt_embeds, data_sharding)
309+
296310
get_memory_allocations()
297311
# evaluate shapes
298312
transformer_eval_params = transformer.init_weights(rngs=rng, max_sequence_length=512, eval_only=True)
299313

300314
# loads pretrained weights
301315
transformer_params = load_flow_model("flux-dev", transformer_eval_params, "cpu")
302-
get_memory_allocations()
303316
# create transformer state
304317
weights_init_fn = functools.partial(transformer.init_weights, rngs=rng, max_sequence_length=512, eval_only=False)
305318
transformer_state, transformer_state_shardings = setup_initial_state(
@@ -308,24 +321,35 @@ def validate_inputs(latents, latent_image_ids, prompt_embeds, text_ids, timestep
308321
config=config,
309322
mesh=mesh,
310323
weights_init_fn=weights_init_fn,
311-
model_params=None,
324+
model_params=transformer_params,
312325
training=False
313326
)
314-
breakpoint()
315-
transformer_state = transformer_state.replace(params=transformer_params)
316-
img = transformer.apply(
317-
{"params" : transformer_state.params},
318-
img=latents,
319-
img_ids=latent_image_ids,
320-
txt=prompt_embeds,
321-
txt_ids=text_ids,
322-
timesteps=timesteps,
323-
guidance=guidance,
324-
y=pooled_prompt_embeds
325-
)
327+
#transformer_state = transformer_state.replace(params=transformer_params)
326328
get_memory_allocations()
327-
breakpoint()
329+
def run_inference(state, transformer):
330+
img = transformer.apply(
331+
{"params" : state.params},
332+
img=latents,
333+
img_ids=latent_image_ids,
334+
txt=prompt_embeds,
335+
txt_ids=text_ids,
336+
timesteps=timesteps,
337+
guidance=guidance,
338+
y=pooled_prompt_embeds
339+
)
340+
return img
341+
342+
p_run_inference = jax.jit(
343+
functools.partial(
344+
run_inference,
345+
transformer=transformer
346+
),
347+
in_shardings=(transformer_state_shardings,),
348+
out_shardings=None
349+
)
328350

351+
img = p_run_inference(transformer_state)
352+
print("img.shape: ", img.shape)
329353

330354

331355
def main(argv: Sequence[str]) -> None:

0 commit comments

Comments
 (0)