some refactoring
This commit is contained in:
@@ -9,6 +9,7 @@ import logging
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -30,31 +31,21 @@ class GRPOStrategy:
|
||||
|
||||
@classmethod
|
||||
def set_training_args_kwargs(cls, cfg):
|
||||
grpo_args_kwargs = {}
|
||||
if cfg.trl and cfg.trl.use_vllm:
|
||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
||||
if cfg.trl and cfg.trl.vllm_device:
|
||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
||||
else:
|
||||
grpo_args_kwargs["vllm_device"] = "auto"
|
||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = cfg.trl.vllm_gpu_memory_utilization
|
||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
||||
if cfg.trl and cfg.trl.num_generations:
|
||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
||||
if cfg.trl and cfg.trl.sync_ref_model:
|
||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
||||
grpo_args_kwargs[
|
||||
"ref_model_mixup_alpha"
|
||||
] = cfg.trl.ref_model_mixup_alpha
|
||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
||||
training_kwargs = [
|
||||
"use_vllm",
|
||||
"vllm_device",
|
||||
"vllm_gpu_memory_utilization",
|
||||
"vllm_max_model_len",
|
||||
"vllm_dtype",
|
||||
"use_liger_loss",
|
||||
"num_generations",
|
||||
"log_completions",
|
||||
"sync_ref_model",
|
||||
"ref_model_mixup_alpha",
|
||||
"ref_model_sync_steps",
|
||||
"max_completion_length",
|
||||
]
|
||||
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
|
||||
return grpo_args_kwargs
|
||||
|
||||
@classmethod
|
||||
@@ -71,9 +62,7 @@ class GRPOStrategy:
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
trainer_kwargs = {}
|
||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||
trainer_kwargs[
|
||||
"reward_processing_classes"
|
||||
] = cfg.trl.reward_processing_classes
|
||||
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
||||
"""
|
||||
Axolotl GRPO Config for GRPO training
|
||||
"""
|
||||
use_liger_loss: bool = False
|
||||
|
||||
@@ -16,6 +16,10 @@ from transformers.utils import is_liger_kernel_available
|
||||
if is_liger_kernel_available():
|
||||
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
|
||||
|
||||
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||
from accelerate.utils import broadcast_object_list, gather_object
|
||||
from trl.trainer.utils import pad
|
||||
|
||||
|
||||
# mypy: ignore-errors
|
||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
@@ -27,9 +31,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Import Liger loss if enabled
|
||||
if self.args.use_liger_loss:
|
||||
|
||||
self.use_liger_loss = kwargs["args"].use_liger_loss
|
||||
if self.use_liger_loss:
|
||||
if not is_liger_kernel_available():
|
||||
raise ValueError(
|
||||
"You set `use_liger_loss=True` but the liger kernel is not available. "
|
||||
@@ -110,7 +114,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
llm_model.load_weights(state_dict.items())
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||
if self.args.use_liger_loss:
|
||||
if self.use_liger_loss:
|
||||
if return_outputs:
|
||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||
|
||||
@@ -199,7 +203,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
)
|
||||
else:
|
||||
with self.accelerator.unwrap_model(model).disable_adapter():
|
||||
ref_per_token_logps, ref_hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
|
||||
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
|
||||
model, prompt_completion_ids, num_logits_to_keep
|
||||
)
|
||||
|
||||
# done in liger
|
||||
# Compute the KL divergence between the model and the reference model
|
||||
@@ -261,7 +267,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
||||
|
||||
|
||||
# done in liger
|
||||
# x - x.detach() allows for preserving gradients from x
|
||||
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
||||
@@ -282,7 +288,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||
|
||||
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
||||
|
||||
|
||||
lm_head = model.get_output_embeddings()
|
||||
|
||||
if self.ref_model is not None:
|
||||
|
||||
@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
|
||||
sync_ref_model: Optional[bool] = False
|
||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||
ref_model_sync_steps: Optional[int] = 64
|
||||
use_liger_loss: Optional[bool] = False
|
||||
|
||||
Reference in New Issue
Block a user