|
18 | 18 | import jax |
19 | 19 | import jax.numpy as jnp |
20 | 20 | from jax.sharding import Mesh, PositionalSharding |
| 21 | +import flax |
| 22 | +import flax.linen as nn |
21 | 23 | from flax import nnx |
22 | 24 | from ...pyconfig import HyperParameters |
23 | 25 | from ... import max_logging |
@@ -54,6 +56,48 @@ def _add_sharding_rule(vs: nnx.VariableState, logical_axis_rules) -> nnx.Variabl |
54 | 56 | vs.sharding_rules = logical_axis_rules |
55 | 57 | return vs |
56 | 58 |
|
| 59 | + |
| 60 | +partial(nnx.jit, static_argnums=(3,)) |
| 61 | +def create_sharded_logical_transformer(devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): |
| 62 | + # breakpoint() |
| 63 | + def create_model(rngs: nnx.Rngs, wan_config: dict): |
| 64 | + wan_transformer = WanModel(**wan_config, rngs=rngs) |
| 65 | + return wan_transformer |
| 66 | + |
| 67 | + wan_config = WanModel.load_config( |
| 68 | + config.pretrained_model_name_or_path, |
| 69 | + subfolder="transformer" |
| 70 | + ) |
| 71 | + wan_config["mesh"] = mesh |
| 72 | + wan_config["dtype"] = config.activations_dtype |
| 73 | + wan_config["weights_dtype"] = config.weights_dtype |
| 74 | + wan_config["attention"] = config.attention |
| 75 | + p_model_factory = partial(create_model, wan_config=wan_config) |
| 76 | + wan_transformer = nnx.eval_shape(p_model_factory, rngs=rngs) |
| 77 | + graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) |
| 78 | + #breakpoint() |
| 79 | + logical_state_spec = nnx.get_partition_spec(state) |
| 80 | + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) |
| 81 | + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) |
| 82 | + params = state.to_pure_dict() |
| 83 | + state = dict(nnx.to_flat_state(state)) |
| 84 | + # del state |
| 85 | + params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") |
| 86 | + params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) |
| 87 | + for path, val in flax.traverse_util.flatten_dict(params).items(): |
| 88 | + sharding = logical_state_sharding[path].value |
| 89 | + state[path].value = jax.device_put(val, sharding) |
| 90 | + state = nnx.from_flat_state(state) |
| 91 | + p_add_sharding_rule = partial(_add_sharding_rule, logical_axis_rules=config.logical_axis_rules) |
| 92 | + state = jax.tree.map(p_add_sharding_rule, state, is_leaf=lambda x: isinstance(x, nnx.VariableState)) |
| 93 | + pspecs = nnx.get_partition_spec(state) |
| 94 | + #breakpoint() |
| 95 | + sharded_state = jax.lax.with_sharding_constraint(state, pspecs) |
| 96 | + #breakpoint() |
| 97 | + #wan_transformer = jax.jit(nnx.merge(graphdef, sharded_state, rest_of_state), in_shardings=None, out_shardings=sharded_state) |
| 98 | + wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) |
| 99 | + return wan_transformer |
| 100 | + |
57 | 101 | partial(nnx.jit, static_argnums=(1,)) |
58 | 102 | def create_sharded_logical_model(model, logical_axis_rules): |
59 | 103 | graphdef, state, rest_of_state = nnx.split(model, nnx.Param, ...) |
@@ -154,26 +198,29 @@ def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: H |
154 | 198 |
|
155 | 199 | @classmethod |
156 | 200 | def load_transformer(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): |
157 | | - wan_transformer = WanModel.from_config( |
158 | | - config.pretrained_model_name_or_path, |
159 | | - subfolder="transformer", |
160 | | - rngs=rngs, |
161 | | - attention=config.attention, |
162 | | - mesh=mesh, |
163 | | - dtype=config.activations_dtype, |
164 | | - weights_dtype=config.weights_dtype |
165 | | - ) |
166 | | - graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) |
167 | | - params = state.to_pure_dict() |
168 | | - del state |
169 | | - params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") |
170 | | - params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) |
171 | | - params = jax.device_put(params, PositionalSharding(devices_array).replicate()) |
172 | | - wan_transformer = nnx.merge(graphdef, params, rest_of_state) |
173 | | - # Shard |
174 | | - p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) |
175 | 201 | with mesh: |
176 | | - wan_transformer = p_create_sharded_logical_model(model=wan_transformer) |
| 202 | + wan_transformer = create_sharded_logical_transformer(devices_array=devices_array, mesh=mesh, rngs=rngs, config=config) |
| 203 | + # wan_transformer = WanModel.from_config( |
| 204 | + # config.pretrained_model_name_or_path, |
| 205 | + # subfolder="transformer", |
| 206 | + # rngs=rngs, |
| 207 | + # attention=config.attention, |
| 208 | + # mesh=mesh, |
| 209 | + # dtype=config.activations_dtype, |
| 210 | + # weights_dtype=config.weights_dtype |
| 211 | + # ) |
| 212 | + # graphdef, state, rest_of_state = nnx.split(wan_transformer, nnx.Param, ...) |
| 213 | + # breakpoint() |
| 214 | + # params = state.to_pure_dict() |
| 215 | + # del state |
| 216 | + # #params = load_wan_transformer(config.pretrained_model_name_or_path, params, "cpu") |
| 217 | + # params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) |
| 218 | + # #params = jax.device_put(params, PositionalSharding(devices_array).replicate()) |
| 219 | + # wan_transformer = nnx.merge(graphdef, params, rest_of_state) |
| 220 | + # # Shard |
| 221 | + # p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) |
| 222 | + # with mesh: |
| 223 | + # wan_transformer = p_create_sharded_logical_model(model=wan_transformer) |
177 | 224 | return wan_transformer |
178 | 225 |
|
179 | 226 | @classmethod |
|
0 commit comments