some refactoring

This commit is contained in:
Salman Mohammadi
2025-02-19 17:35:35 +00:00
parent cf61b4aba7
commit 1a09d5e844
4 changed files with 31 additions and 35 deletions

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl")
@@ -30,31 +31,21 @@ class GRPOStrategy:
@classmethod
def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {}
if cfg.trl and cfg.trl.use_vllm:
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
if cfg.trl and cfg.trl.vllm_device:
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
else:
grpo_args_kwargs["vllm_device"] = "auto"
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = cfg.trl.vllm_gpu_memory_utilization
if cfg.trl and cfg.trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
if cfg.trl and cfg.trl.num_generations:
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
training_kwargs = [
"use_vllm",
"vllm_device",
"vllm_gpu_memory_utilization",
"vllm_max_model_len",
"vllm_dtype",
"use_liger_loss",
"num_generations",
"log_completions",
"sync_ref_model",
"ref_model_mixup_alpha",
"ref_model_sync_steps",
"max_completion_length",
]
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
return grpo_args_kwargs
@classmethod
@@ -71,9 +62,7 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
return trainer_kwargs
@classmethod

View File

@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
"""
Axolotl GRPO Config for GRPO training
"""
use_liger_loss: bool = False

View File

@@ -16,6 +16,10 @@ from transformers.utils import is_liger_kernel_available
if is_liger_kernel_available():
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from accelerate.utils import broadcast_object_list, gather_object
from trl.trainer.utils import pad
# mypy: ignore-errors
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
@@ -27,9 +31,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Import Liger loss if enabled
if self.args.use_liger_loss:
self.use_liger_loss = kwargs["args"].use_liger_loss
if self.use_liger_loss:
if not is_liger_kernel_available():
raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. "
@@ -110,7 +114,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
llm_model.load_weights(state_dict.items())
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if self.args.use_liger_loss:
if self.use_liger_loss:
if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs")
@@ -199,7 +203,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
)
else:
with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps, ref_hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
model, prompt_completion_ids, num_logits_to_keep
)
# done in liger
# Compute the KL divergence between the model and the reference model
@@ -261,7 +267,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# done in liger
# x - x.detach() allows for preserving gradients from x
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
@@ -282,7 +288,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
lm_head = model.get_output_embeddings()
if self.ref_model is not None:

View File

@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64
use_liger_loss: Optional[bool] = False