Skip to content

Commit 1f56029

Browse files
adding necessary files
add decoupling logic config patch to tests add correct ICI parallelism to tests for decoupled mode fixing more UTs adding decoupling logic, biggest change add tensorboardX stub fixing train_tests removing tunix from decoupling logic renaming datasets to local_datasets to avoid confusion with HF datasets library make jax_remove_size_one_mesh_axis_from_type param setting in try block, todo: remove this after updating jax. Configure ICI data parallelism for decoupled mode revert legacy rl trainer Update grpo_trainer.py adding conditional imports because pytest collect always imports before marks are used centralize decoupled dataset paths and base_output_directory skip packed attention if not on cuda sm90+ parameterize test_env_smoke tests undo sft test decoupling changes as it is marked as external training move path logic to setup method and use if logic in train_smoke_test.py fix ref to dummy summary writer add yield to GOODPUT_STUB Add pytest marker for train_compile tests These are actually requiring libtpu and should be TPU tests. fixed path for GrainArrayRecordBestFitPackingTest update local output adding refactoring for test_utils import moved local_datasets, refactored add missing import from flop_calculation test adding gcloud_stub_test.py fix gcloud_stub test, add cpu_only marker
1 parent f44534f commit 1f56029

130 files changed

Lines changed: 1836 additions & 708 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ Check out our [Read The Docs site](https://maxtext.readthedocs.io/en/latest/) or
3535
See our installation guide to [install MaxText with pip from PyPI](https://maxtext.readthedocs.io/en/latest/install_maxtext.html#from-pypi-recommended).
3636

3737
## Decoupled mode
38-
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/guides/run_maxtext/decoupled_mode.html).
38+
See our guide on running MaxText in decoupled mode, without any GCP dependencies in [Decoupled Mode Guide](https://maxtext.readthedocs.io/en/latest/run_maxtext/decoupled_mode.html).
3939

4040
<!-- NEWS START -->
4141

codecov.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ coverage:
6565
patch:
6666
default:
6767
target: auto
68-
threshold: 5% # fail on 5+ percent degradation
68+
threshold: 10% # fail on 10+ percent degradation
6969
flags:
7070
- regular
7171

dependencies/requirements/requirements_decoupled_jax_0_7.1.txt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ flax
88
grain>=0.2.12
99
grpcio>=1.75.1
1010
huggingface_hub>=0.35.3
11+
jax==0.7.1
1112
jaxtyping>=0.3.3
1213
jsonlines>=4.0.0
1314
matplotlib>=3.10.3
@@ -19,6 +20,7 @@ omegaconf>=2.3.0
1920
optax>=0.2.6
2021
orbax-checkpoint>=0.11.25
2122
pandas>=2.3.3
23+
parameterized==0.9.0
2224
pathwaysutils>=0.1.3
2325
pillow>=11.3.0
2426
protobuf>=5.29.5
@@ -39,5 +41,4 @@ tiktoken>=0.12.0
3941
tqdm>=4.67.1
4042
transformers>=4.57.0
4143
urllib3>=2.5.0
42-
jax==0.7.1
43-
git+https://github.com/google/tunix.git
44+
git+https://github.com/google/tunix.git

docs/run_maxtext/decoupled_mode.md

Lines changed: 34 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,40 @@
1414
limitations under the License.
1515
-->
1616

17-
1817
# Via Decoupled Mode (No Google Cloud Dependencies)
1918

2019
Set `DECOUPLE_GCLOUD=TRUE` to run MaxText tests and local development without any Google Cloud SDK, `gs://` buckets, JetStream, or Vertex AI integrations.
2120

2221
When enabled:
23-
* Skips external integration tests with markers:
24-
* `external_serving` (`jetstream`, `serving`, `decode_server`)
25-
* `external_training` (`goodput`)
26-
* `decoupled` – Applied by `tests/conftest.py` to tests that are runnable in decoupled mode (i.e. not skipped for TPU or external markers).
27-
* Production / serving entrypoints (`decode.py`, `maxengine_server.py`, `maxengine_config.py`, tokenizer access in `maxengine.py`) **fail fast with a clear RuntimeError** when decoupled. This prevents accidentally running partial serving logic locally when decoupled mode is ON.
28-
* Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises.
29-
* Conditionally replaces dataset paths in certain tests to point at minimal local datasets.
30-
* Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`).
31-
* All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/utils/test_helper.py`. This helper ensures usage of `decoupled_base_test.yml`.
22+
23+
- Skips external integration tests with markers:
24+
- `external_serving` (`jetstream`, `serving`, `decode_server`)
25+
- `external_training` (`goodput`)
26+
- `decoupled` – Applied by `tests/conftest.py` to tests that are runnable in decoupled mode (i.e. not skipped for TPU or external markers).
27+
- Production / serving entrypoints (`decode.py`, `maxengine_server.py`, `maxengine_config.py`, tokenizer access in `maxengine.py`) **fail fast with a clear RuntimeError** when decoupled. This prevents accidentally running partial serving logic locally when decoupled mode is ON.
28+
- Import-time safety is preserved by lightweight stubs returned from `decouple.py` (so modules import cleanly); only active use of missing functionality raises.
29+
- Conditionally replaces dataset paths in certain tests to point at minimal local datasets.
30+
- Uses a local base output directory (users can override with `LOCAL_BASE_OUTPUT`).
31+
- All tests that previously hard-coded `configs/base.yml` now use the helper `get_test_config_path()` from `tests/utils/test_utils.py`. This helper ensures usage of `decoupled_base_test.yml`.
3232

3333
Minimal datasets included (checked into the repo):
34-
* ArrayRecord shards: generated via `python local_datasets/get_minimal_c4_en_dataset.py`,
34+
35+
- ArrayRecord shards: generated via `python local_datasets/get_minimal_c4_en_dataset.py`,
3536
located in `local_datasets/c4_en_dataset_minimal/c4/en/3.0.1/c4-{train,validation}.array_record-*`
36-
* Parquet (HF style): generated via `python local_datasets/get_minimal_hf_c4_parquet.py`,
37+
- Parquet (HF style): generated via `python local_datasets/get_minimal_hf_c4_parquet.py`,
3738
located in `local_datasets/c4_en_dataset_minimal/hf/c4`
3839

39-
4040
Run a local smoke test fully offline:
41+
4142
```bash
4243
export DECOUPLE_GCLOUD=TRUE
4344
pytest -k train_gpu_smoke_test -q
4445
```
4546

4647
Optional environment variables:
47-
* `LOCAL_GCLOUD_PROJECT` - placeholder project string (default: `local-maxtext-project`).
48-
* `LOCAL_BASE_OUTPUT` - override default local output directory used in tests.
48+
49+
- `LOCAL_GCLOUD_PROJECT` - placeholder project string (default: `local-maxtext-project`).
50+
- `LOCAL_BASE_OUTPUT` - override default local output directory used in tests.
4951

5052
## Centralized Decoupling API (`gcloud_stub.py`)
5153

@@ -55,32 +57,36 @@ MaxText exposes a single module `MaxText.gcloud_stub` to avoid scattering enviro
5557
from MaxText.gcloud_stub import is_decoupled, cloud_diagnostics, jetstream
5658

5759
if is_decoupled():
58-
# Skip optional integrations or use local fallbacks
59-
pass
60+
# Skip optional integrations or use local fallbacks
61+
pass
6062

6163
# Cloud diagnostics (returns diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration)
62-
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = cloud_diagnostics()
64+
diagnostic, debug_configuration, diagnostic_configuration, stack_trace_configuration = (
65+
cloud_diagnostics()
66+
)
6367

6468
# JetStream (serving) components
6569
config_lib, engine_api, token_utils, tokenizer_api, token_params_ns = jetstream()
6670
TokenizerParameters = getattr(token_params_ns, "TokenizerParameters", object)
6771
```
6872

6973
Behavior when `DECOUPLE_GCLOUD=TRUE`:
70-
* `is_decoupled()` returns True.
71-
* Each helper returns lightweight stubs whose attributes are safe to access; calling methods raises a clear `RuntimeError` only when actually invoked.
72-
* Prevents import-time failures for optional dependencies (JetStream).
74+
75+
- `is_decoupled()` returns True.
76+
- Each helper returns lightweight stubs whose attributes are safe to access; calling methods raises a clear `RuntimeError` only when actually invoked.
77+
- Prevents import-time failures for optional dependencies (JetStream).
7378

7479
## Guidelines:
75-
* Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency.
76-
* Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking.
77-
* Use `get_test_config_path()` instead of hard-coded `base.yml`.
78-
* Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths.
79-
* Please add the appropriate external dependency marker (`external_serving` or `external_training`) for new tests. Prefer the smallest scope instead of module-wide `pytestmark` when only a part of a file needs an external dependency.
80-
* Tests add a `decoupled` marker if DECOUPLE_GCLOUD && not marked with external dependency markers. Run tests with:
80+
81+
- Prefer calling `jetstream()` / `cloud_diagnostics()` once at module import and branching on `is_decoupled()` for functionality that truly requires the dependency.
82+
- Use `is_decoupled()` to avoid direct `os.environ["DECOUPLE_GCLOUD"]` checking.
83+
- Use `get_test_config_path()` instead of hard-coded `base.yml`.
84+
- Prefer conditional local fallbacks for cloud buckets and avoid introducing direct `gs://...` paths.
85+
- Please add the appropriate external dependency marker (`external_serving` or `external_training`) for new tests. Prefer the smallest scope instead of module-wide `pytestmark` when only a part of a file needs an external dependency.
86+
- Tests add a `decoupled` marker if DECOUPLE_GCLOUD && not marked with external dependency markers. Run tests with:
87+
8188
```
8289
pytest -m decoupled -vv tests
8390
```
8491

8592
This centralized approach keeps optional integrations cleanly separated from core MaxText logic, making local development (e.g. on ROCm/NVIDIA GPUs) frictionless.
86-

src/MaxText/configs/decoupled_base_test.yml

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
# Decoupled base test config: used when DECOUPLE_GCLOUD=TRUE for tests that previously relied on base.yml.
2-
# Inherit all model defaults from base.yml but override any cloud-coupled paths and disable optional cloud features.
3-
base_config: base.yml
2+
# Inherit all model defaults (PyDantic already does this) but override any cloud-coupled paths and disable
3+
# optional cloud features.
44

55
# Output goes to a local relative directory so tests do not require GCS.
6-
base_output_directory: ./maxtext_local_output
6+
base_output_directory: ./maxtext_local_output/gcloud_decoupled_test_logs
77
run_name: test_decoupled
88

99
# Disable checkpointing by default for speed unless a test explicitly enables it.
@@ -23,7 +23,9 @@ profile_periodically_period: 0
2323
profiler_steps: 0
2424

2525
# Leave dataset-related keys to be overridden by individual tests.
26-
dataset_type: ""
26+
dataset_path: "tests/assets/local_datasets/c4_en_dataset_minimal/"
27+
dataset_name: 'c4/en:3.1.0'
28+
eval_dataset_name: 'c4/en:3.1.0'
2729

2830
# Use dot_product attention to avoid GPU Pallas shared memory limits on AMD GPUs
2931
attention: "dot_product"
@@ -44,6 +46,8 @@ ici_tensor_sequence_parallelism: 1
4446
ici_autoregressive_parallelism: 1
4547
ici_fsdp_parallelism: 1
4648
ici_fsdp_transpose_parallelism: 1
49+
# Allow higher unsharded parameter percentage for small device count
50+
sharding_tolerance: 0.3
4751

4852
# DCN dimensions to 1 (no multi-slice expectation locally).
4953
dcn_data_parallelism: 1
@@ -68,12 +72,4 @@ goodput_upload_interval_seconds: 0
6872
enable_pathways_goodput: false
6973
enable_gcp_goodput_metrics: false
7074

71-
# Disable any cloud logging / BigQuery or external metric uploads.
72-
enable_cloud_logging: false
73-
upload_metrics_to_bigquery: false
74-
bigquery_project: ""
75-
bigquery_dataset: ""
76-
bigquery_table: ""
77-
78-
# Force local-only behavior for tests: avoid accidental env pickup.
79-
tensorboard_dir: "./maxtext_local_output/tensorboard"
75+
tensorboard_dir: "./maxtext_local_output/gcloud_decoupled_test_logs/tensorboard"

src/MaxText/maxengine.py

Lines changed: 42 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
"""Implementation of Engine API for MaxText."""
1616

1717
from collections import defaultdict
18-
from typing import Any, Callable, Union
18+
from typing import Any, Callable
1919
import functools
2020
import os.path
2121
import uuid
@@ -36,13 +36,6 @@
3636
from flax.linen import partitioning as nn_partitioning
3737
import flax
3838

39-
from jetstream.core import config_lib
40-
from jetstream.engine import engine_api
41-
from jetstream.engine import token_utils
42-
from jetstream.engine import tokenizer_api
43-
from jetstream.engine.tokenizer_pb2 import TokenizerParameters
44-
from jetstream.engine.tokenizer_pb2 import TokenizerType
45-
4639
from MaxText import multimodal_utils
4740
from MaxText import pyconfig
4841
from MaxText.common_types import MODEL_MODE_PREFILL, DECODING_ACTIVE_SEQUENCE_INDICATOR, MODEL_MODE_AUTOREGRESSIVE
@@ -53,6 +46,11 @@
5346
from maxtext.utils import lora_utils
5447
from maxtext.utils import max_utils
5548
from maxtext.utils import maxtext_utils
49+
from maxtext.common.gcloud_stub import jetstream, is_decoupled
50+
51+
config_lib, engine_api, token_utils, tokenizer_api, _token_params_ns = jetstream()
52+
TokenizerParameters = getattr(_token_params_ns, "TokenizerParameters", object) # type: ignore[assignment]
53+
TokenizerType = getattr(_token_params_ns, "TokenizerType", object) # type: ignore[assignment]
5654

5755

5856
warnings.simplefilter("ignore", category=FutureWarning)
@@ -95,14 +93,17 @@ def get_keys(self):
9593
return self.keys
9694

9795

98-
class MaxEngine(engine_api.Engine):
96+
_BaseEngine = engine_api.Engine if (not is_decoupled() and hasattr(engine_api, "Engine")) else object
97+
98+
99+
class MaxEngine(_BaseEngine):
99100
"""The computational core of the generative model server.
100101
101102
Engine defines an API that models must adhere to as they plug into the
102103
JetStream efficient serving infrastructure.
103104
"""
104105

105-
def __init__(self, config: Any, devices: Union[config_lib.Devices, None] = None):
106+
def __init__(self, config: Any, devices: Any | None = None):
106107
self.config = config
107108

108109
# Mesh definition
@@ -139,7 +140,7 @@ def print_stats(self, label: str):
139140

140141
def generate_aot(
141142
self, params: Params, decode_state: DecodeState, rng: PRNGKeyType | None = None
142-
) -> tuple[DecodeState, engine_api.ResultTokens]:
143+
): # returns (new_decode_state, result_tokens)
143144
"""Wrapper to generate for ahead of time compilation."""
144145

145146
return self.generate(params=params, decode_state=decode_state, rng=rng)
@@ -393,7 +394,7 @@ def prefill_aot( # pylint: disable=too-many-positional-arguments
393394
padded_tokens: jax.Array,
394395
true_length: int,
395396
rng: PRNGKeyType | None = None,
396-
) -> tuple[Prefix, engine_api.ResultTokens]:
397+
): # returns (new_prefix, result_tokens)
397398
"""Wrapper for prefill for ahead-of-time compilation."""
398399

