Skip to content

Commit 7f0028b

Browse files
Merge pull request #3203 from AI-Hypercomputer:hengtaoguo-re
PiperOrigin-RevId: 874230320
2 parents d8c8862 + 6061c9d commit 7f0028b

82 files changed

Lines changed: 315 additions & 205 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

.vscode/launch.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"console": "integratedTerminal",
99
"justMyCode": false,
1010
"python": "python3",
11-
"module": "maxtext.decode",
11+
"module": "maxtext.inference.decode",
1212
"args": ["src/maxtext/configs/base.yml",
1313
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
1414
"base_output_directory=gs://test-maxtext-output",
@@ -35,7 +35,7 @@
3535
"console": "integratedTerminal",
3636
"justMyCode": false,
3737
"python": "python3",
38-
"module": "maxtext.decode",
38+
"module": "maxtext.inference.decode",
3939
"args": ["src/maxtext/configs/base.yml",
4040
"run_name=runner_$(date +%Y-%m-%d-%H-%M)",
4141
"base_output_directory=gs://test-maxtext-output",

benchmarks/api_server/maxtext_generator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@
3434

3535
from dataclasses import dataclass, field
3636

37-
from MaxText import maxengine, pyconfig
37+
from MaxText import pyconfig
38+
from maxtext.inference.maxengine import maxengine
3839
from maxtext.multimodal import processor as mm_processor
3940
from maxtext.multimodal import utils as mm_utils
4041
from maxtext.utils import max_logging, max_utils

benchmarks/mmlu/mmlu_eval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
from tqdm import tqdm
5858

5959
from MaxText import pyconfig
60-
from MaxText import maxengine
60+
from maxtext.inference.maxengine import maxengine
6161
from maxtext.utils import max_logging
6262
from maxtext.utils import max_utils
6363

docs/run_maxtext/run_maxtext_localhost.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
7272
To demonstrate model output, run the following command:
7373

7474
```bash
75-
python3 -m maxtext.decode src/maxtext/configs/base.yml \
75+
python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
7676
run_name=$YOUR_JOB_NAME \
7777
base_output_directory=gs://<my-bucket> \
7878
per_device_batch_size=1

docs/tutorials/first_run.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ Optional: If you want to try training on a Hugging Face dataset, see [Data Input
6161
5. To demonstrate model output, run the following command:
6262

6363
```sh
64-
python3 -m maxtext.decode src/maxtext/configs/base.yml \
64+
python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
6565
run_name=$YOUR_JOB_NAME \
6666
base_output_directory=gs://<my-bucket> \
6767
per_device_batch_size=1
@@ -93,7 +93,7 @@ python3 -m maxtext.trainers.pre_train.train src/maxtext/configs/base.yml \
9393
3. To demonstrate model output, run the following command:
9494

9595
```sh
96-
python3 -m maxtext.decode src/maxtext/configs/base.yml \
96+
python3 -m maxtext.inference.decode src/maxtext/configs/base.yml \
9797
run_name=$YOUR_JOB_NAME \
9898
base_output_directory=gs://<my-bucket> \
9999
per_device_batch_size=1

docs/tutorials/posttraining/multimodal.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ To run a forward pass and verify the model's output, use the following command:
7272

7373
```shell
7474
# Gemma3 decode
75-
python -m maxtext.decode \
75+
python -m maxtext.inference.decode \
7676
maxtext/configs/base.yml \
7777
model_name=gemma3-4b \
7878
hf_access_token=$HF_ACCESS_TOKEN \
@@ -108,7 +108,7 @@ To decode with multiple images at once, you can provide multiple image paths lik
108108
export TARGET_LENGTH=... # Adjust to fit expected output length
109109
export PREDICT_LENGTH=... # Adjust to fit image tokens + text prompt
110110
111-
python -m maxtext.decode \
111+
python -m maxtext.inference.decode \
112112
maxtext/configs/base.yml \
113113
model_name=gemma3-4b \
114114
... \

src/MaxText/decode.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2023–2026 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Shim for inference decode in `src/maxtext/inference/decode`."""
16+
17+
import sys
18+
import importlib
19+
20+
from absl import logging
21+
22+
from maxtext.utils import max_logging
23+
24+
OLD_MODULE_PATH = "MaxText.decode"
25+
NEW_MODULE_PATH = "maxtext.inference.decode"
26+
27+
if __name__ == "__main__":
28+
try:
29+
logging.set_verbosity(logging.INFO)
30+
_new_module = importlib.import_module(NEW_MODULE_PATH)
31+
if hasattr(_new_module, "main"):
32+
max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n")
33+
_new_module.main(sys.argv)
34+
except ImportError as e:
35+
max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n")
36+
raise e

src/MaxText/maxengine_server.py

Lines changed: 19 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright 2023–2025 Google LLC
1+
# Copyright 2023–2026 Google LLC
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -12,94 +12,31 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Runs a server with maxtext."""
16-
17-
from __future__ import annotations
15+
"""Shim for maxengine_server in `src/maxtext/inference/maxengine/maxengine_server`."""
1816

1917
import os
2018
import sys
21-
from typing import Any
19+
import importlib
2220

2321
import jax
22+
from absl import logging
2423

2524
from MaxText import pyconfig
26-
from MaxText import maxengine_config
27-
from maxtext.common import gcloud_stub
28-
29-
# _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on')
30-
# _THREADS = flags.DEFINE_integer(
31-
# 'threads', 64, 'number of worker threads in thread pool'
32-
# )
33-
# _CONFIG = flags.DEFINE_string(
34-
# 'config',
35-
# 'MaxtextInterleavedServer',
36-
# 'available servers',
37-
# )
38-
39-
40-
def _create_prefix_caching_config(config, config_lib_module):
41-
if not config.enable_prefix_caching:
42-
return None
43-
44-
if not config.use_chunked_prefill:
45-
raise ValueError("Prefix caching requires chunked prefill.")
46-
47-
return config_lib_module.PrefixCachingConfig(
48-
max_hbm_byte=config.prefix_caching_hbm_byte,
49-
max_dram_byte=config.prefix_caching_dram_byte,
50-
)
51-
52-
53-
def main(config):
54-
# Obtain the jetstream helper modules (or stubs if appropriate).
55-
config_lib, _engine_api, *_ = gcloud_stub.jetstream()
56-
57-
# If running decoupled and gcloud_stub returned lightweight stubs, skip
58-
# starting the real server. Use the explicit _IS_STUB marker when present.
59-
config_lib_is_stub = getattr(config_lib, "_IS_STUB", False)
60-
engine_api_is_stub = getattr(_engine_api, "_IS_STUB", False)
61-
if gcloud_stub.is_decoupled() and (config_lib_is_stub or engine_api_is_stub):
62-
raise RuntimeError(
63-
"JetStream helper modules are stubbed or DECOUPLE_GCLOUD=TRUE; server cannot be started in decoupled mode. "
64-
"Unset DECOUPLE_GCLOUD or install JetStream to run the server."
65-
)
66-
67-
# Import the real server_lib now that it's known present.
68-
from jetstream.core import server_lib # type: ignore # pylint: disable=import-outside-toplevel
69-
import pathwaysutils # pylint: disable=unused-import,import-outside-toplevel
70-
71-
pathwaysutils.initialize()
72-
73-
# No devices for local cpu test. A None for prefill and a None for generate.
74-
devices = server_lib.get_devices()
75-
server_config = maxengine_config.get_server_config(config.inference_server, config)
76-
77-
metrics_server_config: Any | None = None
78-
if config.prometheus_port != 0:
79-
metrics_server_config = config_lib.MetricsServerConfig(port=config.prometheus_port)
80-
81-
# We separate credential from run so that we can unit test it with
82-
# local credentials.
83-
# TODO: Add grpc credentials for OSS.
84-
# pylint: disable=unexpected-keyword-arg
85-
jetstream_server = server_lib.run(
86-
threads=256,
87-
port=9000,
88-
config=server_config,
89-
devices=devices,
90-
metrics_server_config=metrics_server_config,
91-
enable_jax_profiler=config.enable_jax_profiler if config.enable_jax_profiler else False,
92-
jax_profiler_port=config.jax_profiler_port if config.jax_profiler_port else 9999,
93-
enable_model_warmup=config.enable_model_warmup if config.enable_model_warmup else False,
94-
lora_input_adapters_path=config.lora_input_adapters_path,
95-
multi_sampling=config.multi_sampling if config.multi_sampling else False,
96-
prefix_caching_config=_create_prefix_caching_config(config, config_lib),
97-
)
98-
jetstream_server.wait_for_termination()
25+
from maxtext.utils import max_logging
9926

27+
OLD_MODULE_PATH = "MaxText.maxengine_server"
28+
NEW_MODULE_PATH = "maxtext.inference.maxengine.maxengine_server"
10029

10130
if __name__ == "__main__":
102-
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
103-
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
104-
cfg = pyconfig.initialize(sys.argv)
105-
main(cfg)
31+
try:
32+
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
33+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
34+
logging.set_verbosity(logging.INFO)
35+
max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n")
36+
_new_module = importlib.import_module(NEW_MODULE_PATH)
37+
if hasattr(_new_module, "main"):
38+
cfg = pyconfig.initialize(sys.argv)
39+
_new_module.main(cfg)
40+
except ImportError as e:
41+
max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n")
42+
raise e

src/maxtext/checkpoint_conversion/examples/convert_gemma2_to_mt.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ echo "--- Checkpoint Conversion Complete ---"
3434
# --- Step 2 (Optional): Decode using the Converted Checkpoint ---
3535

3636
echo "--- Starting Decoding ---"
37-
python3 -m maxtext.decode \
37+
python3 -m maxtext.inference.decode \
3838
${MAXTEXT_CONFIGS_DIR:-${MAXTEXT_REPO_ROOT:-$PWD}/src/maxtext/configs}/base.yml \
3939
model_name="${MODEL_NAME}" \
4040
tokenizer_path="${TOKENIZER_PATH}" \

src/maxtext/checkpoint_conversion/load_and_quantize_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
import jax
2323

24-
from MaxText import maxengine
2524
from MaxText import pyconfig
25+
from maxtext.inference.maxengine import maxengine
2626
from maxtext.utils import max_utils
2727

2828

0 commit comments

Comments
 (0)