some refactoring
This commit is contained in:
@@ -9,6 +9,7 @@ import logging
|
|||||||
from trl.trainer.grpo_trainer import RewardFunc
|
from trl.trainer.grpo_trainer import RewardFunc
|
||||||
|
|
||||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||||
|
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -30,31 +31,21 @@ class GRPOStrategy:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def set_training_args_kwargs(cls, cfg):
|
def set_training_args_kwargs(cls, cfg):
|
||||||
grpo_args_kwargs = {}
|
training_kwargs = [
|
||||||
if cfg.trl and cfg.trl.use_vllm:
|
"use_vllm",
|
||||||
grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm
|
"vllm_device",
|
||||||
if cfg.trl and cfg.trl.vllm_device:
|
"vllm_gpu_memory_utilization",
|
||||||
grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device
|
"vllm_max_model_len",
|
||||||
else:
|
"vllm_dtype",
|
||||||
grpo_args_kwargs["vllm_device"] = "auto"
|
"use_liger_loss",
|
||||||
if cfg.trl and cfg.trl.vllm_gpu_memory_utilization:
|
"num_generations",
|
||||||
grpo_args_kwargs[
|
"log_completions",
|
||||||
"vllm_gpu_memory_utilization"
|
"sync_ref_model",
|
||||||
] = cfg.trl.vllm_gpu_memory_utilization
|
"ref_model_mixup_alpha",
|
||||||
if cfg.trl and cfg.trl.vllm_max_model_len:
|
"ref_model_sync_steps",
|
||||||
grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len
|
"max_completion_length",
|
||||||
if cfg.trl and cfg.trl.num_generations:
|
]
|
||||||
grpo_args_kwargs["num_generations"] = cfg.trl.num_generations
|
grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]}
|
||||||
if cfg.trl and cfg.trl.sync_ref_model:
|
|
||||||
grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model
|
|
||||||
if cfg.trl and cfg.trl.ref_model_mixup_alpha:
|
|
||||||
grpo_args_kwargs[
|
|
||||||
"ref_model_mixup_alpha"
|
|
||||||
] = cfg.trl.ref_model_mixup_alpha
|
|
||||||
if cfg.trl and cfg.trl.ref_model_sync_steps:
|
|
||||||
grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps
|
|
||||||
grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length
|
|
||||||
grpo_args_kwargs["log_completions"] = cfg.trl.log_completions
|
|
||||||
return grpo_args_kwargs
|
return grpo_args_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -71,9 +62,7 @@ class GRPOStrategy:
|
|||||||
def set_trainer_kwargs(cls, cfg):
|
def set_trainer_kwargs(cls, cfg):
|
||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||||
trainer_kwargs[
|
trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes
|
||||||
"reward_processing_classes"
|
|
||||||
] = cfg.trl.reward_processing_classes
|
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig):
|
|||||||
"""
|
"""
|
||||||
Axolotl GRPO Config for GRPO training
|
Axolotl GRPO Config for GRPO training
|
||||||
"""
|
"""
|
||||||
|
use_liger_loss: bool = False
|
||||||
|
|||||||
@@ -16,6 +16,10 @@ from transformers.utils import is_liger_kernel_available
|
|||||||
if is_liger_kernel_available():
|
if is_liger_kernel_available():
|
||||||
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
|
from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss
|
||||||
|
|
||||||
|
from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template
|
||||||
|
from accelerate.utils import broadcast_object_list, gather_object
|
||||||
|
from trl.trainer.utils import pad
|
||||||
|
|
||||||
|
|
||||||
# mypy: ignore-errors
|
# mypy: ignore-errors
|
||||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
||||||
@@ -27,9 +31,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
# Import Liger loss if enabled
|
self.use_liger_loss = kwargs["args"].use_liger_loss
|
||||||
if self.args.use_liger_loss:
|
if self.use_liger_loss:
|
||||||
if not is_liger_kernel_available():
|
if not is_liger_kernel_available():
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"You set `use_liger_loss=True` but the liger kernel is not available. "
|
"You set `use_liger_loss=True` but the liger kernel is not available. "
|
||||||
@@ -110,7 +114,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
llm_model.load_weights(state_dict.items())
|
llm_model.load_weights(state_dict.items())
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
|
||||||
if self.args.use_liger_loss:
|
if self.use_liger_loss:
|
||||||
if return_outputs:
|
if return_outputs:
|
||||||
raise ValueError("The GRPOTrainer does not support returning outputs")
|
raise ValueError("The GRPOTrainer does not support returning outputs")
|
||||||
|
|
||||||
@@ -199,7 +203,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with self.accelerator.unwrap_model(model).disable_adapter():
|
with self.accelerator.unwrap_model(model).disable_adapter():
|
||||||
ref_per_token_logps, ref_hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep)
|
ref_per_token_logps, ref_hidden_states = get_per_token_logps(
|
||||||
|
model, prompt_completion_ids, num_logits_to_keep
|
||||||
|
)
|
||||||
|
|
||||||
# done in liger
|
# done in liger
|
||||||
# Compute the KL divergence between the model and the reference model
|
# Compute the KL divergence between the model and the reference model
|
||||||
@@ -261,7 +267,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
# mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||||
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
# std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0)
|
||||||
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
# advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4)
|
||||||
|
|
||||||
# done in liger
|
# done in liger
|
||||||
# x - x.detach() allows for preserving gradients from x
|
# x - x.detach() allows for preserving gradients from x
|
||||||
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
# per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1)
|
||||||
@@ -282,7 +288,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
|||||||
|
|
||||||
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item())
|
||||||
|
|
||||||
|
|
||||||
lm_head = model.get_output_embeddings()
|
lm_head = model.get_output_embeddings()
|
||||||
|
|
||||||
if self.ref_model is not None:
|
if self.ref_model is not None:
|
||||||
|
|||||||
@@ -33,3 +33,4 @@ class TRLConfig(BaseModel):
|
|||||||
sync_ref_model: Optional[bool] = False
|
sync_ref_model: Optional[bool] = False
|
||||||
ref_model_mixup_alpha: Optional[float] = 0.9
|
ref_model_mixup_alpha: Optional[float] = 0.9
|
||||||
ref_model_sync_steps: Optional[int] = 64
|
ref_model_sync_steps: Optional[int] = 64
|
||||||
|
use_liger_loss: Optional[bool] = False
|
||||||
|
|||||||
Reference in New Issue
Block a user