399400
return self.prefill(
@@ -426,7 +427,7 @@ def _prefill_jit(
426427
topk: int | None = None,
427428
nucleus_topp: float | None = None,
428429
temperature: float | None = None,
429-
) -> tuple[Prefix, engine_api.ResultTokens]:
430+
): # returns (new_prefix, result_tokens)
430431
"""Performs a JIT-compiled prefill operation on a sequence of tokens.
431432
432433
This function processes an input sequence (prompt) through the model to compute
@@ -594,7 +595,7 @@ def prefill(
594595
topk: int | None = None,
595596
nucleus_topp: float | None = None,
596597
temperature: float | None = None,
597-
) -> tuple[Prefix, engine_api.ResultTokens]:
598+
): # returns (new_prefix, result_tokens)
598599
"""Public API for prefill that updates page state outside JIT."""
599600
# Update page state before JIT call
600601
if self.config.attention == "paged" and self.page_manager is not None and self.page_state is not None:
@@ -643,7 +644,7 @@ def prefill_multisampling_aot( # pylint: disable=too-many-positional-arguments
643644
topk: int | None = None,
644645
nucleus_topp: float | None = None,
645646
temperature: float | None = None,
646-
) -> tuple[Prefix, engine_api.ResultTokens]:
647+
): # returns (new_prefix, result_tokens)
647648
"""Wrapper for multi-sampling prefill for ahead-of-time compilation."""
648649
return self.prefill_multisampling(
649650
params=params,
@@ -672,7 +673,7 @@ def prefill_multisampling(
672673
topk: int | None = None,
673674
nucleus_topp: float | None = None,
674675
temperature: float | None = None,
675-
) -> tuple[Prefix, engine_api.ResultTokens]:
676+
): # returns (new_prefix, result_tokens)
676677
"""Public API for prefill multisampling."""
677678

678679
# Sample rng before JIT call
@@ -709,7 +710,7 @@ def _prefill_multisampling_jit(
709710
topk: int | None = None,
710711
nucleus_topp: float | None = None,
711712
temperature: float | None = None,
712-
) -> tuple[Prefix, engine_api.ResultTokens]:
713+
) -> tuple[Prefix, Any]:
713714
"""Computes a kv-cache for a new generate request.
714715
715716
With multi-sampling, the engine will generate multiple first tokens in the
@@ -816,7 +817,7 @@ def prefill_concat(
816817
topk: int | None = None,
817818
nucleus_topp: float | None = None,
818819
temperature: float | None = None,
819-
) -> tuple[Any, PackedPrefix, list[engine_api.ResultTokens]]:
820+
): # returns (maybe_batch, packed_prefix, list_of_result_tokens)
820821
"""Computes a kv-cache for a new packed generate request, which is a
821822
concatenation of several shorter prompts. Experimentation shows that
822823
longer prefill sequences gives approximately 15% boost in time per prefilled
@@ -933,7 +934,7 @@ def generate(
933934
topk: int | None = None,
934935
nucleus_topp: float | None = None,
935936
temperature: float | None = None,
936-
) -> tuple[DecodeState, engine_api.ResultTokens]:
937+
): # returns (decode_state, result_tokens)
937938
"""Public API for generate that updates page state outside JIT."""
938939

939940
# Update page state before JIT call
@@ -976,7 +977,7 @@ def _generate_jit(
976977
topk: int | None = None,
977978
nucleus_topp: float | None = None,
978979
temperature: float | None = None,
979-
) -> tuple[DecodeState, engine_api.ResultTokens]:
980+
): # returns (decode_state, result_tokens)
980981
"""Performs a single, JIT-compiled autoregressive decoding step.
981982
982983
This function takes the current decoding state, which includes the KV cache
@@ -1497,8 +1498,19 @@ def get_prefix_destination_sharding(self) -> Any:
14971498
"token_logp": self.replicated_sharding,
14981499
}
14991500

