Skip to content

Commit e0e5a25

Browse files
gagikaNicoGrande
authored andcommitted
Support Custom MaxText model (with vLLM engine) in RL rollouts.
Fix formatting. Refactor model creation and error handling in RL training fix linting. adding no-op mappings to tunix adapter. removing kvcache init for vllm case. latest updates from debugging. adding null logical axis rules to adapter. adding linting fixes. fixing pyink remove unused imports attentions test. adding fixes. addressing comments in evaluate rl. set weight dtype to bf16 by default. removing unecessary logical axis rules. removing epath. removing deprecated .value call
1 parent 28d570a commit e0e5a25

11 files changed

Lines changed: 102 additions & 47 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,3 +979,9 @@ use_tokamax_gmm: false
979979
use_tokamax_splash: false
980980
# Setting this flag will use a non-pallas implementation.
981981
use_jax_splash: false
982+
983+
# vLLM Adapter Configurations
984+
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
985+
vllm_hf_config_path: ""
986+
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
987+
vllm_additional_config: {}

src/MaxText/configs/types.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -279,9 +279,7 @@ class Checkpointing(BaseModel):
279279
save_checkpoint_on_completion: bool = Field(
280280
True, description="If True, saves a final checkpoint upon training completion."
281281
)
282-
enable_continuous_checkpointing: bool = Field(
283-
False, description="If True, enables continuous checkpointing."
284-
)
282+
enable_continuous_checkpointing: bool = Field(False, description="If True, enables continuous checkpointing.")
285283

286284

287285
class OrbaxStorage(BaseModel):
@@ -463,9 +461,7 @@ class Attention(BaseModel):
463461
ragged_block_size: int = Field(256, description="Block size for ragged attention.")
464462
enable_padding_causal_mask: bool = Field(True, description="Temporary flag for TE padding.")
465463
use_tokamax_splash: bool = Field(False, description="Whether to use tokamax splash attention.")
466-
use_jax_splash: bool = Field(
467-
False, description="Whether to use jax splash attention."
468-
)
464+
use_jax_splash: bool = Field(False, description="Whether to use jax splash attention.")
469465

470466

471467
class MoBa(BaseModel):
@@ -1376,6 +1372,8 @@ class VLLM(BaseModel):
13761372
kv_cache_buffer: int = Field(256, description="Buffer for KV cache.")
13771373
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
13781374
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
1375+
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
1376+
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
13791377

13801378

13811379
class GRPO(BaseModel):
@@ -2163,6 +2161,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21632161
"tensor": self.ici_tensor_parallelism,
21642162
"tensor_transpose": self.ici_tensor_transpose_parallelism,
21652163
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2164+
"model": self.ici_tensor_parallelism,
21662165
"expert": self.ici_expert_parallelism,
21672166
"autoregressive": self.ici_autoregressive_parallelism,
21682167
}
@@ -2179,6 +2178,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21792178
"tensor": self.dcn_tensor_parallelism,
21802179
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
21812180
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2181+
"model": self.dcn_tensor_parallelism,
21822182
"expert": self.dcn_expert_parallelism,
21832183
"autoregressive": self.dcn_autoregressive_parallelism,
21842184
}

src/MaxText/configs/vllm.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ enable_nnx: True
2020
skip_jax_distributed_system: True
2121
# Scanned layers are not supported with vLLM integration
2222
scan_layers: False
23+
# Set weight dtype to bfloat16 as is done in vLLM
24+
weight_dtype: bfloat16
2325

2426

