diff --git a/src/axolotl/core/trainers/grpo/__init__.py b/src/axolotl/core/trainers/grpo/__init__.py index 5202cb09d..f892ff061 100644 --- a/src/axolotl/core/trainers/grpo/__init__.py +++ b/src/axolotl/core/trainers/grpo/__init__.py @@ -9,6 +9,7 @@ import logging from trl.trainer.grpo_trainer import RewardFunc from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer +from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig LOG = logging.getLogger("axolotl") @@ -30,31 +31,21 @@ class GRPOStrategy: @classmethod def set_training_args_kwargs(cls, cfg): - grpo_args_kwargs = {} - if cfg.trl and cfg.trl.use_vllm: - grpo_args_kwargs["use_vllm"] = cfg.trl.use_vllm - if cfg.trl and cfg.trl.vllm_device: - grpo_args_kwargs["vllm_device"] = cfg.trl.vllm_device - else: - grpo_args_kwargs["vllm_device"] = "auto" - if cfg.trl and cfg.trl.vllm_gpu_memory_utilization: - grpo_args_kwargs[ - "vllm_gpu_memory_utilization" - ] = cfg.trl.vllm_gpu_memory_utilization - if cfg.trl and cfg.trl.vllm_max_model_len: - grpo_args_kwargs["vllm_max_model_len"] = cfg.trl.vllm_max_model_len - if cfg.trl and cfg.trl.num_generations: - grpo_args_kwargs["num_generations"] = cfg.trl.num_generations - if cfg.trl and cfg.trl.sync_ref_model: - grpo_args_kwargs["sync_ref_model"] = cfg.trl.sync_ref_model - if cfg.trl and cfg.trl.ref_model_mixup_alpha: - grpo_args_kwargs[ - "ref_model_mixup_alpha" - ] = cfg.trl.ref_model_mixup_alpha - if cfg.trl and cfg.trl.ref_model_sync_steps: - grpo_args_kwargs["ref_model_sync_steps"] = cfg.trl.ref_model_sync_steps - grpo_args_kwargs["max_completion_length"] = cfg.trl.max_completion_length - grpo_args_kwargs["log_completions"] = cfg.trl.log_completions + training_kwargs = [ + "use_vllm", + "vllm_device", + "vllm_gpu_memory_utilization", + "vllm_max_model_len", + "vllm_dtype", + "use_liger_loss", + "num_generations", + "log_completions", + "sync_ref_model", + "ref_model_mixup_alpha", + "ref_model_sync_steps", + "max_completion_length", + ] + grpo_args_kwargs = {k: cfg.trl[k] for k in training_kwargs if cfg.trl[k]} return grpo_args_kwargs @classmethod @@ -71,9 +62,7 @@ class GRPOStrategy: def set_trainer_kwargs(cls, cfg): trainer_kwargs = {} if cfg.trl and cfg.trl.reward_processing_classes: - trainer_kwargs[ - "reward_processing_classes" - ] = cfg.trl.reward_processing_classes + trainer_kwargs["reward_processing_classes"] = cfg.trl.reward_processing_classes return trainer_kwargs @classmethod diff --git a/src/axolotl/core/trainers/grpo/args.py b/src/axolotl/core/trainers/grpo/args.py index e14e6b0dc..f1a4c4680 100644 --- a/src/axolotl/core/trainers/grpo/args.py +++ b/src/axolotl/core/trainers/grpo/args.py @@ -13,3 +13,4 @@ class AxolotlGRPOConfig(AxolotlTrainingMixins, GRPOConfig): """ Axolotl GRPO Config for GRPO training """ + use_liger_loss: bool = False diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index 43d8c892b..747021db3 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -16,6 +16,10 @@ from transformers.utils import is_liger_kernel_available if is_liger_kernel_available(): from liger_kernel.chunked_loss.grpo_loss import LigerFusedLinearGRPOLoss +from trl.data_utils import apply_chat_template, is_conversational, maybe_apply_chat_template +from accelerate.utils import broadcast_object_list, gather_object +from trl.trainer.utils import pad + # mypy: ignore-errors class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): @@ -27,9 +31,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - - # Import Liger loss if enabled - if self.args.use_liger_loss: + + self.use_liger_loss = kwargs["args"].use_liger_loss + if self.use_liger_loss: if not is_liger_kernel_available(): raise ValueError( "You set `use_liger_loss=True` but the liger kernel is not available. " @@ -110,7 +114,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): llm_model.load_weights(state_dict.items()) def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): - if self.args.use_liger_loss: + if self.use_liger_loss: if return_outputs: raise ValueError("The GRPOTrainer does not support returning outputs") @@ -199,7 +203,9 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): ) else: with self.accelerator.unwrap_model(model).disable_adapter(): - ref_per_token_logps, ref_hidden_states = get_per_token_logps(model, prompt_completion_ids, num_logits_to_keep) + ref_per_token_logps, ref_hidden_states = get_per_token_logps( + model, prompt_completion_ids, num_logits_to_keep + ) # done in liger # Compute the KL divergence between the model and the reference model @@ -261,7 +267,7 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): # mean_grouped_rewards = mean_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # std_grouped_rewards = std_grouped_rewards.repeat_interleave(self.num_generations, dim=0) # advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + 1e-4) - + # done in liger # x - x.detach() allows for preserving gradients from x # per_token_loss = torch.exp(per_token_logps - per_token_logps.detach()) * advantages.unsqueeze(1) @@ -282,7 +288,6 @@ class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): self._metrics["reward"].append(self.accelerator.gather_for_metrics(rewards).mean().item()) - lm_head = model.get_output_embeddings() if self.ref_model is not None: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/trl.py b/src/axolotl/utils/config/models/input/v0_4_1/trl.py index 6361bb249..99946cbbf 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/trl.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/trl.py @@ -33,3 +33,4 @@ class TRLConfig(BaseModel): sync_ref_model: Optional[bool] = False ref_model_mixup_alpha: Optional[float] = 0.9 ref_model_sync_steps: Optional[int] = 64 + use_liger_loss: Optional[bool] = False