1500-
def get_tokenizer(self) -> TokenizerParameters:
1501-
"""Return a protobuf of tokenizer info, callable from Py or C++."""
1501+
def get_tokenizer(self) -> Any:
1502+
"""Return tokenizer parameters; requires JetStream when decoupled.
1503+
1504+
When DECOUPLE_GCLOUD is FALSE we provide a clear error instead of failing
1505+
cryptically on attribute access.
1506+
"""
1507+
token_params_is_stub = getattr(_token_params_ns, "_IS_STUB", False)
1508+
engine_api_is_stub = getattr(engine_api, "_IS_STUB", False)
1509+
if is_decoupled() and (token_params_is_stub or engine_api_is_stub):
1510+
raise RuntimeError(
1511+
"JetStream disabled by DECOUPLE_GCLOUD=TRUE or stubbed; get_tokenizer is unsupported. "
1512+
"Unset DECOUPLE_GCLOUD or install JetStream to enable tokenizer functionality."
1513+
)
15021514
try:
15031515
tokenizer_type_val = TokenizerType.DESCRIPTOR.values_by_name[self.config.tokenizer_type].number
15041516
return TokenizerParameters(
@@ -1511,8 +1523,15 @@ def get_tokenizer(self) -> TokenizerParameters:
15111523
except KeyError as _:
15121524
raise KeyError(f"Unsupported tokenizer type: {self.config.tokenizer_type}") from None
15131525

1514-
def build_tokenizer(self, metadata: TokenizerParameters) -> tokenizer_api.Tokenizer:
1526+
def build_tokenizer(self, metadata: Any): # return type depends on JetStream
15151527
"""Return a tokenizer"""
1528+
token_params_is_stub = getattr(_token_params_ns, "_IS_STUB", False)
1529+
engine_api_is_stub = getattr(engine_api, "_IS_STUB", False)
1530+
if is_decoupled() and (token_params_is_stub or engine_api_is_stub):
1531+
raise RuntimeError(
1532+
"JetStream disabled by DECOUPLE_GCLOUD=TRUE or stubbed; build_tokenizer is unsupported. "
1533+
"Unset DECOUPLE_GCLOUD or install JetStream to enable tokenizer functionality."
1534+
)
15161535
if metadata.tokenizer_type == TokenizerType.tiktoken:
15171536
return token_utils.TikToken(metadata)
15181537
elif metadata.tokenizer_type == TokenizerType.sentencepiece:

0 commit comments

Comments
 (0)