|
1 | | -# Copyright 2023–2025 Google LLC |
| 1 | +# Copyright 2023–2026 Google LLC |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -"""Runs a server with maxtext.""" |
16 | | - |
17 | | -from __future__ import annotations |
| 15 | +"""Shim for maxengine_server in `src/maxtext/inference/maxengine/maxengine_server`.""" |
18 | 16 |
|
19 | 17 | import os |
20 | 18 | import sys |
21 | | -from typing import Any |
| 19 | +import importlib |
22 | 20 |
|
23 | 21 | import jax |
| 22 | +from absl import logging |
24 | 23 |
|
25 | 24 | from MaxText import pyconfig |
26 | | -from MaxText import maxengine_config |
27 | | -from maxtext.common import gcloud_stub |
28 | | - |
29 | | -# _PORT = flags.DEFINE_integer('port', 9000, 'port to listen on') |
30 | | -# _THREADS = flags.DEFINE_integer( |
31 | | -# 'threads', 64, 'number of worker threads in thread pool' |
32 | | -# ) |
33 | | -# _CONFIG = flags.DEFINE_string( |
34 | | -# 'config', |
35 | | -# 'MaxtextInterleavedServer', |
36 | | -# 'available servers', |
37 | | -# ) |
38 | | - |
39 | | - |
40 | | -def _create_prefix_caching_config(config, config_lib_module): |
41 | | - if not config.enable_prefix_caching: |
42 | | - return None |
43 | | - |
44 | | - if not config.use_chunked_prefill: |
45 | | - raise ValueError("Prefix caching requires chunked prefill.") |
46 | | - |
47 | | - return config_lib_module.PrefixCachingConfig( |
48 | | - max_hbm_byte=config.prefix_caching_hbm_byte, |
49 | | - max_dram_byte=config.prefix_caching_dram_byte, |
50 | | - ) |
51 | | - |
52 | | - |
53 | | -def main(config): |
54 | | - # Obtain the jetstream helper modules (or stubs if appropriate). |
55 | | - config_lib, _engine_api, *_ = gcloud_stub.jetstream() |
56 | | - |
57 | | - # If running decoupled and gcloud_stub returned lightweight stubs, skip |
58 | | - # starting the real server. Use the explicit _IS_STUB marker when present. |
59 | | - config_lib_is_stub = getattr(config_lib, "_IS_STUB", False) |
60 | | - engine_api_is_stub = getattr(_engine_api, "_IS_STUB", False) |
61 | | - if gcloud_stub.is_decoupled() and (config_lib_is_stub or engine_api_is_stub): |
62 | | - raise RuntimeError( |
63 | | - "JetStream helper modules are stubbed or DECOUPLE_GCLOUD=TRUE; server cannot be started in decoupled mode. " |
64 | | - "Unset DECOUPLE_GCLOUD or install JetStream to run the server." |
65 | | - ) |
66 | | - |
67 | | - # Import the real server_lib now that it's known present. |
68 | | - from jetstream.core import server_lib # type: ignore # pylint: disable=import-outside-toplevel |
69 | | - import pathwaysutils # pylint: disable=unused-import,import-outside-toplevel |
70 | | - |
71 | | - pathwaysutils.initialize() |
72 | | - |
73 | | - # No devices for local cpu test. A None for prefill and a None for generate. |
74 | | - devices = server_lib.get_devices() |
75 | | - server_config = maxengine_config.get_server_config(config.inference_server, config) |
76 | | - |
77 | | - metrics_server_config: Any | None = None |
78 | | - if config.prometheus_port != 0: |
79 | | - metrics_server_config = config_lib.MetricsServerConfig(port=config.prometheus_port) |
80 | | - |
81 | | - # We separate credential from run so that we can unit test it with |
82 | | - # local credentials. |
83 | | - # TODO: Add grpc credentials for OSS. |
84 | | - # pylint: disable=unexpected-keyword-arg |
85 | | - jetstream_server = server_lib.run( |
86 | | - threads=256, |
87 | | - port=9000, |
88 | | - config=server_config, |
89 | | - devices=devices, |
90 | | - metrics_server_config=metrics_server_config, |
91 | | - enable_jax_profiler=config.enable_jax_profiler if config.enable_jax_profiler else False, |
92 | | - jax_profiler_port=config.jax_profiler_port if config.jax_profiler_port else 9999, |
93 | | - enable_model_warmup=config.enable_model_warmup if config.enable_model_warmup else False, |
94 | | - lora_input_adapters_path=config.lora_input_adapters_path, |
95 | | - multi_sampling=config.multi_sampling if config.multi_sampling else False, |
96 | | - prefix_caching_config=_create_prefix_caching_config(config, config_lib), |
97 | | - ) |
98 | | - jetstream_server.wait_for_termination() |
| 25 | +from maxtext.utils import max_logging |
99 | 26 |
|
| 27 | +OLD_MODULE_PATH = "MaxText.maxengine_server" |
| 28 | +NEW_MODULE_PATH = "maxtext.inference.maxengine.maxengine_server" |
100 | 29 |
|
101 | 30 | if __name__ == "__main__": |
102 | | - jax.config.update("jax_default_prng_impl", "unsafe_rbg") |
103 | | - os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" |
104 | | - cfg = pyconfig.initialize(sys.argv) |
105 | | - main(cfg) |
| 31 | + try: |
| 32 | + jax.config.update("jax_default_prng_impl", "unsafe_rbg") |
| 33 | + os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" |
| 34 | + logging.set_verbosity(logging.INFO) |
| 35 | + max_logging.warning(f"'{OLD_MODULE_PATH}' is deprecated; use '{NEW_MODULE_PATH}' instead.\n") |
| 36 | + _new_module = importlib.import_module(NEW_MODULE_PATH) |
| 37 | + if hasattr(_new_module, "main"): |
| 38 | + cfg = pyconfig.initialize(sys.argv) |
| 39 | + _new_module.main(cfg) |
| 40 | + except ImportError as e: |
| 41 | + max_logging.error(f"Shim could not find target module: '{NEW_MODULE_PATH}'\n") |
| 42 | + raise e |
0 commit comments