Skip to content

Commit 8bb73fc

Browse files
Merge pull request #3237 from AI-Hypercomputer:nicogrande/fix-pylint-new
PiperOrigin-RevId: 874862313
2 parents 42ec065 + bad4fc3 commit 8bb73fc

2 files changed

Lines changed: 29 additions & 56 deletions

File tree

src/maxtext/layers/decoders.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,11 +344,21 @@ def get_remat_policy(self):
344344
policy = self.minimal_policy()
345345
elif cfg.remat_policy == "minimal_with_quantization":
346346
if cfg.scan_layers:
347-
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
347+
warnings.warn(
348+
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
349+
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
350+
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
351+
"beneficial for performance."
352+
)
348353
policy = self.minimal_policy(with_context=False, with_quantization=True)
349354
elif cfg.remat_policy == "minimal_with_context_and_quantization":
350355
if cfg.scan_layers:
351-
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
356+
warnings.warn(
357+
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
358+
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
359+
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
360+
"beneficial for performance."
361+
)
352362
policy = self.minimal_policy(with_context=True, with_quantization=True)
353363
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
354364
policy = jax.checkpoint_policies.save_only_these_names(

tests/integration/grpo_correctness.py

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,8 @@ def setUp(self):
6060
self.rng = jax.random.PRNGKey(42)
6161
devices_array = maxtext_utils.create_device_mesh(self.cfg)
6262
mesh = Mesh(devices_array, self.cfg.mesh_axes)
63-
self.model = models.transformer_as_linen(
64-
config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN
65-
)
66-
self.state, _ = maxtext_utils.setup_decode_state(
67-
self.model, self.cfg, self.rng, mesh, None
68-
)
63+
self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN)
64+
self.state, _ = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None)
6965
self.tokenizer_model = transformers.AutoTokenizer.from_pretrained(
7066
"meta-llama/Llama-3.1-8B",
7167
add_bos_token=False,
@@ -104,16 +100,12 @@ def _prepare_maxtext_inputs(self):
104100
"""prepare maxtext inputs"""
105101
prompt = self.tokenizer_model.encode(self.input_str)
106102
input_ids = jnp.pad(
107-
jnp.tile(
108-
jnp.concat([jnp.array(prompt), jnp.array(prompt)], axis=-1), (4, 1)
109-
),
103+
jnp.tile(jnp.concat([jnp.array(prompt), jnp.array(prompt)], axis=-1), (4, 1)),
110104
((0, 0), (0, 4)),
111105
constant_values=0,
112106
) # pad some tokens at the end of input prompt
113107
input_segmentation = (input_ids > 0).astype(jnp.int32)
114-
input_position = jnp.where(
115-
input_segmentation, jnp.arange(input_segmentation.shape[1]), 0
116-
)
108+
input_position = jnp.where(input_segmentation, jnp.arange(input_segmentation.shape[1]), 0)
117109
completion_segmentation = jnp.tile(
118110
jnp.pad(
119111
jnp.array([0] * len(prompt) + [1] * len(prompt)),
@@ -129,12 +121,9 @@ def _prepare_maxtext_inputs(self):
129121
)
130122

131123
def _prepare_trl_inputs(self):
132-
tokenized_inputs = self.tokenizer_model(
133-
[self.input_str], return_tensors="pt"
134-
)
135-
input_ids = torch.cat(
136-
(tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1
137-
)
124+
"""Prepare TRL inputs."""
125+
tokenized_inputs = self.tokenizer_model([self.input_str], return_tensors="pt")
126+
input_ids = torch.cat((tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1)
138127
attention_mask = torch.cat(
139128
(
140129
tokenized_inputs["attention_mask"],
@@ -147,9 +136,7 @@ def _prepare_trl_inputs(self):
147136

148137
def test_logits(self):
149138
def _prepare_inputs():
150-
input_ids = jnp.tile(
151-
jnp.array(self.tokenizer_model.encode(self.input_str)), (4, 1)
152-
)
139+
input_ids = jnp.tile(jnp.array(self.tokenizer_model.encode(self.input_str)), (4, 1))
153140
input_segmentation = (input_ids > 0).astype(jnp.int32)
154141
input_position = jnp.tile(jnp.arange(input_ids.shape[1]), (4, 1))
155142

@@ -175,17 +162,11 @@ def _prepare_inputs():
175162
.numpy()
176163
)
177164
print(f"Max Diff {np.max(np.abs(logits - hf_logits))}")
178-
self.assertTrue(
179-
jax.numpy.allclose(
180-
hf_logits, logits, rtol=1e-2, atol=2e-1, equal_nan=False
181-
)
182-
)
165+
self.assertTrue(jax.numpy.allclose(hf_logits, logits, rtol=1e-2, atol=2e-1, equal_nan=False))
183166

184167
def test_logps(self):
185168

186-
input_ids, input_segmentation, input_position, completion_segmentation = (
187-
self._prepare_maxtext_inputs()
188-
)
169+
input_ids, input_segmentation, input_position, completion_segmentation = self._prepare_maxtext_inputs()
189170
maxtext_per_token_logps, _ = compute_log_probs(
190171
self.model,
191172
self.state.params,
@@ -202,12 +183,7 @@ def test_logps(self):
202183

203184
print(
204185
"Max Diff",
205-
np.max(
206-
np.abs(
207-
np.trim_zeros(np.asarray(maxtext_per_token_logps)[0])
208-
- hf_per_token_logps.detach().numpy()[0]
209-
)
210-
),
186+
np.max(np.abs(np.trim_zeros(np.asarray(maxtext_per_token_logps)[0]) - hf_per_token_logps.detach().numpy()[0])),
211187
)
212188
self.assertTrue(
213189
jax.numpy.allclose(
@@ -228,27 +204,16 @@ def test_loss_kl_div(self):
228204

229205
completions = [{"prompt": self.input_str}] * 4
230206
rewards = torch.tensor(
231-
[
232-
self.trainer.reward_funcs[0](completion)
233-
for completion in completions
234-
],
207+
[self.trainer.reward_funcs[0](completion) for completion in completions],
235208
dtype=torch.float32,
236209
)
237210
# Compute grouped-wise rewards
238-
mean_grouped_rewards = rewards.view(-1, self.trainer.num_generations).mean(
239-
dim=1
240-
)
241-
std_grouped_rewards = rewards.view(-1, self.trainer.num_generations).std(
242-
dim=1
243-
)
211+
mean_grouped_rewards = rewards.view(-1, self.trainer.num_generations).mean(dim=1)
212+
std_grouped_rewards = rewards.view(-1, self.trainer.num_generations).std(dim=1)
244213

245214
# Normalize the rewards to compute the advantages
246-
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
247-
self.trainer.num_generations, dim=0
248-
)
249-
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
250-
self.trainer.num_generations, dim=0
251-
)
215+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.trainer.num_generations, dim=0)
216+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.trainer.num_generations, dim=0)
252217
# since we are using the same completion, so advantages = 0 for every sequence
253218
# but we can keep it this way since our on-policy implementation
254219
# gets average advantage which becomes zero anyway
@@ -273,9 +238,7 @@ def test_loss_kl_div(self):
273238

274239
self.trainer._get_per_token_logps(self.hf_model, hf_input_ids, attention_mask, logits_to_keep) # pylint: disable=protected-access
275240

276-
input_ids, input_segmentation, input_position, completion_segmentation = (
277-
self._prepare_maxtext_inputs()
278-
)
241+
input_ids, input_segmentation, input_position, completion_segmentation = self._prepare_maxtext_inputs()
279242
maxtext_per_token_logps, _ = compute_log_probs(
280243
self.model,
281244
self.state.params,

0 commit comments

Comments
 (0)