2527
# -------------- Logical Axis Rules --------------
@@ -41,6 +43,7 @@ logical_axis_rules: [
4143
['activation_kv_batch_no_exp', []],
4244
['activation_kv_head_dim', ['model']],
4345
['activation_vocab', ['model']],
46+
['activation_embed', ['model']],
4447
['activation_exp', ['expert']],
4548
['decode_batch', ['expert']],
4649
['mlp', ['model']],
@@ -49,7 +52,10 @@ logical_axis_rules: [
4952
['heads', ['model']],
5053
['q_heads', ['model']],
5154
['kv_heads', ['model']],
55+
['kv_head_dim', []],
56+
['kv', []],
5257
['embed', ['expert']],
58+
['embed_no_exp', []],
5359
['q_lora', ['expert']],
5460
['kv_lora', ['expert']],
5561
['norm', ['model']],

src/MaxText/integration/tunix/tunix_adapter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
self,
3838
base_model: Transformer,
3939
use_standalone_mappings: bool = True,
40+
use_no_op_mappings: bool = False,
4041
):
4142
super().__init__()
4243
self.base = base_model
@@ -45,6 +46,7 @@ def __init__(
4546
HF_MODEL_CONFIGS[self.base.config.model_name].to_dict(),
4647
use_standalone_mappings,
4748
)
49+
self.use_no_op_mappings = use_no_op_mappings
4850

4951
# ------------------------------------------------------------------ #
5052
# Tunix call signature
@@ -69,13 +71,25 @@ def __call__(
6971
return logits, None
7072

7173
def to_hf_mappings(self):
74+
if self.use_no_op_mappings:
75+
return {}
76+
7277
return self._vllm_weight_mapping.to_hf_mapping()
7378

7479
def to_hf_transpose_keys(self):
80+
if self.use_no_op_mappings:
81+
return {}
82+
7583
return self._vllm_weight_mapping.to_hf_transpose_keys()
7684

7785
def to_hf_hook_fns(self):
86+
if self.use_no_op_mappings:
87+
return {}
88+
7889
return self._vllm_weight_mapping.to_hf_hook_fns()
7990

8091
def lora_to_hf_mappings(self):
92+
if self.use_no_op_mappings:
93+
return {}
94+
8195
return self._vllm_weight_mapping.lora_to_hf_mappings()

src/MaxText/integration/tunix/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ def to_hf_hook_fns(self):
147147
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_hook_fns()
148148

149149
model_family = self.model_name.split("-")[0]
150-
return VLLM_HOOK_FNS[model_family]()
150+
if model_family in VLLM_HOOK_FNS:
151+
return VLLM_HOOK_FNS[model_family]()
152+
else:
153+
return {}
151154

152155
def lora_to_hf_mappings(self):
153156
if self.use_standalone_mappings:

src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
import jax
1818
import jax.numpy as jnp
19+
import os
1920

20-
from etils import epath
2121
from flax import nnx
2222
import flax.linen as nn
2323
from jax.sharding import Mesh
@@ -49,32 +49,23 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
4949
Raises:
5050
ValueError: If `hf_config_path` is not provided in the vLLM model config.
5151
"""
52-
53-
def _path_exists(path: str) -> bool:
54-
if not path:
55-
return False
56-
return epath.Path(path).exists()
57-
5852
if "maxtext_config" in vllm_config.additional_config:
5953
overrides = vllm_config.additional_config["maxtext_config"]
6054
else:
6155
overrides = {}
62-
load_path = None
63-
if _path_exists(vllm_config.load.download_dir):
64-
load_path = vllm_config.load.download_dir
65-
elif _path_exists(vllm_config.model.model):
66-
load_path = vllm_config.model.model
6756

68-
if load_path:
69-
overrides["load_parameters_path"] = load_path
70-
elif vllm_config.model.model:
71-
overrides["model_name"] = vllm_config.model.model
57+
if vllm_config.load_config.load_format == "dummy":
58+
if overrides.get("load_parameters_path") is not None:
59+
max_logging.log(
60+
"Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
61+
)
62+
overrides["load_parameters_path"] = None
7263

7364
if vllm_config.model_config.hf_config_path is None:
7465
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")
7566

7667
# Add base config path to positional args
77-
base_config_path = epath.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml"
68+
base_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml")
7869
argv_list = ["", str(base_config_path)]
7970

8071
maxtext_config = pyconfig.initialize(argv_list, **overrides)
@@ -110,12 +101,6 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> N
110101

111102
# Handle dummy weight loading during initialization
112103
if vllm_config.load_config.load_format == "dummy":
113-
if self.maxtext_config.load_parameters_path is not None:
114-
max_logging.log(
115-
"Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
116-
)
117-
self.maxtext_config.load_parameters_path = None
118-
119104
with self.mesh:
120105
self.load_weights(rng_key)
121106

@@ -173,7 +158,7 @@ def __call__(
173158
hidden = jnp.squeeze(hidden, axis=0)
174159
logits = jnp.squeeze(logits, axis=0)
175160

176-
self.logits = logits # cache logits for compute_logits call
161+
self.logits = nnx.data(logits) # cache logits for compute_logits call
177162

178163
return kv_caches, hidden, aux_hidden_states
179164

@@ -199,9 +184,14 @@ def load_weights(self, rng_key: jax.Array) -> None:
199184
Args:
200185
rng_key: A JAX random key for model initialization.
201186
"""
202-
self.model, _ = model_creation_utils.create_nnx_model(
203-
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
204-
)
187+
if self.model is not None:
188+
return
189+
190+
with nn.logical_axis_rules(""):
191+
model, _ = model_creation_utils.create_nnx_model(
192+
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
193+
)
194+
self.model = nnx.data(model)
205195

206196

207197
class MaxTextForCausalLM(nnx.Module):

src/MaxText/layers/attentions.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def __init__(
423423
# Module attribute names must match names previously passed to Linen for checkpointing
424424
self.KVCache_0 = (
425425
self.init_kv_caches(inputs_kv_shape=inputs_kv_shape)
426-
if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache
426+
if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache and config.attention != "vllm_rpa"
427427
else None
428428
)
429429

@@ -909,7 +909,7 @@ def forward_serve_vllm(
909909
try:
910910
# pylint: disable=import-outside-toplevel
911911
# pytype: disable=import-error
912-
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
912+
from tpu_inference.layers.common.attention_interface import sharded_ragged_paged_attention as rpa_ops
913913
except ImportError as e:
914914
raise ImportError(
915915
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
@@ -930,7 +930,8 @@ def forward_serve_vllm(
930930

931931
md = rpa_metadata
932932

933-
output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
933+
output, kv_cache = rpa_ops(
934+
self.mesh,
934935
query,
935936
key,
936937
value,
@@ -939,6 +940,12 @@ def forward_serve_vllm(
939940
md.block_tables,
940941
md.query_start_loc,
941942
md.request_distribution,
943+
None,
944+
1.0,
945+
attention_chunk_size,
946+
q_scale,
947+
k_scale,
948+
v_scale,
942949
)
943950
return kv_cache, output
944951

src/MaxText/model_creation_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ def from_config(
7878
Example:
7979
model = from_config(config)
8080
"""
81-
devices_array = maxtext_utils.create_device_mesh(config, devices)
82-
8381
if mesh is None:
82+
devices_array = maxtext_utils.create_device_mesh(config, devices)
83+
8484
if config.shard_mode == ShardMode.EXPLICIT:
8585
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
8686
else:
@@ -154,7 +154,7 @@ def create_sharded_state():
154154
model = _create_model_partial()
155155
return nnx.state(model)
156156

157-
with jax.set_mesh(mesh):
157+
with mesh:
158158
# Create the model with sharded parameters.
159159
with nn.logical_axis_rules(config.logical_axis_rules):
160160
sharded_state = create_sharded_state()

src/MaxText/rl/evaluate_rl.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,16 @@ def score_responses(tmvp_config, question, responses, answer):
121121

122122
# Check exact correctness
123123
try:
124-
if float(extracted_response.strip()) == float(answer.strip()):
125-
is_correct = True
124+
# Remove ',' and '$' then convert to float
125+
val_extracted = float(extracted_response.replace(",", "").replace("$", "").strip())
126+
val_answer = float(answer.replace(",", "").replace("$", "").strip())
127+
is_correct = val_extracted == val_answer
126128

127129
# Check partial correctness (within 10%)
128-
ratio = float(extracted_response.strip()) / float(answer.strip())
130+
ratio = val_extracted / val_answer
129131
if 0.9 <= ratio <= 1.1:
130132
is_partially_correct = True
133+
131134
except Exception as e:
132135
if tmvp_config.debug["rl"]:
133136
max_logging.log(f"Evaluation Exception: {e}")

src/MaxText/rl/train_rl.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import collections
4949
import grain
5050
import jax
51+
import json
5152
import os
5253
import pathwaysutils
5354
import tensorflow_datasets as tfds
@@ -70,6 +71,7 @@
7071

7172
from MaxText import max_logging, max_utils, maxtext_utils, pyconfig
7273
from MaxText import model_creation_utils
74+
from MaxText.globals import MAXTEXT_PKG_DIR
7375
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
7476
from MaxText.rl.evaluate_rl import evaluate
7577
from MaxText.rl import utils_rl
@@ -93,7 +95,8 @@ def get_maxtext_model(config, devices=None):
9395
"""
9496
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
9597
with jax.set_mesh(mesh):
96-
tunix_model = TunixMaxTextAdapter(base_model=model)
98+
use_no_op_mappings = "maxtext_config" in config.vllm_additional_config
99+
tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings)
97100
tunix_model.config = None
98101
return tunix_model, mesh
99102

@@ -312,7 +315,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
312315
maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}
313316
max_logging.log(
314317
f"maxtext_state_flatten[base.token_embedder.embedding].value=\
315-
{maxtext_state_flatten['base.token_embedder.embedding'].value}"
318+
{maxtext_state_flatten['base.token_embedder.embedding'][...]}"
316319
)
317320

318321
# TODO: @mazumdera: change this to use lora
@@ -352,6 +355,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
352355
set_profile_options=False,
353356
)
354357

358+
# Parse vllm_additional_config
359+
rollout_additional_config = None
360+
if trainer_config.vllm_additional_config:
361+
if isinstance(trainer_config.vllm_additional_config, dict):
362+
# It's already parsed into a dict
363+
rollout_additional_config = trainer_config.vllm_additional_config
364+
elif isinstance(trainer_config.vllm_additional_config, str):
365+
# It's a string, so we need to parse it
366+
try:
367+
rollout_additional_config = json.loads(trainer_config.vllm_additional_config)
368+
except json.JSONDecodeError as e:
369+
raise ValueError(f"Failed to parse additional_config JSON: {e}") from e
370+
371+
max_logging.log(f"Parsed additional config: {rollout_additional_config}")
372+
355373
# RL Cluster config
356374
# Note that we use vLLM as the rollout engine.
357375
# and we are using Tensor Parallelism for rollout
@@ -394,6 +412,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
394412
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
395413
rollout_vllm_tpu_backend_type="jax",
396414
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
415+
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
416+
rollout_vllm_additional_config=rollout_additional_config,
417+
rollout_vllm_init_with_random_weights=True,
397418
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
398419
),
399420
)
@@ -423,7 +444,12 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
423444
max_logging.log(
424445
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
425446
)
426-
with nn_partitioning.axis_rules(trainer_config.logical_axis_rules):
447+
448+
vllm_config_path = epath.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml"
449+
argv_list = ["", str(vllm_config_path), "log_config=False"]
450+
vllm_config = pyconfig.initialize(argv_list)
451+
452+
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
427453
rl_cluster = rl_cluster_lib.RLCluster(
428454
actor=actor_model,
429455
reference=reference_model,

0 commit comments

Comments
 (0)