Skip to content

Commit d76d5e8

Browse files
committed
Replace positional sharding with named sharding
Signed-off-by: Kunjan <kunjanp@google.com>
1 parent 04735f4 commit d76d5e8

4 files changed

Lines changed: 9 additions & 9 deletions

File tree

src/maxdiffusion/generate_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
from PIL import Image
2525
import jax
26-
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
26+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2727
import jax.numpy as jnp
2828
import flax.linen as nn
2929
from chex import Array
@@ -343,7 +343,7 @@ def run(config):
343343
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
344344
)
345345

346-
encoders_sharding = PositionalSharding(devices_array).replicate()
346+
encoders_sharding = NamedSharding(devices_array, P())
347347
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
348348
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
349349
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)

src/maxdiffusion/generate_flux_multi_res.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import numpy as np
2424
from PIL import Image
2525
import jax
26-
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
26+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2727
import jax.numpy as jnp
2828
import flax.linen as nn
2929
from chex import Array
@@ -381,7 +381,7 @@ def run(config):
381381
config.t5xxl_model_name_or_path, max_length=config.max_sequence_length, use_fast=True
382382
)
383383

384-
encoders_sharding = PositionalSharding(devices_array).replicate()
384+
encoders_sharding = NamedSharding(devices_array, P())
385385
partial_device_put_replicated = functools.partial(device_put_replicated, sharding=encoders_sharding)
386386
clip_text_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), clip_text_encoder.params)
387387
clip_text_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, clip_text_encoder.params)

src/maxdiffusion/pipelines/wan/wan_pipeline.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
import numpy as np
1818
import jax
1919
import jax.numpy as jnp
20-
from jax.sharding import Mesh, PositionalSharding, PartitionSpec as P
20+
from jax.sharding import Mesh, NamedSharding, PartitionSpec as P
2121
import flax
2222
import flax.linen as nn
2323
from flax import nnx
@@ -195,7 +195,7 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H
195195
# This replaces random params with the model.
196196
params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu")
197197
params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params)
198-
params = jax.device_put(params, PositionalSharding(devices_array).replicate())
198+
params = jax.device_put(params, NamedSharding(devices_array, P()))
199199
wan_vae = nnx.merge(graphdef, params)
200200
p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules)
201201
# Shard
@@ -395,7 +395,7 @@ def __call__(
395395
num_channels_latents=num_channel_latents,
396396
)
397397

398-
data_sharding = PositionalSharding(self.devices_array).replicate()
398+
data_sharding = NamedSharding(self.devices_array, P())
399399
if len(prompt) % jax.device_count() == 0:
400400
data_sharding = jax.sharding.NamedSharding(self.mesh, P(*self.config.data_sharding))
401401

src/maxdiffusion/trainers/flux_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import numpy as np
2222
import jax
2323
import jax.numpy as jnp
24-
from jax.sharding import PositionalSharding, PartitionSpec as P
24+
from jax.sharding import NamedSharding, PartitionSpec as P
2525
from flax.linen import partitioning as nn_partitioning
2626
from maxdiffusion.checkpointing.flux_checkpointer import (
2727
FluxCheckpointer,
@@ -87,7 +87,7 @@ def start_training(self):
8787
state_shardings = {}
8888

8989
# move params to accelerator
90-
encoders_sharding = jax.NamedSharding(self.mesh, P(None))
90+
encoders_sharding = NamedSharding(self.mesh, P(None))
9191
partial_device_put_replicated = partial(max_utils.device_put_replicated, sharding=encoders_sharding)
9292
pipeline.clip_encoder.params = jax.tree_util.tree_map(lambda x: x.astype(jnp.bfloat16), pipeline.clip_encoder.params)
9393
pipeline.clip_encoder.params = jax.tree_util.tree_map(partial_device_put_replicated, pipeline.clip_encoder.params)

0 commit comments

Comments
 (0)