From 76e2d2e60bbc0d0af640825b04a84e278d2347aa Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Mon, 14 Apr 2025 21:02:30 +0000 Subject: [PATCH] progress --- src/axolotl/core/trainers/grpo/trainer.py | 94 ++++++++++++++++--- .../core/trainers/mixins/sequence_parallel.py | 39 ++++++++ 2 files changed, 120 insertions(+), 13 deletions(-) diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index c7b6715d2..ea15088a4 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -7,11 +7,11 @@ import torch.distributed as dist from accelerate.utils import is_deepspeed_available, is_peft_model from trl import GRPOTrainer from trl.extras.profiling import profiling_decorator +from trl.trainer.utils import selective_log_softmax from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.monkeypatch.attention.ring_attn import ( get_ring_attn_group, - update_ring_attn_params, ) if is_deepspeed_available(): @@ -72,6 +72,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): # Get the per-token log probabilities for the completions for the model and the reference model def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + if dist.get_rank() == 0: + import ipdb; ipdb.set_trace() + dist.barrier() + + if dist.get_rank() == 1: + import ipdb; ipdb.set_trace() + dist.barrier() + if self.args.sequence_parallel_degree > 1: sp_group = get_ring_attn_group() self.local_rank = dist.get_rank(group=sp_group) @@ -81,17 +89,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): total_seq_len = input_ids.shape[1] remainder = total_seq_len % self.local_world_size if remainder != 0: - padding = self.local_world_size - remainder - - if dist.get_rank() == 0: - import ipdb - - ipdb.set_trace() - dist.barrier() - + to_pad = self.local_world_size - remainder pad_token_id = self.processing_class.pad_token_id or 0 padding = torch.full( - (input_ids.shape[0], padding), + (input_ids.shape[0], to_pad), pad_token_id, dtype=input_ids.dtype, device=input_ids.device, @@ -101,14 +102,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): # Also pad attention mask if it exists if attention_mask is not None: attn_padding = torch.zeros( - (attention_mask.shape[0], padding), + (attention_mask.shape[0], to_pad), dtype=attention_mask.dtype, device=attention_mask.device, ) attention_mask = torch.cat([attention_mask, attn_padding], dim=1) # Update total_seq_len after padding - total_seq_len += padding + total_seq_len += to_pad # Get local (start, end) for sequence parallelism slicing slice_size = total_seq_len // self.local_world_size @@ -119,4 +120,71 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): input_ids = input_ids[:, start:end] attention_mask = attention_mask[:, start:end] - super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) + # Calculate if this rank contains any tokens we need to keep + tokens_before_our_slice = self.local_rank * slice_size + print(f"{self.local_rank}: slice_size: {slice_size}") + print(f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}") + if tokens_before_our_slice < logits_to_keep: + # How many tokens from our slice are needed + tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice + logits_to_keep = min(slice_size, tokens_needed_from_slice) + else: + # This rank doesn't contain any tokens we need to keep + logits_to_keep = 0 + + print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}") + + # We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded + logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits + logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred + + print(f"{self.local_rank}: logits.shape: {logits.shape}") + + # First, let all ranks know the shape of each rank's tensor + local_shape = torch.tensor([logits.shape[0], logits.shape[1], logits.shape[2]], device=logits.device) + all_shapes = [torch.zeros_like(local_shape) for _ in range(self.local_world_size)] + dist.all_gather(all_shapes, local_shape, group=sp_group) + + # Use a list-based approach to collect logits of different sizes + if self.local_rank == 0: + # Root process allocates space for receiving + gathered_logits = [] + for shape in all_shapes: + b, s, v = shape.tolist() + gathered_logits.append(torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device)) + else: + gathered_logits = None + + # Gather to rank 0 + dist.gather(logits, gathered_logits, dst=0, group=sp_group) + + # On rank 0, concatenate and distribute the result + if self.local_rank == 0: + concatenated_logits = torch.cat(gathered_logits, dim=1) + # Trim to keep only what we need + if concatenated_logits.shape[1] > logits_to_keep: + concatenated_logits = concatenated_logits[:, -logits_to_keep:, :] + else: + concatenated_logits = torch.zeros( + (logits.shape[0], logits_to_keep, logits.shape[2]), + dtype=logits.dtype, + device=logits.device + ) + + # Broadcast the result back to all ranks + dist.broadcast(concatenated_logits, src=0, group=sp_group) + logits = concatenated_logits + + input_ids = input_ids[:, -logits_to_keep:] + # For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves. + # See https://github.com/huggingface/trl/issues/2770 + logits = logits[:, -logits_to_keep:] + # Divide logits by sampling temperature. + # See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details + logits = logits / self.temperature + + dist.barrier() + + return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens + + # super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) diff --git a/src/axolotl/core/trainers/mixins/sequence_parallel.py b/src/axolotl/core/trainers/mixins/sequence_parallel.py index 362acb88e..87e385b68 100644 --- a/src/axolotl/core/trainers/mixins/sequence_parallel.py +++ b/src/axolotl/core/trainers/mixins/sequence_parallel.py @@ -4,6 +4,7 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag import functools import logging +from contextlib import contextmanager import torch import torch.distributed as dist @@ -159,6 +160,7 @@ class SequenceParallelMixin: ) +<<<<<<< HEAD class SequenceParallelContextManager: """ Context manager for sequence parallelism operations. @@ -311,3 +313,40 @@ class SequenceParallelContextManager: result[:, pos] = gathered_tensor[:, i] return result +======= +class SequenceParallelismManager: + def __init__(self, local_rank, local_world_size): + self.local_rank = local_rank + self.local_world_size = local_world_size + + @contextmanager + def apply(self, batch): + """ + Context manager that applies sequence parallelism slicing to a batch, + and restores the original batch afterward if needed. + + Args: + batch: Batch dictionary from parent collator. + + Yields: + Sliced batch dictionary for use in the model. + """ + # Get local (start, end) for sequence parallelism slicing + total_seq_len = batch["input_ids"].size(1) + slice_size = total_seq_len // self.local_world_size + start = self.local_rank * slice_size + end = start + slice_size + + # Update params for varlen ring attention calculation + if batch.get("position_ids") is not None: + update_ring_attn_params( + input_ids=batch["input_ids"], position_ids=batch["position_ids"] + ) + + # Slice batch for sequence parallel processing + for key in batch: + if isinstance(batch[key], torch.Tensor) and batch[key].size(1) == total_seq_len: + batch[key] = batch[key][:, start:end] + + yield batch +>>>>>>> c0054f07 (progress)