Skip to content

Commit bad4fc3

Browse files
committed
adding misc lint fixes.
1 parent f1fc688 commit bad4fc3

2 files changed

Lines changed: 29 additions & 56 deletions

File tree

src/maxtext/layers/decoders.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -319,11 +319,21 @@ def get_remat_policy(self):
319319
policy = self.minimal_policy()
320320
elif cfg.remat_policy == "minimal_with_quantization":
321321
if cfg.scan_layers:
322-
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
322+
warnings.warn(
323+
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
324+
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
325+
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
326+
"beneficial for performance."
327+
)
323328
policy = self.minimal_policy(with_context=False, with_quantization=True)
324329
elif cfg.remat_policy == "minimal_with_context_and_quantization":
325330
if cfg.scan_layers:
326-
warnings.warn('Scan layers can introduce overhead to checkpointed values that in some configurations is slower than not checkpointing at all. If you are using scan layers, benchmark with and without quantization checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is beneficial for performance.')
331+
warnings.warn(
332+
"Scan layers can introduce overhead to checkpointed values that in some configurations is slower"
333+
"than not checkpointing at all. If you are using scan layers, benchmark with and without quantization "
334+
"checkpointing in your workflow to see which is faster. Without scan layers, checkpointing quantizations is "
335+
"beneficial for performance."
336+
)
327337
policy = self.minimal_policy(with_context=True, with_quantization=True)
328338
elif cfg.remat_policy == "save_dot_with_context_except_mlp":
329339
policy = jax.checkpoint_policies.save_only_these_names(

tests/integration/grpo_correctness.py

Lines changed: 17 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,8 @@ def setUp(self):
6060
self.rng = jax.random.PRNGKey(42)
6161
devices_array = maxtext_utils.create_device_mesh(self.cfg)
6262
mesh = Mesh(devices_array, self.cfg.mesh_axes)
63-
self.model = models.transformer_as_linen(
64-
config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN
65-
)
66-
self.state, _ = maxtext_utils.setup_decode_state(
67-
self.model, self.cfg, self.rng, mesh, None
68-
)
63+
self.model = models.transformer_as_linen(config=self.cfg, mesh=mesh, quant=None, model_mode=MODEL_MODE_TRAIN)
64+
self.state, _ = maxtext_utils.setup_decode_state(self.model, self.cfg, self.rng, mesh, None)
6965
self.tokenizer_model = transformers.AutoTokenizer.from_pretrained(
7066
"meta-llama/Llama-3.1-8B",
7167
add_bos_token=False,
@@ -104,16 +100,12 @@ def _prepare_maxtext_inputs(self):
104100
"""prepare maxtext inputs"""
105101
prompt = self.tokenizer_model.encode(self.input_str)
106102
input_ids = jnp.pad(
107-
jnp.tile(
108-
jnp.concat([jnp.array(prompt), jnp.array(prompt)], axis=-1), (4, 1)
109-
),
103+
jnp.tile(jnp.concat([jnp.array(prompt), jnp.array(prompt)], axis=-1), (4, 1)),
110104
((0, 0), (0, 4)),
111105
constant_values=0,
112106
) # pad some tokens at the end of input prompt
113107
input_segmentation = (input_ids > 0).astype(jnp.int32)
114-
input_position = jnp.where(
115-
input_segmentation, jnp.arange(input_segmentation.shape[1]), 0
116-
)
108+
input_position = jnp.where(input_segmentation, jnp.arange(input_segmentation.shape[1]), 0)
117109
completion_segmentation = jnp.tile(
118110
jnp.pad(
119111
jnp.array([0] * len(prompt) + [1] * len(prompt)),
@@ -129,12 +121,9 @@ def _prepare_maxtext_inputs(self):
129121
)
130122

131123
def _prepare_trl_inputs(self):
132-
tokenized_inputs = self.tokenizer_model(
133-
[self.input_str], return_tensors="pt"
134-
)
135-
input_ids = torch.cat(
136-
(tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1
137-
)
124+
"""Prepare TRL inputs."""
125+
tokenized_inputs = self.tokenizer_model([self.input_str], return_tensors="pt")
126+
input_ids = torch.cat((tokenized_inputs["input_ids"], tokenized_inputs["input_ids"]), axis=-1)
138127
attention_mask = torch.cat(
139128
(
140129
tokenized_inputs["attention_mask"],
@@ -147,9 +136,7 @@ def _prepare_trl_inputs(self):
147136

148137
def test_logits(self):
149138
def _prepare_inputs():
150-
input_ids = jnp.tile(
151-
jnp.array(self.tokenizer_model.encode(self.input_str)), (4, 1)
152-
)
139+
input_ids = jnp.tile(jnp.array(self.tokenizer_model.encode(self.input_str)), (4, 1))
153140
input_segmentation = (input_ids > 0).astype(jnp.int32)
154141
input_position = jnp.tile(jnp.arange(input_ids.shape[1]), (4, 1))
155142

@@ -175,17 +162,11 @@ def _prepare_inputs():
175162
.numpy()
176163
)
177164
print(f"Max Diff {np.max(np.abs(logits - hf_logits))}")
178-
self.assertTrue(
179-
jax.numpy.allclose(
180-
hf_logits, logits, rtol=1e-2, atol=2e-1, equal_nan=False
181-
)
182-
)
165+
self.assertTrue(jax.numpy.allclose(hf_logits, logits, rtol=1e-2, atol=2e-1, equal_nan=False))
183166

184167
def test_logps(self):
185168

186-
input_ids, input_segmentation, input_position, completion_segmentation = (
187-
self._prepare_maxtext_inputs()
188-
)
169+
input_ids, input_segmentation, input_position, completion_segmentation = self._prepare_maxtext_inputs()
189170
maxtext_per_token_logps, _ = compute_log_probs(
190171
self.model,
191172
self.state.params,
@@ -202,12 +183,7 @@ def test_logps(self):
202183

203184
print(
204185
"Max Diff",
205-
np.max(
206-
np.abs(
207-
np.trim_zeros(np.asarray(maxtext_per_token_logps)[0])
208-
- hf_per_token_logps.detach().numpy()[0]
209-
)
210-
),
186+
np.max(np.abs(np.trim_zeros(np.asarray(maxtext_per_token_logps)[0]) - hf_per_token_logps.detach().numpy()[0])),
211187
)
212188
self.assertTrue(
213189
jax.numpy.allclose(
@@ -228,27 +204,16 @@ def test_loss_kl_div(self):
228204

229205
completions = [{"prompt": self.input_str}] * 4
230206
rewards = torch.tensor(
231-
[
232-
self.trainer.reward_funcs[0](completion)
233-
for completion in completions
234-
],
207+
[self.trainer.reward_funcs[0](completion) for completion in completions],
235208
dtype=torch.float32,
236209
)
237210
# Compute grouped-wise rewards
238-
mean_grouped_rewards = rewards.view(-1, self.trainer.num_generations).mean(
239-
dim=1
240-
)
241-
std_grouped_rewards = rewards.view(-1, self.trainer.num_generations).std(
242-
dim=1
243-
)
211+
mean_grouped_rewards = rewards.view(-1, self.trainer.num_generations).mean(dim=1)
212+
std_grouped_rewards = rewards.view(-1, self.trainer.num_generations).std(dim=1)
244213

245214
# Normalize the rewards to compute the advantages
246-
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(
247-
self.trainer.num_generations, dim=0
248-
)
249-
std_grouped_rewards = std_grouped_rewards.repeat_interleave(
250-
self.trainer.num_generations, dim=0
251-
)
215+
mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.trainer.num_generations, dim=0)
216+
std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.trainer.num_generations, dim=0)
252217
# since we are using the same completion, so advantages = 0 for every sequence
253218
# but we can keep it this way since our on-policy implementation
254219
# gets average advantage which becomes zero anyway
@@ -273,9 +238,7 @@ def test_loss_kl_div(self):
273238

274239
self.trainer._get_per_token_logps(self.hf_model, hf_input_ids, attention_mask, logits_to_keep) # pylint: disable=protected-access
275240

276-
input_ids, input_segmentation, input_position, completion_segmentation = (
277-
self._prepare_maxtext_inputs()
278-
)
241+
input_ids, input_segmentation, input_position, completion_segmentation = self._prepare_maxtext_inputs()
279242
maxtext_per_token_logps, _ = compute_log_probs(
280243
self.model,
281244
self.state.params,

0 commit comments

Comments
 (0)