This commit is contained in:
Dan Saunders
2025-04-14 21:02:30 +00:00
parent 11b6803ff4
commit 76e2d2e60b
2 changed files with 120 additions and 13 deletions

View File

@@ -7,11 +7,11 @@ import torch.distributed as dist
from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from trl.trainer.utils import selective_log_softmax
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
update_ring_attn_params,
)
if is_deepspeed_available():
@@ -72,6 +72,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
if dist.get_rank() == 0:
import ipdb; ipdb.set_trace()
dist.barrier()
if dist.get_rank() == 1:
import ipdb; ipdb.set_trace()
dist.barrier()
if self.args.sequence_parallel_degree > 1:
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
@@ -81,17 +89,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
total_seq_len = input_ids.shape[1]
remainder = total_seq_len % self.local_world_size
if remainder != 0:
padding = self.local_world_size - remainder
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
to_pad = self.local_world_size - remainder
pad_token_id = self.processing_class.pad_token_id or 0
padding = torch.full(
(input_ids.shape[0], padding),
(input_ids.shape[0], to_pad),
pad_token_id,
dtype=input_ids.dtype,
device=input_ids.device,
@@ -101,14 +102,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Also pad attention mask if it exists
if attention_mask is not None:
attn_padding = torch.zeros(
(attention_mask.shape[0], padding),
(attention_mask.shape[0], to_pad),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
# Update total_seq_len after padding
total_seq_len += padding
total_seq_len += to_pad
# Get local (start, end) for sequence parallelism slicing
slice_size = total_seq_len // self.local_world_size
@@ -119,4 +120,71 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
input_ids = input_ids[:, start:end]
attention_mask = attention_mask[:, start:end]
super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
# Calculate if this rank contains any tokens we need to keep
tokens_before_our_slice = self.local_rank * slice_size
print(f"{self.local_rank}: slice_size: {slice_size}")
print(f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}")
if tokens_before_our_slice < logits_to_keep:
# How many tokens from our slice are needed
tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice
logits_to_keep = min(slice_size, tokens_needed_from_slice)
else:
# This rank doesn't contain any tokens we need to keep
logits_to_keep = 0
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
print(f"{self.local_rank}: logits.shape: {logits.shape}")
# First, let all ranks know the shape of each rank's tensor
local_shape = torch.tensor([logits.shape[0], logits.shape[1], logits.shape[2]], device=logits.device)
all_shapes = [torch.zeros_like(local_shape) for _ in range(self.local_world_size)]
dist.all_gather(all_shapes, local_shape, group=sp_group)
# Use a list-based approach to collect logits of different sizes
if self.local_rank == 0:
# Root process allocates space for receiving
gathered_logits = []
for shape in all_shapes:
b, s, v = shape.tolist()
gathered_logits.append(torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device))
else:
gathered_logits = None
# Gather to rank 0
dist.gather(logits, gathered_logits, dst=0, group=sp_group)
# On rank 0, concatenate and distribute the result
if self.local_rank == 0:
concatenated_logits = torch.cat(gathered_logits, dim=1)
# Trim to keep only what we need
if concatenated_logits.shape[1] > logits_to_keep:
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
else:
concatenated_logits = torch.zeros(
(logits.shape[0], logits_to_keep, logits.shape[2]),
dtype=logits.dtype,
device=logits.device
)
# Broadcast the result back to all ranks
dist.broadcast(concatenated_logits, src=0, group=sp_group)
logits = concatenated_logits
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
dist.barrier()
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
# super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

View File

@@ -4,6 +4,7 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag
import functools
import logging
from contextlib import contextmanager
import torch
import torch.distributed as dist
@@ -159,6 +160,7 @@ class SequenceParallelMixin:
)
<<<<<<< HEAD
class SequenceParallelContextManager:
"""
Context manager for sequence parallelism operations.
@@ -311,3 +313,40 @@ class SequenceParallelContextManager:
result[:, pos] = gathered_tensor[:, i]
return result
=======
class SequenceParallelismManager:
def __init__(self, local_rank, local_world_size):
self.local_rank = local_rank
self.local_world_size = local_world_size
@contextmanager
def apply(self, batch):
"""
Context manager that applies sequence parallelism slicing to a batch,
and restores the original batch afterward if needed.
Args:
batch: Batch dictionary from parent collator.
Yields:
Sliced batch dictionary for use in the model.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].size(1)
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Update params for varlen ring attention calculation
if batch.get("position_ids") is not None:
update_ring_attn_params(
input_ids=batch["input_ids"], position_ids=batch["position_ids"]
)
# Slice batch for sequence parallel processing
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].size(1) == total_seq_len:
batch[key] = batch[key][:, start:end]
yield batch
>>>>>>> c0054f07 (progress)