diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 44f8c5d2b..7527069ee 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1048,6 +1048,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rpo_alpha is not None: training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha + training_args_kwargs["sequence_parallel_degree"] = ( + self.cfg.sequence_parallel_degree + ) + training_args_cls = None blocklist_args_kwargs = [] if self.cfg.rl == "simpo": diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 25aafa6a7..c7b6715d2 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -1,23 +1,25 @@ -""" -Axolotl GRPO trainer -""" +"""Axolotl GRPO trainer""" from contextlib import nullcontext +import torch +import torch.distributed as dist from accelerate.utils import is_deepspeed_available, is_peft_model from trl import GRPOTrainer from trl.extras.profiling import profiling_decorator from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin +from axolotl.monkeypatch.attention.ring_attn import ( + get_ring_attn_group, + update_ring_attn_params, +) if is_deepspeed_available(): import deepspeed class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): - """ - Extend the base GRPOTrainer for axolotl helpers - """ + """Extend the base GRPOTrainer for axolotl helpers""" _tag_names = ["trl", "grpo", "axolotl"] @@ -67,3 +69,54 @@ class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): # Reset cache on main process if self.accelerator.is_main_process: self.vllm_client.reset_prefix_cache() + + # Get the per-token log probabilities for the completions for the model and the reference model + def _get_per_token_logps(self, model, input_ids, attention_mask, logits_to_keep): + if self.args.sequence_parallel_degree > 1: + sp_group = get_ring_attn_group() + self.local_rank = dist.get_rank(group=sp_group) + self.local_world_size = dist.get_world_size(group=sp_group) + + # Pad sequence if needed + total_seq_len = input_ids.shape[1] + remainder = total_seq_len % self.local_world_size + if remainder != 0: + padding = self.local_world_size - remainder + + if dist.get_rank() == 0: + import ipdb + + ipdb.set_trace() + dist.barrier() + + pad_token_id = self.processing_class.pad_token_id or 0 + padding = torch.full( + (input_ids.shape[0], padding), + pad_token_id, + dtype=input_ids.dtype, + device=input_ids.device, + ) + input_ids = torch.cat([input_ids, padding], dim=1) + + # Also pad attention mask if it exists + if attention_mask is not None: + attn_padding = torch.zeros( + (attention_mask.shape[0], padding), + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, attn_padding], dim=1) + + # Update total_seq_len after padding + total_seq_len += padding + + # Get local (start, end) for sequence parallelism slicing + slice_size = total_seq_len // self.local_world_size + start = self.local_rank * slice_size + end = start + slice_size + + # Slice data for sequence parallel processing + input_ids = input_ids[:, start:end] + attention_mask = attention_mask[:, start:end] + + super()._get_per_token_logps(model, input_ids, attention_mask, logits_to_keep)