progress
This commit is contained in:
@@ -7,11 +7,11 @@ import torch.distributed as dist
|
|||||||
from accelerate.utils import is_deepspeed_available, is_peft_model
|
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.extras.profiling import profiling_decorator
|
from trl.extras.profiling import profiling_decorator
|
||||||
|
from trl.trainer.utils import selective_log_softmax
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
from axolotl.monkeypatch.attention.ring_attn import (
|
from axolotl.monkeypatch.attention.ring_attn import (
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
update_ring_attn_params,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_deepspeed_available():
|
if is_deepspeed_available():
|
||||||
@@ -72,6 +72,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
|
|
||||||
# Get the per-token log probabilities for the completions for the model and the reference model
|
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||||
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
if dist.get_rank() == 1:
|
||||||
|
import ipdb; ipdb.set_trace()
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
if self.args.sequence_parallel_degree > 1:
|
if self.args.sequence_parallel_degree > 1:
|
||||||
sp_group = get_ring_attn_group()
|
sp_group = get_ring_attn_group()
|
||||||
self.local_rank = dist.get_rank(group=sp_group)
|
self.local_rank = dist.get_rank(group=sp_group)
|
||||||
@@ -81,17 +89,10 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
total_seq_len = input_ids.shape[1]
|
total_seq_len = input_ids.shape[1]
|
||||||
remainder = total_seq_len % self.local_world_size
|
remainder = total_seq_len % self.local_world_size
|
||||||
if remainder != 0:
|
if remainder != 0:
|
||||||
padding = self.local_world_size - remainder
|
to_pad = self.local_world_size - remainder
|
||||||
|
|
||||||
if dist.get_rank() == 0:
|
|
||||||
import ipdb
|
|
||||||
|
|
||||||
ipdb.set_trace()
|
|
||||||
dist.barrier()
|
|
||||||
|
|
||||||
pad_token_id = self.processing_class.pad_token_id or 0
|
pad_token_id = self.processing_class.pad_token_id or 0
|
||||||
padding = torch.full(
|
padding = torch.full(
|
||||||
(input_ids.shape[0], padding),
|
(input_ids.shape[0], to_pad),
|
||||||
pad_token_id,
|
pad_token_id,
|
||||||
dtype=input_ids.dtype,
|
dtype=input_ids.dtype,
|
||||||
device=input_ids.device,
|
device=input_ids.device,
|
||||||
@@ -101,14 +102,14 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
# Also pad attention mask if it exists
|
# Also pad attention mask if it exists
|
||||||
if attention_mask is not None:
|
if attention_mask is not None:
|
||||||
attn_padding = torch.zeros(
|
attn_padding = torch.zeros(
|
||||||
(attention_mask.shape[0], padding),
|
(attention_mask.shape[0], to_pad),
|
||||||
dtype=attention_mask.dtype,
|
dtype=attention_mask.dtype,
|
||||||
device=attention_mask.device,
|
device=attention_mask.device,
|
||||||
)
|
)
|
||||||
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
||||||
|
|
||||||
# Update total_seq_len after padding
|
# Update total_seq_len after padding
|
||||||
total_seq_len += padding
|
total_seq_len += to_pad
|
||||||
|
|
||||||
# Get local (start, end) for sequence parallelism slicing
|
# Get local (start, end) for sequence parallelism slicing
|
||||||
slice_size = total_seq_len // self.local_world_size
|
slice_size = total_seq_len // self.local_world_size
|
||||||
@@ -119,4 +120,71 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
input_ids = input_ids[:, start:end]
|
input_ids = input_ids[:, start:end]
|
||||||
attention_mask = attention_mask[:, start:end]
|
attention_mask = attention_mask[:, start:end]
|
||||||
|
|
||||||
super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
# Calculate if this rank contains any tokens we need to keep
|
||||||
|
tokens_before_our_slice = self.local_rank * slice_size
|
||||||
|
print(f"{self.local_rank}: slice_size: {slice_size}")
|
||||||
|
print(f"{self.local_rank}: tokens_before_our_slice: {tokens_before_our_slice}")
|
||||||
|
if tokens_before_our_slice < logits_to_keep:
|
||||||
|
# How many tokens from our slice are needed
|
||||||
|
tokens_needed_from_slice = logits_to_keep - tokens_before_our_slice
|
||||||
|
logits_to_keep = min(slice_size, tokens_needed_from_slice)
|
||||||
|
else:
|
||||||
|
# This rank doesn't contain any tokens we need to keep
|
||||||
|
logits_to_keep = 0
|
||||||
|
|
||||||
|
print(f"{self.local_rank}: logits_to_keep: {logits_to_keep}")
|
||||||
|
|
||||||
|
# We add 1 to `logits_to_keep` because the last logits of the sequence is later excluded
|
||||||
|
logits = model(input_ids=input_ids, attention_mask=attention_mask, logits_to_keep=logits_to_keep + 1).logits
|
||||||
|
logits = logits[:, :-1, :] # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
|
||||||
|
|
||||||
|
print(f"{self.local_rank}: logits.shape: {logits.shape}")
|
||||||
|
|
||||||
|
# First, let all ranks know the shape of each rank's tensor
|
||||||
|
local_shape = torch.tensor([logits.shape[0], logits.shape[1], logits.shape[2]], device=logits.device)
|
||||||
|
all_shapes = [torch.zeros_like(local_shape) for _ in range(self.local_world_size)]
|
||||||
|
dist.all_gather(all_shapes, local_shape, group=sp_group)
|
||||||
|
|
||||||
|
# Use a list-based approach to collect logits of different sizes
|
||||||
|
if self.local_rank == 0:
|
||||||
|
# Root process allocates space for receiving
|
||||||
|
gathered_logits = []
|
||||||
|
for shape in all_shapes:
|
||||||
|
b, s, v = shape.tolist()
|
||||||
|
gathered_logits.append(torch.zeros((b, s, v), dtype=logits.dtype, device=logits.device))
|
||||||
|
else:
|
||||||
|
gathered_logits = None
|
||||||
|
|
||||||
|
# Gather to rank 0
|
||||||
|
dist.gather(logits, gathered_logits, dst=0, group=sp_group)
|
||||||
|
|
||||||
|
# On rank 0, concatenate and distribute the result
|
||||||
|
if self.local_rank == 0:
|
||||||
|
concatenated_logits = torch.cat(gathered_logits, dim=1)
|
||||||
|
# Trim to keep only what we need
|
||||||
|
if concatenated_logits.shape[1] > logits_to_keep:
|
||||||
|
concatenated_logits = concatenated_logits[:, -logits_to_keep:, :]
|
||||||
|
else:
|
||||||
|
concatenated_logits = torch.zeros(
|
||||||
|
(logits.shape[0], logits_to_keep, logits.shape[2]),
|
||||||
|
dtype=logits.dtype,
|
||||||
|
device=logits.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Broadcast the result back to all ranks
|
||||||
|
dist.broadcast(concatenated_logits, src=0, group=sp_group)
|
||||||
|
logits = concatenated_logits
|
||||||
|
|
||||||
|
input_ids = input_ids[:, -logits_to_keep:]
|
||||||
|
# For transformers<=4.48, logits_to_keep argument isn't supported, so here we drop logits ourselves.
|
||||||
|
# See https://github.com/huggingface/trl/issues/2770
|
||||||
|
logits = logits[:, -logits_to_keep:]
|
||||||
|
# Divide logits by sampling temperature.
|
||||||
|
# See https://huggingface.co/blog/the_n_implementation_details_of_rlhf_with_ppo#policy-training-implementation-details
|
||||||
|
logits = logits / self.temperature
|
||||||
|
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
return selective_log_softmax(logits, input_ids) # compute logprobs for the input tokens
|
||||||
|
|
||||||
|
# super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ Module for Axolotl trainer sequence parallelism mixin and training context manag
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@@ -159,6 +160,7 @@ class SequenceParallelMixin:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
<<<<<<< HEAD
|
||||||
class SequenceParallelContextManager:
|
class SequenceParallelContextManager:
|
||||||
"""
|
"""
|
||||||
Context manager for sequence parallelism operations.
|
Context manager for sequence parallelism operations.
|
||||||
@@ -311,3 +313,40 @@ class SequenceParallelContextManager:
|
|||||||
result[:, pos] = gathered_tensor[:, i]
|
result[:, pos] = gathered_tensor[:, i]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
=======
|
||||||
|
class SequenceParallelismManager:
|
||||||
|
def __init__(self, local_rank, local_world_size):
|
||||||
|
self.local_rank = local_rank
|
||||||
|
self.local_world_size = local_world_size
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def apply(self, batch):
|
||||||
|
"""
|
||||||
|
Context manager that applies sequence parallelism slicing to a batch,
|
||||||
|
and restores the original batch afterward if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: Batch dictionary from parent collator.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
Sliced batch dictionary for use in the model.
|
||||||
|
"""
|
||||||
|
# Get local (start, end) for sequence parallelism slicing
|
||||||
|
total_seq_len = batch["input_ids"].size(1)
|
||||||
|
slice_size = total_seq_len // self.local_world_size
|
||||||
|
start = self.local_rank * slice_size
|
||||||
|
end = start + slice_size
|
||||||
|
|
||||||
|
# Update params for varlen ring attention calculation
|
||||||
|
if batch.get("position_ids") is not None:
|
||||||
|
update_ring_attn_params(
|
||||||
|
input_ids=batch["input_ids"], position_ids=batch["position_ids"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Slice batch for sequence parallel processing
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor) and batch[key].size(1) == total_seq_len:
|
||||||
|
batch[key] = batch[key][:, start:end]
|
||||||
|
|
||||||
|
yield batch
|
||||||
|
>>>>>>> c0054f07 (progress)
|
||||||
|
|||||||
Reference in New Issue
Block a user