This commit is contained in:
Dan Saunders
2025-04-18 21:35:33 +00:00
parent 4f2d092216
commit 7f4e4076e1

View File

@@ -617,9 +617,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
prompt_ids = prompt_ids[:, -self.max_prompt_length :]
prompt_mask = prompt_mask[:, -self.max_prompt_length :]
# Generate completions using vLLM: gather all prompts and use them in a single call in the main process
all_prompts_text = gather_object(prompts_text)
# Generate completions using either vLLM or regular generation
if self.args.use_vllm:
# First, have main process load weights if needed
@@ -627,6 +624,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
self._move_model_to_vllm()
self._last_loaded_step = self.state.global_step
all_prompts_text = gather_object(prompts_text)
if self.accelerator.is_main_process:
# Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate
# num_generations outputs for each one. This is faster than generating outputs for each duplicate
@@ -648,7 +646,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
guided_decoding_regex=self.guided_decoding_regex,
)
else:
completion_ids = [None] * len(all_prompts_text)
completion_ids = [None] * (
len(all_prompts_text) // self.args.sequence_parallel_degree
)
# Broadcast the completions from the main process to all processes
completion_ids = broadcast_object_list(completion_ids, from_process=0)
@@ -719,6 +719,75 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
) # we only need to compute the logits for the completion tokens
with torch.no_grad():
if self.args.sequence_parallel_degree > 1:
# Pad sequence to be divisible by SP degree if needed
total_seq_len = prompt_completion_ids.shape[1]
if total_seq_len % self.local_world_size != 0:
pad_len = self.local_world_size - (
total_seq_len % self.local_world_size
)
pad_token_id = self.processing_class.pad_token_id or 0
# Pad input_ids and attention_mask
padding = torch.full(
(prompt_completion_ids.shape[0], pad_len),
pad_token_id,
dtype=prompt_completion_ids.dtype,
device=prompt_completion_ids.device,
)
prompt_completion_ids = torch.cat(
[prompt_completion_ids, padding], dim=1
)
attn_padding = torch.zeros(
(attention_mask.shape[0], pad_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
total_seq_len += pad_len
logits_to_keep += pad_len
# Split the sequence
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Get our slice
prompt_completion_ids = prompt_completion_ids[:, start:end]
attention_mask = attention_mask[:, start:end]
# Calculate how many completion tokens each rank should process
prompt_len = prompt_ids.size(1)
completion_len = completion_ids.size(
1
) # This is equal to logits_to_keep
# Calculate where our slice starts and ends relative to the completion tokens
if start >= prompt_len:
# Slice starts within the completion section
start_in_completion = start - prompt_len
end_in_completion = min(end - prompt_len, completion_len)
logits_to_keep = end_in_completion - start_in_completion
completion_mask = completion_mask[
:, start_in_completion:end_in_completion
]
elif end <= prompt_len:
# Slice is entirely within the prompt section (no completion tokens)
logits_to_keep = 0
completion_mask = torch.zeros(
(completion_mask.size(0), 0), device=completion_mask.device
)
else:
# Slice contains the boundary between prompt and completion
start_in_completion = 0
end_in_completion = min(end - prompt_len, completion_len)
logits_to_keep = end_in_completion - start_in_completion
completion_mask = completion_mask[
:, start_in_completion:end_in_completion
]
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's
# computation here, and use per_token_logps.detach() instead.
if self.num_iterations > 1:
@@ -731,6 +800,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
if self.beta == 0.0:
ref_per_token_logps = None
elif self.ref_model is not None:
print(f"{dist.get_rank()}: prompt_completion_ids.shape: {prompt_completion_ids.shape}")
print(f"{dist.get_rank()}: attention_mask.shape: {attention_mask.shape}")
print(f"{dist.get_rank()}: logits_to_keep: {logits_to_keep}")
ref_per_token_logps = self._get_per_token_logps(
self.ref_model,
prompt_completion_ids,
@@ -943,132 +1016,323 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"advantages": advantages,
}
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
if self.args.sequence_parallel_degree > 1:
print(f"{self.local_rank}: input_ids.shape: {input_ids.shape}")
print(f"{self.local_rank}: input_ids[0, :20]: {input_ids[0, :20]}")
print(f"{self.local_rank}: input_ids[0, -20:]: {input_ids[0, -20:]}")
@profiling_decorator
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
# Pad sequence if needed
# Unpack inputs
prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"]
completion_ids, completion_mask = (
inputs["completion_ids"],
inputs["completion_mask"],
)
input_ids = torch.cat([prompt_ids, completion_ids], dim=1)
attention_mask = torch.cat([prompt_mask, completion_mask], dim=1)
logits_to_keep = completion_ids.size(1)
if self.args.sequence_parallel_degree > 1:
# Pad sequence to be divisible by SP degree if needed
total_seq_len = input_ids.shape[1]
remainder = total_seq_len % self.local_world_size
if remainder != 0:
to_pad = self.local_world_size - remainder
if total_seq_len % self.local_world_size != 0:
pad_len = self.local_world_size - (
total_seq_len % self.local_world_size
)
pad_token_id = self.processing_class.pad_token_id or 0
# Pad input_ids and attention_mask
padding = torch.full(
(input_ids.shape[0], to_pad),
(input_ids.shape[0], pad_len),
pad_token_id,
dtype=input_ids.dtype,
device=input_ids.device,
)
input_ids = torch.cat([input_ids, padding], dim=1)
# Also pad attention mask if it exists
if attention_mask is not None:
attn_padding = torch.zeros(
(attention_mask.shape[0], to_pad),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
attn_padding = torch.zeros(
(attention_mask.shape[0], pad_len),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
# Update total_seq_len after padding
total_seq_len += to_pad
total_seq_len += pad_len
logits_to_keep += pad_len
# Get local (start, end) for sequence parallelism slicing
# Split the sequence
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Slice data for sequence parallel processing
input_ids = input_ids[:, start:end]
attention_mask = attention_mask[:, start:end]
# Get our slice
input_ids_slice = input_ids[:, start:end]
attention_mask_slice = attention_mask[:, start:end]
# Calculate if this rank contains any tokens we need to keep
tokens_before_our_slice = self.local_rank * slice_size
print(f"{self.local_rank}: slice_size: {slice_size}")
print(
f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}"
)
if tokens_before_our_slice < logits_to_keep:
# How many tokens from our slice are needed
tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice
logits_to_keep = min(slice_size, tokens_needed_from_slice)
else:
# This rank doesn't contain any tokens we need to keep
logits_to_keep = 0
# Calculate how many completion tokens each rank should process
prompt_len = prompt_ids.size(1)
completion_len = completion_ids.size(1) # This is equal to logits_to_keep
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(
input_ids=input_ids,
attention_mask=attention_mask,
logits_to_keep=logits_to_keep + 1,
).logits
logits = logits[
:, :-1, :
] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
print(f"{self.local_rank}: logits.shape: {logits.shape}")
# First, let all ranks know the shape of each rank's tensor
local_shape = torch.tensor(
[logits.shape[0], logits.shape[1], logits.shape[2]],
device=logits.device,
)
all_shapes = [
torch.zeros_like(local_shape) for _ in range(self.local_world_size)
]
dist.all_gather(all_shapes, local_shape, group=self.sp_group)
# Use a list-based approach to collect logits of different sizes
if self.local_rank == 0:
# Root process allocates space for receiving
gathered_logits = []
for shape in all_shapes:
b, s, v = shape.tolist()
gathered_logits.append(
torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device)
)
else:
gathered_logits = None
# Gather to rank 0
dist.gather(logits, gathered_logits, dst=0, group=self.sp_group)
# On rank 0, concatenate and distribute the result
if self.local_rank == 0:
concatenated_logits = torch.cat(gathered_logits, dim=1)
# Trim to keep only what we need
if concatenated_logits.shape[1] > logits_to_keep:
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
else:
concatenated_logits = torch.zeros(
(logits.shape[0], logits_to_keep, logits.shape[2]),
dtype=logits.dtype,
device=logits.device,
# Calculate where our slice starts and ends relative to the completion tokens
if start >= prompt_len:
# Slice starts within the completion section
start_in_completion = start - prompt_len
end_in_completion = min(end - prompt_len, completion_len)
local_logits_to_keep = end_in_completion - start_in_completion
completion_mask = completion_mask[
:, start_in_completion:end_in_completion
]
elif end <= prompt_len:
# Slice is entirely within the prompt section (no completion tokens)
local_logits_to_keep = 0
completion_mask = torch.zeros(
(completion_mask.size(0), 0), device=completion_mask.device
)
else:
# Slice contains the boundary between prompt and completion
start_in_completion = 0
end_in_completion = min(end - prompt_len, completion_len)
local_logits_to_keep = end_in_completion - start_in_completion
completion_mask = completion_mask[
:, start_in_completion:end_in_completion
]
# Broadcast the result back to all ranks
dist.broadcast(concatenated_logits, src=0, group=self.sp_group)
logits = concatenated_logits
# Run model on our slice
if local_logits_to_keep > 0:
# Get logits with enough context to compute log probs
logits = model(
input_ids=input_ids_slice,
attention_mask=attention_mask_slice,
logits_to_keep=local_logits_to_keep + 1,
).logits
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
# Remove the last prediction token on the last rank
# if self.local_rank == self.local_world_size - 1:
# logits = logits[:, :-1, :]
logits = logits[:, :-1, :]
dist.barrier()
return selective_log_softmax(
logits, input_ids
) # compute logprobs for the input tokens
# Compute log probabilities on our local slice
local_input_ids = input_ids_slice[:, -local_logits_to_keep:]
logits = logits / self.temperature
per_token_logps = selective_log_softmax(logits, local_input_ids)
else:
# This rank doesn't have any tokens to keep
per_token_logps = torch.zeros(
(input_ids.shape[0], 0),
dtype=torch.float32,
device=input_ids.device,
)
else:
super()._get_per_token_logps(
per_token_logps = super()._get_per_token_logps(
model, input_ids, attention_mask, logits_to_keep
)
# Compute the KL divergence between the model and the reference model
if self.beta != 0.0:
ref_per_token_logps = inputs["ref_per_token_logps"]
per_token_kl = (
torch.exp(ref_per_token_logps - per_token_logps)
- (ref_per_token_logps - per_token_logps)
- 1
)
# Compute the loss
advantages = inputs["advantages"]
# When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation
# and use per_token_logps.detach() instead.
old_per_token_logps = (
inputs["old_per_token_logps"]
if self.num_iterations > 1
else per_token_logps.detach()
)
coef_1 = torch.exp(per_token_logps - old_per_token_logps)
coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high)
per_token_loss1 = coef_1 * advantages.unsqueeze(1)
per_token_loss2 = coef_2 * advantages.unsqueeze(1)
per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
if self.beta != 0.0:
per_token_loss = per_token_loss + self.beta * per_token_kl
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
if dist.get_rank() == 1:
import ipdb
ipdb.set_trace()
dist.barrier()
loss = (per_token_loss * completion_mask).sum() / completion_mask.sum()
# Log metrics
mode = "eval" if self.control.should_evaluate else "train"
if self.beta != 0.0:
mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["kl"].append(
self.accelerator.gather_for_metrics(mean_kl).mean().item()
)
is_clipped = (per_token_loss1 < per_token_loss2).float()
clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum()
self._metrics[mode]["clip_ratio"].append(
self.accelerator.gather_for_metrics(clip_ratio).mean().item()
)
return loss
# def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
# # if self.args.sequence_parallel_degree > 1:
# if False:
# # Handle padding to make sequence length divisible by world size
# total_seq_len = input_ids.shape[1]
# if total_seq_len % self.local_world_size != 0:
# # Pad to make divisible
# pad_len = self.local_world_size - (
# total_seq_len % self.local_world_size
# )
# pad_token_id = self.processing_class.pad_token_id or 0
# # Pad input_ids
# padding = torch.full(
# (input_ids.shape[0], pad_len),
# pad_token_id,
# dtype=input_ids.dtype,
# device=input_ids.device,
# )
# input_ids = torch.cat([input_ids, padding], dim=1)
# # Pad attention mask
# if attention_mask is not None:
# attn_padding = torch.zeros(
# (attention_mask.shape[0], pad_len),
# dtype=attention_mask.dtype,
# device=attention_mask.device,
# )
# attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
# total_seq_len += pad_len
# logits_to_keep += pad_len
# # Share logits_to_keep across ranks to ensure consistency
# lt_keep = torch.tensor([logits_to_keep], device=input_ids.device)
# dist.broadcast(lt_keep, src=0, group=self.sp_group)
# logits_to_keep = lt_keep.item()
# # Split the sequence across ranks
# slice_size = total_seq_len // self.local_world_size
# start = self.local_rank * slice_size
# end = start + slice_size
# # Slice for this rank
# input_ids_slice = input_ids[:, start:end]
# attention_mask_slice = (
# attention_mask[:, start:end] if attention_mask is not None else None
# )
# # Calculate how many tokens this rank needs to keep
# tokens_before_slice = self.local_rank * slice_size
# local_logits_to_keep = 0
# if tokens_before_slice < logits_to_keep:
# # This rank has tokens we need to keep
# local_logits_to_keep = min(
# slice_size, logits_to_keep - tokens_before_slice
# )
# # Run the model on our slice
# if local_logits_to_keep > 0:
# logits = model(
# input_ids=input_ids_slice,
# attention_mask=attention_mask_slice,
# logits_to_keep=local_logits_to_keep + 1,
# ).logits
# if self.local_rank == self.local_world_size - 1:
# logits = logits[:, :-1, :]
# # Get the relevant input_ids for computing log probs
# # Ensure this is the correct slice that corresponds to the logits
# relevant_input_ids = input_ids_slice[:, -local_logits_to_keep:]
# else:
# # Create empty logits with correct shape if we don't need to keep any
# vocab_size = model.config.vocab_size
# logits = torch.zeros(
# (input_ids.shape[0], 0, vocab_size),
# dtype=torch.float32,
# device=input_ids.device,
# )
# relevant_input_ids = torch.zeros(
# (input_ids.shape[0], 0),
# dtype=torch.float32,
# device=input_ids.device,
# )
# # Temperature scaling
# logits = logits / self.temperature
# print(f"{dist.get_rank()}: logits.shape: {logits.shape}")
# print(
# f"{dist.get_rank()}: relevant_input_ids.shape: {relevant_input_ids.shape}"
# )
# return selective_log_softmax(logits, relevant_input_ids)
# # All-gather results across SP group with proper shape handling
# # local_shape = torch.tensor([logits.shape[1]], device=logits.device)
# # all_shapes = [
# # torch.zeros_like(local_shape) for _ in range(self.local_world_size)
# # ]
# # dist.all_gather(all_shapes, local_shape, group=self.sp_group)
# # # Create full tensor to hold the complete result
# # full_logits = torch.zeros(
# # (input_ids.shape[0], logits_to_keep, model.config.vocab_size),
# # dtype=torch.float32,
# # device=input_ids.device,
# # )
# # # Calculate positions for each rank's contribution
# # position = 0
# # for i in range(self.local_world_size):
# # shape = all_shapes[i].item()
# # if i < self.local_rank:
# # position += shape
# # # Add our contribution to the full tensor
# # if local_logits_to_keep > 0:
# # # Make sure we're not exceeding bounds
# # end_pos = min(position + local_logits_to_keep, logits_to_keep)
# # copy_size = end_pos - position
# # if dist.get_rank() == 0:
# # import ipdb; ipdb.set_trace()
# # dist.barrier()
# # if dist.get_rank() == 1:
# # import ipdb; ipdb.set_trace()
# # dist.barrier()
# # if copy_size > 0:
# # full_logits[:, position:end_pos, :] = logits[:, :copy_size, :]
# # # Combine results via all-reduce
# # dist.all_reduce(full_logits, op=dist.ReduceOp.SUM, group=self.sp_group)
# # # Remove the last prediction token
# # full_logits = full_logits[:, :-1, :]
# # # Get the relevant input_ids for computing log probs
# # # Ensure this is the correct slice that corresponds to the logits
# # relevant_input_ids = input_ids[:, -logits_to_keep:]
# # # Temperature scaling
# # full_logits = full_logits / self.temperature
# # return selective_log_softmax(full_logits, relevant_input_ids)
# else:
# return super()._get_per_token_logps(
# model, input_ids, attention_mask, logits_to_keep
# )