From 7f4e4076e1a10091b7c2ba7022e44d0fca8bcc8f Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 18 Apr 2025 21:35:33 +0000 Subject: [PATCH] progress --- src/axolotl/core/trainers/grpo/trainer.py | 480 +++++++++++++++++----- 1 file changed, 372 insertions(+), 108 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index d672770e9..48e28c22f 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -617,9 +617,6 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): prompt_ids = prompt_ids[:, -self.max_prompt_length :] prompt_mask = prompt_mask[:, -self.max_prompt_length :] - # Generate completions using vLLM: gather all prompts and use them in a single call in the main process - all_prompts_text = gather_object(prompts_text) - # Generate completions using either vLLM or regular generation if self.args.use_vllm: # First, have main process load weights if needed @@ -627,6 +624,7 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): self._move_model_to_vllm() self._last_loaded_step = self.state.global_step + all_prompts_text = gather_object(prompts_text) if self.accelerator.is_main_process: # Since 'prompts' contains 'num_generations' duplicates, we first take unique prompts, and generate # num_generations outputs for each one. This is faster than generating outputs for each duplicate @@ -648,7 +646,9 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): guided_decoding_regex=self.guided_decoding_regex, ) else: - completion_ids = [None] * len(all_prompts_text) + completion_ids = [None] * ( + len(all_prompts_text) // self.args.sequence_parallel_degree + ) # Broadcast the completions from the main process to all processes completion_ids = broadcast_object_list(completion_ids, from_process=0) @@ -719,6 +719,75 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): ) # we only need to compute the logits for the completion tokens with torch.no_grad(): + if self.args.sequence_parallel_degree > 1: + # Pad sequence to be divisible by SP degree if needed + total_seq_len = prompt_completion_ids.shape[1] + if total_seq_len % self.local_world_size != 0: + pad_len = self.local_world_size - ( + total_seq_len % self.local_world_size + ) + pad_token_id = self.processing_class.pad_token_id or 0 + + # Pad input_ids and attention_mask + padding = torch.full( + (prompt_completion_ids.shape[0], pad_len), + pad_token_id, + dtype=prompt_completion_ids.dtype, + device=prompt_completion_ids.device, + ) + prompt_completion_ids = torch.cat( + [prompt_completion_ids, padding], dim=1 + ) + + attn_padding = torch.zeros( + (attention_mask.shape[0], pad_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, attn_padding], dim=1) + + total_seq_len += pad_len + logits_to_keep += pad_len + + # Split the sequence + slice_size = total_seq_len // self.local_world_size + start = self.local_rank * slice_size + end = start + slice_size + + # Get our slice + prompt_completion_ids = prompt_completion_ids[:, start:end] + attention_mask = attention_mask[:, start:end] + + # Calculate how many completion tokens each rank should process + prompt_len = prompt_ids.size(1) + completion_len = completion_ids.size( + 1 + ) # This is equal to logits_to_keep + + # Calculate where our slice starts and ends relative to the completion tokens + if start >= prompt_len: + # Slice starts within the completion section + start_in_completion = start - prompt_len + end_in_completion = min(end - prompt_len, completion_len) + logits_to_keep = end_in_completion - start_in_completion + completion_mask = completion_mask[ + :, start_in_completion:end_in_completion + ] + elif end <= prompt_len: + # Slice is entirely within the prompt section (no completion tokens) + logits_to_keep = 0 + completion_mask = torch.zeros( + (completion_mask.size(0), 0), device=completion_mask.device + ) + else: + # Slice contains the boundary between prompt and completion + start_in_completion = 0 + end_in_completion = min(end - prompt_len, completion_len) + logits_to_keep = end_in_completion - start_in_completion + completion_mask = completion_mask[ + :, start_in_completion:end_in_completion + ] + # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip it's # computation here, and use per_token_logps.detach() instead. if self.num_iterations > 1: @@ -731,6 +800,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): if self.beta == 0.0: ref_per_token_logps = None elif self.ref_model is not None: + print(f"{dist.get_rank()}: prompt_completion_ids.shape: {prompt_completion_ids.shape}") + print(f"{dist.get_rank()}: attention_mask.shape: {attention_mask.shape}") + print(f"{dist.get_rank()}: logits_to_keep: {logits_to_keep}") + ref_per_token_logps = self._get_per_token_logps( self.ref_model, prompt_completion_ids, @@ -943,132 +1016,323 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): "advantages": advantages, } - # Get the per-token log probabilities for the completions for the model and the reference model - def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): - if self.args.sequence_parallel_degree > 1: - print(f"{self.local_rank}: input_ids.shape: {input_ids.shape}") - print(f"{self.local_rank}: input_ids[0, :20]: {input_ids[0, :20]}") - print(f"{self.local_rank}: input_ids[0, -20:]: {input_ids[0, -20:]}") + @profiling_decorator + def compute_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): + if return_outputs: + raise ValueError("The GRPOTrainer does not support returning outputs") - # Pad sequence if needed + # Unpack inputs + prompt_ids, prompt_mask = inputs["prompt_ids"], inputs["prompt_mask"] + completion_ids, completion_mask = ( + inputs["completion_ids"], + inputs["completion_mask"], + ) + input_ids = torch.cat([prompt_ids, completion_ids], dim=1) + attention_mask = torch.cat([prompt_mask, completion_mask], dim=1) + logits_to_keep = completion_ids.size(1) + + if self.args.sequence_parallel_degree > 1: + # Pad sequence to be divisible by SP degree if needed total_seq_len = input_ids.shape[1] - remainder = total_seq_len % self.local_world_size - if remainder != 0: - to_pad = self.local_world_size - remainder + if total_seq_len % self.local_world_size != 0: + pad_len = self.local_world_size - ( + total_seq_len % self.local_world_size + ) pad_token_id = self.processing_class.pad_token_id or 0 + + # Pad input_ids and attention_mask padding = torch.full( - (input_ids.shape[0], to_pad), + (input_ids.shape[0], pad_len), pad_token_id, dtype=input_ids.dtype, device=input_ids.device, ) input_ids = torch.cat([input_ids, padding], dim=1) - # Also pad attention mask if it exists - if attention_mask is not None: - attn_padding = torch.zeros( - (attention_mask.shape[0], to_pad), - dtype=attention_mask.dtype, - device=attention_mask.device, - ) - attention_mask = torch.cat([attention_mask, attn_padding], dim=1) + attn_padding = torch.zeros( + (attention_mask.shape[0], pad_len), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, attn_padding], dim=1) - # Update total_seq_len after padding - total_seq_len += to_pad + total_seq_len += pad_len + logits_to_keep += pad_len - # Get local (start, end) for sequence parallelism slicing + # Split the sequence slice_size = total_seq_len // self.local_world_size start = self.local_rank * slice_size end = start + slice_size - # Slice data for sequence parallel processing - input_ids = input_ids[:, start:end] - attention_mask = attention_mask[:, start:end] + # Get our slice + input_ids_slice = input_ids[:, start:end] + attention_mask_slice = attention_mask[:, start:end] - # Calculate if this rank contains any tokens we need to keep - tokens_before_our_slice = self.local_rank * slice_size - print(f"{self.local_rank}: slice_size: {slice_size}") - print( - f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}" - ) - if tokens_before_our_slice < logits_to_keep: - # How many tokens from our slice are needed - tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice - logits_to_keep = min(slice_size, tokens_needed_from_slice) - else: - # This rank doesn't contain any tokens we need to keep - logits_to_keep = 0 + # Calculate how many completion tokens each rank should process + prompt_len = prompt_ids.size(1) + completion_len = completion_ids.size(1) # This is equal to logits_to_keep - print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}") - - # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded - logits = model( - input_ids=input_ids, - attention_mask=attention_mask, - logits_to_keep=logits_to_keep + 1, - ).logits - logits = logits[ - :, :-1, : - ] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred - - print(f"{self.local_rank}: logits.shape: {logits.shape}") - - # First, let all ranks know the shape of each rank's tensor - local_shape = torch.tensor( - [logits.shape[0], logits.shape[1], logits.shape[2]], - device=logits.device, - ) - all_shapes = [ - torch.zeros_like(local_shape) for _ in range(self.local_world_size) - ] - dist.all_gather(all_shapes, local_shape, group=self.sp_group) - - # Use a list-based approach to collect logits of different sizes - if self.local_rank == 0: - # Root process allocates space for receiving - gathered_logits = [] - for shape in all_shapes: - b, s, v = shape.tolist() - gathered_logits.append( - torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device) - ) - else: - gathered_logits = None - - # Gather to rank 0 - dist.gather(logits, gathered_logits, dst=0, group=self.sp_group) - - # On rank 0, concatenate and distribute the result - if self.local_rank == 0: - concatenated_logits = torch.cat(gathered_logits, dim=1) - # Trim to keep only what we need - if concatenated_logits.shape[1] > logits_to_keep: - concatenated_logits = concatenated_logits[:, -logits_to_keep:, :] - else: - concatenated_logits = torch.zeros( - (logits.shape[0], logits_to_keep, logits.shape[2]), - dtype=logits.dtype, - device=logits.device, + # Calculate where our slice starts and ends relative to the completion tokens + if start >= prompt_len: + # Slice starts within the completion section + start_in_completion = start - prompt_len + end_in_completion = min(end - prompt_len, completion_len) + local_logits_to_keep = end_in_completion - start_in_completion + completion_mask = completion_mask[ + :, start_in_completion:end_in_completion + ] + elif end <= prompt_len: + # Slice is entirely within the prompt section (no completion tokens) + local_logits_to_keep = 0 + completion_mask = torch.zeros( + (completion_mask.size(0), 0), device=completion_mask.device ) + else: + # Slice contains the boundary between prompt and completion + start_in_completion = 0 + end_in_completion = min(end - prompt_len, completion_len) + local_logits_to_keep = end_in_completion - start_in_completion + completion_mask = completion_mask[ + :, start_in_completion:end_in_completion + ] - # Broadcast the result back to all ranks - dist.broadcast(concatenated_logits, src=0, group=self.sp_group) - logits = concatenated_logits + # Run model on our slice + if local_logits_to_keep > 0: + # Get logits with enough context to compute log probs + logits = model( + input_ids=input_ids_slice, + attention_mask=attention_mask_slice, + logits_to_keep=local_logits_to_keep + 1, + ).logits - input_ids = input_ids[:, -logits_to_keep:] - # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. - # See https://github.com/huggingface/trl/issues/2770 - logits = logits[:, -logits_to_keep:] - # Divide logits by sampling temperature. - # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details - logits = logits / self.temperature + # Remove the last prediction token on the last rank + # if self.local_rank == self.local_world_size - 1: + # logits = logits[:, :-1, :] + logits = logits[:, :-1, :] - dist.barrier() - - return selective_log_softmax( - logits, input_ids - ) # compute logprobs for the input tokens + # Compute log probabilities on our local slice + local_input_ids = input_ids_slice[:, -local_logits_to_keep:] + logits = logits / self.temperature + per_token_logps = selective_log_softmax(logits, local_input_ids) + else: + # This rank doesn't have any tokens to keep + per_token_logps = torch.zeros( + (input_ids.shape[0], 0), + dtype=torch.float32, + device=input_ids.device, + ) else: - super()._get_per_token_logps( + per_token_logps = super()._get_per_token_logps( model, input_ids, attention_mask, logits_to_keep ) + + # Compute the KL divergence between the model and the reference model + if self.beta != 0.0: + ref_per_token_logps = inputs["ref_per_token_logps"] + per_token_kl = ( + torch.exp(ref_per_token_logps - per_token_logps) + - (ref_per_token_logps - per_token_logps) + - 1 + ) + + # Compute the loss + advantages = inputs["advantages"] + # When using num_iterations == 1, old_per_token_logps == per_token_logps, so we can skip its computation + # and use per_token_logps.detach() instead. + old_per_token_logps = ( + inputs["old_per_token_logps"] + if self.num_iterations > 1 + else per_token_logps.detach() + ) + coef_1 = torch.exp(per_token_logps - old_per_token_logps) + coef_2 = torch.clamp(coef_1, 1 - self.epsilon_low, 1 + self.epsilon_high) + per_token_loss1 = coef_1 * advantages.unsqueeze(1) + per_token_loss2 = coef_2 * advantages.unsqueeze(1) + per_token_loss = -torch.min(per_token_loss1, per_token_loss2) + + if self.beta != 0.0: + per_token_loss = per_token_loss + self.beta * per_token_kl + + if dist.get_rank() == 0: + import ipdb + + ipdb.set_trace() + dist.barrier() + if dist.get_rank() == 1: + import ipdb + + ipdb.set_trace() + dist.barrier() + + loss = (per_token_loss * completion_mask).sum() / completion_mask.sum() + + # Log metrics + mode = "eval" if self.control.should_evaluate else "train" + + if self.beta != 0.0: + mean_kl = (per_token_kl * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]["kl"].append( + self.accelerator.gather_for_metrics(mean_kl).mean().item() + ) + + is_clipped = (per_token_loss1 < per_token_loss2).float() + clip_ratio = (is_clipped * completion_mask).sum() / completion_mask.sum() + self._metrics[mode]["clip_ratio"].append( + self.accelerator.gather_for_metrics(clip_ratio).mean().item() + ) + + return loss + + # def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + # # if self.args.sequence_parallel_degree > 1: + # if False: + # # Handle padding to make sequence length divisible by world size + # total_seq_len = input_ids.shape[1] + # if total_seq_len % self.local_world_size != 0: + # # Pad to make divisible + # pad_len = self.local_world_size - ( + # total_seq_len % self.local_world_size + # ) + # pad_token_id = self.processing_class.pad_token_id or 0 + + # # Pad input_ids + # padding = torch.full( + # (input_ids.shape[0], pad_len), + # pad_token_id, + # dtype=input_ids.dtype, + # device=input_ids.device, + # ) + # input_ids = torch.cat([input_ids, padding], dim=1) + + # # Pad attention mask + # if attention_mask is not None: + # attn_padding = torch.zeros( + # (attention_mask.shape[0], pad_len), + # dtype=attention_mask.dtype, + # device=attention_mask.device, + # ) + # attention_mask = torch.cat([attention_mask, attn_padding], dim=1) + + # total_seq_len += pad_len + # logits_to_keep += pad_len + + # # Share logits_to_keep across ranks to ensure consistency + # lt_keep = torch.tensor([logits_to_keep], device=input_ids.device) + # dist.broadcast(lt_keep, src=0, group=self.sp_group) + # logits_to_keep = lt_keep.item() + + # # Split the sequence across ranks + # slice_size = total_seq_len // self.local_world_size + # start = self.local_rank * slice_size + # end = start + slice_size + + # # Slice for this rank + # input_ids_slice = input_ids[:, start:end] + # attention_mask_slice = ( + # attention_mask[:, start:end] if attention_mask is not None else None + # ) + + # # Calculate how many tokens this rank needs to keep + # tokens_before_slice = self.local_rank * slice_size + # local_logits_to_keep = 0 + + # if tokens_before_slice < logits_to_keep: + # # This rank has tokens we need to keep + # local_logits_to_keep = min( + # slice_size, logits_to_keep - tokens_before_slice + # ) + + # # Run the model on our slice + # if local_logits_to_keep > 0: + # logits = model( + # input_ids=input_ids_slice, + # attention_mask=attention_mask_slice, + # logits_to_keep=local_logits_to_keep + 1, + # ).logits + # if self.local_rank == self.local_world_size - 1: + # logits = logits[:, :-1, :] + + # # Get the relevant input_ids for computing log probs + # # Ensure this is the correct slice that corresponds to the logits + # relevant_input_ids = input_ids_slice[:, -local_logits_to_keep:] + # else: + # # Create empty logits with correct shape if we don't need to keep any + # vocab_size = model.config.vocab_size + # logits = torch.zeros( + # (input_ids.shape[0], 0, vocab_size), + # dtype=torch.float32, + # device=input_ids.device, + # ) + # relevant_input_ids = torch.zeros( + # (input_ids.shape[0], 0), + # dtype=torch.float32, + # device=input_ids.device, + # ) + + # # Temperature scaling + # logits = logits / self.temperature + + # print(f"{dist.get_rank()}: logits.shape: {logits.shape}") + # print( + # f"{dist.get_rank()}: relevant_input_ids.shape: {relevant_input_ids.shape}" + # ) + + # return selective_log_softmax(logits, relevant_input_ids) + + # # All-gather results across SP group with proper shape handling + # # local_shape = torch.tensor([logits.shape[1]], device=logits.device) + # # all_shapes = [ + # # torch.zeros_like(local_shape) for _ in range(self.local_world_size) + # # ] + # # dist.all_gather(all_shapes, local_shape, group=self.sp_group) + + # # # Create full tensor to hold the complete result + # # full_logits = torch.zeros( + # # (input_ids.shape[0], logits_to_keep, model.config.vocab_size), + # # dtype=torch.float32, + # # device=input_ids.device, + # # ) + + # # # Calculate positions for each rank's contribution + # # position = 0 + # # for i in range(self.local_world_size): + # # shape = all_shapes[i].item() + # # if i < self.local_rank: + # # position += shape + + # # # Add our contribution to the full tensor + # # if local_logits_to_keep > 0: + # # # Make sure we're not exceeding bounds + # # end_pos = min(position + local_logits_to_keep, logits_to_keep) + # # copy_size = end_pos - position + + # # if dist.get_rank() == 0: + # # import ipdb; ipdb.set_trace() + # # dist.barrier() + # # if dist.get_rank() == 1: + # # import ipdb; ipdb.set_trace() + # # dist.barrier() + + # # if copy_size > 0: + # # full_logits[:, position:end_pos, :] = logits[:, :copy_size, :] + + # # # Combine results via all-reduce + # # dist.all_reduce(full_logits, op=dist.ReduceOp.SUM, group=self.sp_group) + + # # # Remove the last prediction token + # # full_logits = full_logits[:, :-1, :] + + # # # Get the relevant input_ids for computing log probs + # # # Ensure this is the correct slice that corresponds to the logits + # # relevant_input_ids = input_ids[:, -logits_to_keep:] + + # # # Temperature scaling + # # full_logits = full_logits / self.temperature + + # # return selective_log_softmax(full_logits, relevant_input_ids) + # else: + # return super()._get_per_token_logps( + # model, input_ids, attention_mask, logits_to_keep + # )