4848import collections
4949import grain
5050import jax
51+ import json
5152import os
5253import pathwaysutils
5354import tensorflow_datasets as tfds
7071
7172from MaxText import max_logging , max_utils , maxtext_utils , pyconfig
7273from MaxText import model_creation_utils
74+ from MaxText .globals import MAXTEXT_PKG_DIR
7375from MaxText .integration .tunix .tunix_adapter import TunixMaxTextAdapter
7476from MaxText .rl .evaluate_rl import evaluate
7577from MaxText .rl import utils_rl
@@ -93,7 +95,8 @@ def get_maxtext_model(config, devices=None):
9395 """
9496 model , mesh = model_creation_utils .create_nnx_model (config , devices = devices )
9597 with jax .set_mesh (mesh ):
96- tunix_model = TunixMaxTextAdapter (base_model = model )
98+ use_no_op_mappings = "maxtext_config" in config .vllm_additional_config
99+ tunix_model = TunixMaxTextAdapter (base_model = model , use_no_op_mappings = use_no_op_mappings )
97100 tunix_model .config = None
98101 return tunix_model , mesh
99102
@@ -312,7 +315,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
312315 maxtext_state_flatten = {"." .join (str (key ) for key in keys ): v for keys , v in _maxtext_state_flatten }
313316 max_logging .log (
314317 f"maxtext_state_flatten[base.token_embedder.embedding].value=\
315- { maxtext_state_flatten ['base.token_embedder.embedding' ]. value } "
318+ { maxtext_state_flatten ['base.token_embedder.embedding' ][...] } "
316319 )
317320
318321 # TODO: @mazumdera: change this to use lora
@@ -352,6 +355,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
352355 set_profile_options = False ,
353356 )
354357
358+ # Parse vllm_additional_config
359+ rollout_additional_config = None
360+ if trainer_config .vllm_additional_config :
361+ if isinstance (trainer_config .vllm_additional_config , dict ):
362+ # It's already parsed into a dict
363+ rollout_additional_config = trainer_config .vllm_additional_config
364+ elif isinstance (trainer_config .vllm_additional_config , str ):
365+ # It's a string, so we need to parse it
366+ try :
367+ rollout_additional_config = json .loads (trainer_config .vllm_additional_config )
368+ except json .JSONDecodeError as e :
369+ raise ValueError (f"Failed to parse additional_config JSON: { e } " ) from e
370+
371+ max_logging .log (f"Parsed additional config: { rollout_additional_config } " )
372+
355373 # RL Cluster config
356374 # Note that we use vLLM as the rollout engine.
357375 # and we are using Tensor Parallelism for rollout
@@ -394,6 +412,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
394412 rollout_vllm_hbm_utilization = trainer_config .hbm_utilization_vllm ,
395413 rollout_vllm_tpu_backend_type = "jax" ,
396414 rollout_vllm_swap_space_size_gb = trainer_config .swap_space_vllm_gb ,
415+ rollout_vllm_hf_config_path = trainer_config .vllm_hf_config_path ,
416+ rollout_vllm_additional_config = rollout_additional_config ,
417+ rollout_vllm_init_with_random_weights = True ,
397418 ** get_rollout_kwargs_for_data_parallelism (sampler_config , len (sampler_devices )),
398419 ),
399420 )
@@ -423,7 +444,12 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
423444 max_logging .log (
424445 "enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
425446 )
426- with nn_partitioning .axis_rules (trainer_config .logical_axis_rules ):
447+
448+ vllm_config_path = epath .Path (MAXTEXT_PKG_DIR ) / "configs" / "vllm.yml"
449+ argv_list = ["" , str (vllm_config_path ), "log_config=False" ]
450+ vllm_config = pyconfig .initialize (argv_list )
451+
452+ with nn_partitioning .axis_rules (vllm_config .logical_axis_rules ):
427453 rl_cluster = rl_cluster_lib .RLCluster (
428454 actor = actor_model ,
429455 reference = reference_model ,
0 commit comments