grpo sp support
This commit is contained in:
@@ -1048,6 +1048,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rpo_alpha is not None:
|
if self.cfg.rpo_alpha is not None:
|
||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
|
training_args_kwargs["sequence_parallel_degree"] = (
|
||||||
|
self.cfg.sequence_parallel_degree
|
||||||
|
)
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
blocklist_args_kwargs = []
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
|
|||||||
@@ -1,23 +1,25 @@
|
|||||||
"""
|
"""Axolotl GRPO trainer"""
|
||||||
Axolotl GRPO trainer
|
|
||||||
"""
|
|
||||||
|
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
from accelerate.utils import is_deepspeed_available, is_peft_model
|
from accelerate.utils import is_deepspeed_available, is_peft_model
|
||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.extras.profiling import profiling_decorator
|
from trl.extras.profiling import profiling_decorator
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
|
from axolotl.monkeypatch.attention.ring_attn import (
|
||||||
|
get_ring_attn_group,
|
||||||
|
update_ring_attn_params,
|
||||||
|
)
|
||||||
|
|
||||||
if is_deepspeed_available():
|
if is_deepspeed_available():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||||
"""
|
"""Extend the base GRPOTrainer for axolotl helpers"""
|
||||||
Extend the base GRPOTrainer for axolotl helpers
|
|
||||||
"""
|
|
||||||
|
|
||||||
_tag_names = ["trl", "grpo", "axolotl"]
|
_tag_names = ["trl", "grpo", "axolotl"]
|
||||||
|
|
||||||
@@ -67,3 +69,54 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
|||||||
# Reset cache on main process
|
# Reset cache on main process
|
||||||
if self.accelerator.is_main_process:
|
if self.accelerator.is_main_process:
|
||||||
self.vllm_client.reset_prefix_cache()
|
self.vllm_client.reset_prefix_cache()
|
||||||
|
|
||||||
|
# Get the per-token log probabilities for the completions for the model and the reference model
|
||||||
|
def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep):
|
||||||
|
if self.args.sequence_parallel_degree > 1:
|
||||||
|
sp_group = get_ring_attn_group()
|
||||||
|
self.local_rank = dist.get_rank(group=sp_group)
|
||||||
|
self.local_world_size = dist.get_world_size(group=sp_group)
|
||||||
|
|
||||||
|
# Pad sequence if needed
|
||||||
|
total_seq_len = input_ids.shape[1]
|
||||||
|
remainder = total_seq_len % self.local_world_size
|
||||||
|
if remainder != 0:
|
||||||
|
padding = self.local_world_size - remainder
|
||||||
|
|
||||||
|
if dist.get_rank() == 0:
|
||||||
|
import ipdb
|
||||||
|
|
||||||
|
ipdb.set_trace()
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
pad_token_id = self.processing_class.pad_token_id or 0
|
||||||
|
padding = torch.full(
|
||||||
|
(input_ids.shape[0], padding),
|
||||||
|
pad_token_id,
|
||||||
|
dtype=input_ids.dtype,
|
||||||
|
device=input_ids.device,
|
||||||
|
)
|
||||||
|
input_ids = torch.cat([input_ids, padding], dim=1)
|
||||||
|
|
||||||
|
# Also pad attention mask if it exists
|
||||||
|
if attention_mask is not None:
|
||||||
|
attn_padding = torch.zeros(
|
||||||
|
(attention_mask.shape[0], padding),
|
||||||
|
dtype=attention_mask.dtype,
|
||||||
|
device=attention_mask.device,
|
||||||
|
)
|
||||||
|
attention_mask = torch.cat([attention_mask, attn_padding], dim=1)
|
||||||
|
|
||||||
|
# Update total_seq_len after padding
|
||||||
|
total_seq_len += padding
|
||||||
|
|
||||||
|
# Get local (start, end) for sequence parallelism slicing
|
||||||
|
slice_size = total_seq_len // self.local_world_size
|
||||||
|
start = self.local_rank * slice_size
|
||||||
|
end = start + slice_size
|
||||||
|
|
||||||
|
# Slice data for sequence parallel processing
|
||||||
|
input_ids = input_ids[:, start:end]
|
||||||
|
attention_mask = attention_mask[:, start:end]
|
||||||
|
|
||||||
|
super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)
|
||||||
|
|||||||
Reference in New Issue
Block a user