2424size that does not cause an out-of-memory (OOM) error.
2525
2626The key functions in this script are:
27- - `is_oom`: Checks if a given configuration results in an OOM error.
28- - `largest_batch_size`: Finds the largest batch size for a given policy.
29- - `search`: The main algorithm that iterates through policies and batch sizes.
27+ - `` is_oom` `: Checks if a given configuration results in an OOM error.
28+ - `` largest_batch_size` `: Finds the largest batch size for a given policy.
29+ - `` search` `: The main algorithm that iterates through policies and batch sizes.
3030
3131By automating this search, the script helps to efficiently find the most
3232performant and memory-efficient training configurations.
@@ -141,8 +141,8 @@ def next_policy(policy: dict) -> dict[str, str] | None:
141141 Generates the next rematerialization policy in the sequence.
142142
143143 This function iterates through the policy and changes the first tensor it
144- finds with a ' device' value to ' offload' , or the first ' offload' to 'remat'.
145- If all tensors are already set to ' remat' , it returns None.
144+ finds with a `` device`` value to `` offload`` , or the first `` offload`` to
145+ ``remat``. If all tensors are already set to `` remat`` , it returns None.
146146
147147 Args:
148148 policy: The current policy dictionary.
@@ -166,18 +166,20 @@ def next_policy(policy: dict) -> dict[str, str] | None:
166166
167167def largest_batch_size (base_argv , policy , min_pdb , max_pdb = 64 ) -> int :
168168 """
169- Finds the largest possible per_device_batch_size (pdb) that does not cause an OOM error.
169+ Finds the largest possible ``per_device_batch_size`` (pdb) that does not cause
170+ an OOM error.
170171
171172 This function uses a binary search algorithm within the provided min and max
172173 range to efficiently find the optimal batch size.
173174
174175 Args:
175176 policy: The rematerialization policy dictionary.
176- min_pdb: The minimum per_device_batch_size to test.
177- max_pdb: The maximum per_device_batch_size to test.
177+ min_pdb: The minimum `` per_device_batch_size`` to test.
178+ max_pdb: The maximum `` per_device_batch_size`` to test.
178179
179180 Returns:
180- The largest per_device_batch_size within the range that does not result in an OOM error.
181+ The largest ``per_device_batch_size`` within the range that does not result
182+ in an OOM error.
181183 """
182184 print (f"Starting binary search for the largest batch size between { min_pdb } and { max_pdb } ." )
183185
@@ -263,7 +265,7 @@ def search_policy_only(
263265 base_argv: The base command-line arguments.
264266 pdb: The fixed per-device batch size to test against.
265267 init_policy: The policy to start searching from. If None, defaults to
266- ' full_device_policy' (no remat).
268+ `` full_device_policy`` (no remat).
267269
268270 Returns:
269271 The first rematerialization policy that did *not* OOM.
@@ -300,7 +302,7 @@ def search(
300302
301303 Args:
302304 config: The model configuration.
303- max_pdb: The maximum per_device_batch_size to test.
305+ max_pdb: The maximum `` per_device_batch_size`` to test.
304306
305307 Returns:
306308 A list of tuples, where each tuple contains a batch size and its
@@ -341,12 +343,13 @@ def get_parameter_value(config_tuple, prefix):
341343
342344 Args:
343345 config_tuple: A tuple of strings to search.
344- prefix: The prefix string to look for (e.g., ' key=' ).
346+ prefix: The prefix string to look for (e.g., `` key=`` ).
345347
346348 Returns:
347349 A tuple of (bool, str or None).
348- - (True, value) if the prefix is found.
349- - (False, None) if the prefix is not found.
350+
351+ * ``(True, value)`` if the prefix is found.
352+ * ``(False, None)`` if the prefix is not found.
350353 """
351354 for item in config_tuple :
352355 if item .startswith (prefix ):
@@ -361,15 +364,16 @@ def get_parameter_value(config_tuple, prefix):
361364
362365def find_batch_size (base_argv ):
363366 """
364- Parses the base arguments to find the ' per_device_batch_size' .
367+ Parses the base arguments to find the `` per_device_batch_size`` .
365368
366369 Args:
367- base_argv: The tuple of command-line arguments.
370+ base_argv: The tuple of command-line arguments.
368371
369372 Returns:
370- A tuple of (bool, int or None):
371- - (True, batch_size) if 'per_device_batch_size=...' was found.
372- - (False, None) if it was not found.
373+ A tuple of (bool, int or None)
374+
375+ * ``(True, batch_size)`` if ``per_device_batch_size=...`` was found.
376+ * ``(False, None)`` if it was not found.
373377 """
374378 pdb_provided , pdb_str = get_parameter_value (base_argv , prefix = "per_device_batch_size=" )
375379
@@ -384,10 +388,10 @@ def find_remat_policy_tensor_names(base_argv):
384388 to be considered for rematerialization.
385389
386390 Args:
387- base_argv: The tuple of command-line arguments.
391+ base_argv: The tuple of command-line arguments.
388392
389393 Returns:
390- A list of tensor names that were passed as flags.
394+ A list of tensor names that were passed as flags.
391395 """
392396 full_tensor_list = [
393397 "context" ,
0 commit comments