Skip to content

Commit 498d482

Browse files
committed
Fix formatting in docstrings for src/MaxText
1 parent 5b0bb89 commit 498d482

5 files changed

Lines changed: 43 additions & 33 deletions

File tree

src/MaxText/estimator.py

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
size that does not cause an out-of-memory (OOM) error.
2525
2626
The key functions in this script are:
27-
- `is_oom`: Checks if a given configuration results in an OOM error.
28-
- `largest_batch_size`: Finds the largest batch size for a given policy.
29-
- `search`: The main algorithm that iterates through policies and batch sizes.
27+
- ``is_oom``: Checks if a given configuration results in an OOM error.
28+
- ``largest_batch_size``: Finds the largest batch size for a given policy.
29+
- ``search``: The main algorithm that iterates through policies and batch sizes.
3030
3131
By automating this search, the script helps to efficiently find the most
3232
performant and memory-efficient training configurations.
@@ -141,8 +141,8 @@ def next_policy(policy: dict) -> dict[str, str] | None:
141141
Generates the next rematerialization policy in the sequence.
142142
143143
This function iterates through the policy and changes the first tensor it
144-
finds with a 'device' value to 'offload', or the first 'offload' to 'remat'.
145-
If all tensors are already set to 'remat', it returns None.
144+
finds with a ``device`` value to ``offload``, or the first ``offload`` to
145+
``remat``. If all tensors are already set to ``remat``, it returns None.
146146
147147
Args:
148148
policy: The current policy dictionary.
@@ -166,18 +166,20 @@ def next_policy(policy: dict) -> dict[str, str] | None:
166166

167167
def largest_batch_size(base_argv, policy, min_pdb, max_pdb=64) -> int:
168168
"""
169-
Finds the largest possible per_device_batch_size (pdb) that does not cause an OOM error.
169+
Finds the largest possible ``per_device_batch_size`` (pdb) that does not cause
170+
an OOM error.
170171
171172
This function uses a binary search algorithm within the provided min and max
172173
range to efficiently find the optimal batch size.
173174
174175
Args:
175176
policy: The rematerialization policy dictionary.
176-
min_pdb: The minimum per_device_batch_size to test.
177-
max_pdb: The maximum per_device_batch_size to test.
177+
min_pdb: The minimum ``per_device_batch_size`` to test.
178+
max_pdb: The maximum ``per_device_batch_size`` to test.
178179
179180
Returns:
180-
The largest per_device_batch_size within the range that does not result in an OOM error.
181+
The largest ``per_device_batch_size`` within the range that does not result
182+
in an OOM error.
181183
"""
182184
print(f"Starting binary search for the largest batch size between {min_pdb} and {max_pdb}.")
183185

