Skip to content

Commit c2574ab

Browse files
Merge pull request #2778 from AI-Hypercomputer:nicogrande/maxtext-vllm-rl-integration
PiperOrigin-RevId: 847084739
2 parents 4e927f5 + e0e5a25 commit c2574ab

11 files changed

Lines changed: 100 additions & 41 deletions

File tree

src/MaxText/configs/base.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -979,3 +979,9 @@ use_tokamax_gmm: false
979979
use_tokamax_splash: false
980980
# Setting this flag will use a non-pallas implementation.
981981
use_jax_splash: false
982+
983+
# vLLM Adapter Configurations
984+
# Path to the HuggingFace-style config directory for the adapter (e.g. src/MaxText/integration/vllm/maxtext_vllm_adapter)
985+
vllm_hf_config_path: ""
986+
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
987+
vllm_additional_config: {}

src/MaxText/configs/types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1373,6 +1373,8 @@ class VLLM(BaseModel):
13731373
kv_cache_buffer: int = Field(256, description="Buffer for KV cache.")
13741374
hbm_utilization_vllm: float = Field(0.72, description="Target HBM utilization for vLLM.")
13751375
swap_space_vllm_gb: int = Field(2, description="Swap space in GB for vLLM.")
1376+
vllm_additional_config: dict[str, Any] = Field(default_factory=dict, description="Additional vLLM config options.")
1377+
vllm_hf_config_path: str = Field("", description="Path to HuggingFace model config for MaxText model.")
13761378

13771379

13781380
class GRPO(BaseModel):
@@ -2160,6 +2162,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21602162
"tensor": self.ici_tensor_parallelism,
21612163
"tensor_transpose": self.ici_tensor_transpose_parallelism,
21622164
"tensor_sequence": self.ici_tensor_sequence_parallelism,
2165+
"model": self.ici_tensor_parallelism,
21632166
"expert": self.ici_expert_parallelism,
21642167
"autoregressive": self.ici_autoregressive_parallelism,
21652168
}
@@ -2176,6 +2179,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
21762179
"tensor": self.dcn_tensor_parallelism,
21772180
"tensor_transpose": self.dcn_tensor_transpose_parallelism,
21782181
"tensor_sequence": self.dcn_tensor_sequence_parallelism,
2182+
"model": self.dcn_tensor_parallelism,
21792183
"expert": self.dcn_expert_parallelism,
21802184
"autoregressive": self.dcn_autoregressive_parallelism,
21812185
}

src/MaxText/configs/vllm.yml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ enable_nnx: True
2020
skip_jax_distributed_system: True
2121
# Scanned layers are not supported with vLLM integration
2222
scan_layers: False
23+
# Set weight dtype to bfloat16 as is done in vLLM
24+
weight_dtype: bfloat16
2325

2426

