Skip to content

Commit ab45a33

Browse files
coolkphx89
authored andcommitted
Fix namedsharding for replicating params (AI-Hypercomputer#188)
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent 1e7df6b commit ab45a33

4 files changed

Lines changed: 4 additions & 4 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def run(config):
343343
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
344344
)
345345

346-
encoders_sharding = NamedSharding(devices_array, P())
346+
encoders_sharding = NamedSharding(mesh, P())
347347
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
348348
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
349349
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)

src/maxdiffusion/generate_flux_multi_res.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ def run(config):
381381
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
382382
)
383383

384-
encoders_sharding = NamedSharding(devices_array, P())
384+
encoders_sharding = NamedSharding(mesh, P())
385385
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
386386
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
387387
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
195195
# This replaces random params with the model.
196196
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
197197
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
198-
params = jax.device_put(params, NamedSharding(devices_array, P()))
198+
params = jax.device_put(params, NamedSharding(mesh, P()))
199199
wan_vae = nnx.merge(graphdef, params)
200200
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
201201
# Shard

src/maxdiffusion/trainers/wan_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from ..schedulers import FlaxEulerDiscreteScheduler
2727
from .. import max_utils, max_logging, train_utils, maxdiffusion_utils
2828
from ..checkpointing.wan_checkpointer import (WanCheckpointer, WAN_CHECKPOINT)
29-
from multihost_dataloading import _form_global_array
29+
from maxdiffusion.multihost_dataloading import _form_global_array
3030

3131

3232
class WanTrainer(WanCheckpointer):

0 commit comments

Comments
 (0)