@@ -263,7 +265,7 @@ def search_policy_only(
263265
base_argv: The base command-line arguments.
264266
pdb: The fixed per-device batch size to test against.
265267
init_policy: The policy to start searching from. If None, defaults to
266-
'full_device_policy' (no remat).
268+
``full_device_policy`` (no remat).
267269
268270
Returns:
269271
The first rematerialization policy that did *not* OOM.
@@ -300,7 +302,7 @@ def search(
300302
301303
Args:
302304
config: The model configuration.
303-
max_pdb: The maximum per_device_batch_size to test.
305+
max_pdb: The maximum ``per_device_batch_size`` to test.
304306
305307
Returns:
306308
A list of tuples, where each tuple contains a batch size and its
@@ -341,12 +343,13 @@ def get_parameter_value(config_tuple, prefix):
341343
342344
Args:
343345
config_tuple: A tuple of strings to search.
344-
prefix: The prefix string to look for (e.g., 'key=').
346+
prefix: The prefix string to look for (e.g., ``key=``).
345347
346348
Returns:
347349
A tuple of (bool, str or None).
348-
- (True, value) if the prefix is found.
349-
- (False, None) if the prefix is not found.
350+
351+
* ``(True, value)`` if the prefix is found.
352+
* ``(False, None)`` if the prefix is not found.
350353
"""
351354
for item in config_tuple:
352355
if item.startswith(prefix):
@@ -361,15 +364,16 @@ def get_parameter_value(config_tuple, prefix):
361364

362365
def find_batch_size(base_argv):
363366
"""
364-
Parses the base arguments to find the 'per_device_batch_size'.
367+
Parses the base arguments to find the ``per_device_batch_size``.
365368
366369
Args:
367-
base_argv: The tuple of command-line arguments.
370+
base_argv: The tuple of command-line arguments.
368371
369372
Returns:
370-
A tuple of (bool, int or None):
371-
- (True, batch_size) if 'per_device_batch_size=...' was found.
372-
- (False, None) if it was not found.
373+
A tuple of (bool, int or None)
374+
375+
* ``(True, batch_size)`` if ``per_device_batch_size=...`` was found.
376+
* ``(False, None)`` if it was not found.
373377
"""
374378
pdb_provided, pdb_str = get_parameter_value(base_argv, prefix="per_device_batch_size=")
375379

@@ -384,10 +388,10 @@ def find_remat_policy_tensor_names(base_argv):
384388
to be considered for rematerialization.
385389
386390
Args:
387-
base_argv: The tuple of command-line arguments.
391+
base_argv: The tuple of command-line arguments.
388392
389393
Returns:
390-
A list of tensor names that were passed as flags.
394+
A list of tensor names that were passed as flags.
391395
"""
392396
full_tensor_list = [
393397
"context",

src/MaxText/generate_param_only_checkpoint.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414

1515
# pylint: disable=g-bad-todo, abstract-method, consider-using-with
1616
"""Transforms a "full state" including optimizer state to a bfloat16 "parameter state" without optimizer state.
17-
This typically used for turning a state output by training.py into a state than can be consumed by decode.py.
1817
19-
The input "fullstate" is passed in via:
20-
load_full_state_path.
21-
The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16.
18+
This typically used for turning a state output by training.py into a state than can be consumed by decode.py.
19+
20+
The input "fullstate" is passed in via ``load_full_state_path``.
21+
22+
The output "parameter state" is output to the checkpoint directory. Additionally it is cast down to bf16.
2223
"""
2324

2425
import os.path
@@ -155,8 +156,9 @@ def _save_decode_checkpoint(config, state, checkpoint_manager):
155156
def generate_decode_checkpoint(config):
156157
"""
157158
Generate an decode checkpoint from a given training checkpoint.
158-
- Training checkpoint is loaded from config.load_full_state_path.
159-
- Inference checkpoint will be saved at the config's checkpoint directory.
159+
160+
* Training checkpoint is loaded from config.load_full_state_path.
161+
* Inference checkpoint will be saved at the config's checkpoint directory.
160162
"""
161163

162164
devices_array = maxtext_utils.create_device_mesh(config)

src/MaxText/pyconfig_deprecated.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1285,12 +1285,16 @@ def validate_and_update_keys(raw_keys, model_keys, config_name: str):
12851285

12861286
def get_individual_scales(scale):
12871287
"""Choose appropriate scales for individual dimensions based on global scale
1288+
12881289
We choose to rotate between doubling:
1289-
num_head and mlp_dim
1290-
embed_dim
1291-
num_layers
1290+
1291+
* ``num_head`` and ``mlp_dim``
1292+
* ``embed_dim``
1293+
* ``num_layers``
1294+
12921295
Any one of these steps is not a perfect doubling, although going through a cycle
1293-
of three is a near perfect 8x scaling except for the linear -> softmax -> output step"""
1296+
of three is a near perfect 8x scaling except for the linear -> softmax -> output step
1297+
"""
12941298

12951299
log_2_scale = math.floor((math.log2(scale)))
12961300
if 2**log_2_scale != scale:

src/MaxText/sft/sft_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Shim for SFT Trainer in `src/maxtext/trainers/post_train/sft`."""
15+
"""Shim for SFT Trainer in ``src/maxtext/trainers/post_train/sft``."""
1616

1717
import sys
1818
import importlib

src/MaxText/train_tokenizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Shim for `train_tokenizer` in `src/maxtext/trainers/tokenizer`."""
15+
"""Shim for ``train_tokenizer`` in ``src/maxtext/trainers/tokenizer``."""
1616

1717
from absl import logging
1818

0 commit comments

Comments
 (0)