From 17ba9dcfdbe0b40e52ecee43ccb0bb9a5700bf50 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 16 Dec 2024 14:16:36 -0500 Subject: [PATCH] refactor trainer to prevent circular dependencies later fix loader default --- docs/rlhf.qmd | 2 +- src/axolotl/core/trainer_builder.py | 1177 +------------------ src/axolotl/core/trainers/base.py | 933 +++++++++++++++ src/axolotl/core/training_args.py | 220 ++++ src/axolotl/prompt_strategies/base.py | 2 + src/axolotl/prompt_strategies/dpo/chatml.py | 29 +- src/axolotl/prompt_strategies/dpo/llama3.py | 30 +- 7 files changed, 1232 insertions(+), 1161 deletions(-) create mode 100644 src/axolotl/core/trainers/base.py create mode 100644 src/axolotl/core/training_args.py diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 48701c87a..8f52876c1 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -29,7 +29,7 @@ datasets: type: chatml.intel - path: argilla/ultrafeedback-binarized-preferences split: train - type: chatml.argilla + type: chatml ``` #### IPO diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 176ce4174..dfc750c44 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -4,54 +4,43 @@ Builder for the training args and trainer """ import abc -import gc import importlib import importlib.util import inspect import logging import math -import os import sys from abc import abstractmethod -from collections import defaultdict -from dataclasses import dataclass, field -from functools import wraps from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Type, Union +from typing import List, Type, Union import torch import transformers -from datasets import Dataset -from peft.optimizers import create_loraplus_optimizer -from torch import nn -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler -from transformers import ( - DataCollatorWithFlattening, - EarlyStoppingCallback, - Trainer, - TrainerCallback, - TrainingArguments, -) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker -from transformers.utils import is_sagemaker_mp_enabled -from trl import ( - CPOConfig, - CPOTrainer, - DPOConfig, - DPOTrainer, - KTOConfig, - KTOTrainer, - ORPOConfig, - ORPOTrainer, - RewardConfig, - RewardTrainer, -) -from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length +from transformers import EarlyStoppingCallback, TrainerCallback, DataCollatorWithFlattening +from trl.trainer.utils import RewardDataCollatorWithPadding +from axolotl.core.trainers.base import ( + AxolotlCPOTrainer, + AxolotlDPOTrainer, + AxolotlKTOTrainer, + AxolotlMambaTrainer, + AxolotlORPOTrainer, + AxolotlRewardTrainer, + AxolotlTrainer, + ReLoRATrainer, +) +from axolotl.core.training_args import ( + AxolotlCPOConfig, + AxolotlDPOConfig, + AxolotlKTOConfig, + AxolotlORPOConfig, + AxolotlRewardConfig, + AxolotlTrainingArguments, +) from axolotl.integrations.base import PluginManager +from axolotl.integrations.liger.trainer.dpo_trainer import AxolotlLigerDPOTrainer from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES -from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler +from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, @@ -76,15 +65,6 @@ from axolotl.utils.collators import ( ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype -from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -from axolotl.utils.schedulers import ( - get_cosine_schedule_with_min_lr, - get_cosine_schedule_with_quadratic_warmup, - get_cosine_schedule_with_warmup_decay_constant, -) - -if is_sagemaker_mp_enabled(): - import smdistributed.modelparallel.torch as smp try: import torch._dynamo # pylint: disable=ungrouped-imports @@ -94,1112 +74,6 @@ except ImportError: LOG = logging.getLogger("axolotl.core.trainer_builder") -def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): - if isinstance(tag_names, str): - tag_names = [tag_names] - - if kwargs is not None: - if "tags" not in kwargs: - kwargs["tags"] = tag_names - elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].extend(tag_names) - elif "tags" in kwargs and isinstance(kwargs["tags"], str): - tag_names.append(kwargs["tags"]) - kwargs["tags"] = tag_names - - return kwargs - - -def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): - if isinstance(dataset_tags, str): - dataset_tags = [dataset_tags] - - if (dataset_tags is not None) and (kwargs is not None): - if "dataset_tags" not in kwargs: - kwargs["dataset_tags"] = dataset_tags - elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): - kwargs["dataset_tags"].extend(dataset_tags) - elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): - dataset_tags.append(kwargs["dataset_tags"]) - kwargs["dataset_tags"] = dataset_tags - - return kwargs - - -@dataclass -class AxolotlTrainingMixins: - """ - Mixin class for the Axolotl training args. - """ - - # pylint: disable=duplicate-code - model_type: Optional[str] = field( - default=None, metadata={"help": "HF model configuration model_type."} - ) - lr_quadratic_warmup: bool = field( - default=False, - metadata={"help": "Use quadratic warmup for cosine scheduling."}, - ) - pretraining: bool = field( - default=False, - metadata={ - "help": "Indicates to trainer whether we are doing continued pretraining." - }, - ) - sample_packing: bool = field( - default=False, - metadata={"help": "Use sample packing for efficient training."}, - ) - multipack_real_batches: bool = field( - default=False, - metadata={"help": "Use real batches for efficient training."}, - ) - eval_sample_packing: Optional[bool] = field( - default=None, - metadata={"help": "Use sample packing for efficient evals."}, - ) - sample_packing_efficiency: float = field( - default=1.0, - metadata={"help": "Sample packing efficiency for calculating batch length."}, - ) - sample_packing_bin_size: int = field( - default=200, - metadata={ - "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." - }, - ) - sample_packing_group_size: int = field( - default=100000, - metadata={ - "help": "The number of samples to group together for packing. Increase for better packing." - }, - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "The maximum sequence length the model can handle"}, - ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_anneal_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_prune_ratio: Optional[float] = field( - default=0.9, - metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, - ) - bench_split: Optional[str] = field( - default="eval", metadata={"help": "The benchmark split to run on"} - ) - bench_dataset: Optional[str] = field( - default="pharaouk/dharma-1/dharma_1_mini.json", - metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" - }, - ) - do_bench_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Benchmark evaluation."} - ) - do_causal_lm_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Causal LM evaluation."} - ) - max_bench_samples: Optional[int] = field( - default=None, - metadata={ - "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." - }, - ) - bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for bench."} - ) - dataloader_prefetch_factor: Optional[int] = field( - default=None, - metadata={"help": "prefetch_factor argument to the dataloader"}, - ) - cosine_min_lr_ratio: Optional[float] = field( - default=None, - metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, - ) - cosine_constant_lr_ratio: Optional[float] = field( - default=None, - metadata={ - "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" - }, - ) - loraplus_lr_ratio: Optional[float] = field( - default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} - ) - loraplus_lr_embedding: Optional[float] = field( - default=1e-6, - metadata={"help": "loraplus learning rate for lora embedding layers."}, - ) - embedding_lr_scale: Optional[float] = field( - default=None, - metadata={"help": "Scale the learning rate for the embedding layers."}, - ) - embedding_lr: Optional[float] = field( - default=None, - metadata={"help": "absolute learning rate for the embedding layers."}, - ) - qlora: bool = field( - default=False, - metadata={"help": "whether this is a qlora training"}, - ) - orpo_alpha: Optional[float] = field( - default=None, - ) - lisa_n_layers: Optional[int] = field( - default=None, - metadata={"help": "the number of activate layers in LISA"}, - ) - lisa_step_interval: Optional[int] = field( - default=None, - metadata={"help": "how often to switch layers in LISA"}, - ) - lisa_layers_attribute: Optional[str] = field( - default=None, - metadata={"help": "path under the model to access the layers"}, - ) - curriculum_sampling: Optional[bool] = field( - default=None, - metadata={"help": "whether to use sequential sampling for curriculum learning"}, - ) - alternate_optimizer: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate optimizer to the HF trainer" - }, - ) - alternate_lr_scheduler_type: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate lr scheduler to the HF trainer" - }, - ) - chat_template: Optional[str] = field( - default=None, - metadata={"help": "Chat template converting chat messages to text"}, - ) - - -@dataclass -class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): - """ - Training arguments for Causal trainer - - This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value - so it can't be used as a mixin. - """ - - -@dataclass -class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): - """ - DPO config for DPO training - """ - - -@dataclass -class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): - """ - ORPO config for ORPO training - """ - - -@dataclass -class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig): - """ - KTO config for KTO training - """ - - -@dataclass -class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): - """ - CPO config for CPO training - """ - - simpo_gamma: Optional[float] = field( - default=None, - metadata={"help": "simpo gamma parameter"}, - ) - - -@dataclass -class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig): - """ - Reward config for Reward training - """ - - -class SchedulerMixin(Trainer): - """ - Mixin class for scheduler setup in CausalTrainer. - """ - - args = None # type: AxolotlTrainingArguments - - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or - passed as an argument. - - Args: - num_training_steps (int): The number of training steps to do. - optimizer (torch.optim.Optimizer): The training optimizer - """ - use_cosine_quadratic = ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ) - - use_cosine_min_lr = ( - self.args.lr_scheduler_type == "cosine" - and self.args.cosine_min_lr_ratio is not None - ) - - # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition - # fmt: on - if self.args.alternate_lr_scheduler_type == "one_cycle": - num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - pct_start = num_warmup_steps / num_training_steps - extra_lr_kwargs = {} - if "pct_start" not in self.args.lr_scheduler_kwargs: - extra_lr_kwargs["pct_start"] = pct_start - if "anneal_strategy" not in self.args.lr_scheduler_kwargs: - extra_lr_kwargs["anneal_strategy"] = "cos" - - self.lr_scheduler = OneCycleLR( - optimizer, - max_lr=self.args.learning_rate, - total_steps=num_training_steps, - **extra_lr_kwargs, - **self.args.lr_scheduler_kwargs, - ) - elif use_cosine_quadratic: - if use_cosine_min_lr: - LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") - - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - constant_lr_ratio=self.args.cosine_constant_lr_ratio, - ) - elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - ) - else: - return super().create_scheduler(num_training_steps, optimizer=optimizer) - else: - if use_cosine_quadratic: - LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") - - if use_cosine_min_lr: - LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") - - return self.lr_scheduler - - -class AxolotlTrainer(SchedulerMixin, Trainer): - """ - Extend the base Trainer for axolotl helpers - """ - - args = None # type: AxolotlTrainingArguments - tag_names = ["axolotl"] - - def __init__( - self, - *_args, - bench_data_collator=None, - eval_data_collator=None, - dataset_tags=None, - **kwargs, - ): - self.bench_data_collator = bench_data_collator - self.eval_data_collator = eval_data_collator - self.dataset_tags = dataset_tags - super().__init__(*_args, **kwargs) - self.train_data_collator = self.data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - if self.args.orpo_alpha: - self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - - def _wrap_model(self, model, training=True, dataloader=None): - if self.args.torch_compile: - torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access - 256 - ) - model = torch.compile( - model, - backend=self.args.torch_compile_backend, - mode=self.args.torch_compile_mode, - ) - return super()._wrap_model(model, training=training, dataloader=dataloader) - - def create_optimizer(self): - if ( - self.args.loraplus_lr_ratio is None - and self.args.embedding_lr_scale is None - and self.args.embedding_lr is None - and self.args.alternate_optimizer - not in [ - "optimi_adamw", - "ao_adamw_8bit", - "ao_adamw_4bit", - "ao_adamw_fp8", - "adopt_adamw", - ] - ): - return super().create_optimizer() - - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - decay_parameters = self.get_decay_parameter_names(opt_model) - params = { - "to_weight_decay": {}, # LayerNorm and bias - "embeddings": {}, # lm_head, embed_tokens, - "no_weight_decay": {}, - } - - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - if name.endswith("modules_to_save.default.weight") or any( - embed_name in name for embed_name in ["embed_tokens", "lm_head"] - ): - params["embeddings"][name] = param - elif name in decay_parameters: - params["to_weight_decay"][name] = param - else: - params["no_weight_decay"][name] = param - optimizer_grouped_parameters = [] - if params["to_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["to_weight_decay"].values()), - "weight_decay": self.args.weight_decay, - "lr": optimizer_kwargs["lr"], - } - ) - if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name - elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name - optimizer_grouped_parameters.append( - { - "params": list(params["embeddings"].values()), - "weight_decay": 0.0, - "lr": lr, - } - ) - if params["no_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["no_weight_decay"].values()), - "weight_decay": 0.0, - "lr": optimizer_kwargs["lr"], - } - ) - - if self.args.loraplus_lr_ratio is not None: - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr( - self.args, "loraplus_lr_embedding", 1e-6 - ) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, - **optimizer_kwargs, - ) - elif ( - self.args.embedding_lr_scale is not None - or self.args.embedding_lr is not None - ): - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "optimi_adamw": - from optimi import AdamW - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW( - optimizer_grouped_parameters, foreach=False, **optimizer_kwargs - ) - ) - elif self.args.alternate_optimizer == "ao_adamw_4bit": - from torchao.prototype.low_bit_optim import AdamW4bit - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "ao_adamw_8bit": - from torchao.prototype.low_bit_optim import AdamW8bit - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "ao_adamw_fp8": - from torchao.prototype.low_bit_optim import AdamWFp8 - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "adopt_adamw": - from axolotl.utils.optimizers.adopt import ADOPT - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - ADOPT( - optimizer_grouped_parameters, - decouple=True, - **optimizer_kwargs, - ) - ) - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) - - return self.optimizer - - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and not self.args.pretraining: - if self.args.multipack_real_batches: - batch_size = self.args.per_device_train_batch_size - batch_max_len = self.args.max_seq_length - else: - batch_size = 1 - train_batch_size = ( - self.state.train_batch_size or self.args.per_device_train_batch_size - ) - batch_max_len = train_batch_size * self.args.max_seq_length - - if self.args.curriculum_sampling: - sampler = SequentialSampler(self.train_dataset) - else: - sampler = RandomSampler(self.train_dataset) - - return MultipackBatchSampler( - sampler, - lengths=get_dataset_lengths(self.train_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - batch_max_len=batch_max_len, - batch_size=batch_size, - group_size=self.args.sample_packing_group_size, - bin_size=self.args.sample_packing_bin_size, - drop_last=True, - ) - if self.args.curriculum_sampling: - return SequentialSampler(self.train_dataset) - return super()._get_train_sampler() - - def _get_eval_sampler( - self, eval_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and self.args.eval_sample_packing is not False: - if self.args.multipack_real_batches: - batch_size = self.args.per_device_eval_batch_size - batch_max_len = self.args.max_seq_length - else: - batch_size = 1 - batch_max_len = ( - self.args.per_device_eval_batch_size * self.args.max_seq_length - ) - return MultipackBatchSampler( - SequentialSampler(eval_dataset), - lengths=get_dataset_lengths(self.eval_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - batch_max_len=batch_max_len, - batch_size=batch_size, - group_size=self.args.sample_packing_group_size, - bin_size=self.args.sample_packing_bin_size, - drop_last=True, - ) - return super()._get_eval_sampler(eval_dataset) - - def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing and not self.args.pretraining: - train_dataset = self.train_dataset - if "length" in train_dataset.features.keys(): - train_dataset = train_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor - - sampler = self._get_train_sampler() - if isinstance(sampler, BatchSampler): - dataloader_params["batch_sampler"] = sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(train_dataset, **dataloader_params) - ) - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - if self.args.sample_packing and self.args.eval_sample_packing is False: - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.eval_data_collator - ) - if eval_dataset: - eval_dataset = eval_dataset.remove_columns(["length"]) - dataloader = super().get_eval_dataloader(eval_dataset) - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.train_data_collator - ) - return dataloader - - if self.args.sample_packing and self.args.eval_sample_packing is not False: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset - ) - - eval_sampler = self._get_eval_sampler(eval_dataset) - eval_dataset = eval_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor - - if isinstance(eval_sampler, BatchSampler): - dataloader_params["batch_sampler"] = eval_sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = eval_sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(eval_dataset, **dataloader_params) - ) - - return super().get_eval_dataloader(eval_dataset) - - def _get_bench_sampler( - self, bench_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size <= 1: - return SequentialSampler(bench_dataset) - return None - - def get_bench_dataloader( - self, - bench_dataset: Dataset, - ) -> DataLoader: - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": self.bench_data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - if not isinstance(bench_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - return DataLoader(bench_dataset, **dataloader_params) - # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) - - def compute_loss( - self, model, inputs, return_outputs=False, num_items_in_batch=None - ): - # use one's weighted cross entropy loss calc - # if self.args.sample_packing: - # labels = inputs.pop("labels") - # outputs = model(**inputs) - # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) - # return (loss, outputs) if return_outputs else loss - if self.args.orpo_alpha: - return self.orpo_compute_loss( - model, - inputs, - return_outputs=return_outputs, - num_items_in_batch=num_items_in_batch, - ) - return super().compute_loss( - model, - inputs, - return_outputs=return_outputs, - num_items_in_batch=num_items_in_batch, - ) - - @staticmethod - def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): - concatenated_batch = {} - - max_length = max( - inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1] - ) - # Concatenate positive and negative inputs - concatenated_batch["input_ids"] = pad_to_length( - inputs["input_ids"], max_length, pad_token - ) - concatenated_batch["rejected_input_ids"] = pad_to_length( - inputs["rejected_input_ids"], max_length, pad_token - ) - concatenated_batch["labels"] = pad_to_length( - inputs["labels"], max_length, label_pad_token - ) - concatenated_batch["rejected_labels"] = pad_to_length( - inputs["rejected_labels"], max_length, label_pad_token - ) - concatenated_batch["attention_mask"] = pad_to_length( - inputs["attention_mask"], max_length, 0 - ) - concatenated_batch["rejected_attention_mask"] = pad_to_length( - inputs["rejected_attention_mask"], max_length, 0 - ) - concatenated_batch["prompt_attention_mask"] = pad_to_length( - inputs["prompt_attention_mask"], max_length, 0 - ).to(device=device) - - input_ids = torch.cat( - [concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]], - dim=0, - ).to(device=device) - attention_mask = torch.cat( - [ - concatenated_batch["attention_mask"], - concatenated_batch["rejected_attention_mask"], - ], - dim=0, - ).to(device=device) - labels = torch.cat( - [concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0 - ).to(device=device) - - return { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "prompt_attention_mask": concatenated_batch["prompt_attention_mask"], - } - - def orpo_compute_custom_loss(self, logits, labels): - logits = logits.contiguous() - loss = 0.0 - - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - - # Flatten the tokens - loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( - dim=-1 - ) - - return loss - - def orpo_compute_logps( - self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits - ): - # Get the shape of chosen_attention_mask[:, :-1] - chosen_shape = chosen_attention_mask[:, :-1].shape - - # Calculate the padding size - pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) - - # Pad prompt_attention_mask with zeros to match the desired shape - prompt_attention_mask_padded = torch.nn.functional.pad( - prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 - ) - - # Perform the subtraction operation - mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded - - per_token_logps = torch.gather( - logits[:, :-1, :].log_softmax(-1), - dim=2, - index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), - ).squeeze(2) - return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) - - def orpo_compute_loss( - self, - model, - inputs, - return_outputs=False, - num_items_in_batch=None, # pylint: disable=unused-argument - ): - concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( - inputs, - label_pad_token=-100, - pad_token=self.tokenizer.pad_token_id, - device=self.accelerator.device, - ) - - # Perform a single forward pass - outputs = model( - **{ - "input_ids": concat_inputs["input_ids"], - "attention_mask": concat_inputs["attention_mask"], - "labels": concat_inputs["labels"], - }, - output_hidden_states=True, - ) - - # Split the outputs for positive and negative examples - outputs_pos, outputs_neg = outputs.logits.chunk(2) - - # Calculate NLL loss - pos_loss = self.orpo_compute_custom_loss( - logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0] - ) - - # Calculate Log Probability - pos_prob = self.orpo_compute_logps( - prompt_attention_mask=concat_inputs["prompt_attention_mask"], - chosen_inputs=concat_inputs["input_ids"].chunk(2)[0], - chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0], - logits=outputs_pos, - ) - neg_prob = self.orpo_compute_logps( - prompt_attention_mask=concat_inputs["prompt_attention_mask"], - chosen_inputs=concat_inputs["input_ids"].chunk(2)[1], - chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1], - logits=outputs_neg, - ) - - # Calculate log odds - log_odds = (pos_prob - neg_prob) - ( - torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) - ) - sig_ratio = torch.nn.functional.sigmoid(log_odds) - ratio = torch.log(sig_ratio) - - # Calculate the Final Loss - loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( - dtype=torch.bfloat16 - ) - - metrics = {} - metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() - metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() - metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() - metrics["log_odds"] = torch.mean(log_odds).cpu().item() - self.store_metrics(metrics, train_eval="train") - - return (loss, outputs_pos) if return_outputs else loss - - @wraps(Trainer.push_to_hub) - def push_to_hub(self, *args, **kwargs) -> str: - """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - """ - kwargs = _sanitize_kwargs_for_ds_tagging( - dataset_tags=self.dataset_tags, kwargs=kwargs - ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) - - return super().push_to_hub(*args, **kwargs) - - @wraps(Trainer.create_accelerator_and_postprocess) - def create_accelerator_and_postprocess(self): - res = super().create_accelerator_and_postprocess() - - if self.is_fsdp_enabled: - if ( - "limit_all_gathers" in self.args.fsdp_config - and self.args.fsdp_config["limit_all_gathers"] - ): - self.accelerator.state.fsdp_plugin.limit_all_gathers = True - - return res - - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`Dict[str, float]`): - The values to log. - start_time (`Optional[float]`): - The start of training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - - return super().log(logs, start_time) - - def store_metrics( - self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" - ) -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def _save_checkpoint(self, model, trial, **kwargs): - # make sure the checkpoint dir exists, since trainer is flakey - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" - run_dir = self._get_output_dir(trial=trial) - output_dir = os.path.join(run_dir, checkpoint_folder) - os.makedirs(output_dir, exist_ok=True) - return super()._save_checkpoint(model, trial, **kwargs) - - -class AxolotlMambaTrainer(AxolotlTrainer): - """ - Mamba specific trainer to handle loss calculation - """ - - tag_names = ["axolotl", "mamba"] - - def compute_loss( - self, - model, - inputs, - return_outputs=False, # pylint: disable=unused-argument - num_items_in_batch=None, # pylint: disable=unused-argument - ): - input_ids = inputs.pop("input_ids") - lm_logits = model(input_ids).logits - - labels = input_ids.to(lm_logits.device) - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) - - return lm_loss - - -class ReLoRATrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - tag_names = ["axolotl", "relora"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler = super().create_scheduler(num_training_steps, optimizer) - - if self.args.relora_steps: - warmup_steps = ( - self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 - ) - anneal_steps = ( - self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 - ) - self.lr_scheduler = ReLoRAScheduler( - optimizer, - lr_scheduler, - self.args.relora_steps, - anneal_steps, - warmup_steps, - ) - else: - self.lr_scheduler = lr_scheduler - - return self.lr_scheduler - - -class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "dpo"] - - def __init__(self, *args, dataset_tags=None, **kwargs): - super().__init__(*args, **kwargs) - self.dataset_tags = dataset_tags - self.optimizer = None - - def create_optimizer(self): - if self.args.loraplus_lr_ratio is None: - return super().create_optimizer() - - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) - - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - if loraplus_lr_ratio: - print("Using lora+") - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, - **optimizer_kwargs, - ) - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) - - return self.optimizer - - @wraps(DPOTrainer.push_to_hub) - def push_to_hub(self, *args, **kwargs) -> str: - """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - """ - kwargs = _sanitize_kwargs_for_ds_tagging( - dataset_tags=self.dataset_tags, kwargs=kwargs - ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) - - return super().push_to_hub(*args, **kwargs) - - @staticmethod - def tokenize_row( - features, - processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, - ) -> Dict: - res = DPOTrainer.tokenize_row( - features, - processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, - ) - # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen - if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: - for key in res.keys(): - res[key] = res[key][1:] - - if processing_class.bos_token and processing_class.bos_token_id is not None: - # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs - if res["chosen_input_ids"][0] == processing_class.bos_token_id: - res["chosen_input_ids"] = res["chosen_input_ids"][1:] - res["chosen_labels"] = res["chosen_labels"][1:] - res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] - if res["rejected_input_ids"][0] == processing_class.bos_token_id: - res["rejected_input_ids"] = res["rejected_input_ids"][1:] - res["rejected_labels"] = res["rejected_labels"][1:] - res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] - - return res - - def training_step( - self, - model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - num_items_in_batch=None, - ) -> torch.Tensor: - loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) - gc.collect() - torch.cuda.empty_cache() - return loss - - -class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): - """ - Extend the base ORPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "orpo"] - - -class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): - """ - Extend the base KTOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "kto"] - - -class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): - """ - Extend the base CPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "cpo"] - - -class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): - """ - Extend the base RewardTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "reward"] - - class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -2114,7 +988,10 @@ class HFRLTrainerBuilder(TrainerBuilderBase): "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs if self.cfg.rl in ["dpo", "ipo"]: - trainer_cls = AxolotlDPOTrainer + if self.cfg.liger_pref_rl: + trainer_cls = AxolotlLigerDPOTrainer + else: + trainer_cls = AxolotlDPOTrainer trainer_cls_args = [self.model, self.model_ref] elif self.cfg.rl == "orpo": trainer_cls = AxolotlORPOTrainer diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py new file mode 100644 index 000000000..8014d67a7 --- /dev/null +++ b/src/axolotl/core/trainers/base.py @@ -0,0 +1,933 @@ +""" +module for customized trainers +""" + +from __future__ import annotations + +# pylint: disable=too-many-lines +import gc +import logging +import os +from collections import defaultdict +from functools import wraps +from typing import Any, Dict, Literal, Optional, Union + +import torch +import transformers +from datasets import Dataset +from peft.optimizers import create_loraplus_optimizer +from torch import nn +from torch.optim.lr_scheduler import OneCycleLR +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler +from transformers import Trainer +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker +from transformers.utils import is_sagemaker_mp_enabled +from trl import CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RewardTrainer +from trl.trainer.utils import pad_to_length + +from axolotl.monkeypatch.relora import ReLoRAScheduler +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.schedulers import ( + get_cosine_schedule_with_min_lr, + get_cosine_schedule_with_quadratic_warmup, + get_cosine_schedule_with_warmup_decay_constant, +) + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + +LOG = logging.getLogger("axolotl.core.trainer_builder") + + +def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + + return kwargs + + +def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): + if isinstance(dataset_tags, str): + dataset_tags = [dataset_tags] + + if (dataset_tags is not None) and (kwargs is not None): + if "dataset_tags" not in kwargs: + kwargs["dataset_tags"] = dataset_tags + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): + kwargs["dataset_tags"].extend(dataset_tags) + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): + dataset_tags.append(kwargs["dataset_tags"]) + kwargs["dataset_tags"] = dataset_tags + + return kwargs + + +class SchedulerMixin(Trainer): + """ + Mixin class for scheduler setup in CausalTrainer. + """ + + args = None # type: AxolotlTrainingArguments + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer + """ + use_cosine_quadratic = ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ) + + use_cosine_min_lr = ( + self.args.lr_scheduler_type == "cosine" + and self.args.cosine_min_lr_ratio is not None + ) + + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if self.args.alternate_lr_scheduler_type == "one_cycle": + num_warmup_steps = self.args.get_warmup_steps(num_training_steps) + pct_start = num_warmup_steps / num_training_steps + extra_lr_kwargs = {} + if "pct_start" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["pct_start"] = pct_start + if "anneal_strategy" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["anneal_strategy"] = "cos" + + self.lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + **extra_lr_kwargs, + **self.args.lr_scheduler_kwargs, + ) + elif use_cosine_quadratic: + if use_cosine_min_lr: + LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") + + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + constant_lr_ratio=self.args.cosine_constant_lr_ratio, + ) + elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + ) + else: + return super().create_scheduler(num_training_steps, optimizer=optimizer) + else: + if use_cosine_quadratic: + LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") + + if use_cosine_min_lr: + LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + + return self.lr_scheduler + + +class AxolotlTrainer(SchedulerMixin, Trainer): + """ + Extend the base Trainer for axolotl helpers + """ + + args = None # type: AxolotlTrainingArguments + tag_names = ["axolotl"] + + def __init__( + self, + *_args, + bench_data_collator=None, + eval_data_collator=None, + dataset_tags=None, + **kwargs, + ): + self.bench_data_collator = bench_data_collator + self.eval_data_collator = eval_data_collator + self.dataset_tags = dataset_tags + super().__init__(*_args, **kwargs) + self.train_data_collator = self.data_collator + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if self.args.orpo_alpha: + self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.torch_compile: + torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access + 256 + ) + model = torch.compile( + model, + backend=self.args.torch_compile_backend, + mode=self.args.torch_compile_mode, + ) + return super()._wrap_model(model, training=training, dataloader=dataloader) + + def create_optimizer(self): + if ( + self.args.loraplus_lr_ratio is None + and self.args.embedding_lr_scale is None + and self.args.embedding_lr is None + and self.args.alternate_optimizer + not in [ + "optimi_adamw", + "ao_adamw_8bit", + "ao_adamw_4bit", + "ao_adamw_fp8", + "adopt_adamw", + ] + ): + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + decay_parameters = self.get_decay_parameter_names(opt_model) + params = { + "to_weight_decay": {}, # LayerNorm and bias + "embeddings": {}, # lm_head, embed_tokens, + "no_weight_decay": {}, + } + + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": self.args.weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if self.args.embedding_lr_scale: + lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + elif self.args.embedding_lr: + lr = self.args.embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + + if self.args.loraplus_lr_ratio is not None: + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + loraplus_lr_embedding = getattr( + self.args, "loraplus_lr_embedding", 1e-6 + ) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, + ) + elif ( + self.args.embedding_lr_scale is not None + or self.args.embedding_lr is not None + ): + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "optimi_adamw": + from optimi import AdamW + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW( + optimizer_grouped_parameters, foreach=False, **optimizer_kwargs + ) + ) + elif self.args.alternate_optimizer == "ao_adamw_4bit": + from torchao.prototype.low_bit_optim import AdamW4bit + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "ao_adamw_8bit": + from torchao.prototype.low_bit_optim import AdamW8bit + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "ao_adamw_fp8": + from torchao.prototype.low_bit_optim import AdamWFp8 + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "adopt_adamw": + from axolotl.utils.optimizers.adopt import ADOPT + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + ADOPT( + optimizer_grouped_parameters, + decouple=True, + **optimizer_kwargs, + ) + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.args.sample_packing and not self.args.pretraining: + if self.args.multipack_real_batches: + batch_size = self.args.per_device_train_batch_size + batch_max_len = self.args.max_seq_length + else: + batch_size = 1 + train_batch_size = ( + self.state.train_batch_size or self.args.per_device_train_batch_size + ) + batch_max_len = train_batch_size * self.args.max_seq_length + + if self.args.curriculum_sampling: + sampler = SequentialSampler(self.train_dataset) + else: + sampler = RandomSampler(self.train_dataset) + + return MultipackBatchSampler( + sampler, + lengths=get_dataset_lengths(self.train_dataset), + packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_max_len=batch_max_len, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, + drop_last=True, + ) + if self.args.curriculum_sampling: + return SequentialSampler(self.train_dataset) + return super()._get_train_sampler() + + def _get_eval_sampler( + self, eval_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.sample_packing and self.args.eval_sample_packing is not False: + if self.args.multipack_real_batches: + batch_size = self.args.per_device_eval_batch_size + batch_max_len = self.args.max_seq_length + else: + batch_size = 1 + batch_max_len = ( + self.args.per_device_eval_batch_size * self.args.max_seq_length + ) + return MultipackBatchSampler( + SequentialSampler(eval_dataset), + lengths=get_dataset_lengths(self.eval_dataset), + packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_max_len=batch_max_len, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, + drop_last=True, + ) + return super()._get_eval_sampler(eval_dataset) + + def get_train_dataloader(self) -> DataLoader: + if self.args.sample_packing and not self.args.pretraining: + train_dataset = self.train_dataset + if "length" in train_dataset.features.keys(): + train_dataset = train_dataset.remove_columns(["length"]) + data_collator = self.data_collator + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params[ + "prefetch_factor" + ] = self.args.dataloader_prefetch_factor + + sampler = self._get_train_sampler() + if isinstance(sampler, BatchSampler): + dataloader_params["batch_sampler"] = sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + + self.accelerator.even_batches = False + return self.accelerator.prepare_data_loader( + DataLoader(train_dataset, **dataloader_params) + ) + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + if self.args.sample_packing and self.args.eval_sample_packing is False: + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.eval_data_collator + ) + if eval_dataset: + eval_dataset = eval_dataset.remove_columns(["length"]) + dataloader = super().get_eval_dataloader(eval_dataset) + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.train_data_collator + ) + return dataloader + + if self.args.sample_packing and self.args.eval_sample_packing is not False: + eval_dataset = ( + eval_dataset if eval_dataset is not None else self.eval_dataset + ) + + eval_sampler = self._get_eval_sampler(eval_dataset) + eval_dataset = eval_dataset.remove_columns(["length"]) + data_collator = self.data_collator + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params[ + "prefetch_factor" + ] = self.args.dataloader_prefetch_factor + + if isinstance(eval_sampler, BatchSampler): + dataloader_params["batch_sampler"] = eval_sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = eval_sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + self.accelerator.even_batches = False + return self.accelerator.prepare_data_loader( + DataLoader(eval_dataset, **dataloader_params) + ) + + return super().get_eval_dataloader(eval_dataset) + + def _get_bench_sampler( + self, bench_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size <= 1: + return SequentialSampler(bench_dataset) + return None + + def get_bench_dataloader( + self, + bench_dataset: Dataset, + ) -> DataLoader: + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": self.bench_data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + if not isinstance(bench_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + return DataLoader(bench_dataset, **dataloader_params) + # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + + def compute_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): + # use one's weighted cross entropy loss calc + # if self.args.sample_packing: + # labels = inputs.pop("labels") + # outputs = model(**inputs) + # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) + # return (loss, outputs) if return_outputs else loss + if self.args.orpo_alpha: + return self.orpo_compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) + return super().compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) + + @staticmethod + def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): + concatenated_batch = {} + + max_length = max( + inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1] + ) + # Concatenate positive and negative inputs + concatenated_batch["input_ids"] = pad_to_length( + inputs["input_ids"], max_length, pad_token + ) + concatenated_batch["rejected_input_ids"] = pad_to_length( + inputs["rejected_input_ids"], max_length, pad_token + ) + concatenated_batch["labels"] = pad_to_length( + inputs["labels"], max_length, label_pad_token + ) + concatenated_batch["rejected_labels"] = pad_to_length( + inputs["rejected_labels"], max_length, label_pad_token + ) + concatenated_batch["attention_mask"] = pad_to_length( + inputs["attention_mask"], max_length, 0 + ) + concatenated_batch["rejected_attention_mask"] = pad_to_length( + inputs["rejected_attention_mask"], max_length, 0 + ) + concatenated_batch["prompt_attention_mask"] = pad_to_length( + inputs["prompt_attention_mask"], max_length, 0 + ).to(device=device) + + input_ids = torch.cat( + [concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]], + dim=0, + ).to(device=device) + attention_mask = torch.cat( + [ + concatenated_batch["attention_mask"], + concatenated_batch["rejected_attention_mask"], + ], + dim=0, + ).to(device=device) + labels = torch.cat( + [concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0 + ).to(device=device) + + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "prompt_attention_mask": concatenated_batch["prompt_attention_mask"], + } + + def orpo_compute_custom_loss(self, logits, labels): + logits = logits.contiguous() + loss = 0.0 + + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( + dim=-1 + ) + + return loss + + def orpo_compute_logps( + self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits + ): + # Get the shape of chosen_attention_mask[:, :-1] + chosen_shape = chosen_attention_mask[:, :-1].shape + + # Calculate the padding size + pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) + + # Pad prompt_attention_mask with zeros to match the desired shape + prompt_attention_mask_padded = torch.nn.functional.pad( + prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 + ) + + # Perform the subtraction operation + mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded + + per_token_logps = torch.gather( + logits[:, :-1, :].log_softmax(-1), + dim=2, + index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), + ).squeeze(2) + return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) + + def orpo_compute_loss( + self, + model, + inputs, + return_outputs=False, + num_items_in_batch=None, # pylint: disable=unused-argument + ): + concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( + inputs, + label_pad_token=-100, + pad_token=self.tokenizer.pad_token_id, + device=self.accelerator.device, + ) + + # Perform a single forward pass + outputs = model( + **{ + "input_ids": concat_inputs["input_ids"], + "attention_mask": concat_inputs["attention_mask"], + "labels": concat_inputs["labels"], + }, + output_hidden_states=True, + ) + + # Split the outputs for positive and negative examples + outputs_pos, outputs_neg = outputs.logits.chunk(2) + + # Calculate NLL loss + pos_loss = self.orpo_compute_custom_loss( + logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0] + ) + + # Calculate Log Probability + pos_prob = self.orpo_compute_logps( + prompt_attention_mask=concat_inputs["prompt_attention_mask"], + chosen_inputs=concat_inputs["input_ids"].chunk(2)[0], + chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0], + logits=outputs_pos, + ) + neg_prob = self.orpo_compute_logps( + prompt_attention_mask=concat_inputs["prompt_attention_mask"], + chosen_inputs=concat_inputs["input_ids"].chunk(2)[1], + chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1], + logits=outputs_neg, + ) + + # Calculate log odds + log_odds = (pos_prob - neg_prob) - ( + torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) + ) + sig_ratio = torch.nn.functional.sigmoid(log_odds) + ratio = torch.log(sig_ratio) + + # Calculate the Final Loss + loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( + dtype=torch.bfloat16 + ) + + metrics = {} + metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() + metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() + metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() + metrics["log_odds"] = torch.mean(log_odds).cpu().item() + self.store_metrics(metrics, train_eval="train") + + return (loss, outputs_pos) if return_outputs else loss + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tags=self.dataset_tags, kwargs=kwargs + ) + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + + return super().push_to_hub(*args, **kwargs) + + @wraps(Trainer.create_accelerator_and_postprocess) + def create_accelerator_and_postprocess(self): + res = super().create_accelerator_and_postprocess() + + if self.is_fsdp_enabled: + if ( + "limit_all_gathers" in self.args.fsdp_config + and self.args.fsdp_config["limit_all_gathers"] + ): + self.accelerator.state.fsdp_plugin.limit_all_gathers = True + + return res + + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + start_time (`Optional[float]`): + The start of training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + + return super().log(logs, start_time) + + def store_metrics( + self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" + ) -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def _save_checkpoint(self, model, trial, **kwargs): + # make sure the checkpoint dir exists, since trainer is flakey + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + os.makedirs(output_dir, exist_ok=True) + return super()._save_checkpoint(model, trial, **kwargs) + + +class AxolotlMambaTrainer(AxolotlTrainer): + """ + Mamba specific trainer to handle loss calculation + """ + + tag_names = ["axolotl", "mamba"] + + def compute_loss( + self, + model, + inputs, + return_outputs=False, # pylint: disable=unused-argument + num_items_in_batch=None, # pylint: disable=unused-argument + ): + input_ids = inputs.pop("input_ids") + lm_logits = model(input_ids).logits + + labels = input_ids.to(lm_logits.device) + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss() + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) + + return lm_loss + + +class ReLoRATrainer(AxolotlTrainer): + """ + Trainer subclass that uses the OneCycleLR scheduler + """ + + tag_names = ["axolotl", "relora"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + + def create_scheduler( + self, + num_training_steps: int, + optimizer: Optional[torch.optim.Optimizer] = None, + ): + optimizer = self.optimizer if optimizer is None else optimizer + lr_scheduler = super().create_scheduler(num_training_steps, optimizer) + + if self.args.relora_steps: + warmup_steps = ( + self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 + ) + anneal_steps = ( + self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 + ) + self.lr_scheduler = ReLoRAScheduler( + optimizer, + lr_scheduler, + self.args.relora_steps, + anneal_steps, + warmup_steps, + ) + else: + self.lr_scheduler = lr_scheduler + + return self.lr_scheduler + + +class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): + """ + Extend the base DPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "dpo"] + + def __init__(self, *args, dataset_tags=None, **kwargs): + super().__init__(*args, **kwargs) + self.dataset_tags = dataset_tags + self.optimizer = None + + def create_optimizer(self): + if self.args.loraplus_lr_ratio is None: + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + if loraplus_lr_ratio: + print("Using lora+") + loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + + @wraps(DPOTrainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tags=self.dataset_tags, kwargs=kwargs + ) + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + + return super().push_to_hub(*args, **kwargs) + + @staticmethod + def tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) -> Dict: + res = DPOTrainer.tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) + # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen + if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: + for key in res.keys(): + res[key] = res[key][1:] + + if processing_class.bos_token and processing_class.bos_token_id is not None: + # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs + if res["chosen_input_ids"][0] == processing_class.bos_token_id: + res["chosen_input_ids"] = res["chosen_input_ids"][1:] + res["chosen_labels"] = res["chosen_labels"][1:] + res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] + if res["rejected_input_ids"][0] == processing_class.bos_token_id: + res["rejected_input_ids"] = res["rejected_input_ids"][1:] + res["rejected_labels"] = res["rejected_labels"][1:] + res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] + + return res + + def training_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + num_items_in_batch=None, + ) -> torch.Tensor: + loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) + gc.collect() + torch.cuda.empty_cache() + return loss + + +class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): + """ + Extend the base ORPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "orpo"] + + +class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): + """ + Extend the base KTOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "kto"] + + +class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): + """ + Extend the base CPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "cpo"] + + +class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): + """ + Extend the base RewardTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "reward"] diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py new file mode 100644 index 000000000..6a8753e23 --- /dev/null +++ b/src/axolotl/core/training_args.py @@ -0,0 +1,220 @@ +""" +extra axolotl specific training args +""" +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments +from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig + + +@dataclass +class AxolotlTrainingMixins: + """ + Mixin class for the Axolotl training args. + """ + + # pylint: disable=duplicate-code + model_type: Optional[str] = field( + default=None, metadata={"help": "HF model configuration model_type."} + ) + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + multipack_real_batches: bool = field( + default=False, + metadata={"help": "Use real batches for efficient training."}, + ) + eval_sample_packing: Optional[bool] = field( + default=None, + metadata={"help": "Use sample packing for efficient evals."}, + ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) + sample_packing_bin_size: int = field( + default=200, + metadata={ + "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." + }, + ) + sample_packing_group_size: int = field( + default=100000, + metadata={ + "help": "The number of samples to group together for packing. Increase for better packing." + }, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) + relora_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for ReLoRA"}, + ) + relora_warmup_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_anneal_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_prune_ratio: Optional[float] = field( + default=0.9, + metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, + ) + bench_split: Optional[str] = field( + default="eval", metadata={"help": "The benchmark split to run on"} + ) + bench_dataset: Optional[str] = field( + default="pharaouk/dharma-1/dharma_1_mini.json", + metadata={ + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" + }, + ) + do_bench_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Benchmark evaluation."} + ) + do_causal_lm_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Causal LM evaluation."} + ) + max_bench_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." + }, + ) + bench_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for bench."} + ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={"help": "prefetch_factor argument to the dataloader"}, + ) + cosine_min_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, + ) + cosine_constant_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} + ) + loraplus_lr_embedding: Optional[float] = field( + default=1e-6, + metadata={"help": "loraplus learning rate for lora embedding layers."}, + ) + embedding_lr_scale: Optional[float] = field( + default=None, + metadata={"help": "Scale the learning rate for the embedding layers."}, + ) + embedding_lr: Optional[float] = field( + default=None, + metadata={"help": "absolute learning rate for the embedding layers."}, + ) + qlora: bool = field( + default=False, + metadata={"help": "whether this is a qlora training"}, + ) + orpo_alpha: Optional[float] = field( + default=None, + ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) + alternate_optimizer: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate optimizer to the HF trainer" + }, + ) + alternate_lr_scheduler_type: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate lr scheduler to the HF trainer" + }, + ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) + + +@dataclass +class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): + """ + Training arguments for Causal trainer + + This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value + so it can't be used as a mixin. + """ + + +@dataclass +class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): + """ + DPO config for DPO training + """ + + +@dataclass +class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): + """ + ORPO config for ORPO training + """ + + +@dataclass +class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig): + """ + KTO config for KTO training + """ + + +@dataclass +class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): + """ + CPO config for CPO training + """ + + simpo_gamma: Optional[float] = field( + default=None, + metadata={"help": "simpo gamma parameter"}, + ) + + +@dataclass +class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig): + """ + Reward config for Reward training + """ diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index fce2aba14..cddb3d0e1 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -10,6 +10,8 @@ LOG = logging.getLogger("axolotl") def load(strategy, cfg, module_base=None, **kwargs): try: + if len(strategy.split(".")) == 1: + strategy = strategy + ".default" load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) mod = importlib.import_module(f".{strategy}", module_base) diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 585696e29..5043a501e 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -3,22 +3,41 @@ DPO strategies for chatml """ -def argilla( +def default( cfg, **kwargs, ): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): + if "prompt" in sample.keys(): + prompt_key = "prompt" + elif "input" in sample.keys(): + prompt_key = "input" + elif "question" in sample.keys(): + prompt_key = "question" + else: + prompt_key = "instruction" + + if "chosen" in sample.keys(): + chosen_key = "chosen" + else: + chosen_key = "chosen_response" + + if "rejected" in sample.keys(): + rejected_key = "rejected" + else: + rejected_key = "rejected_response" + if "system" in sample and sample["system"]: sample["prompt"] = ( f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" ) else: sample[ "prompt" - ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" - sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" + ] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample[chosen_key]}<|im_end|>" + sample["rejected"] = f"{sample[rejected_key]}<|im_end|>" return sample return transform_fn diff --git a/src/axolotl/prompt_strategies/dpo/llama3.py b/src/axolotl/prompt_strategies/dpo/llama3.py index cb394cc22..d10aa223b 100644 --- a/src/axolotl/prompt_strategies/dpo/llama3.py +++ b/src/axolotl/prompt_strategies/dpo/llama3.py @@ -3,22 +3,42 @@ DPO strategies for llama-3 chat template """ -def argilla( +def default( cfg, **kwargs, ): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): + # pylint: disable=duplicate-code + if "prompt" in sample.keys(): + prompt_key = "prompt" + elif "input" in sample.keys(): + prompt_key = "input" + elif "question" in sample.keys(): + prompt_key = "question" + else: + prompt_key = "instruction" + + if "chosen" in sample.keys(): + chosen_key = "chosen" + else: + chosen_key = "chosen_response" + + if "rejected" in sample.keys(): + rejected_key = "rejected" + else: + rejected_key = "rejected_response" + if "system" in sample and sample["system"]: sample["prompt"] = ( f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" - f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: sample[ "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" - sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>" - sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>" + sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>" return sample return transform_fn