2527
# -------------- Logical Axis Rules --------------
@@ -41,6 +43,7 @@ logical_axis_rules: [
4143
['activation_kv_batch_no_exp', []],
4244
['activation_kv_head_dim', ['model']],
4345
['activation_vocab', ['model']],
46+
['activation_embed', ['model']],
4447
['activation_exp', ['expert']],
4548
['decode_batch', ['expert']],
4649
['mlp', ['model']],
@@ -49,7 +52,10 @@ logical_axis_rules: [
4952
['heads', ['model']],
5053
['q_heads', ['model']],
5154
['kv_heads', ['model']],
55+
['kv_head_dim', []],
56+
['kv', []],
5257
['embed', ['expert']],
58+
['embed_no_exp', []],
5359
['q_lora', ['expert']],
5460
['kv_lora', ['expert']],
5561
['norm', ['model']],

src/MaxText/integration/tunix/tunix_adapter.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
self,
3838
base_model: Transformer,
3939
use_standalone_mappings: bool = True,
40+
use_no_op_mappings: bool = False,
4041
):
4142
super().__init__()
4243
self.base = base_model
@@ -45,6 +46,7 @@ def __init__(
4546
HF_MODEL_CONFIGS[self.base.config.model_name].to_dict(),
4647
use_standalone_mappings,
4748
)
49+
self.use_no_op_mappings = use_no_op_mappings
4850

4951
# ------------------------------------------------------------------ #
5052
# Tunix call signature
@@ -69,13 +71,25 @@ def __call__(
6971
return logits, None
7072

7173
def to_hf_mappings(self):
74+
if self.use_no_op_mappings:
75+
return {}
76+
7277
return self._vllm_weight_mapping.to_hf_mapping()
7378

7479
def to_hf_transpose_keys(self):
80+
if self.use_no_op_mappings:
81+
return {}
82+
7583
return self._vllm_weight_mapping.to_hf_transpose_keys()
7684

7785
def to_hf_hook_fns(self):
86+
if self.use_no_op_mappings:
87+
return {}
88+
7889
return self._vllm_weight_mapping.to_hf_hook_fns()
7990

8091
def lora_to_hf_mappings(self):
92+
if self.use_no_op_mappings:
93+
return {}
94+
8195
return self._vllm_weight_mapping.lora_to_hf_mappings()

src/MaxText/integration/tunix/utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,10 @@ def to_hf_hook_fns(self):
147147
return STANDALONE_VLLM_WEIGHT_MAPPING[self.model_name].to_hf_hook_fns()
148148

149149
model_family = self.model_name.split("-")[0]
150-
return VLLM_HOOK_FNS[model_family]()
150+
if model_family in VLLM_HOOK_FNS:
151+
return VLLM_HOOK_FNS[model_family]()
152+
else:
153+
return {}
151154

152155
def lora_to_hf_mappings(self):
153156
if self.use_standalone_mappings:

src/MaxText/integration/vllm/maxtext_vllm_adapter/adapter.py

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,8 @@
1616

1717
import jax
1818
import jax.numpy as jnp
19+
import os
1920

20-
from etils import epath
2121
from flax import nnx
2222
import flax.linen as nn
2323
from jax.sharding import Mesh
@@ -49,32 +49,23 @@ def generate_maxtext_config(vllm_config: VllmConfig) -> pyconfig.HyperParameters
4949
Raises:
5050
ValueError: If `hf_config_path` is not provided in the vLLM model config.
5151
"""
52-
53-
def _path_exists(path: str) -> bool:
54-
if not path:
55-
return False
56-
return epath.Path(path).exists()
57-
5852
if "maxtext_config" in vllm_config.additional_config:
5953
overrides = vllm_config.additional_config["maxtext_config"]
6054
else:
6155
overrides = {}
62-
load_path = None
63-
if _path_exists(vllm_config.load.download_dir):
64-
load_path = vllm_config.load.download_dir
65-
elif _path_exists(vllm_config.model.model):
66-
load_path = vllm_config.model.model
6756

68-
if load_path:
69-
overrides["load_parameters_path"] = load_path
70-
elif vllm_config.model.model:
71-
overrides["model_name"] = vllm_config.model.model
57+
if vllm_config.load_config.load_format == "dummy":
58+
if overrides.get("load_parameters_path") is not None:
59+
max_logging.log(
60+
"Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
61+
)
62+
overrides["load_parameters_path"] = None
7263

7364
if vllm_config.model_config.hf_config_path is None:
7465
raise ValueError("hf_config_path must be provided when using MaxTextForCausalLM.")
7566

7667
# Add base config path to positional args
77-
base_config_path = epath.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml"
68+
base_config_path = os.path.join(MAXTEXT_PKG_DIR, "configs", "vllm.yml")
7869
argv_list = ["", str(base_config_path)]
7970

8071
maxtext_config = pyconfig.initialize(argv_list, **overrides)
@@ -110,12 +101,6 @@ def __init__(self, vllm_config: VllmConfig, rng_key: jax.Array, mesh: Mesh) -> N
110101

111102
# Handle dummy weight loading during initialization
112103
if vllm_config.load_config.load_format == "dummy":
113-
if self.maxtext_config.load_parameters_path is not None:
114-
max_logging.log(
115-
"Warning: load_parameters_path is set when using dummy load format. Checkpoint loading will be skipped."
116-
)
117-
self.maxtext_config.load_parameters_path = None
118-
119104
with self.mesh:
120105
self.load_weights(rng_key)
121106

@@ -173,7 +158,7 @@ def __call__(
173158
hidden = jnp.squeeze(hidden, axis=0)
174159
logits = jnp.squeeze(logits, axis=0)
175160

176-
self.logits = logits # cache logits for compute_logits call
161+
self.logits = nnx.data(logits) # cache logits for compute_logits call
177162

178163
return kv_caches, hidden, aux_hidden_states
179164

@@ -199,9 +184,14 @@ def load_weights(self, rng_key: jax.Array) -> None:
199184
Args:
200185
rng_key: A JAX random key for model initialization.
201186
"""
202-
self.model, _ = model_creation_utils.create_nnx_model(
203-
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
204-
)
187+
if self.model is not None:
188+
return
189+
190+
with nn.logical_axis_rules(""):
191+
model, _ = model_creation_utils.create_nnx_model(
192+
self.maxtext_config, mesh=self.mesh, model_mode=self.model_mode, rng_key=rng_key
193+
)
194+
self.model = nnx.data(model)
205195

206196

207197
class MaxTextForCausalLM(nnx.Module):

src/MaxText/layers/attentions.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,7 @@ def __init__(
423423
# Module attribute names must match names previously passed to Linen for checkpointing
424424
self.KVCache_0 = (
425425
self.init_kv_caches(inputs_kv_shape=inputs_kv_shape)
426-
if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache
426+
if self.model_mode != MODEL_MODE_TRAIN and base_kv_cache and config.attention != "vllm_rpa"
427427
else None
428428
)
429429

@@ -909,7 +909,7 @@ def forward_serve_vllm(
909909
try:
910910
# pylint: disable=import-outside-toplevel
911911
# pytype: disable=import-error
912-
from tpu_inference.layers.jax.attention_interface import sharded_ragged_paged_attention as rpa_ops
912+
from tpu_inference.layers.common.attention_interface import sharded_ragged_paged_attention as rpa_ops
913913
except ImportError as e:
914914
raise ImportError(
915915
"vLLM RPA attention ops require the vllm-tpu package. Please install it with `pip install vllm-tpu`."
@@ -930,7 +930,8 @@ def forward_serve_vllm(
930930

931931
md = rpa_metadata
932932

933-
output, kv_cache = rpa_ops(1.0, self.mesh, attention_chunk_size, q_scale, k_scale, v_scale)(
933+
output, kv_cache = rpa_ops(
934+
self.mesh,
934935
query,
935936
key,
936937
value,
@@ -939,6 +940,12 @@ def forward_serve_vllm(
939940
md.block_tables,
940941
md.query_start_loc,
941942
md.request_distribution,
943+
None,
944+
1.0,
945+
attention_chunk_size,
946+
q_scale,
947+
k_scale,
948+
v_scale,
942949
)
943950
return kv_cache, output
944951

src/MaxText/model_creation_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -78,9 +78,9 @@ def from_config(
7878
Example:
7979
model = from_config(config)
8080
"""
81-
devices_array = maxtext_utils.create_device_mesh(config, devices)
82-
8381
if mesh is None:
82+
devices_array = maxtext_utils.create_device_mesh(config, devices)
83+
8484
if config.shard_mode == ShardMode.EXPLICIT:
8585
axis_types = tuple([AxisType.Explicit] * len(config.mesh_axes))
8686
else:
@@ -154,7 +154,7 @@ def create_sharded_state():
154154
model = _create_model_partial()
155155
return nnx.state(model)
156156

157-
with jax.set_mesh(mesh):
157+
with mesh:
158158
# Create the model with sharded parameters.
159159
with nn.logical_axis_rules(config.logical_axis_rules):
160160
sharded_state = create_sharded_state()

src/MaxText/rl/evaluate_rl.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,13 +121,16 @@ def score_responses(tmvp_config, question, responses, answer):
121121

122122
# Check exact correctness
123123
try:
124-
if float(extracted_response.strip()) == float(answer.strip()):
125-
is_correct = True
124+
# Remove ',' and '$' then convert to float
125+
val_extracted = float(extracted_response.replace(",", "").replace("$", "").strip())
126+
val_answer = float(answer.replace(",", "").replace("$", "").strip())
127+
is_correct = val_extracted == val_answer
126128

127129
# Check partial correctness (within 10%)
128-
ratio = float(extracted_response.strip()) / float(answer.strip())
130+
ratio = val_extracted / val_answer
129131
if 0.9 <= ratio <= 1.1:
130132
is_partially_correct = True
133+
131134
except Exception as e:
132135
if tmvp_config.debug["rl"]:
133136
max_logging.log(f"Evaluation Exception: {e}")

src/MaxText/rl/train_rl.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
import collections
4949
import grain
5050
import jax
51+
import json
5152
import os
5253
import pathwaysutils
5354
import tensorflow_datasets as tfds
@@ -70,6 +71,7 @@
7071

7172
from MaxText import max_logging, max_utils, maxtext_utils, pyconfig
7273
from MaxText import model_creation_utils
74+
from MaxText.globals import MAXTEXT_PKG_DIR
7375
from MaxText.integration.tunix.tunix_adapter import TunixMaxTextAdapter
7476
from MaxText.rl.evaluate_rl import evaluate
7577
from MaxText.rl import utils_rl
@@ -93,7 +95,8 @@ def get_maxtext_model(config, devices=None):
9395
"""
9496
model, mesh = model_creation_utils.create_nnx_model(config, devices=devices)
9597
with jax.set_mesh(mesh):
96-
tunix_model = TunixMaxTextAdapter(base_model=model)
98+
use_no_op_mappings = "maxtext_config" in config.vllm_additional_config
99+
tunix_model = TunixMaxTextAdapter(base_model=model, use_no_op_mappings=use_no_op_mappings)
97100
tunix_model.config = None
98101
return tunix_model, mesh
99102

@@ -312,7 +315,7 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
312315
maxtext_state_flatten = {".".join(str(key) for key in keys): v for keys, v in _maxtext_state_flatten}
313316
max_logging.log(
314317
f"maxtext_state_flatten[base.token_embedder.embedding].value=\
315-
{maxtext_state_flatten['base.token_embedder.embedding'].value}"
318+
{maxtext_state_flatten['base.token_embedder.embedding'][...]}"
316319
)
317320

318321
# TODO: @mazumdera: change this to use lora
@@ -352,6 +355,21 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
352355
set_profile_options=False,
353356
)
354357

358+
# Parse vllm_additional_config
359+
rollout_additional_config = None
360+
if trainer_config.vllm_additional_config:
361+
if isinstance(trainer_config.vllm_additional_config, dict):
362+
# It's already parsed into a dict
363+
rollout_additional_config = trainer_config.vllm_additional_config
364+
elif isinstance(trainer_config.vllm_additional_config, str):
365+
# It's a string, so we need to parse it
366+
try:
367+
rollout_additional_config = json.loads(trainer_config.vllm_additional_config)
368+
except json.JSONDecodeError as e:
369+
raise ValueError(f"Failed to parse additional_config JSON: {e}") from e
370+
371+
max_logging.log(f"Parsed additional config: {rollout_additional_config}")
372+
355373
# RL Cluster config
356374
# Note that we use vLLM as the rollout engine.
357375
# and we are using Tensor Parallelism for rollout
@@ -394,6 +412,9 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
394412
rollout_vllm_hbm_utilization=trainer_config.hbm_utilization_vllm,
395413
rollout_vllm_tpu_backend_type="jax",
396414
rollout_vllm_swap_space_size_gb=trainer_config.swap_space_vllm_gb,
415+
rollout_vllm_hf_config_path=trainer_config.vllm_hf_config_path,
416+
rollout_vllm_additional_config=rollout_additional_config,
417+
rollout_vllm_init_with_random_weights=True,
397418
**get_rollout_kwargs_for_data_parallelism(sampler_config, len(sampler_devices)),
398419
),
399420
)
@@ -423,7 +444,12 @@ def rl_train(trainer_config, sampler_config, trainer_devices, sampler_devices):
423444
max_logging.log(
424445
"enable_tunix_perf_metrics is True but tunix.perf modules are not available, skipping Tunix-managed metrics."
425446
)
426-
with nn_partitioning.axis_rules(trainer_config.logical_axis_rules):
447+
448+
vllm_config_path = epath.Path(MAXTEXT_PKG_DIR) / "configs" / "vllm.yml"
449+
argv_list = ["", str(vllm_config_path), "log_config=False"]
450+
vllm_config = pyconfig.initialize(argv_list)
451+
452+
with nn_partitioning.axis_rules(vllm_config.logical_axis_rules):
427453
rl_cluster = rl_cluster_lib.RLCluster(
428454
actor=actor_model,
429455
reference=reference_model,

0 commit comments

Comments
 (0)