From d495e41ba12002f6cc3c4634b6cacadf1cea3c35 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 2 Feb 2025 18:13:09 -0500 Subject: [PATCH] refactor dpo trainer into own module --- src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/core/trainers/base.py | 114 +-------------------- src/axolotl/core/trainers/dpo_trainer.py | 125 +++++++++++++++++++++++ 3 files changed, 128 insertions(+), 113 deletions(-) create mode 100644 src/axolotl/core/trainers/dpo_trainer.py diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 89480d775..aeee49f5a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -39,7 +39,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding from axolotl.core.trainers.base import ( AxolotlCPOTrainer, - AxolotlDPOTrainer, AxolotlKTOTrainer, AxolotlMambaTrainer, AxolotlORPOTrainer, @@ -48,6 +47,7 @@ from axolotl.core.trainers.base import ( AxolotlTrainer, ReLoRATrainer, ) +from axolotl.core.trainers.dpo_trainer import AxolotlDPOTrainer from axolotl.core.training_args import ( AxolotlCPOConfig, AxolotlDPOConfig, diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 44a6d54d7..ee2545b21 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -5,30 +5,21 @@ module for customized trainers from __future__ import annotations # pylint: disable=too-many-lines -import gc import logging import os from collections import defaultdict from functools import wraps -from typing import Any, Dict, Literal, Optional, Union +from typing import Dict, Literal, Optional import torch from datasets import Dataset from peft.optimizers import create_loraplus_optimizer -from torch import nn from torch.optim.lr_scheduler import OneCycleLR from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import Trainer from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import ( - CPOTrainer, - DPOTrainer, - KTOTrainer, - ORPOTrainer, - PRMTrainer, - RewardTrainer, -) +from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer from trl.trainer.utils import pad_to_length from axolotl.monkeypatch.relora import ReLoRAScheduler @@ -847,107 +838,6 @@ class ReLoRATrainer(AxolotlTrainer): return self.lr_scheduler -class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "dpo"] - - def __init__(self, *args, dataset_tags=None, **kwargs): - super().__init__(*args, **kwargs) - self.dataset_tags = dataset_tags - self.optimizer = None - self.model_accepts_loss_kwargs = False - - def create_optimizer(self): - if self.args.loraplus_lr_ratio is None: - return super().create_optimizer() - - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) - - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - if loraplus_lr_ratio: - print("Using lora+") - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, - **optimizer_kwargs, - ) - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) - - return self.optimizer - - @wraps(DPOTrainer.push_to_hub) - def push_to_hub(self, *args, **kwargs) -> str: - """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - """ - kwargs = _sanitize_kwargs_for_ds_tagging( - dataset_tags=self.dataset_tags, kwargs=kwargs - ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) - - return super().push_to_hub(*args, **kwargs) - - @staticmethod - def tokenize_row( - features, - processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, - ) -> Dict: - res = DPOTrainer.tokenize_row( - features, - processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, - ) - # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen - if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: - for key in res.keys(): - res[key] = res[key][1:] - - if processing_class.bos_token and processing_class.bos_token_id is not None: - # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs - if res["chosen_input_ids"][0] == processing_class.bos_token_id: - res["chosen_input_ids"] = res["chosen_input_ids"][1:] - res["chosen_labels"] = res["chosen_labels"][1:] - res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] - if res["rejected_input_ids"][0] == processing_class.bos_token_id: - res["rejected_input_ids"] = res["rejected_input_ids"][1:] - res["rejected_labels"] = res["rejected_labels"][1:] - res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] - - return res - - def training_step( - self, - model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - num_items_in_batch=None, - ) -> torch.Tensor: - loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) - gc.collect() - torch.cuda.empty_cache() - return loss - - class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): """ Extend the base ORPOTrainer for axolotl helpers diff --git a/src/axolotl/core/trainers/dpo_trainer.py b/src/axolotl/core/trainers/dpo_trainer.py new file mode 100644 index 000000000..a1de4cc82 --- /dev/null +++ b/src/axolotl/core/trainers/dpo_trainer.py @@ -0,0 +1,125 @@ +""" +DPO trainer for axolotl +""" +import gc +from functools import wraps +from typing import Any, Dict, Union + +import torch +from peft.optimizers import create_loraplus_optimizer +from torch import nn +from transformers import Trainer +from transformers.utils import is_sagemaker_mp_enabled +from trl import DPOTrainer + +from axolotl.core.trainers.base import ( + SchedulerMixin, + _sanitize_kwargs_for_ds_tagging, + _sanitize_kwargs_for_tagging, +) + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + + +class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): + """ + Extend the base DPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "dpo"] + + def __init__(self, *args, dataset_tags=None, **kwargs): + super().__init__(*args, **kwargs) + self.dataset_tags = dataset_tags + self.optimizer = None + self.model_accepts_loss_kwargs = False + + def create_optimizer(self): + # pylint: disable=duplicate-code + if self.args.loraplus_lr_ratio is None: + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + if loraplus_lr_ratio: + print("Using lora+") + loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) + # pylint: disable=duplicate-code + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + + @wraps(DPOTrainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tags=self.dataset_tags, kwargs=kwargs + ) + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + + return super().push_to_hub(*args, **kwargs) + + @staticmethod + def tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) -> Dict: + res = DPOTrainer.tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) + # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen + if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: + for key in res.keys(): + res[key] = res[key][1:] + + if processing_class.bos_token and processing_class.bos_token_id is not None: + # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs + if res["chosen_input_ids"][0] == processing_class.bos_token_id: + res["chosen_input_ids"] = res["chosen_input_ids"][1:] + res["chosen_labels"] = res["chosen_labels"][1:] + res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] + if res["rejected_input_ids"][0] == processing_class.bos_token_id: + res["rejected_input_ids"] = res["rejected_input_ids"][1:] + res["rejected_labels"] = res["rejected_labels"][1:] + res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] + + return res + + def training_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + num_items_in_batch=None, + ) -> torch.Tensor: + loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) + gc.collect() + torch.cuda.empty_cache() + return loss