From 0585c9585a5c4233d7270210aa1d70b3cb00e92a Mon Sep 17 00:00:00 2001 From: Kunjan Date: Fri, 20 Jun 2025 07:49:20 +0000 Subject: [PATCH] Fix namedsharding for replicating params Signed-off-by: Kunjan --- src/maxdiffusion/generate_flux.py | 2 +- src/maxdiffusion/generate_flux_multi_res.py | 2 +- src/maxdiffusion/pipelines/wan/wan_pipeline.py | 2 +- src/maxdiffusion/trainers/wan_trainer.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/maxdiffusion/generate_flux.py b/src/maxdiffusion/generate_flux.py index 0e6866346..f2ae1b3f1 100644 --- a/src/maxdiffusion/generate_flux.py +++ b/src/maxdiffusion/generate_flux.py @@ -343,7 +343,7 @@ def run(config): config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True ) - encoders_sharding = NamedSharding(devices_array, P()) + encoders_sharding = NamedSharding(mesh, P()) partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) diff --git a/src/maxdiffusion/generate_flux_multi_res.py b/src/maxdiffusion/generate_flux_multi_res.py index 4c824db8b..7d07883c6 100644 --- a/src/maxdiffusion/generate_flux_multi_res.py +++ b/src/maxdiffusion/generate_flux_multi_res.py @@ -381,7 +381,7 @@ def run(config): config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True ) - encoders_sharding = NamedSharding(devices_array, P()) + encoders_sharding = NamedSharding(mesh, P()) partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding) clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params) clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index cf97c890a..0e8336c81 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -195,7 +195,7 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H # This replaces random params with the model. params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - params = jax.device_put(params, NamedSharding(devices_array, P())) + params = jax.device_put(params, NamedSharding(mesh, P())) wan_vae = nnx.merge(graphdef, params) p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) # Shard diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 3740e2cf1..b11626c27 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -26,7 +26,7 @@ from ..schedulers import FlaxEulerDiscreteScheduler from .. import max_utils, max_logging, train_utils, maxdiffusion_utils from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT) -from multihost_dataloading import _form_global_array +from maxdiffusion.multihost_dataloading import _form_global_array class WanTrainer(WanCheckpointer):