diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 8149c8292..e8fb88043 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -48,6 +48,8 @@ replicate_vae: False # fp32 activations and fp32 weights with HIGHEST will provide the best precision # at the cost of time. precision: "DEFAULT" +# Use jax.lax.scan for transformer layers +scan_layers: True # if False state is not jitted and instead replicate is called. This is good for debugging on single host # It must be True for multi-host. diff --git a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py index ef5cd125b..cc0c8fd33 100644 --- a/src/maxdiffusion/input_pipeline/_tfds_data_processing.py +++ b/src/maxdiffusion/input_pipeline/_tfds_data_processing.py @@ -24,6 +24,7 @@ AUTOTUNE = tf.data.AUTOTUNE os.environ["TOKENIZERS_PARALLELISM"] = "false" + def load_as_tf_dataset(dataset, global_batch_size, shuffle, dataloading_host_count): dataset = dataset.with_format("tensorflow")[:] tf_dataset = tf.data.Dataset.from_tensor_slices(dataset) diff --git a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py index f960afd28..27f2ad259 100644 --- a/src/maxdiffusion/input_pipeline/input_pipeline_interface.py +++ b/src/maxdiffusion/input_pipeline/input_pipeline_interface.py @@ -42,6 +42,7 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE os.environ["TOKENIZERS_PARALLELISM"] = "false" + def make_data_iterator( config, dataloading_host_index, diff --git a/src/maxdiffusion/models/gradient_checkpoint.py b/src/maxdiffusion/models/gradient_checkpoint.py index a111ef717..086223f84 100644 --- a/src/maxdiffusion/models/gradient_checkpoint.py +++ b/src/maxdiffusion/models/gradient_checkpoint.py @@ -85,6 +85,7 @@ def apply( names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], static_argnums=(), + prevent_cse: bool = False, ) -> nnx.Module: """ Applies a gradient checkpoint policy to a module @@ -99,4 +100,4 @@ def apply( policy = self.to_jax_policy(names_which_can_be_saved, names_which_can_be_offloaded) if policy == SKIP_GRADIENT_CHECKPOINT_KEY: return module - return nnx.remat(module, prevent_cse=False, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name + return nnx.remat(module, prevent_cse=prevent_cse, policy=policy, static_argnums=static_argnums) # pylint: disable=invalid-name diff --git a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py index 0edab5070..685b0c0b3 100644 --- a/src/maxdiffusion/models/modeling_flax_pytorch_utils.py +++ b/src/maxdiffusion/models/modeling_flax_pytorch_utils.py @@ -25,7 +25,6 @@ from chex import Array from ..utils import logging from .. import max_logging -from .. import common_types logger = logging.get_logger(__name__) @@ -87,7 +86,7 @@ def rename_key(key): # Adapted from https://github.com/huggingface/transformers/blob/c603c80f46881ae18b2ca50770ef65fa4033eacd/src/transformers/modeling_flax_pytorch_utils.py#L69 # and https://github.com/patil-suraj/stable-diffusion-jax/blob/main/stable_diffusion_jax/convert_diffusers_to_jax.py -def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, model_type=None): +def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dict, scan_layers=False): """Rename PT weight names to corresponding Flax weight names and reshape tensor if necessary""" # conv norm or layer norm renamed_pt_tuple_key = pt_tuple_key[:-1] + ("scale",) @@ -112,12 +111,12 @@ def rename_key_and_reshape_tensor(pt_tuple_key, pt_tensor, random_flax_state_dic if isinstance(random_flax_state_dict[renamed_pt_tuple_key], Partitioned): # Wan 2.1 uses nnx.scan and nnx.vmap which stacks layer weights which will cause a shape mismatch # from the original weights which are not stacked. - if model_type is not None and model_type == common_types.WAN_MODEL: + if scan_layers: pass else: assert random_flax_state_dict[renamed_pt_tuple_key].value.shape == pt_tensor.T.shape else: - if model_type is not None and model_type == common_types.WAN_MODEL: + if scan_layers: pass else: assert random_flax_state_dict[renamed_pt_tuple_key].shape == pt_tensor.T.shape diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 48ed7b8ec..a246380f1 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -396,10 +396,12 @@ def __init__( remat_policy: str = "None", names_which_can_be_saved: list = [], names_which_can_be_offloaded: list = [], + scan_layers: bool = True, ): inner_dim = num_attention_heads * attention_head_dim out_channels = out_channels or in_channels self.num_layers = num_layers + self.scan_layers = scan_layers # 1. Patch & position embedding self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len) @@ -455,8 +457,29 @@ def init_block(rngs): self.gradient_checkpoint = GradientCheckpointType.from_str(remat_policy) self.names_which_can_be_offloaded = names_which_can_be_offloaded self.names_which_can_be_saved = names_which_can_be_saved - - self.blocks = init_block(rngs) + if scan_layers: + self.blocks = init_block(rngs) + else: + blocks = nnx.List([]) + for _ in range(num_layers): + block = WanTransformerBlock( + rngs=rngs, + dim=inner_dim, + ffn_dim=ffn_dim, + num_heads=num_attention_heads, + qk_norm=qk_norm, + cross_attn_norm=cross_attn_norm, + eps=eps, + flash_min_seq_length=flash_min_seq_length, + flash_block_sizes=flash_block_sizes, + mesh=mesh, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, + attention=attention, + ) + blocks.append(block) + self.blocks = blocks self.norm_out = FP32LayerNorm(rngs=rngs, dim=inner_dim, eps=eps, elementwise_affine=False) self.proj_out = nnx.Linear( @@ -505,24 +528,38 @@ def __call__( if encoder_hidden_states_image is not None: raise NotImplementedError("img2vid is not yet implemented.") - def scan_fn(carry, block): - hidden_states_carry, rngs_carry = carry - hidden_states = block(hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry) - new_carry = (hidden_states, rngs_carry) - return new_carry, None + if self.scan_layers: - rematted_block_forward = self.gradient_checkpoint.apply( - scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded - ) - initial_carry = (hidden_states, rngs) - final_carry, _ = nnx.scan( - rematted_block_forward, - length=self.num_layers, - in_axes=(nnx.Carry, 0), - out_axes=(nnx.Carry, 0), - )(initial_carry, self.blocks) - - hidden_states, _ = final_carry + def scan_fn(carry, block): + hidden_states_carry, rngs_carry = carry + hidden_states = block( + hidden_states_carry, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs_carry + ) + new_carry = (hidden_states, rngs_carry) + return new_carry, None + + rematted_block_forward = self.gradient_checkpoint.apply( + scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + initial_carry = (hidden_states, rngs) + final_carry, _ = nnx.scan( + rematted_block_forward, + length=self.num_layers, + in_axes=(nnx.Carry, 0), + out_axes=(nnx.Carry, 0), + )(initial_carry, self.blocks) + + hidden_states, _ = final_carry + else: + for block in self.blocks: + + def layer_forward(hidden_states): + return block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb, deterministic, rngs) + + rematted_layer_forward = self.gradient_checkpoint.apply( + layer_forward, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + ) + hidden_states = rematted_layer_forward(hidden_states) shift, scale = jnp.split(self.scale_shift_table + jnp.expand_dims(temb, axis=1), 2, axis=1) diff --git a/src/maxdiffusion/models/wan/wan_utils.py b/src/maxdiffusion/models/wan/wan_utils.py index 628207a9a..ec97abd30 100644 --- a/src/maxdiffusion/models/wan/wan_utils.py +++ b/src/maxdiffusion/models/wan/wan_utils.py @@ -24,7 +24,6 @@ from safetensors import safe_open from flax.traverse_util import unflatten_dict, flatten_dict from ..modeling_flax_pytorch_utils import (rename_key, rename_key_and_reshape_tensor, torch2jax, validate_flax_state_dict) -from ...common_types import WAN_MODEL CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH = "lightx2v/Wan2.1-T2V-14B-CausVid" WAN_21_FUSION_X_MODEL_NAME_OR_PATH = "vrgamedevgirl84/Wan14BT2VFusioniX" @@ -73,8 +72,35 @@ def rename_for_custom_trasformer(key): return renamed_pt_key +def get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers): + if scan_layers: + if "blocks" in pt_tuple_key: + new_key = ("blocks",) + pt_tuple_key[2:] + block_index = int(pt_tuple_key[1]) + pt_tuple_key = new_key + + flax_key, flax_tensor = rename_key_and_reshape_tensor(pt_tuple_key, tensor, random_flax_state_dict, scan_layers) + + flax_key = rename_for_nnx(flax_key) + flax_key = _tuple_str_to_int(flax_key) + + if scan_layers: + if "blocks" in flax_key: + if flax_key in flax_state_dict: + new_tensor = flax_state_dict[flax_key] + else: + new_tensor = jnp.zeros((40,) + flax_tensor.shape) + flax_tensor = new_tensor.at[block_index].set(flax_tensor) + return flax_key, flax_tensor + + def load_fusionx_transformer( - pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 + pretrained_model_name_or_path: str, + eval_shapes: dict, + device: str, + hf_download: bool = True, + num_layers: int = 40, + scan_layers: bool = True, ): device = jax.local_devices(backend=device)[0] with jax.default_device(device): @@ -101,23 +127,9 @@ def load_fusionx_transformer( pt_tuple_key = tuple(renamed_pt_key.split(".")) - if "blocks" in pt_tuple_key: - new_key = ("blocks",) + pt_tuple_key[2:] - block_index = int(pt_tuple_key[1]) - pt_tuple_key = new_key - flax_key, flax_tensor = rename_key_and_reshape_tensor( - pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL - ) - flax_key = rename_for_nnx(flax_key) - flax_key = _tuple_str_to_int(flax_key) - - if "blocks" in flax_key: - if flax_key in flax_state_dict: - new_tensor = flax_state_dict[flax_key] - else: - new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) - flax_tensor = new_tensor.at[block_index].set(flax_tensor) + flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) del tensors @@ -126,7 +138,12 @@ def load_fusionx_transformer( def load_causvid_transformer( - pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 + pretrained_model_name_or_path: str, + eval_shapes: dict, + device: str, + hf_download: bool = True, + num_layers: int = 40, + scan_layers: bool = True, ): device = jax.local_devices(backend=device)[0] with jax.default_device(device): @@ -150,24 +167,9 @@ def load_causvid_transformer( renamed_pt_key = rename_for_custom_trasformer(renamed_pt_key) pt_tuple_key = tuple(renamed_pt_key.split(".")) - - if "blocks" in pt_tuple_key: - new_key = ("blocks",) + pt_tuple_key[2:] - block_index = int(pt_tuple_key[1]) - pt_tuple_key = new_key - flax_key, flax_tensor = rename_key_and_reshape_tensor( - pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL - ) - flax_key = rename_for_nnx(flax_key) - flax_key = _tuple_str_to_int(flax_key) - - if "blocks" in flax_key: - if flax_key in flax_state_dict: - new_tensor = flax_state_dict[flax_key] - else: - new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) - flax_tensor = new_tensor.at[block_index].set(flax_tensor) + flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) del tensors @@ -176,19 +178,31 @@ def load_causvid_transformer( def load_wan_transformer( - pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 + pretrained_model_name_or_path: str, + eval_shapes: dict, + device: str, + hf_download: bool = True, + num_layers: int = 40, + scan_layers: bool = True, ): if pretrained_model_name_or_path == CAUSVID_TRANSFORMER_MODEL_NAME_OR_PATH: - return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) + return load_causvid_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) elif pretrained_model_name_or_path == WAN_21_FUSION_X_MODEL_NAME_OR_PATH: - return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) + return load_fusionx_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers) else: - return load_base_wan_transformer(pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers) + return load_base_wan_transformer( + pretrained_model_name_or_path, eval_shapes, device, hf_download, num_layers, scan_layers + ) def load_base_wan_transformer( - pretrained_model_name_or_path: str, eval_shapes: dict, device: str, hf_download: bool = True, num_layers: int = 40 + pretrained_model_name_or_path: str, + eval_shapes: dict, + device: str, + hf_download: bool = True, + num_layers: int = 40, + scan_layers: bool = True, ): device = jax.local_devices(backend=device)[0] subfolder = "transformer" @@ -247,24 +261,9 @@ def load_base_wan_transformer( renamed_pt_key = renamed_pt_key.replace("ffn.net_0", "ffn.act_fn") renamed_pt_key = renamed_pt_key.replace("norm2", "norm2.layer_norm") pt_tuple_key = tuple(renamed_pt_key.split(".")) - - if "blocks" in pt_tuple_key: - new_key = ("blocks",) + pt_tuple_key[2:] - block_index = int(pt_tuple_key[1]) - pt_tuple_key = new_key - flax_key, flax_tensor = rename_key_and_reshape_tensor( - pt_tuple_key, tensor, random_flax_state_dict, model_type=WAN_MODEL - ) - flax_key = rename_for_nnx(flax_key) - flax_key = _tuple_str_to_int(flax_key) - - if "blocks" in flax_key: - if flax_key in flax_state_dict: - new_tensor = flax_state_dict[flax_key] - else: - new_tensor = jnp.zeros((num_layers,) + flax_tensor.shape) - flax_tensor = new_tensor.at[block_index].set(flax_tensor) + flax_key, flax_tensor = get_key_and_value(pt_tuple_key, tensor, flax_state_dict, random_flax_state_dict, scan_layers) flax_state_dict[flax_key] = jax.device_put(jnp.asarray(flax_tensor), device=cpu) + validate_flax_state_dict(eval_shapes, flax_state_dict) flax_state_dict = unflatten_dict(flax_state_dict) del tensors diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index c78d8bae2..f596d827e 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -90,6 +90,7 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): wan_config["names_which_can_be_offloaded"] = config.names_which_can_be_offloaded wan_config["flash_min_seq_length"] = config.flash_min_seq_length wan_config["dropout"] = config.dropout + wan_config["scan_layers"] = config.scan_layers # 2. eval_shape - will not use flops or create weights on device # thus not using HBM memory. @@ -111,7 +112,11 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): params = restored_checkpoint["wan_state"] else: params = load_wan_transformer( - config.wan_transformer_pretrained_model_name_or_path, params, "cpu", num_layers=wan_config["num_layers"] + config.wan_transformer_pretrained_model_name_or_path, + params, + "cpu", + num_layers=wan_config["num_layers"], + scan_layers=config.scan_layers, ) params = jax.tree_util.tree_map(lambda x: x.astype(config.weights_dtype), params) for path, val in flax.traverse_util.flatten_dict(params).items(): @@ -249,7 +254,7 @@ def get_basic_config(cls, dtype, config: HyperParameters): module_path=config.qwix_module_path, weight_qtype=dtype, act_qtype=dtype, - op_names=("dot_general","einsum", "conv_general_dilated"), + op_names=("dot_general", "einsum", "conv_general_dilated"), ) ] return rules @@ -272,11 +277,11 @@ def get_fp8_config(cls, config: HyperParameters): weight_calibration_method=config.quantization_calibration_method, act_calibration_method=config.quantization_calibration_method, bwd_calibration_method=config.quantization_calibration_method, - op_names=("dot_general","einsum"), + op_names=("dot_general", "einsum"), ), qwix.QtRule( module_path=config.qwix_module_path, - weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes + weight_qtype=jnp.float8_e4m3fn, # conv_general_dilated requires the same dtypes act_qtype=jnp.float8_e4m3fn, bwd_qtype=jnp.float8_e4m3fn, bwd_use_original_residuals=True, @@ -285,7 +290,7 @@ def get_fp8_config(cls, config: HyperParameters): act_calibration_method=config.quantization_calibration_method, bwd_calibration_method=config.quantization_calibration_method, op_names=("conv_general_dilated"), - ) + ), ] return rules diff --git a/src/maxdiffusion/tests/configuration_utils_test.py b/src/maxdiffusion/tests/configuration_utils_test.py index a70aac1ac..b7dae6814 100644 --- a/src/maxdiffusion/tests/configuration_utils_test.py +++ b/src/maxdiffusion/tests/configuration_utils_test.py @@ -5,38 +5,40 @@ from maxdiffusion.configuration_utils import ConfigMixin from maxdiffusion import __version__ + class DummyConfigMixin(ConfigMixin): - config_name = "config.json" + config_name = "config.json" + + def __init__(self, **kwargs): + self.register_to_config(**kwargs) - def __init__(self, **kwargs): - self.register_to_config(**kwargs) def test_to_json_string_with_config(): - # Load the YAML config file - config_path = os.path.join(os.path.dirname(__file__), "..", "configs", "base_wan_14b.yml") + # Load the YAML config file + config_path = os.path.join(os.path.dirname(__file__), "..", "configs", "base_wan_14b.yml") - # Initialize pyconfig with the YAML config - pyconfig.initialize([None, config_path], unittest=True) - config = pyconfig.config + # Initialize pyconfig with the YAML config + pyconfig.initialize([None, config_path], unittest=True) + config = pyconfig.config - # Create a DummyConfigMixin instance - dummy_config = DummyConfigMixin(**config.get_keys()) + # Create a DummyConfigMixin instance + dummy_config = DummyConfigMixin(**config.get_keys()) - # Get the JSON string - json_string = dummy_config.to_json_string() + # Get the JSON string + json_string = dummy_config.to_json_string() - # Parse the JSON string - parsed_json = json.loads(json_string) + # Parse the JSON string + parsed_json = json.loads(json_string) - # Assertions - assert parsed_json["_class_name"] == "DummyConfigMixin" - assert parsed_json["_diffusers_version"] == __version__ + # Assertions + assert parsed_json["_class_name"] == "DummyConfigMixin" + assert parsed_json["_diffusers_version"] == __version__ - # Check a few values from the config - assert parsed_json["run_name"] == config.run_name - assert parsed_json["pretrained_model_name_or_path"] == config.pretrained_model_name_or_path - assert parsed_json["flash_block_sizes"]["block_q"] == config.flash_block_sizes["block_q"] + # Check a few values from the config + assert parsed_json["run_name"] == config.run_name + assert parsed_json["pretrained_model_name_or_path"] == config.pretrained_model_name_or_path + assert parsed_json["flash_block_sizes"]["block_q"] == config.flash_block_sizes["block_q"] - # The following keys are explicitly removed in to_json_string, so we assert they are not present - assert "weights_dtype" not in parsed_json - assert "precision" not in parsed_json + # The following keys are explicitly removed in to_json_string, so we assert they are not present + assert "weights_dtype" not in parsed_json + assert "precision" not in parsed_json diff --git a/src/maxdiffusion/tests/wan_transformer_test.py b/src/maxdiffusion/tests/wan_transformer_test.py index 3d1327c3b..2a3da9094 100644 --- a/src/maxdiffusion/tests/wan_transformer_test.py +++ b/src/maxdiffusion/tests/wan_transformer_test.py @@ -40,7 +40,7 @@ import qwix import flax -flax.config.update('flax_always_shard_variable', False) +flax.config.update("flax_always_shard_variable", False) RealQtRule = qwix.QtRule @@ -287,8 +287,10 @@ def test_get_qt_provider(self, mock_qt_rule): """ Tests the provider logic for all config branches. """ + def create_real_rule_instance(*args, **kwargs): - return RealQtRule(*args, **kwargs) + return RealQtRule(*args, **kwargs) + mock_qt_rule.side_effect = create_real_rule_instance # Case 1: Quantization disabled @@ -303,7 +305,12 @@ def create_real_rule_instance(*args, **kwargs): config_int8.qwix_module_path = ".*" provider_int8 = WanPipeline.get_qt_provider(config_int8) self.assertIsNotNone(provider_int8) - mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.int8, act_qtype=jnp.int8, op_names=("dot_general","einsum", "conv_general_dilated")) + mock_qt_rule.assert_called_once_with( + module_path=".*", + weight_qtype=jnp.int8, + act_qtype=jnp.int8, + op_names=("dot_general", "einsum", "conv_general_dilated"), + ) # Case 3: Quantization enabled, type 'fp8' mock_qt_rule.reset_mock() @@ -313,7 +320,12 @@ def create_real_rule_instance(*args, **kwargs): config_fp8.qwix_module_path = ".*" provider_fp8 = WanPipeline.get_qt_provider(config_fp8) self.assertIsNotNone(provider_fp8) - mock_qt_rule.assert_called_once_with(module_path=".*", weight_qtype=jnp.float8_e4m3fn, act_qtype=jnp.float8_e4m3fn, op_names=("dot_general","einsum", "conv_general_dilated")) + mock_qt_rule.assert_called_once_with( + module_path=".*", + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + op_names=("dot_general", "einsum", "conv_general_dilated"), + ) # Case 4: Quantization enabled, type 'fp8_full' mock_qt_rule.reset_mock() @@ -325,29 +337,30 @@ def create_real_rule_instance(*args, **kwargs): provider_fp8_full = WanPipeline.get_qt_provider(config_fp8_full) self.assertIsNotNone(provider_fp8_full) expected_calls = [ - call(module_path=".*", # Apply to all modules - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e5m2, - bwd_use_original_residuals=True, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config_fp8_full.quantization_calibration_method, - act_calibration_method=config_fp8_full.quantization_calibration_method, - bwd_calibration_method=config_fp8_full.quantization_calibration_method, - op_names=("dot_general","einsum"), + call( + module_path=".*", # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e5m2, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=config_fp8_full.quantization_calibration_method, + act_calibration_method=config_fp8_full.quantization_calibration_method, + bwd_calibration_method=config_fp8_full.quantization_calibration_method, + op_names=("dot_general", "einsum"), + ), + call( + module_path=".*", # Apply to all modules + weight_qtype=jnp.float8_e4m3fn, + act_qtype=jnp.float8_e4m3fn, + bwd_qtype=jnp.float8_e4m3fn, + bwd_use_original_residuals=True, + disable_channelwise_axes=True, # per_tensor calibration + weight_calibration_method=config_fp8_full.quantization_calibration_method, + act_calibration_method=config_fp8_full.quantization_calibration_method, + bwd_calibration_method=config_fp8_full.quantization_calibration_method, + op_names=("conv_general_dilated"), ), - call( - module_path=".*", # Apply to all modules - weight_qtype=jnp.float8_e4m3fn, - act_qtype=jnp.float8_e4m3fn, - bwd_qtype=jnp.float8_e4m3fn, - bwd_use_original_residuals=True, - disable_channelwise_axes=True, # per_tensor calibration - weight_calibration_method=config_fp8_full.quantization_calibration_method, - act_calibration_method=config_fp8_full.quantization_calibration_method, - bwd_calibration_method=config_fp8_full.quantization_calibration_method, - op_names=("conv_general_dilated"), - ) ] mock_qt_rule.assert_has_calls(expected_calls, any_order=True) diff --git a/src/maxdiffusion/tests/wan_vae_test.py b/src/maxdiffusion/tests/wan_vae_test.py index 66d8dce9d..2268411c2 100644 --- a/src/maxdiffusion/tests/wan_vae_test.py +++ b/src/maxdiffusion/tests/wan_vae_test.py @@ -52,7 +52,9 @@ CACHE_T = 2 -flax.config.update('flax_always_shard_variable', False) +flax.config.update("flax_always_shard_variable", False) + + class TorchWanRMS_norm(nn.Module): r""" A custom RMS normalization layer. diff --git a/src/maxdiffusion/train_flux.py b/src/maxdiffusion/train_flux.py index fef182062..05cdae44d 100644 --- a/src/maxdiffusion/train_flux.py +++ b/src/maxdiffusion/train_flux.py @@ -40,6 +40,7 @@ def main(argv: Sequence[str]) -> None: max_logging.log(f"Found {jax.device_count()} devices.") train(config) + if __name__ == "__main__": with transformer_engine_context(): app.run(main) diff --git a/src/maxdiffusion/train_utils.py b/src/maxdiffusion/train_utils.py index 8f986a90a..9d2b8a3f1 100644 --- a/src/maxdiffusion/train_utils.py +++ b/src/maxdiffusion/train_utils.py @@ -22,6 +22,7 @@ from maxdiffusion import max_utils, max_logging from contextlib import contextmanager + def get_first_step(state): return int(state.step) @@ -200,16 +201,16 @@ def generate_timestep_weights(config, num_timesteps): @contextmanager def transformer_engine_context(): - """ If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation. """ + """If TransformerEngine is available, this context manager will provide the library with MaxDiffusion-specific details needed for correcct operation.""" try: from transformer_engine.jax.sharding import global_shard_guard, MeshResource # Inform TransformerEngine of MaxDiffusion's physical mesh resources. mesh_resource = MeshResource( - dp_resource = "data", - tp_resource = "tensor", - fsdp_resource = "fsdp", - pp_resource = None, - cp_resource = None, + dp_resource="data", + tp_resource="tensor", + fsdp_resource="fsdp", + pp_resource=None, + cp_resource=None, ) with global_shard_guard(mesh_resource): yield diff --git a/src/maxdiffusion/train_wan.py b/src/maxdiffusion/train_wan.py index 2fbb069d3..3b45ea8fd 100644 --- a/src/maxdiffusion/train_wan.py +++ b/src/maxdiffusion/train_wan.py @@ -35,7 +35,7 @@ def main(argv: Sequence[str]) -> None: config = pyconfig.config validate_train_config(config) max_logging.log(f"Found {jax.device_count()} devices.") - flax.config.update('flax_always_shard_variable', False) + flax.config.update("flax_always_shard_variable", False) train(config) diff --git a/src/maxdiffusion/trainers/wan_trainer.py b/src/maxdiffusion/trainers/wan_trainer.py index 1265223a7..d6a0cc803 100644 --- a/src/maxdiffusion/trainers/wan_trainer.py +++ b/src/maxdiffusion/trainers/wan_trainer.py @@ -243,9 +243,7 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr try: eval_start_time = datetime.datetime.now() eval_batch = load_next_batch(eval_data_iterator, None, self.config) - with mesh, nn_partitioning.axis_rules( - self.config.logical_axis_rules - ): + with mesh, nn_partitioning.axis_rules(self.config.logical_axis_rules): metrics, eval_rng = p_eval_step(state, eval_batch, eval_rng, scheduler_state) metrics["scalar"]["learning/eval_loss"].block_until_ready() losses = metrics["scalar"]["learning/eval_loss"] @@ -258,7 +256,7 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr for t, l in zip(gathered_timesteps.flatten(), gathered_losses.flatten()): timestep = int(t) if timestep not in eval_losses_by_timestep: - eval_losses_by_timestep[timestep] = [] + eval_losses_by_timestep[timestep] = [] eval_losses_by_timestep[timestep].append(l) eval_end_time = datetime.datetime.now() eval_duration = eval_end_time - eval_start_time @@ -272,11 +270,11 @@ def eval(self, mesh, eval_rng_key, step, p_eval_step, state, scheduler_state, wr if jax.process_index() == 0: max_logging.log(f"Step {step}, calculating mean loss per timestep...") for timestep, losses in sorted(eval_losses_by_timestep.items()): - losses = jnp.array(losses) - losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] - mean_loss = jnp.mean(losses) - max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") - mean_per_timestep.append(mean_loss) + losses = jnp.array(losses) + losses = losses[: min(self.config.eval_max_number_of_samples_in_bucket, len(losses))] + mean_loss = jnp.mean(losses) + max_logging.log(f" Mean eval loss for timestep {timestep}: {mean_loss:.4f}") + mean_per_timestep.append(mean_loss) final_eval_loss = jnp.mean(jnp.array(mean_per_timestep)) max_logging.log(f"Step {step}, Final Average Eval loss: {final_eval_loss:.4f}") if writer: @@ -480,7 +478,7 @@ def loss_fn(params, latents, encoder_hidden_states, timesteps, rng): for i in range(0, bs, single_batch_size): start = i end = min(i + single_batch_size, bs) - latents= data["latents"][start:end, :].astype(config.weights_dtype) + latents = data["latents"][start:end, :].astype(config.weights_dtype) encoder_hidden_states = data["encoder_hidden_states"][start:end, :].astype(config.weights_dtype) timesteps = data["timesteps"][start:end].astype("int64") _, new_rng = jax.random.split(rng, num=2)