progress
This commit is contained in:
@@ -7,11 +7,11 @@ import torch.distributed as dist
|
||||
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||
from trl import GRPOTrainer
|
||||
from trl.extras.profiling import profiling_decorator
|
||||
from trl.trainer.utils import selective_log_softmax
|
||||
|
||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
update_ring_attn_params,
|
||||
)
|
||||
|
||||
if is_deepspeed_available():
|
||||
@@ -72,6 +72,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
|
||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||
if dist.get_rank() == 0:
|
||||
import ipdb; ipdb.set_trace()
|
||||
dist.barrier()
|
||||
|
||||
if dist.get_rank() == 1:
|
||||
import ipdb; ipdb.set_trace()
|
||||
dist.barrier()
|
||||
|
||||
if self.args.sequence_parallel_degree > 1:
|
||||
sp_group = get_ring_attn_group()
|
||||
self.local_rank = dist.get_rank(group=sp_group)
|
||||
@@ -81,17 +89,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
total_seq_len = input_ids.shape[1]
|
||||
remainder = total_seq_len % self.local_world_size
|
||||
if remainder != 0:
|
||||
padding = self.local_world_size - remainder
|
||||
|
||||
if dist.get_rank() == 0:
|
||||
import ipdb
|
||||
|
||||
ipdb.set_trace()
|
||||
dist.barrier()
|
||||
|
||||
to_pad = self.local_world_size - remainder
|
||||
pad_token_id = self.processing_class.pad_token_id or 0
|
||||
padding = torch.full(
|
||||
(input_ids.shape[0], padding),
|
||||
(input_ids.shape[0], to_pad),
|
||||
pad_token_id,
|
||||
dtype=input_ids.dtype,
|
||||
device=input_ids.device,
|
||||
@@ -101,14 +102,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
# Also pad attention mask if it exists
|
||||
if attention_mask is not None:
|
||||
attn_padding = torch.zeros(
|
||||
(attention_mask.shape[0], padding),
|
||||
(attention_mask.shape[0], to_pad),
|
||||
dtype=attention_mask.dtype,
|
||||
device=attention_mask.device,
|
||||
)
|
||||
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
||||
|
||||
# Update total_seq_len after padding
|
||||
total_seq_len += padding
|
||||
total_seq_len += to_pad
|
||||
|
||||
# Get local (start, end) for sequence parallelism slicing
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
@@ -119,4 +120,71 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||
input_ids = input_ids[:, start:end]
|
||||
attention_mask = attention_mask[:, start:end]
|
||||
|
||||
super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||
# Calculate if this rank contains any tokens we need to keep
|
||||
tokens_before_our_slice = self.local_rank * slice_size
|
||||
print(f"{self.local_rank}: slice_size: {slice_size}")
|
||||
print(f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}")
|
||||
if tokens_before_our_slice < logits_to_keep:
|
||||
# How many tokens from our slice are needed
|
||||
tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice
|
||||
logits_to_keep = min(slice_size, tokens_needed_from_slice)
|
||||
else:
|
||||
# This rank doesn't contain any tokens we need to keep
|
||||
logits_to_keep = 0
|
||||
|
||||
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
|
||||
|
||||
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
||||
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||
|
||||
print(f"{self.local_rank}: logits.shape: {logits.shape}")
|
||||
|
||||
# First, let all ranks know the shape of each rank's tensor
|
||||
local_shape = torch.tensor([logits.shape[0], logits.shape[1], logits.shape[2]], device=logits.device)
|
||||
all_shapes = [torch.zeros_like(local_shape) for _ in range(self.local_world_size)]
|
||||
dist.all_gather(all_shapes, local_shape, group=sp_group)
|
||||
|
||||
# Use a list-based approach to collect logits of different sizes
|
||||
if self.local_rank == 0:
|
||||
# Root process allocates space for receiving
|
||||
gathered_logits = []
|
||||
for shape in all_shapes:
|
||||
b, s, v = shape.tolist()
|
||||
gathered_logits.append(torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device))
|
||||
else:
|
||||
gathered_logits = None
|
||||
|
||||
# Gather to rank 0
|
||||
dist.gather(logits, gathered_logits, dst=0, group=sp_group)
|
||||
|
||||
# On rank 0, concatenate and distribute the result
|
||||
if self.local_rank == 0:
|
||||
concatenated_logits = torch.cat(gathered_logits, dim=1)
|
||||
# Trim to keep only what we need
|
||||
if concatenated_logits.shape[1] > logits_to_keep:
|
||||
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
|
||||
else:
|
||||
concatenated_logits = torch.zeros(
|
||||
(logits.shape[0], logits_to_keep, logits.shape[2]),
|
||||
dtype=logits.dtype,
|
||||
device=logits.device
|
||||
)
|
||||
|
||||
# Broadcast the result back to all ranks
|
||||
dist.broadcast(concatenated_logits, src=0, group=sp_group)
|
||||
logits = concatenated_logits
|
||||
|
||||
input_ids = input_ids[:, -logits_to_keep:]
|
||||
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
||||
# See https://github.com/huggingface/trl/issues/2770
|
||||
logits = logits[:, -logits_to_keep:]
|
||||
# Divide logits by sampling temperature.
|
||||
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
|
||||
logits = logits / self.temperature
|
||||
|
||||
dist.barrier()
|
||||
|
||||
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
||||
|
||||
# super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||
|
||||
@@ -4,6 +4,7 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag
|
||||
|
||||
import functools
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -159,6 +160,7 @@ class SequenceParallelMixin:
|
||||
)
|
||||
|
||||
|
||||
<<<<<<< HEAD
|
||||
class SequenceParallelContextManager:
|
||||
"""
|
||||
Context manager for sequence parallelism operations.
|
||||
@@ -311,3 +313,40 @@ class SequenceParallelContextManager:
|
||||
result[:, pos] = gathered_tensor[:, i]
|
||||
|
||||
return result
|
||||
=======
|
||||
class SequenceParallelismManager:
|
||||
def __init__(self, local_rank, local_world_size):
|
||||
self.local_rank = local_rank
|
||||
self.local_world_size = local_world_size
|
||||
|
||||
@contextmanager
|
||||
def apply(self, batch):
|
||||
"""
|
||||
Context manager that applies sequence parallelism slicing to a batch,
|
||||
and restores the original batch afterward if needed.
|
||||
|
||||
Args:
|
||||
batch: Batch dictionary from parent collator.
|
||||
|
||||
Yields:
|
||||
Sliced batch dictionary for use in the model.
|
||||
"""
|
||||
# Get local (start, end) for sequence parallelism slicing
|
||||
total_seq_len = batch["input_ids"].size(1)
|
||||
slice_size = total_seq_len // self.local_world_size
|
||||
start = self.local_rank * slice_size
|
||||
end = start + slice_size
|
||||
|
||||
# Update params for varlen ring attention calculation
|
||||
if batch.get("position_ids") is not None:
|
||||
update_ring_attn_params(
|
||||
input_ids=batch["input_ids"], position_ids=batch["position_ids"]
|
||||
)
|
||||
|
||||
# Slice batch for sequence parallel processing
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].size(1) == total_seq_len:
|
||||
batch[key] = batch[key][:, start:end]
|
||||
|
||||
yield batch
|
||||
>>>>>>> c0054f07 (progress)
|
||||
|
||||
Reference in New Issue
Block a user