some refactoring

This commit is contained in:
Salman Mohammadi
2025-02-19 17:35:35 +00:00
parent cf61b4aba7
commit 1a09d5e844
4 changed files with 31 additions and 35 deletions

View File

@@ -9,6 +9,7 @@ import logging
from trl.trainer.grpo_trainer import RewardFunc from trl.trainer.grpo_trainer import RewardFunc
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -30,31 +31,21 @@ class GRPOStrategy:
@classmethod @classmethod
def set_training_args_kwargs(cls, cfg): def set_training_args_kwargs(cls, cfg):
grpo_args_kwargs = {} training_kwargs = [
if cfg.trl and cfg.trl.use_vllm: "use_vllm",
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm "vllm_device",
if cfg.trl and cfg.trl.vllm_device: "vllm_gpu_memory_utilization",
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device "vllm_max_model_len",
else: "vllm_dtype",
grpo_args_kwargs["vllm_device"] = "auto" "use_liger_loss",
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization: "num_generations",
grpo_args_kwargs[ "log_completions",
"vllm_gpu_memory_utilization" "sync_ref_model",
] = cfg.trl.vllm_gpu_memory_utilization "ref_model_mixup_alpha",
if cfg.trl and cfg.trl.vllm_max_model_len: "ref_model_sync_steps",
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len "max_completion_length",
if cfg.trl and cfg.trl.num_generations: ]
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
if cfg.trl and cfg.trl.sync_ref_model:
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
grpo_args_kwargs[
"ref_model_mixup_alpha"
] = cfg.trl.ref_model_mixup_alpha
if cfg.trl and cfg.trl.ref_model_sync_steps:
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
return grpo_args_kwargs return grpo_args_kwargs
@classmethod @classmethod
@@ -71,9 +62,7 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg): def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {} trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes: if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[ trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
return trainer_kwargs return trainer_kwargs
@classmethod @classmethod

View File

@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
""" """
Axolotl GRPO Config for GRPO training Axolotl GRPO Config for GRPO training
""" """
use_liger_loss: bool = False

View File

@@ -16,6 +16,10 @@ from transformers.utils import is_liger_kernel_available
if is_liger_kernel_available(): if is_liger_kernel_available():
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
from accelerate.utils import broadcast_object_list, gather_object
from trl.trainer.utils import pad
# mypy: ignore-errors # mypy: ignore-errors
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
@@ -27,9 +31,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
# Import Liger loss if enabled self.use_liger_loss = kwargs["args"].use_liger_loss
if self.args.use_liger_loss: if self.use_liger_loss:
if not is_liger_kernel_available(): if not is_liger_kernel_available():
raise ValueError( raise ValueError(
"You set `use_liger_loss=True` but the liger kernel is not available. " "You set `use_liger_loss=True` but the liger kernel is not available. "
@@ -110,7 +114,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
llm_model.load_weights(state_dict.items()) llm_model.load_weights(state_dict.items())
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
if self.args.use_liger_loss: if self.use_liger_loss:
if return_outputs: if return_outputs:
raise ValueError("The GRPOTrainer does not support returning outputs") raise ValueError("The GRPOTrainer does not support returning outputs")
@@ -199,7 +203,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
) )
else: else:
with self.accelerator.unwrap_model(model).disable_adapter(): with self.accelerator.unwrap_model(model).disable_adapter():
ref_per_token_logps, ref_hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep) ref_per_token_logps, ref_hidden_states = get_per_token_logps(
model, prompt_completion_ids, num_logits_to_keep
)
# done in liger # done in liger
# Compute the KL divergence between the model and the reference model # Compute the KL divergence between the model and the reference model
@@ -261,7 +267,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) # advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
# done in liger # done in liger
# x - x.detach() allows for preserving gradients from x # x - x.detach() allows for preserving gradients from x
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
@@ -282,7 +288,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
lm_head = model.get_output_embeddings() lm_head = model.get_output_embeddings()
if self.ref_model is not None: if self.ref_model is not None:

View File

@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
sync_ref_model: Optional[bool] = False sync_ref_model: Optional[bool] = False
ref_model_mixup_alpha: Optional[float] = 0.9 ref_model_mixup_alpha: Optional[float] = 0.9
ref_model_sync_steps: Optional[int] = 64 ref_model_sync_steps: Optional[int] = 64
use_liger_loss: Optional[bool] = False