diff --git a/requirements.txt b/requirements.txt index eeaf2c9e3..7ae4bc64a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,7 +23,6 @@ pytest==8.2.2 tensorflow>=2.17.0 tensorflow-datasets>=4.9.6 ruff>=0.1.5,<=0.2 -git+https://github.com/mlperf/logging.git opencv-python-headless==4.10.0.84 orbax-checkpoint==0.10.3 tokenizers==0.21.0 diff --git a/src/maxdiffusion/common_types.py b/src/maxdiffusion/common_types.py index 2402a3d08..b75f5ceec 100644 --- a/src/maxdiffusion/common_types.py +++ b/src/maxdiffusion/common_types.py @@ -36,6 +36,7 @@ BATCH = "activation_batch" LENGTH = "activation_length" +KV_LENGTH = "activation_kv_length" EMBED = "activation_embed" HEAD = "activation_heads" D_KV = "activation_kv" diff --git a/src/maxdiffusion/configs/base14.yml b/src/maxdiffusion/configs/base14.yml index 00ee172cf..97f2fccf8 100644 --- a/src/maxdiffusion/configs/base14.yml +++ b/src/maxdiffusion/configs/base14.yml @@ -135,6 +135,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base21.yml b/src/maxdiffusion/configs/base21.yml index f5a05b0e4..53d06a689 100644 --- a/src/maxdiffusion/configs/base21.yml +++ b/src/maxdiffusion/configs/base21.yml @@ -136,6 +136,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index 1113d03b6..8a38f87f7 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -149,6 +149,8 @@ ici_data_parallelism: -1 # recommended ICI axis to be auto-sharded for TPUv5e ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index 3a5f294a9..220a5bb2c 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -162,6 +162,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index c37923911..8ae40a779 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -162,6 +162,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 8e8db4a44..80fe9d1ce 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -170,6 +170,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 21796738a..f3799e79f 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -56,6 +56,17 @@ split_head_dim: True attention: 'flash' # Supported attention: dot_product, flash, cudnn_flash_te flash_block_sizes: {} +# Use on v6e +# flash_block_sizes: { +# "block_q" : 3024, +# "block_kv_compute" : 1024, +# "block_kv" : 2048, +# "block_q_dkv" : 3024, +# "block_kv_dkv" : 2048, +# "block_kv_dkv_compute" : 2048, +# "block_q_dq" : 3024, +# "block_kv_dq" : 2048 +# } # GroupNorm groups norm_num_groups: 32 @@ -115,17 +126,15 @@ mesh_axes: ['data', 'fsdp', 'tensor'] # conv_out : conv.shape[-1] weight logical_axis_rules: [ ['batch', 'data'], - ['activation_heads', 'fsdp'], - ['activation_batch', ['data','fsdp']], - ['activation_kv', 'tensor'], + ['activation_length', 'fsdp'], + ['activation_heads', 'tensor'], + ['activation_batch', 'data'], ['mlp','tensor'], ['embed','fsdp'], - ['heads', 'tensor'], - ['norm', 'fsdp'], + ['norm', 'tensor'], ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], - ['conv_out', 'fsdp'], - ['conv_in', 'fsdp'] + ['conv_in', 'fsdp'], ] data_sharding: [['data', 'fsdp', 'tensor']] @@ -140,6 +149,8 @@ ici_data_parallelism: 1 ici_fsdp_parallelism: -1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index e773c19e0..5dd66e7c9 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -135,6 +135,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: 'diffusers/pokemon-gpt4-captions' diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index aafeea2bd..ca2ba2306 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -115,6 +115,8 @@ ici_data_parallelism: -1 ici_fsdp_parallelism: 1 # recommended ICI axis to be auto-sharded ici_tensor_parallelism: 1 +allow_split_physical_axes: False + # Dataset # Replace with dataset path or train_data_dir. One has to be set. dataset_name: '' diff --git a/src/maxdiffusion/generate_wan.py b/src/maxdiffusion/generate_wan.py index fc089337c..ad10cdf06 100644 --- a/src/maxdiffusion/generate_wan.py +++ b/src/maxdiffusion/generate_wan.py @@ -20,6 +20,8 @@ from absl import app from maxdiffusion.utils import export_to_video +jax.config.update("jax_use_shardy_partitioner", True) + def run(config, pipeline=None, filename_prefix=""): print("seed: ", config.seed) @@ -78,7 +80,7 @@ def run(config, pipeline=None, filename_prefix=""): slg_start=slg_start, slg_end=slg_end, ) - print("compile time: ", (time.perf_counter() - s0)) + print("generation time: ", (time.perf_counter() - s0)) s0 = time.perf_counter() if config.enable_profiler: diff --git a/src/maxdiffusion/max_utils.py b/src/maxdiffusion/max_utils.py index fab895f97..aaa929c59 100644 --- a/src/maxdiffusion/max_utils.py +++ b/src/maxdiffusion/max_utils.py @@ -281,9 +281,13 @@ def create_device_mesh(config, devices=None, logging=True): ici_parallelism = fill_unspecified_mesh_axes(ici_parallelism, num_devices_per_slice, "ICI") if multi_slice_env: dcn_parallelism = fill_unspecified_mesh_axes(dcn_parallelism, num_slices, "DCN") - mesh = mesh_utils.create_hybrid_device_mesh(ici_parallelism, dcn_parallelism, devices) + mesh = mesh_utils.create_hybrid_device_mesh( + ici_parallelism, dcn_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes + ) else: - mesh = mesh_utils.create_device_mesh(ici_parallelism, devices) + mesh = mesh_utils.create_device_mesh( + ici_parallelism, devices, allow_split_physical_axes=config.allow_split_physical_axes + ) if logging: max_logging.log(f"Decided on mesh: {mesh}") diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index 006614f87..a00928e3e 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -38,6 +38,7 @@ AxisNames = common_types.AxisNames BATCH = common_types.BATCH LENGTH = common_types.LENGTH +KV_LENGTH = common_types.KV_LENGTH HEAD = common_types.HEAD D_KV = common_types.D_KV EMBED = common_types.EMBED @@ -75,8 +76,8 @@ def _reshape_batch_dim_to_heads(tensor, heads): head_size = heads tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) - return tensor + reshaped_tensor = tensor.reshape(batch_size // head_size, seq_len, dim * head_size) + return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) def _reshape_heads_to_batch_dim(tensor, heads): @@ -85,12 +86,12 @@ def _reshape_heads_to_batch_dim(tensor, heads): head_size = heads tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) tensor = jnp.transpose(tensor, (0, 2, 1, 3)) - tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) else: batch_size, head_size, seq_len, head_dim = tensor.shape - tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) + reshaped_tensor = tensor.reshape(batch_size * head_size, seq_len, head_dim) - return tensor + return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) def _reshape_heads_to_head_dim(tensor): @@ -98,7 +99,8 @@ def _reshape_heads_to_head_dim(tensor): # This is used to transform the output of flash attention back into the format of other attention outputs b, h, s, d = tensor.shape tensor = jnp.transpose(tensor, axes=[0, 2, 1, 3]) - return jnp.reshape(tensor, (b, -1, h * d)) + reshaped_tensor = jnp.reshape(tensor, (b, -1, h * d)) + return jax.lax.with_sharding_constraint(reshaped_tensor, PartitionSpec("data", "fsdp", "tensor")) def _unflatten_heads(tensor, heads): @@ -110,33 +112,43 @@ def _unflatten_heads(tensor, heads): return tensor -def _reshape_data_for_flash(tensor, heads, flash_block_size): +def _reshape_data_for_flash(tensor, heads, flash_block_size, num_shards: int = 1): """ Reshapes tensors for pallas flash attention adding padding to both seq_len and head_dim. + Pads seq_len to a multiple of flash_block_size, and ensures the resulting number of + blocks is divisible by the number of shards. """ if tensor.ndim != 4: tensor = _unflatten_heads(tensor, heads) - # pad head_dim to 128 if less than that. + # Pad head_dim to 128 if less than that. kv_size = tensor.shape[-1] head_dim_pad = 0 if kv_size < 128: head_dim_pad = 128 - kv_size - # pad seq_len to a multiple of flash_block_size if needed. + # Pad seq_len with sharding constraints. seq_len = tensor.shape[2] - # remainder + + # 1. First, pad seq_len to be a multiple of flash_block_size rem = seq_len % flash_block_size - seq_len_pad = 0 if rem != 0: - # multiplier - mul = seq_len // flash_block_size - # pad to the closest multiplier of flash_block_size - seq_len_pad = (mul + 1) * flash_block_size - seq_len + seq_len_padded_pre = seq_len + (flash_block_size - rem) + else: + seq_len_padded_pre = seq_len + + # 2. Ensure num_blocks is divisible by num_shards + num_blocks = seq_len_padded_pre // flash_block_size + if num_blocks % num_shards != 0: + num_blocks += num_shards - (num_blocks % num_shards) - if kv_size < 128 or rem != 0: + final_padded_len = num_blocks * flash_block_size + seq_len_pad = final_padded_len - seq_len + + if kv_size < 128 or seq_len_pad != 0: npad = ((0, 0), (0, 0), (0, seq_len_pad), (0, head_dim_pad)) - tensor = jnp.pad(tensor, npad) + padded_tensor = jnp.pad(tensor, npad) + tensor = jax.lax.with_sharding_constraint(padded_tensor, PartitionSpec("data", "tensor", "fsdp", None)) return tensor, kv_size, seq_len @@ -147,7 +159,8 @@ def _tpu_flash_attention( value: jax.Array, heads: int, mesh: Mesh, - flash_axis_names: AxisNames, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dtype: jnp.dtype = jnp.float32, ) -> jax.Array: @@ -168,30 +181,52 @@ def _tpu_flash_attention( block_kv_dq=min(max_block_size, query.shape[2]), ) - query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q) - key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute) - value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute) + num_fsdp_shards = mesh.shape["fsdp"] + query, kv_size, query_seq_len = _reshape_data_for_flash(query, heads, block_sizes.block_q, num_fsdp_shards) + key, _, _ = _reshape_data_for_flash(key, heads, block_sizes.block_kv_compute, num_fsdp_shards) + value, _, _ = _reshape_data_for_flash(value, heads, block_sizes.block_kv_compute, num_fsdp_shards) + q_axis_names = nn.logical_to_mesh_axes(axis_names_q) + kv_axis_names = nn.logical_to_mesh_axes(axis_names_kv) + flash_axis_names_splash_kernel: AxisNames = (HEAD, LENGTH) + axis_names_splash_kernel = nn.logical_to_mesh_axes(flash_axis_names_splash_kernel) + named_sharding = jax.sharding.NamedSharding(mesh, axis_names_splash_kernel) + + shard_head_size = mesh.shape["tensor"] + + @functools.partial( + jax.jit, + static_argnames=["multi_head_mask", "shard_head_size"], + ) + def wrap_splash_kernel(multi_head_mask, shard_head_size=1): + splash_kernel = splash_attention_kernel.make_splash_mha( + mask=multi_head_mask, + head_shards=shard_head_size, # the sizes of the axis is sharding over heads + q_seq_shards=1, # the sizes of the axis is sharding over seq_len + block_sizes=block_sizes, + ) + return splash_kernel + + mask = splash_attention_mask.FullMask(_shape=(query.shape[2], key.shape[2])) - axis_names = nn.logical_to_mesh_axes(flash_axis_names) + multi_head_mask = splash_attention_mask.MultiHeadMask(masks=(mask,) * query.shape[1]) + splash_kernel = wrap_splash_kernel(multi_head_mask, int(shard_head_size)) + segment_axis_names_splash_kernel = splash_kernel.manual_sharding_spec(named_sharding) @functools.partial( shard_map.shard_map, mesh=mesh, in_specs=( - axis_names, - axis_names, - axis_names, + q_axis_names, + kv_axis_names, + kv_axis_names, + segment_axis_names_splash_kernel, ), - out_specs=axis_names, + out_specs=q_axis_names, check_rep=False, ) - def wrap_flash_attention(query, key, value): - masks = [splash_attention_mask.FullMask(_shape=(query.shape[2], query.shape[2])) for _ in range(query.shape[1])] - multi_head_mask = splash_attention_mask.MultiHeadMask(masks=masks) - splash_kernel = splash_attention_kernel.make_splash_mha( - mask=multi_head_mask, head_shards=1, q_seq_shards=1, block_sizes=block_sizes - ) - return jax.vmap(splash_kernel)(query, key, value) + def wrap_flash_attention(query, key, value, splash_kernel): + attention_output = jax.vmap(splash_kernel)(query, key, value) + return attention_output devices_in_data_fsdp = mesh.shape["data"] * mesh.shape["fsdp"] # This warning might show up when doing model eval for example, when calculating model flops @@ -201,7 +236,7 @@ def wrap_flash_attention(query, key, value): "Warning, batch dimension should be shardable among the devices in data and fsdp" f" axis, batch dimension: {query.shape[0]}, devices_in_data_fsdp: {devices_in_data_fsdp}" ) - x = wrap_flash_attention(query, key, value) + x = wrap_flash_attention(query, key, value, splash_kernel) x = x[:, :, :query_seq_len, :kv_size] x = _reshape_heads_to_head_dim(x) @@ -327,7 +362,8 @@ def _apply_attention( scale: float, dtype: jnp.dtype, mesh: Mesh, - flash_axis_names: AxisNames, + axis_names_q: AxisNames, + axis_names_kv: AxisNames, flash_block_sizes: BlockSizes, dpa_layer: Callable, ): @@ -350,7 +386,9 @@ def _apply_attention( query, key, value, dtype, heads, dim_head, scale, split_head_dim, float32_qk_product, use_memory_efficient_attention ) elif attention_kernel == "flash": - return _tpu_flash_attention(query, key * scale, value, heads, mesh, flash_axis_names, flash_block_sizes, dtype) + return _tpu_flash_attention( + query, key * scale, value, heads, mesh, axis_names_q, axis_names_kv, flash_block_sizes, dtype + ) elif attention_kernel == "cudnn_flash_te": return _cudnn_flash_attention(query, key, value, heads, mesh, dpa_layer) else: @@ -473,7 +511,8 @@ def __init__( use_memory_efficient_attention: bool = False, split_head_dim: bool = False, float32_qk_product: bool = True, - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV), + axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV), flash_min_seq_length: int = 4096, flash_block_sizes: BlockSizes = None, dtype: DType = jnp.float32, @@ -491,7 +530,8 @@ def __init__( self.use_memory_efficient_attention = use_memory_efficient_attention self.split_head_dim = split_head_dim self.float32_qk_product = float32_qk_product - self.flash_axis_names = flash_axis_names + self.axis_names_q = axis_names_q + self.axis_names_kv = axis_names_kv self.flash_min_seq_length = flash_min_seq_length self.flash_block_sizes = flash_block_sizes self.dtype = dtype @@ -512,7 +552,8 @@ def apply_attention(self, query: Array, key: Array, value: Array): scale=self.scale, dtype=self.dtype, mesh=self.mesh, - flash_axis_names=self.flash_axis_names, + axis_names_q=self.axis_names_q, + axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, ) @@ -527,7 +568,8 @@ class AttentionOp(nn.Module): use_memory_efficient_attention: bool = False split_head_dim: bool = False float32_qk_product: bool = True - flash_axis_names: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + axis_names_q: AxisNames = (BATCH, HEAD, LENGTH, D_KV) + axis_names_kv: AxisNames = (BATCH, HEAD, KV_LENGTH, D_KV) flash_min_seq_length: int = 4096 flash_block_sizes: BlockSizes = None dtype: DType = jnp.float32 @@ -568,7 +610,8 @@ def apply_attention(self, query: Array, key: Array, value: Array): scale=self.scale, dtype=self.dtype, mesh=self.mesh, - flash_axis_names=self.flash_axis_names, + axis_names_q=self.axis_names_q, + axis_names_kv=self.axis_names_kv, flash_block_sizes=self.flash_block_sizes, dpa_layer=self.dpa_layer, ) @@ -643,6 +686,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) self.key = nnx.Linear( @@ -653,6 +697,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) self.value = nnx.Linear( @@ -663,6 +708,7 @@ def __init__( dtype=dtype, param_dtype=weights_dtype, precision=precision, + bias_init=nnx.with_partitioning(nnx.initializers.zeros, ("embed",)), ) self.proj_attn = nnx.Linear( @@ -732,12 +778,8 @@ def __call__( key_proj = _unflatten_heads(key_proj, self.heads) value_proj = _unflatten_heads(value_proj, self.heads) query_proj, key_proj = self._apply_rope(query_proj, key_proj, rotary_emb) - query_proj = jax.lax.with_sharding_constraint(query_proj, PartitionSpec("data", "tensor", None, None)) - key_proj = jax.lax.with_sharding_constraint(key_proj, PartitionSpec("data", "tensor", None, None)) - value_proj = jax.lax.with_sharding_constraint(value_proj, PartitionSpec("data", "tensor", None, None)) attn_output = self.attention_op.apply_attention(query_proj, key_proj, value_proj) - attn_output = jax.lax.with_sharding_constraint(attn_output, PartitionSpec("data", None, None)) attn_output = attn_output.astype(dtype=dtype) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index a084447b6..e0db9dd16 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -43,13 +43,6 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): freq = get_1d_rotary_pos_embed(dim, max_seq_len, theta, freqs_dtype=jnp.float64, use_real=False) freqs.append(freq) freqs = jnp.concatenate(freqs, axis=1) - # sizes = jnp.array([ - # attention_head_dim // 2 - 2 * (attention_head_dim // 6), - # attention_head_dim // 6, - # attention_head_dim // 6, - # ]) - # cumulative_sizes = jnp.cumsum(jnp.array(sizes)) - # split_indices = cumulative_sizes[:-1] t_size = attention_head_dim // 2 - 2 * (attention_head_dim // 6) hw_size = attention_head_dim // 6 @@ -469,11 +462,19 @@ def __call__( if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") + + def skip_block_true(hidden_states): + split_bs = hidden_states.shape[0] // 2 + prev_neg_hidden_states = hidden_states[split_bs:] + hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb) + hidden_states = jnp.concatenate([hidden_states[:split_bs], prev_neg_hidden_states], axis=0) + return hidden_states + for block_idx, block in enumerate(self.blocks): should_skip_block = slg_mask[block_idx] & is_uncond hidden_states = jax.lax.cond( should_skip_block, - lambda hs: hs, # If true, pass through original hidden_states (skip block) + lambda _: skip_block_true(hidden_states), # If true, pass through original hidden_states (skip block) lambda _: block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb), hidden_states, ) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 77a7229ad..6623e78df 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -1,3 +1,4 @@ +import os import json import torch import jax @@ -139,69 +140,85 @@ def load_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, def load_base_wan_transformer(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] + subfolder = "transformer" + filename = "diffusion_pytorch_model.safetensors.index.json" + local_files = False + if os.path.isdir(pretrained_model_name_or_path): + index_file_path = os.path.join(pretrained_model_name_or_path, subfolder, filename) + if not os.path.isfile(index_file_path): + raise FileNotFoundError(f"File {index_file_path} not found for local directory.") + local_files = True + elif hf_download: + # download the index file for sharded models. + index_file_path = hf_hub_download( + pretrained_model_name_or_path, + subfolder=subfolder, + filename=filename, + ) with jax.default_device(device): - if hf_download: - # download the index file for sharded models. - index_file_path = hf_hub_download( - pretrained_model_name_or_path, subfolder="transformer", filename="diffusion_pytorch_model.safetensors.index.json" - ) - # open the index file. - with open(index_file_path, "r") as f: - index_dict = json.load(f) - model_files = set() - for key in index_dict["weight_map"].keys(): - model_files.add(index_dict["weight_map"][key]) - - model_files = list(model_files) - tensors = {} - for model_file in model_files: - ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder="transformer", filename=model_file) - # now get all the filenames for the model that need downloading - max_logging.log(f"Load and port Wan 2.1 transformer on {device}") - - if ckpt_shard_path is not None: - with safe_open(ckpt_shard_path, framework="pt") as f: - for k in f.keys(): - tensors[k] = torch2jax(f.get_tensor(k)) - flax_state_dict = {} - cpu = jax.local_devices(backend="cpu")[0] - flattened_dict = flatten_dict(eval_shapes) - # turn all block numbers to strings just for matching weights. - # Later they will be turned back to ints. - random_flax_state_dict = {} - for key in flattened_dict: - string_tuple = tuple([str(item) for item in key]) - random_flax_state_dict[string_tuple] = flattened_dict[key] - del flattened_dict - for pt_key, tensor in tensors.items(): - renamed_pt_key = rename_key(pt_key) - renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") - renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") - renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") - renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") - renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") - pt_tuple_key = tuple(renamed_pt_key.split(".")) - - flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) - flax_key = rename_for_nnx(flax_key) - flax_key = _tuple_str_to_int(flax_key) - flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) - validate_flax_state_dict(eval_shapes, flax_state_dict) - flax_state_dict = unflatten_dict(flax_state_dict) - del tensors - jax.clear_caches() - return flax_state_dict + # open the index file. + with open(index_file_path, "r") as f: + index_dict = json.load(f) + model_files = set() + for key in index_dict["weight_map"].keys(): + model_files.add(index_dict["weight_map"][key]) + + model_files = list(model_files) + tensors = {} + for model_file in model_files: + if local_files: + ckpt_shard_path = os.path.join(pretrained_model_name_or_path, subfolder, model_file) + else: + ckpt_shard_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=model_file) + # now get all the filenames for the model that need downloading + max_logging.log(f"Load and port Wan 2.1 transformer on {device}") + + if ckpt_shard_path is not None: + with safe_open(ckpt_shard_path, framework="pt") as f: + for k in f.keys(): + tensors[k] = torch2jax(f.get_tensor(k)) + flax_state_dict = {} + cpu = jax.local_devices(backend="cpu")[0] + flattened_dict = flatten_dict(eval_shapes) + # turn all block numbers to strings just for matching weights. + # Later they will be turned back to ints. + random_flax_state_dict = {} + for key in flattened_dict: + string_tuple = tuple([str(item) for item in key]) + random_flax_state_dict[string_tuple] = flattened_dict[key] + del flattened_dict + for pt_key, tensor in tensors.items(): + renamed_pt_key = rename_key(pt_key) + renamed_pt_key = renamed_pt_key.replace("blocks_", "blocks.") + renamed_pt_key = renamed_pt_key.replace("to_out_0", "proj_attn") + renamed_pt_key = renamed_pt_key.replace("ffn.net_2", "ffn.proj_out") + renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") + renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") + pt_tuple_key = tuple(renamed_pt_key.split(".")) + + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict) + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) + flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) + flax_state_dict = unflatten_dict(flax_state_dict) + del tensors + jax.clear_caches() + return flax_state_dict def load_wan_vae(pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True): device = jax.devices(device)[0] + subfolder = "vae" + filename = "diffusion_pytorch_model.safetensors" + if os.path.isdir(pretrained_model_name_or_path): + ckpt_path = os.path.join(pretrained_model_name_or_path, subfolder, filename) + if not os.path.isfile(ckpt_path): + raise FileNotFoundError(f"File {ckpt_path} not found for local directory.") + elif hf_download: + ckpt_path = hf_hub_download(pretrained_model_name_or_path, subfolder=subfolder, filename=filename) + max_logging.log(f"Load and port Wan 2.1 VAE on {device}") with jax.default_device(device): - if hf_download: - ckpt_path = hf_hub_download( - pretrained_model_name_or_path, subfolder="vae", filename="diffusion_pytorch_model.safetensors" - ) - max_logging.log(f"Load and port Wan 2.1 VAE on {device}") - if ckpt_path is not None: tensors = {} with safe_open(ckpt_path, framework="pt") as f: diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index d01aea3fc..ed5b84489 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -25,7 +25,7 @@ from ...pyconfig import HyperParameters from ... import max_logging from ... import max_utils -from ...max_utils import get_flash_block_sizes, get_precision +from ...max_utils import get_flash_block_sizes, get_precision, device_put_replicated from ...models.wan.wan_utils import load_wan_transformer, load_wan_vae from ...models.wan.transformers.transformer_wan import WanModel from ...models.wan.autoencoder_kl_wan import AutoencoderKLWan, AutoencoderKLWanCache @@ -99,7 +99,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): sharding = logical_state_sharding[path].value - state[path].value = jax.device_put(val, sharding) + state[path].value = device_put_replicated(val, sharding) state = nnx.from_flat_state(state) wan_transformer = nnx.merge(graphdef, state, rest_of_state) @@ -183,27 +183,42 @@ def load_tokenizer(cls, config: HyperParameters): @classmethod def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): - wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, - subfolder="vae", - rngs=rngs, - mesh=mesh, - dtype=config.activations_dtype, - weights_dtype=config.weights_dtype, - ) - vae_cache = AutoencoderKLWanCache(wan_vae) + def create_model(rngs: nnx.Rngs, config: HyperParameters): + wan_vae = AutoencoderKLWan.from_config( + config.pretrained_model_name_or_path, + subfolder="vae", + rngs=rngs, + mesh=mesh, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, + ) + return wan_vae + + # 1. eval shape + p_model_factory = partial(create_model, config=config) + wan_vae = nnx.eval_shape(p_model_factory, rngs=rngs) graphdef, state = nnx.split(wan_vae, nnx.Param) + + # 2. retrieve the state shardings, mapping logical names to mesh axis names. + logical_state_spec = nnx.get_partition_spec(state) + logical_state_sharding = nn.logical_to_mesh_sharding(logical_state_spec, mesh, config.logical_axis_rules) + logical_state_sharding = dict(nnx.to_flat_state(logical_state_sharding)) params = state.to_pure_dict() - # This replaces random params with the model. + state = dict(nnx.to_flat_state(state)) + + # 4. Load pretrained weights and move them to device using the state shardings from (3) above. + # This helps with loading sharded weights directly into the accelerators without fist copying them + # all to one device and then distributing them, thus using low HBM memory. params = load_wan_vae(config.pretrained_model_name_or_path, params, "cpu") params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) - params = jax.device_put(params, NamedSharding(mesh, P())) - wan_vae = nnx.merge(graphdef, params) - p_create_sharded_logical_model = partial(create_sharded_logical_model, logical_axis_rules=config.logical_axis_rules) - # Shard - with mesh: - wan_vae = p_create_sharded_logical_model(model=wan_vae) + for path, val in flax.traverse_util.flatten_dict(params).items(): + sharding = logical_state_sharding[path].value + state[path].value = device_put_replicated(val, sharding) + state = nnx.from_flat_state(state) + + wan_vae = nnx.merge(graphdef, state) + vae_cache = AutoencoderKLWanCache(wan_vae) return wan_vae, vae_cache @classmethod @@ -434,12 +449,13 @@ def __call__( prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, ) - latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) - latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) - latents = latents / latents_std + latents_mean - latents = latents.astype(self.config.weights_dtype) + latents_mean = jnp.array(self.vae.latents_mean).reshape(1, self.vae.z_dim, 1, 1, 1) + latents_std = 1.0 / jnp.array(self.vae.latents_std).reshape(1, self.vae.z_dim, 1, 1, 1) + latents = latents / latents_std + latents_mean + latents = latents.astype(self.config.weights_dtype) - video = self.vae.decode(latents, self.vae_cache)[0] + with self.mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): + video = self.vae.decode(latents, self.vae_cache)[0] video = jnp.transpose(video, (0, 4, 1, 2, 3)) video = torch.from_numpy(np.array(video.astype(dtype=jnp.float32))).to(dtype=torch.bfloat16) @@ -447,12 +463,31 @@ def __call__( return video -@jax.jit -def transformer_forward_pass(graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, is_uncond, slg_mask): +@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale")) +def transformer_forward_pass( + graphdef, + sharded_state, + rest_of_state, + latents, + timestep, + prompt_embeds, + is_uncond, + slg_mask, + do_classifier_free_guidance, + guidance_scale, +): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) - return wan_transformer( + noise_pred = wan_transformer( hidden_states=latents, timestep=timestep, encoder_hidden_states=prompt_embeds, is_uncond=is_uncond, slg_mask=slg_mask ) + if do_classifier_free_guidance: + bsz = latents.shape[0] // 2 + noise_uncond = noise_pred[bsz:] + noise_pred = noise_pred[:bsz] + noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) + latents = latents[:bsz] + + return noise_pred, latents def run_inference( @@ -472,35 +507,29 @@ def run_inference( slg_end: float = 1.0, ): do_classifier_free_guidance = guidance_scale > 1.0 + if do_classifier_free_guidance: + prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) for step in range(num_inference_steps): slg_mask = jnp.zeros(num_transformer_layers, dtype=jnp.bool_) if slg_layers and int(slg_start * num_inference_steps) <= step < int(slg_end * num_inference_steps): slg_mask = slg_mask.at[jnp.array(slg_layers)].set(True) t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] + if do_classifier_free_guidance: + latents = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, latents.shape[0]) - noise_pred = transformer_forward_pass( + noise_pred, latents = transformer_forward_pass( graphdef, sharded_state, rest_of_state, latents, timestep, prompt_embeds, - is_uncond=jnp.array(False, dtype=jnp.bool_), + is_uncond=jnp.array(True, dtype=jnp.bool_), slg_mask=slg_mask, + do_classifier_free_guidance=do_classifier_free_guidance, + guidance_scale=guidance_scale, ) - if do_classifier_free_guidance: - noise_uncond = transformer_forward_pass( - graphdef, - sharded_state, - rest_of_state, - latents, - timestep, - negative_prompt_embeds, - is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=slg_mask, - ) - noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() return latents diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index edcf96164..1ebd95c83 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -16,6 +16,7 @@ # pylint: disable=missing-module-docstring import os +import ast import json import sys from collections import OrderedDict @@ -36,7 +37,11 @@ def string_to_bool(s: str) -> bool: raise ValueError(f"Can't convert {s} to bool") -_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool} +def string_to_list(string_list: str) -> list: + return ast.literal_eval(string_list) + + +_yaml_types_to_parser = {str: str, int: int, float: float, bool: string_to_bool, list: string_to_list} _config = None config = None diff --git a/src/maxdiffusion/tests/attention_test.py b/src/maxdiffusion/tests/attention_test.py index 3b013b791..c2180240f 100644 --- a/src/maxdiffusion/tests/attention_test.py +++ b/src/maxdiffusion/tests/attention_test.py @@ -23,7 +23,6 @@ from ..models.attention_flax import FlaxAttention from .. import max_utils from .. import pyconfig -from maxdiffusion import FlaxUNet2DConditionModel THIS_DIR = os.path.dirname(os.path.abspath(__file__)) @@ -73,54 +72,26 @@ def test_splash_attention(self): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) flash_block_sizes = max_utils.get_flash_block_sizes(config) - splash_attention = FlaxAttention( - heads * head_depth, - heads, - head_depth, - split_head_dim=True, - attention_kernel="flash", - mesh=mesh, - dtype=jnp.bfloat16, - flash_block_sizes=flash_block_sizes, - ) - - params = splash_attention.init(key2, x)["params"] - p_apply = jax.jit(splash_attention.apply).lower({"params": params}, x).compile() - splash_attention_out = p_apply({"params": params}, x) + with mesh: + splash_attention = FlaxAttention( + heads * head_depth, + heads, + head_depth, + split_head_dim=True, + attention_kernel="flash", + mesh=mesh, + dtype=jnp.bfloat16, + flash_block_sizes=flash_block_sizes, + ) + + params = splash_attention.init(key2, x)["params"] + p_apply = jax.jit(splash_attention.apply).lower({"params": params}, x).compile() + splash_attention_out = p_apply({"params": params}, x) diff_norm = jnp.linalg.norm(dot_attention_out - splash_attention_out) assert diff_norm < 1.0 - def test_flash_block_sizes(self): - """Test loading flash block sizes from cli.""" - - pyconfig.initialize( - [ - None, - os.path.join(THIS_DIR, "..", "configs", "base_2_base.yml"), - 'flash_block_sizes={"block_q" : 256, "block_kv_compute": 256, "block_kv": 256,' - '"block_q_dkv": 256, "block_kv_dkv": 256, "block_kv_dkv_compute": 256,' - '"block_q_dq": 256, "block_kv_dq": 256}', - "attention=flash", - ], - unittest=True, - ) - config = pyconfig.config - devices_array = max_utils.create_device_mesh(config) - mesh = Mesh(devices_array, config.mesh_axes) - flash_block_sizes = max_utils.get_flash_block_sizes(config) - _, _ = FlaxUNet2DConditionModel.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - subfolder="unet", - dtype=jnp.bfloat16, - from_pt=config.from_pt, - attention_kernel=config.attention, - flash_block_sizes=flash_block_sizes, - mesh=mesh, - ) - if __name__ == "__main__": absltest.main() diff --git a/src/maxdiffusion/tests/flop_calculations_test.py b/src/maxdiffusion/tests/flop_calculations_test.py index f4465ec0c..db1216f72 100644 --- a/src/maxdiffusion/tests/flop_calculations_test.py +++ b/src/maxdiffusion/tests/flop_calculations_test.py @@ -1,16 +1,25 @@ import os import unittest import jax +from jax.sharding import Mesh import flax.linen as nn from absl.testing import absltest from maxdiffusion.max_utils import calculate_model_tflops from maxdiffusion.models.attention_flax import FlaxAttention +from .. import pyconfig, max_utils THIS_DIR = os.path.dirname(os.path.abspath(__file__)) class FlopCalculation(unittest.TestCase): + def setUp(self): + FlopCalculation.dummy_data = {} + pyconfig.initialize([None, os.path.join(THIS_DIR, "..", "configs", "base21.yml")], unittest=True) + self.config = pyconfig.config + devices_array = max_utils.create_device_mesh(self.config) + self.mesh = Mesh(devices_array, self.config.mesh_axes) + def test_dense_layer_model_flops(self): class SimpleLinearModel(nn.Module): @@ -45,8 +54,8 @@ def __call__(self, x): model = SimpleConv() rng = jax.random.PRNGKey(0) x = jax.random.normal(rng, (1, 28, 28, 1)) - - training_tflops = calculate_model_tflops(model, rng, train=True, x=x) + with self.mesh: + training_tflops = calculate_model_tflops(model, rng, train=True, x=x) macs = (3 * 3 * 28 * 28 * 16) + (3 * 3 * 28 * 28 * 32 * 16) + (28 * 28 * 32 * 10) forward_tflops = (2 * macs) / 10**12 calculated_training_tflops = 3 * forward_tflops @@ -67,8 +76,8 @@ def __call__(self, x): model = SimpleAttn() rng = jax.random.PRNGKey(0) x = jax.random.normal(rng, (1, 9216, 320)) - - training_tflops = calculate_model_tflops(model, rng, train=True, x=x) + with self.mesh: + training_tflops = calculate_model_tflops(model, rng, train=True, x=x) # For linears before attn qkv_macs = 3 * (320 * 320 * 9216) qkv_tflops = 2 * qkv_macs / 10**12 diff --git a/src/maxdiffusion/tests/input_pipeline_interface_test.py b/src/maxdiffusion/tests/input_pipeline_interface_test.py index 79e7a0891..92b1aa3f8 100644 --- a/src/maxdiffusion/tests/input_pipeline_interface_test.py +++ b/src/maxdiffusion/tests/input_pipeline_interface_test.py @@ -151,34 +151,35 @@ def test_make_pokemon_hf_iterator(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - p_encode = None - p_vae_apply = None - rng = None - tokenize_fn = partial( - tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) + with mesh: + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + p_encode = None + p_vae_apply = None + rng = None + tokenize_fn = partial( + tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() assert data["input_ids"].shape == (device_count, 77) assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) @@ -200,37 +201,38 @@ def test_make_pokemon_hf_iterator_sdxl(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - p_encode = None - p_vae_apply = None - rng = None - tokenize_fn = partial( - tokenize_captions_xl, - caption_column=config.caption_column, - tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], - p_encode=p_encode, - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) + with mesh: + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + p_encode = None + p_vae_apply = None + rng = None + tokenize_fn = partial( + tokenize_captions_xl, + caption_column=config.caption_column, + tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], + p_encode=p_encode, + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() assert data["input_ids"].shape == (device_count, 2, 77) assert data["pixel_values"].shape == (device_count, 3, config.resolution, config.resolution) @@ -253,40 +255,41 @@ def test_make_pokemon_tf_iterator_cache(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - rng = jax.random.PRNGKey(config.seed) - p_encode = None - p_vae_apply = None - if config.cache_latents_text_encoder_outputs: - p_encode = jax.jit(partial(encode, text_encoder=pipeline.text_encoder, text_encoder_params=params["text_encoder"])) - p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) - tokenize_fn = partial( - tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) + with mesh: + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + rng = jax.random.PRNGKey(config.seed) + p_encode = None + p_vae_apply = None + if config.cache_latents_text_encoder_outputs: + p_encode = jax.jit(partial(encode, text_encoder=pipeline.text_encoder, text_encoder_params=params["text_encoder"])) + p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) + tokenize_fn = partial( + tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - encoder_hidden_states = data["input_ids"] + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + encoder_hidden_states = data["input_ids"] assert encoder_hidden_states.shape == (device_count, 77, 1024) assert data["pixel_values"].shape == ( @@ -316,37 +319,38 @@ def test_make_pokemon_iterator_no_cache(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - rng = jax.random.PRNGKey(config.seed) - p_encode = None - p_vae_apply = None - if config.cache_latents_text_encoder_outputs: - p_encode = jax.jit(partial(encode, text_encoder=pipeline.text_encoder, text_encoder_params=params["text_encoder"])) - p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) - tokenize_fn = partial( - tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) + with mesh: + pipeline, params = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + rng = jax.random.PRNGKey(config.seed) + p_encode = None + p_vae_apply = None + if config.cache_latents_text_encoder_outputs: + p_encode = jax.jit(partial(encode, text_encoder=pipeline.text_encoder, text_encoder_params=params["text_encoder"])) + p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) + tokenize_fn = partial( + tokenize_captions, caption_column=config.caption_column, tokenizer=pipeline.tokenizer, p_encode=p_encode + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, + ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() encoder_hidden_states = data["input_ids"] assert encoder_hidden_states.shape == (device_count, 77) @@ -372,51 +376,52 @@ def test_make_pokemon_iterator_sdxl_cache(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - rng = jax.random.PRNGKey(config.seed) - p_encode = None - p_vae_apply = None - if config.cache_latents_text_encoder_outputs: - p_encode = jax.jit( - partial( - encode_xl, - text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2], - text_encoder_params=[params["text_encoder"], params["text_encoder_2"]], - ) + with mesh: + pipeline, params = FlaxStableDiffusionXLPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) + rng = jax.random.PRNGKey(config.seed) + p_encode = None + p_vae_apply = None + if config.cache_latents_text_encoder_outputs: + p_encode = jax.jit( + partial( + encode_xl, + text_encoders=[pipeline.text_encoder, pipeline.text_encoder_2], + text_encoder_params=[params["text_encoder"], params["text_encoder_2"]], + ) + ) + p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) + tokenize_fn = partial( + tokenize_captions_xl, + caption_column=config.caption_column, + tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], + p_encode=p_encode, + ) + image_transforms_fn = partial( + transform_images, + image_column=config.image_column, + image_resolution=config.resolution, + rng=rng, + global_batch_size=global_batch_size, + p_vae_apply=p_vae_apply, ) - p_vae_apply = jax.jit(partial(vae_apply, vae=pipeline.vae, vae_params=params["vae"])) - tokenize_fn = partial( - tokenize_captions_xl, - caption_column=config.caption_column, - tokenizers=[pipeline.tokenizer, pipeline.tokenizer_2], - p_encode=p_encode, - ) - image_transforms_fn = partial( - transform_images, - image_column=config.image_column, - image_resolution=config.resolution, - rng=rng, - global_batch_size=global_batch_size, - p_vae_apply=p_vae_apply, - ) - train_iterator = make_data_iterator( - config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn - ) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator( + config, jax.process_index(), jax.process_count(), mesh, global_batch_size, tokenize_fn, image_transforms_fn + ) + data = next(train_iterator) + device_count = jax.device_count() - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - prompt_embeds = data["prompt_embeds"] - text_embeds = data["text_embeds"] + prompt_embeds = data["prompt_embeds"] + text_embeds = data["text_embeds"] assert prompt_embeds.shape == (device_count, 77, 2048) assert text_embeds.shape == (device_count, 1280) assert data["pixel_values"].shape == ( @@ -452,27 +457,27 @@ def test_make_laion_grain_iterator(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) + with mesh: + pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) - pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - - train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) - data = next(train_iterator) - device_count = jax.device_count() + train_iterator = make_data_iterator(config, jax.process_index(), jax.process_count(), mesh, global_batch_size) + data = next(train_iterator) + device_count = jax.device_count() - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - encoder_hidden_states = data["input_ids"] + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + encoder_hidden_states = data["input_ids"] - # TODO - laion dataset was prepared with an extra dim. - # need to preprocess the dataset with dim removed. - if len(encoder_hidden_states.shape) == 4: - encoder_hidden_states = jnp.squeeze(encoder_hidden_states) + # TODO - laion dataset was prepared with an extra dim. + # need to preprocess the dataset with dim removed. + if len(encoder_hidden_states.shape) == 4: + encoder_hidden_states = jnp.squeeze(encoder_hidden_states) assert encoder_hidden_states.shape == (device_count, 77, 1024) assert data["pixel_values"].shape == ( @@ -496,43 +501,43 @@ def test_make_laion_tfrecord_iterator(self): global_batch_size = config.per_device_batch_size * jax.device_count() devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) + with mesh: + pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + dtype=config.activations_dtype, + safety_checker=None, + feature_extractor=None, + from_pt=config.from_pt, + ) - pipeline, _ = FlaxStableDiffusionPipeline.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - dtype=config.activations_dtype, - safety_checker=None, - feature_extractor=None, - from_pt=config.from_pt, - ) - - feature_description = { - "moments": tf.io.FixedLenFeature([], tf.string), - "clip_embeddings": tf.io.FixedLenFeature([], tf.string), - } - - def _parse_tfrecord_fn(example): - return tf.io.parse_single_example(example, feature_description) - - train_iterator = make_data_iterator( - config, - jax.process_index(), - jax.process_count(), - mesh, - global_batch_size, - feature_description=feature_description, - prepare_sample_fn=_parse_tfrecord_fn, - ) - data = next(train_iterator) - device_count = jax.device_count() + feature_description = { + "moments": tf.io.FixedLenFeature([], tf.string), + "clip_embeddings": tf.io.FixedLenFeature([], tf.string), + } + + def _parse_tfrecord_fn(example): + return tf.io.parse_single_example(example, feature_description) + + train_iterator = make_data_iterator( + config, + jax.process_index(), + jax.process_count(), + mesh, + global_batch_size, + feature_description=feature_description, + prepare_sample_fn=_parse_tfrecord_fn, + ) + data = next(train_iterator) + device_count = jax.device_count() - vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) - encoder_hidden_states = data["input_ids"] + vae_scale_factor = 2 ** (len(pipeline.vae.config.block_out_channels) - 1) + encoder_hidden_states = data["input_ids"] - # TODO - laion dataset was prepared with an extra dim. - # need to preprocess the dataset with dim removed. - if len(encoder_hidden_states.shape) == 4: - encoder_hidden_states = jnp.squeeze(encoder_hidden_states) + # TODO - laion dataset was prepared with an extra dim. + # need to preprocess the dataset with dim removed. + if len(encoder_hidden_states.shape) == 4: + encoder_hidden_states = jnp.squeeze(encoder_hidden_states) assert encoder_hidden_states.shape == (device_count, 77, 1024) assert data["pixel_values"].shape == ( diff --git a/src/maxdiffusion/tests/unet_test.py b/src/maxdiffusion/tests/unet_test.py index e24852636..562fb5a33 100644 --- a/src/maxdiffusion/tests/unet_test.py +++ b/src/maxdiffusion/tests/unet_test.py @@ -51,31 +51,34 @@ def test_unet15_sharding_test(self): unittest=True, ) config = pyconfig.config - unet, params = FlaxUNet2DConditionModel.from_pretrained( - config.pretrained_model_name_or_path, - revision=config.revision, - subfolder="unet", - dtype=jnp.bfloat16, - from_pt=config.from_pt, - ) devices_array = max_utils.create_device_mesh(config) - - rng = jax.random.PRNGKey(config.seed) mesh = Mesh(devices_array, config.mesh_axes) - k = jax.random.key(0) - tx = optax.adam(learning_rate=0.001) - latents = jnp.ones((4, 4, 64, 64), dtype=jnp.float32) - timesteps = jnp.ones((4,)) - encoder_hidden_states = jnp.ones((4, 77, 1024)) - - variables = jax.jit(unet.init)(k, latents, timesteps, encoder_hidden_states) - weights_init_fn = functools.partial(unet.init_weights, rng=rng) - _, state_mesh_annotations, _ = max_utils.get_abstract_state(unet, tx, config, mesh, weights_init_fn, False) - del variables - conv_sharding = PartitionSpec(None, None, None, "fsdp") - qkv_sharding = PartitionSpec("fsdp", "tensor") - to_out_sharding = PartitionSpec("tensor", "fsdp") - time_emb_proj_sharding = PartitionSpec() + with mesh: + unet, params = FlaxUNet2DConditionModel.from_pretrained( + config.pretrained_model_name_or_path, + revision=config.revision, + subfolder="unet", + dtype=jnp.bfloat16, + from_pt=config.from_pt, + ) + devices_array = max_utils.create_device_mesh(config) + + rng = jax.random.PRNGKey(config.seed) + mesh = Mesh(devices_array, config.mesh_axes) + k = jax.random.key(0) + tx = optax.adam(learning_rate=0.001) + latents = jnp.ones((4, 4, 64, 64), dtype=jnp.float32) + timesteps = jnp.ones((4,)) + encoder_hidden_states = jnp.ones((4, 77, 1024)) + + variables = jax.jit(unet.init)(k, latents, timesteps, encoder_hidden_states) + weights_init_fn = functools.partial(unet.init_weights, rng=rng) + _, state_mesh_annotations, _ = max_utils.get_abstract_state(unet, tx, config, mesh, weights_init_fn, False) + del variables + conv_sharding = PartitionSpec(None, None, None, "fsdp") + qkv_sharding = PartitionSpec("fsdp", "tensor") + to_out_sharding = PartitionSpec("tensor", "fsdp") + time_emb_proj_sharding = PartitionSpec() assert state_mesh_annotations.params["down_blocks_0"]["resnets_0"]["time_emb_proj"]["kernel"] == time_emb_proj_sharding assert state_mesh_annotations.params["down_blocks_0"]["downsamplers_0"]["conv"]["kernel"] == conv_sharding @@ -97,10 +100,10 @@ def test_unet15_sharding_test(self): state_mesh_annotations.params["down_blocks_1"]["attentions_1"]["transformer_blocks_0"]["attn1"]["to_out_0"]["kernel"] == to_out_sharding ) - - state, state_mesh_shardings = max_utils.setup_initial_state( - unet, tx, config, mesh, weights_init_fn, None, None, None, False - ) + with mesh: + state, state_mesh_shardings = max_utils.setup_initial_state( + unet, tx, config, mesh, weights_init_fn, None, None, None, False + ) # Validate named shardings. conv_named_sharding = NamedSharding(mesh, conv_sharding) diff --git a/src/maxdiffusion/tests/vae_test.py b/src/maxdiffusion/tests/vae_test.py index cf7fb399d..e3a46b109 100644 --- a/src/maxdiffusion/tests/vae_test.py +++ b/src/maxdiffusion/tests/vae_test.py @@ -47,7 +47,6 @@ def test_flux_vae(self): image = 2.0 * image - 1.0 image = np.expand_dims(image, 0) image = np.transpose(image, (0, 3, 1, 2)) # (1, 3, 1024, 1024), BCWH - vae, vae_params = FlaxAutoencoderKL.from_pretrained( "black-forest-labs/FLUX.1-dev", subfolder="vae", from_pt=True, use_safetensors=True, dtype="bfloat16" ) diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 17741191a..4ea50cc7a 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -248,7 +248,7 @@ def test_wan_model(self): batch_size = 1 channels = 16 - frames = 21 + frames = 1 height = 90 width = 160 hidden_states_shape = (batch_size, channels, frames, height, width) @@ -262,12 +262,8 @@ def test_wan_model(self): mesh = Mesh(devices_array, config.mesh_axes) batch_size = 1 - wan_model = WanModel( - rngs=rngs, - attention="flash", - mesh=mesh, - flash_block_sizes=flash_block_sizes, - ) + num_layers = 1 + wan_model = WanModel(rngs=rngs, attention="flash", mesh=mesh, flash_block_sizes=flash_block_sizes, num_layers=num_layers) dummy_timestep = jnp.ones((batch_size)) dummy_encoder_hidden_states = jnp.ones((batch_size, 512, 4096)) @@ -277,7 +273,7 @@ def test_wan_model(self): timestep=dummy_timestep, encoder_hidden_states=dummy_encoder_hidden_states, is_uncond=jnp.array(True, dtype=jnp.bool_), - slg_mask=jnp.zeros(40, dtype=jnp.bool_), + slg_mask=jnp.zeros(num_layers, dtype=jnp.bool_), ) assert dummy_output.shape == hidden_states_shape