Skip to content

Commit 4a588c4

Browse files
coolkphx89
authored andcommitted
Optimize batch loading and metrics writing, replace PositionalSharding with NamedSharding (#186)
* fix profiling * Use torch cpu, async write to tensorboard, script to convert latents to tfrecord, batch iterator for tfrecord cached, namedsharding instead of positional sharding Signed-off-by: Kunjan <kunjanp@google.com> * Replace positional sharding with named sharding Signed-off-by: Kunjan <kunjanp@google.com> * Formatting Signed-off-by: Kunjan <kunjanp@google.com> * Formatting Signed-off-by: Kunjan <kunjanp@google.com> * Fallback to regular tfrecord iterator for datasets without all the processed features Signed-off-by: Kunjan <kunjanp@google.com> * README update --------- Signed-off-by: Kunjan <kunjanp@google.com>
1 parent d392bf4 commit 4a588c4

3 files changed

Lines changed: 3 additions & 7 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def run(config):
343343
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
344344
)
345345

346-
encoders_sharding = NamedSharding(mesh, P())
346+
encoders_sharding = NamedSharding(devices_array, P())
347347
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
348348
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
349349
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)

src/maxdiffusion/generate_flux_multi_res.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def run(config):
381381
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
382382
)
383383

384-
encoders_sharding = NamedSharding(mesh, P())
384+
encoders_sharding = NamedSharding(devices_array, P())
385385
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
386386
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
387387
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,7 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
198198
# This replaces random params with the model.
199199
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
200200
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
201-
<<<<<<< HEAD
202201
params = jax.device_put(params, NamedSharding(mesh, P()))
203-
=======
204-
params = jax.device_put(params, NamedSharding(devices_array, P()))
205-
>>>>>>> f344ab0 (Optimize batch loading and metrics writing, replace PositionalSharding with NamedSharding (#186))
206202
wan_vae = nnx.merge(graphdef, params)
207203
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
208204
# Shard
@@ -403,7 +399,7 @@ def __call__(
403399
num_channels_latents=num_channel_latents,
404400
)
405401

406-
data_sharding = NamedSharding(self.mesh, P())
402+
data_sharding = NamedSharding(self.devices_array, P())
407403
if len(prompt) % jax.device_count() == 0:
408404
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
409405

0 commit comments

Comments
 (0)