grpo sp support

This commit is contained in:
Dan Saunders
2025-04-09 00:46:05 +00:00
parent e55dce9995
commit 11b6803ff4
2 changed files with 63 additions and 6 deletions

View File

@@ -1048,6 +1048,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_kwargs["sequence_parallel_degree"] = (
self.cfg.sequence_parallel_degree
)
training_args_cls = None
blocklist_args_kwargs = []
if self.cfg.rl == "simpo":

View File

@@ -1,23 +1,25 @@
"""
Axolotl GRPO trainer
"""
"""Axolotl GRPO trainer"""
from contextlib import nullcontext
import torch
import torch.distributed as dist
from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.monkeypatch.attention.ring_attn import (
get_ring_attn_group,
update_ring_attn_params,
)
if is_deepspeed_available():
import deepspeed
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
"""Extend the base GRPOTrainer for axolotl helpers"""
_tag_names = ["trl", "grpo", "axolotl"]
@@ -67,3 +69,54 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
# Reset cache on main process
if self.accelerator.is_main_process:
self.vllm_client.reset_prefix_cache()
# Get the per-token log probabilities for the completions for the model and the reference model
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
if self.args.sequence_parallel_degree > 1:
sp_group = get_ring_attn_group()
self.local_rank = dist.get_rank(group=sp_group)
self.local_world_size = dist.get_world_size(group=sp_group)
# Pad sequence if needed
total_seq_len = input_ids.shape[1]
remainder = total_seq_len % self.local_world_size
if remainder != 0:
padding = self.local_world_size - remainder
if dist.get_rank() == 0:
import ipdb
ipdb.set_trace()
dist.barrier()
pad_token_id = self.processing_class.pad_token_id or 0
padding = torch.full(
(input_ids.shape[0], padding),
pad_token_id,
dtype=input_ids.dtype,
device=input_ids.device,
)
input_ids = torch.cat([input_ids, padding], dim=1)
# Also pad attention mask if it exists
if attention_mask is not None:
attn_padding = torch.zeros(
(attention_mask.shape[0], padding),
dtype=attention_mask.dtype,
device=attention_mask.device,
)
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
# Update total_seq_len after padding
total_seq_len += padding
# Get local (start, end) for sequence parallelism slicing
slice_size = total_seq_len // self.local_world_size
start = self.local_rank * slice_size
end = start + slice_size
# Slice data for sequence parallel processing
input_ids = input_ids[:, start:end]
attention_mask = attention_mask[:, start:end]
super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)