This commit is contained in:
Dan Saunders
2025-04-14 21:02:30 +00:00
parent 11b6803ff4
commit 76e2d2e60b
2 changed files with 120 additions and 13 deletions

View File

@@ -7,11 +7,11 @@ import torch.distributed as dist
from accelerate.utils import is_deepspeed_available, is_peft_model from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator from trl.extras.profiling import profiling_decorator
from trl.trainer.utils import selective_log_softmax
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn import ( from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group, get_ring_attn_group,
update_ring_attn_params,
) )
if is_deepspeed_available(): if is_deepspeed_available():
@@ -72,6 +72,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Get the per-token log probabilities for the completions for the model and the reference model # Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
if dist.get_rank() == 0:
import ipdb; ipdb.set_trace()
dist.barrier()
if dist.get_rank() == 1:
import ipdb; ipdb.set_trace()
dist.barrier()
if self.args.sequence_parallel_degree > 1: if self.args.sequence_parallel_degree > 1:
sp_group = get_ring_attn_group() sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group) self.local_rank = dist.get_rank(group=sp_group)
@@ -81,17 +89,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
total_seq_len = input_ids.shape[1] total_seq_len = input_ids.shape[1]
remainder = total_seq_len % self.local_world_size remainder = total_seq_len % self.local_world_size
if remainder != 0: if remainder != 0:
padding = self.local_world_size - remainder to_pad = self.local_world_size - remainder
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
pad_token_id = self.processing_class.pad_token_id or 0 pad_token_id = self.processing_class.pad_token_id or 0
padding = torch.full( padding = torch.full(
(input_ids.shape[0], padding), (input_ids.shape[0], to_pad),
pad_token_id, pad_token_id,
dtype=input_ids.dtype, dtype=input_ids.dtype,
device=input_ids.device, device=input_ids.device,
@@ -101,14 +102,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Also pad attention mask if it exists # Also pad attention mask if it exists
if attention_mask is not None: if attention_mask is not None:
attn_padding = torch.zeros( attn_padding = torch.zeros(
(attention_mask.shape[0], padding), (attention_mask.shape[0], to_pad),
dtype=attention_mask.dtype, dtype=attention_mask.dtype,
device=attention_mask.device, device=attention_mask.device,
) )
attention_mask = torch.cat([attention_mask, attn_padding], dim=1) attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
# Update total_seq_len after padding # Update total_seq_len after padding
total_seq_len += padding total_seq_len += to_pad
# Get local (start, end) for sequence parallelism slicing # Get local (start, end) for sequence parallelism slicing
slice_size = total_seq_len // self.local_world_size slice_size = total_seq_len // self.local_world_size
@@ -119,4 +120,71 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
input_ids = input_ids[:, start:end] input_ids = input_ids[:, start:end]
attention_mask = attention_mask[:, start:end] attention_mask = attention_mask[:, start:end]
super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep) # Calculate if this rank contains any tokens we need to keep
tokens_before_our_slice = self.local_rank * slice_size
print(f"{self.local_rank}: slice_size: {slice_size}")
print(f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}")
if tokens_before_our_slice < logits_to_keep:
# How many tokens from our slice are needed
tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice
logits_to_keep = min(slice_size, tokens_needed_from_slice)
else:
# This rank doesn't contain any tokens we need to keep
logits_to_keep = 0
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
print(f"{self.local_rank}: logits.shape: {logits.shape}")
# First, let all ranks know the shape of each rank's tensor
local_shape = torch.tensor([logits.shape[0], logits.shape[1], logits.shape[2]], device=logits.device)
all_shapes = [torch.zeros_like(local_shape) for _ in range(self.local_world_size)]
dist.all_gather(all_shapes, local_shape, group=sp_group)
# Use a list-based approach to collect logits of different sizes
if self.local_rank == 0:
# Root process allocates space for receiving
gathered_logits = []
for shape in all_shapes:
b, s, v = shape.tolist()
gathered_logits.append(torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device))
else:
gathered_logits = None
# Gather to rank 0
dist.gather(logits, gathered_logits, dst=0, group=sp_group)
# On rank 0, concatenate and distribute the result
if self.local_rank == 0:
concatenated_logits = torch.cat(gathered_logits, dim=1)
# Trim to keep only what we need
if concatenated_logits.shape[1] > logits_to_keep:
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
else:
concatenated_logits = torch.zeros(
(logits.shape[0], logits_to_keep, logits.shape[2]),
dtype=logits.dtype,
device=logits.device
)
# Broadcast the result back to all ranks
dist.broadcast(concatenated_logits, src=0, group=sp_group)
logits = concatenated_logits
input_ids = input_ids[:, -logits_to_keep:]
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
# See https://github.com/huggingface/trl/issues/2770
logits = logits[:, -logits_to_keep:]
# Divide logits by sampling temperature.
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
logits = logits / self.temperature
dist.barrier()
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
# super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)

View File

@@ -4,6 +4,7 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag
import functools import functools
import logging import logging
from contextlib import contextmanager
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@@ -159,6 +160,7 @@ class SequenceParallelMixin:
) )
<<<<<<< HEAD
class SequenceParallelContextManager: class SequenceParallelContextManager:
""" """
Context manager for sequence parallelism operations. Context manager for sequence parallelism operations.
@@ -311,3 +313,40 @@ class SequenceParallelContextManager:
result[:, pos] = gathered_tensor[:, i] result[:, pos] = gathered_tensor[:, i]
return result return result
=======
class SequenceParallelismManager:
def __init__(self, local_rank, local_world_size):
self.local_rank = local_rank
self.local_world_size = local_world_size
@contextmanager
def apply(self, batch):
"""
Context manager that applies sequence parallelism slicing to a batch,
and restores the original batch afterward if needed.
Args:
batch: Batch dictionary from parent collator.
Yields:
Sliced batch dictionary for use in the model.
"""
# Get local (start, end) for sequence parallelism slicing
total_seq_len = batch["input_ids"].size(1)
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Update params for varlen ring attention calculation
if batch.get("position_ids") is not None:
update_ring_attn_params(
input_ids=batch["input_ids"], position_ids=batch["position_ids"]
)
# Slice batch for sequence parallel processing
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].size(1) == total_seq_len:
batch[key] = batch[key][:, start:end]
yield batch
>>>>>>> c0054f07 (progress)