From 78ce2688484b5e77244b241d5542445cdee3882e Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 31 Jan 2025 20:18:52 -0500 Subject: [PATCH] KD Trainer w logprobs (#2303) * refactor trainer to prevent circular dependencies later fix loader default KD dataset loading and KD with logprobs filter bad rows make batch smaller handle padding/collation for KD datasets make it work flipped the slice cross entropy loss coefficient during KD make sure to multiply against the correct loss chore: lint triton wip no where support v2 trial no torch.exp inside triton kernel no log etc no torch.tensor v3 fix kwarg don't use triton for now better rescaling for temperatures hash for temperature too use kd_alpha in the correct loss method fix kd loss so it's causal (fixes repeating tokens) var naming and add todo chore: lint refactor so we can easily add new loss functions add license block remove references to triton kd for now handle token/logprob shifting support for custom trainer classes from plugins refactor kd chat template loader move more things to kd plugin remove moved class from import make plugin setup concise increase logging around loading plugins add copyrights remove duplicate code more info on preprocess for kd and fix import be a bit pickier about loading dynamic prompt strategies kd sample packing make loss torch script compat support streaming for processing sft datasts? improve iterable support ensure that batch vs single is done properly tweak check for batched prompt data reward can use same batch check fix reward trainer calls for tokenization improve check for batched reward model doesn't work well with batched add kd trainer e2e test linting rename test files so it gets picked up make the kd e2e fit in vram for ci and add lora version set lora_dropout explicitly lower lr make sure to set tokenizer from l3 70b and save safetensors make sure to use the correct tokenizer fix adapter model check make sure to use tensorboard to capture loss for checks chore: lint chore: lint improve logprob masking and shift in trainer more fixes try tests for kd on l40s don't shift student logits for kd no batching for kd chat templates make sure to truncate logprobs if there are more than top_k change up logic so we always truncate to top_k use iter instead of tuple fix finding the top-k rather than assuming first position has the correct val apply z-score scaling to kd kd loss needs to be calculated in full precision Always re-normalize teacher distribution various fixes * support for configurable top-k/softmax ordering * add attribute check for filter rows and lint * fix logic * handle none case for conversion to int * fix student logit off by one * set kd_temp to 1.0 for test loss * address PR feedback --- .github/workflows/tests.yml | 4 +- cicd/tests.py | 2 +- docs/rlhf.qmd | 2 +- src/axolotl/cli/args.py | 6 + src/axolotl/cli/preprocess.py | 5 +- src/axolotl/common/datasets.py | 6 + src/axolotl/core/trainer_builder.py | 1273 +---------------- src/axolotl/core/trainers/base.py | 988 +++++++++++++ src/axolotl/core/training_args.py | 264 ++++ src/axolotl/datasets.py | 32 +- src/axolotl/integrations/base.py | 41 +- src/axolotl/integrations/kd/__init__.py | 36 + src/axolotl/integrations/kd/args.py | 37 + src/axolotl/integrations/kd/chat_template.py | 201 +++ src/axolotl/integrations/kd/collator.py | 255 ++++ .../integrations/kd/kernels/__init__.py | 0 .../integrations/kd/topk_logprob/LICENSE.md | 58 + .../integrations/kd/topk_logprob/__init__.py | 0 .../kd/topk_logprob/forward_kl.py | 235 +++ src/axolotl/integrations/kd/trainer.py | 113 ++ src/axolotl/prompt_strategies/__init__.py | 13 +- src/axolotl/prompt_strategies/base.py | 2 + .../bradley_terry/chat_template.py | 10 +- .../prompt_strategies/chat_template.py | 145 +- src/axolotl/prompt_strategies/dpo/chatml.py | 29 +- src/axolotl/prompt_strategies/dpo/llama3.py | 30 +- src/axolotl/prompt_tokenizers.py | 4 +- .../config/models/input/v0_4_1/__init__.py | 4 + src/axolotl/utils/data/sft.py | 69 +- src/axolotl/utils/data/shared.py | 14 +- src/axolotl/utils/tokenization.py | 8 + src/axolotl/utils/trainer.py | 214 ++- tests/e2e/integrations/test_kd.py | 121 ++ tests/e2e/integrations/test_liger.py | 2 + 34 files changed, 2873 insertions(+), 1350 deletions(-) create mode 100644 src/axolotl/core/trainers/base.py create mode 100644 src/axolotl/core/training_args.py create mode 100644 src/axolotl/integrations/kd/__init__.py create mode 100644 src/axolotl/integrations/kd/args.py create mode 100644 src/axolotl/integrations/kd/chat_template.py create mode 100644 src/axolotl/integrations/kd/collator.py create mode 100644 src/axolotl/integrations/kd/kernels/__init__.py create mode 100644 src/axolotl/integrations/kd/topk_logprob/LICENSE.md create mode 100644 src/axolotl/integrations/kd/topk_logprob/__init__.py create mode 100644 src/axolotl/integrations/kd/topk_logprob/forward_kl.py create mode 100644 src/axolotl/integrations/kd/trainer.py create mode 100644 tests/e2e/integrations/test_kd.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 87d532a3b..b530908fd 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -207,7 +207,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.4.1 + pytorch: 2.5.1 num_gpus: 1 axolotl_extras: steps: @@ -247,7 +247,7 @@ jobs: - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" - pytorch: 2.5.1 + pytorch: 2.4.1 num_gpus: 1 axolotl_extras: steps: diff --git a/cicd/tests.py b/cicd/tests.py index 616554e64..6fe701632 100644 --- a/cicd/tests.py +++ b/cicd/tests.py @@ -59,7 +59,7 @@ VOLUME_CONFIG = { } N_GPUS = int(os.environ.get("N_GPUS", 1)) -GPU_CONFIG = modal.gpu.A10G(count=N_GPUS) +GPU_CONFIG = modal.gpu.L40S(count=N_GPUS) def run_cmd(cmd: str, run_folder: str): diff --git a/docs/rlhf.qmd b/docs/rlhf.qmd index 48701c87a..8f52876c1 100644 --- a/docs/rlhf.qmd +++ b/docs/rlhf.qmd @@ -29,7 +29,7 @@ datasets: type: chatml.intel - path: argilla/ultrafeedback-binarized-preferences split: train - type: chatml.argilla + type: chatml ``` #### IPO diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index a5865be1c..a39ffc308 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -13,6 +13,12 @@ class PreprocessCliArgs: debug_num_examples: int = field(default=1) prompter: Optional[str] = field(default=None) download: Optional[bool] = field(default=True) + iterable: Optional[bool] = field( + default=None, + metadata={ + "help": "Use IterableDataset for streaming processing of large datasets" + }, + ) @dataclass diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 760fe76fa..5585c88a7 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -75,7 +75,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: ) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: +def do_cli( + config: Union[Path, str] = Path("examples/"), + **kwargs, +) -> None: """ Parses `axolotl` config, CLI args, and calls `do_preprocess`. diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index c693c26d8..cbc0d127c 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -63,11 +63,17 @@ def load_datasets( """ tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None + preprocess_iterable = ( + hasattr(cli_args, "iterable") + and cli_args.iterable is not None + and cli_args.iterable + ) train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( cfg, tokenizer, processor=processor, + preprocess_iterable=preprocess_iterable, ) if ( diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 6bf03d78c..89480d775 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1,59 +1,65 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # pylint: disable=too-many-lines """ Builder for the training args and trainer """ import abc -import gc import importlib import importlib.util import inspect import logging import math -import os import sys from abc import abstractmethod -from collections import defaultdict -from dataclasses import dataclass, field -from functools import wraps from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Type, Union +from typing import List, Type, Union import torch import transformers -from datasets import Dataset -from peft.optimizers import create_loraplus_optimizer -from torch import nn -from torch.optim.lr_scheduler import OneCycleLR -from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from transformers import ( DataCollatorWithFlattening, EarlyStoppingCallback, - Trainer, TrainerCallback, - TrainingArguments, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker -from transformers.utils import is_sagemaker_mp_enabled -from trl import ( - CPOConfig, - CPOTrainer, - DPOConfig, - DPOTrainer, - KTOConfig, - KTOTrainer, - ORPOConfig, - ORPOTrainer, - PRMConfig, - PRMTrainer, - RewardConfig, - RewardTrainer, -) -from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length +from trl.trainer.utils import RewardDataCollatorWithPadding +from axolotl.core.trainers.base import ( + AxolotlCPOTrainer, + AxolotlDPOTrainer, + AxolotlKTOTrainer, + AxolotlMambaTrainer, + AxolotlORPOTrainer, + AxolotlPRMTrainer, + AxolotlRewardTrainer, + AxolotlTrainer, + ReLoRATrainer, +) +from axolotl.core.training_args import ( + AxolotlCPOConfig, + AxolotlDPOConfig, + AxolotlKTOConfig, + AxolotlORPOConfig, + AxolotlPRMConfig, + AxolotlRewardConfig, + AxolotlTrainingArguments, +) from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES -from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler +from axolotl.monkeypatch.relora import ReLoRACallback from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils.callbacks import ( EvalFirstStepCallback, @@ -78,15 +84,6 @@ from axolotl.utils.collators import ( ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator from axolotl.utils.models import ensure_dtype -from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths -from axolotl.utils.schedulers import ( - get_cosine_schedule_with_min_lr, - get_cosine_schedule_with_quadratic_warmup, - get_cosine_schedule_with_warmup_decay_constant, -) - -if is_sagemaker_mp_enabled(): - import smdistributed.modelparallel.torch as smp try: import torch._dynamo # pylint: disable=ungrouped-imports @@ -96,1171 +93,6 @@ except ImportError: LOG = logging.getLogger("axolotl.core.trainer_builder") -def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): - if isinstance(tag_names, str): - tag_names = [tag_names] - - if kwargs is not None: - if "tags" not in kwargs: - kwargs["tags"] = tag_names - elif "tags" in kwargs and isinstance(kwargs["tags"], list): - kwargs["tags"].extend(tag_names) - elif "tags" in kwargs and isinstance(kwargs["tags"], str): - tag_names.append(kwargs["tags"]) - kwargs["tags"] = tag_names - - return kwargs - - -def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): - if isinstance(dataset_tags, str): - dataset_tags = [dataset_tags] - - if (dataset_tags is not None) and (kwargs is not None): - if "dataset_tags" not in kwargs: - kwargs["dataset_tags"] = dataset_tags - elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): - kwargs["dataset_tags"].extend(dataset_tags) - elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): - dataset_tags.append(kwargs["dataset_tags"]) - kwargs["dataset_tags"] = dataset_tags - - return kwargs - - -@dataclass -class AxolotlTrainingMixins: - """ - Mixin class for the Axolotl training args. - """ - - # pylint: disable=duplicate-code - model_type: Optional[str] = field( - default=None, metadata={"help": "HF model configuration model_type."} - ) - lr_quadratic_warmup: bool = field( - default=False, - metadata={"help": "Use quadratic warmup for cosine scheduling."}, - ) - pretraining: bool = field( - default=False, - metadata={ - "help": "Indicates to trainer whether we are doing continued pretraining." - }, - ) - sample_packing: bool = field( - default=False, - metadata={"help": "Use sample packing for efficient training."}, - ) - multipack_real_batches: bool = field( - default=False, - metadata={"help": "Use real batches for efficient training."}, - ) - eval_sample_packing: Optional[bool] = field( - default=None, - metadata={"help": "Use sample packing for efficient evals."}, - ) - sample_packing_efficiency: float = field( - default=1.0, - metadata={"help": "Sample packing efficiency for calculating batch length."}, - ) - sample_packing_bin_size: int = field( - default=200, - metadata={ - "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." - }, - ) - sample_packing_group_size: int = field( - default=100000, - metadata={ - "help": "The number of samples to group together for packing. Increase for better packing." - }, - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "The maximum sequence length the model can handle"}, - ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_anneal_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_prune_ratio: Optional[float] = field( - default=0.9, - metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, - ) - bench_split: Optional[str] = field( - default="eval", metadata={"help": "The benchmark split to run on"} - ) - bench_dataset: Optional[str] = field( - default="pharaouk/dharma-1/dharma_1_mini.json", - metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" - }, - ) - do_bench_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Benchmark evaluation."} - ) - do_causal_lm_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Causal LM evaluation."} - ) - max_bench_samples: Optional[int] = field( - default=None, - metadata={ - "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." - }, - ) - bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for bench."} - ) - dataloader_prefetch_factor: Optional[int] = field( - default=None, - metadata={"help": "prefetch_factor argument to the dataloader"}, - ) - cosine_min_lr_ratio: Optional[float] = field( - default=None, - metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, - ) - cosine_constant_lr_ratio: Optional[float] = field( - default=None, - metadata={ - "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" - }, - ) - loraplus_lr_ratio: Optional[float] = field( - default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} - ) - loraplus_lr_embedding: Optional[float] = field( - default=1e-6, - metadata={"help": "loraplus learning rate for lora embedding layers."}, - ) - embedding_lr_scale: Optional[float] = field( - default=None, - metadata={"help": "Scale the learning rate for the embedding layers."}, - ) - lr_groups: Optional[list[dict]] = field( - default=None, - metadata={"help": "Specify learning rate groups for with different LRs."}, - ) - embedding_lr: Optional[float] = field( - default=None, - metadata={"help": "absolute learning rate for the embedding layers."}, - ) - qlora: bool = field( - default=False, - metadata={"help": "whether this is a qlora training"}, - ) - orpo_alpha: Optional[float] = field( - default=None, - ) - lisa_n_layers: Optional[int] = field( - default=None, - metadata={"help": "the number of activate layers in LISA"}, - ) - lisa_step_interval: Optional[int] = field( - default=None, - metadata={"help": "how often to switch layers in LISA"}, - ) - lisa_layers_attribute: Optional[str] = field( - default=None, - metadata={"help": "path under the model to access the layers"}, - ) - curriculum_sampling: Optional[bool] = field( - default=None, - metadata={"help": "whether to use sequential sampling for curriculum learning"}, - ) - alternate_optimizer: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate optimizer to the HF trainer" - }, - ) - alternate_lr_scheduler_type: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate lr scheduler to the HF trainer" - }, - ) - chat_template: Optional[str] = field( - default=None, - metadata={"help": "Chat template converting chat messages to text"}, - ) - - -@dataclass -class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): - """ - Training arguments for Causal trainer - - This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value - so it can't be used as a mixin. - """ - - -@dataclass -class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): - """ - DPO config for DPO training - """ - - -@dataclass -class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): - """ - ORPO config for ORPO training - """ - - -@dataclass -class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig): - """ - KTO config for KTO training - """ - - -@dataclass -class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): - """ - CPO config for CPO training - """ - - simpo_gamma: Optional[float] = field( - default=None, - metadata={"help": "simpo gamma parameter"}, - ) - - -@dataclass -class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig): - """ - Reward config for Reward training - """ - - -@dataclass -class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig): - """ - PRM config for PRM training - """ - - -class SchedulerMixin(Trainer): - """ - Mixin class for scheduler setup in CausalTrainer. - """ - - args = None # type: AxolotlTrainingArguments - - def create_scheduler( - self, num_training_steps: int, optimizer: torch.optim.Optimizer = None - ): - """ - Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or - passed as an argument. - - Args: - num_training_steps (int): The number of training steps to do. - optimizer (torch.optim.Optimizer): The training optimizer - """ - use_cosine_quadratic = ( - self.args.lr_scheduler_type == "cosine" - and self.args.lr_quadratic_warmup is True - ) - - use_cosine_min_lr = ( - self.args.lr_scheduler_type == "cosine" - and self.args.cosine_min_lr_ratio is not None - ) - - # fmt: off - if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition - # fmt: on - if self.args.alternate_lr_scheduler_type == "one_cycle": - num_warmup_steps = self.args.get_warmup_steps(num_training_steps) - pct_start = num_warmup_steps / num_training_steps - extra_lr_kwargs = {} - if "pct_start" not in self.args.lr_scheduler_kwargs: - extra_lr_kwargs["pct_start"] = pct_start - if "anneal_strategy" not in self.args.lr_scheduler_kwargs: - extra_lr_kwargs["anneal_strategy"] = "cos" - - self.lr_scheduler = OneCycleLR( - optimizer, - max_lr=self.args.learning_rate, - total_steps=num_training_steps, - **extra_lr_kwargs, - **self.args.lr_scheduler_kwargs, - ) - elif use_cosine_quadratic: - if use_cosine_min_lr: - LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") - - self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - ) - elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - constant_lr_ratio=self.args.cosine_constant_lr_ratio, - ) - elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: - assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" - self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init - optimizer, - num_warmup_steps=self.args.get_warmup_steps(num_training_steps), - num_training_steps=num_training_steps, - min_lr_ratio=self.args.cosine_min_lr_ratio, - ) - else: - return super().create_scheduler(num_training_steps, optimizer=optimizer) - else: - if use_cosine_quadratic: - LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") - - if use_cosine_min_lr: - LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") - - return self.lr_scheduler - - -class AxolotlTrainer(SchedulerMixin, Trainer): - """ - Extend the base Trainer for axolotl helpers - """ - - args = None # type: AxolotlTrainingArguments - tag_names = ["axolotl"] - - def __init__( - self, - *_args, - bench_data_collator=None, - eval_data_collator=None, - dataset_tags=None, - **kwargs, - ): - self.bench_data_collator = bench_data_collator - self.eval_data_collator = eval_data_collator - self.dataset_tags = dataset_tags - super().__init__(*_args, **kwargs) - self.train_data_collator = self.data_collator - self._stored_metrics = defaultdict(lambda: defaultdict(list)) - if self.args.orpo_alpha: - self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") - - def _wrap_model(self, model, training=True, dataloader=None): - if self.args.torch_compile: - torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access - 256 - ) - model = torch.compile( - model, - backend=self.args.torch_compile_backend, - mode=self.args.torch_compile_mode, - ) - return super()._wrap_model(model, training=training, dataloader=dataloader) - - def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs): - decay_parameters = self.get_decay_parameter_names(opt_model) - params = { - "to_weight_decay": {}, # LayerNorm and bias - "embeddings": {}, # lm_head, embed_tokens, - "no_weight_decay": {}, - } - lr_groups_lookup = {} - lr_groups_learning_rates = {} - if self.args.lr_groups: - for lr_group in self.args.lr_groups: - group_name = lr_group["name"] - group_modules = lr_group["modules"] - for module in group_modules: - lr_groups_lookup[module] = group_name - lr_groups_learning_rates[group_name] = lr_group["lr"] - params[f"to_weight_decay_{group_name}"] = {} - - for name, param in opt_model.named_parameters(): - if not param.requires_grad: - continue - if name.endswith("modules_to_save.default.weight") or any( - embed_name in name for embed_name in ["embed_tokens", "lm_head"] - ): - params["embeddings"][name] = param - elif name in decay_parameters: - lr_group_modules = [ - group_modules - for group_modules in lr_groups_lookup - if group_modules in name - ] - if lr_groups_lookup and any(lr_group_modules): - lr_group_module = lr_group_modules[0] - group_name = lr_groups_lookup[lr_group_module] - params[f"to_weight_decay_{group_name}"][name] = param - else: - params["to_weight_decay"][name] = param - else: - params["no_weight_decay"][name] = param - optimizer_grouped_parameters = [] - if params["to_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["to_weight_decay"].values()), - "weight_decay": self.args.weight_decay, - "lr": optimizer_kwargs["lr"], - } - ) - if params["embeddings"]: - lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name - if self.args.embedding_lr_scale: - lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name - elif self.args.embedding_lr: - lr = self.args.embedding_lr # pylint: disable=invalid-name - optimizer_grouped_parameters.append( - { - "params": list(params["embeddings"].values()), - "weight_decay": 0.0, - "lr": lr, - } - ) - if params["no_weight_decay"]: - optimizer_grouped_parameters.append( - { - "params": list(params["no_weight_decay"].values()), - "weight_decay": 0.0, - "lr": optimizer_kwargs["lr"], - } - ) - for group_name, group_lr in lr_groups_learning_rates.items(): - if params[f"to_weight_decay_{group_name}"]: - optimizer_grouped_parameters.append( - { - "params": list( - params[f"to_weight_decay_{group_name}"].values() - ), - "weight_decay": self.args.weight_decay, - "lr": group_lr, - } - ) - - return optimizer_grouped_parameters - - def create_optimizer(self): - if ( - self.args.loraplus_lr_ratio is None - and self.args.embedding_lr_scale is None - and self.args.embedding_lr is None - and self.args.lr_groups is None - and self.args.alternate_optimizer - not in [ - "optimi_adamw", - "ao_adamw_8bit", - "ao_adamw_4bit", - "ao_adamw_fp8", - "adopt_adamw", - ] - ): - return super().create_optimizer() - - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) - optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( - opt_model, optimizer_kwargs - ) - - if self.args.loraplus_lr_ratio is not None: - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr( - self.args, "loraplus_lr_embedding", 1e-6 - ) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, - **optimizer_kwargs, - ) - elif ( - self.args.embedding_lr_scale is not None - or self.args.embedding_lr is not None - or self.args.lr_groups is not None - ): - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "optimi_adamw": - from optimi import AdamW - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW( - optimizer_grouped_parameters, foreach=False, **optimizer_kwargs - ) - ) - elif self.args.alternate_optimizer == "ao_adamw_4bit": - from torchao.prototype.low_bit_optim import AdamW4bit - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "ao_adamw_8bit": - from torchao.prototype.low_bit_optim import AdamW8bit - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "ao_adamw_fp8": - from torchao.prototype.low_bit_optim import AdamWFp8 - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs) - ) - elif self.args.alternate_optimizer == "adopt_adamw": - from axolotl.utils.optimizers.adopt import ADOPT - - self.optimizer = ( # pylint: disable=attribute-defined-outside-init - ADOPT( - optimizer_grouped_parameters, - decouple=True, - **optimizer_kwargs, - ) - ) - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) - - return self.optimizer - - def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and not self.args.pretraining: - if self.args.multipack_real_batches: - batch_size = self.args.per_device_train_batch_size - batch_max_len = self.args.max_seq_length - else: - batch_size = 1 - train_batch_size = ( - self.state.train_batch_size or self.args.per_device_train_batch_size - ) - batch_max_len = train_batch_size * self.args.max_seq_length - - if self.args.curriculum_sampling: - sampler = SequentialSampler(self.train_dataset) - else: - sampler = RandomSampler(self.train_dataset) - - return MultipackBatchSampler( - sampler, - lengths=get_dataset_lengths(self.train_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - batch_max_len=batch_max_len, - batch_size=batch_size, - group_size=self.args.sample_packing_group_size, - bin_size=self.args.sample_packing_bin_size, - drop_last=True, - ) - if self.args.curriculum_sampling: - return SequentialSampler(self.train_dataset) - return super()._get_train_sampler() - - def _get_eval_sampler( - self, eval_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.sample_packing and self.args.eval_sample_packing is not False: - if self.args.multipack_real_batches: - batch_size = self.args.per_device_eval_batch_size - batch_max_len = self.args.max_seq_length - else: - batch_size = 1 - batch_max_len = ( - self.args.per_device_eval_batch_size * self.args.max_seq_length - ) - return MultipackBatchSampler( - SequentialSampler(eval_dataset), - lengths=get_dataset_lengths(self.eval_dataset), - packing_efficiency_estimate=self.args.sample_packing_efficiency, - batch_max_len=batch_max_len, - batch_size=batch_size, - group_size=self.args.sample_packing_group_size, - bin_size=self.args.sample_packing_bin_size, - drop_last=True, - ) - return super()._get_eval_sampler(eval_dataset) - - def get_train_dataloader(self) -> DataLoader: - if self.args.sample_packing and not self.args.pretraining: - train_dataset = self.train_dataset - if "length" in train_dataset.features.keys(): - train_dataset = train_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self._train_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor - - sampler = self._get_train_sampler() - if isinstance(sampler, BatchSampler): - dataloader_params["batch_sampler"] = sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - dataloader_params["worker_init_fn"] = seed_worker - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(train_dataset, **dataloader_params) - ) - return super().get_train_dataloader() - - def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: - if self.args.sample_packing and self.args.eval_sample_packing is False: - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.eval_data_collator - ) - if eval_dataset: - eval_dataset = eval_dataset.remove_columns(["length"]) - dataloader = super().get_eval_dataloader(eval_dataset) - self.data_collator = ( # pylint: disable=attribute-defined-outside-init - self.train_data_collator - ) - return dataloader - - if self.args.sample_packing and self.args.eval_sample_packing is not False: - eval_dataset = ( - eval_dataset if eval_dataset is not None else self.eval_dataset - ) - - eval_sampler = self._get_eval_sampler(eval_dataset) - eval_dataset = eval_dataset.remove_columns(["length"]) - data_collator = self.data_collator - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params[ - "prefetch_factor" - ] = self.args.dataloader_prefetch_factor - - if isinstance(eval_sampler, BatchSampler): - dataloader_params["batch_sampler"] = eval_sampler - del dataloader_params["batch_size"] - else: - dataloader_params["sampler"] = eval_sampler - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - self.accelerator.even_batches = False - return self.accelerator.prepare_data_loader( - DataLoader(eval_dataset, **dataloader_params) - ) - - return super().get_eval_dataloader(eval_dataset) - - def _get_bench_sampler( - self, bench_dataset: Dataset - ) -> Optional[torch.utils.data.Sampler]: - if self.args.world_size <= 1: - return SequentialSampler(bench_dataset) - return None - - def get_bench_dataloader( - self, - bench_dataset: Dataset, - ) -> DataLoader: - dataloader_params = { - "batch_size": self.args.eval_batch_size, - "collate_fn": self.bench_data_collator, - "num_workers": self.args.dataloader_num_workers, - "pin_memory": self.args.dataloader_pin_memory, - } - if self.args.dataloader_prefetch_factor: - dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor - - if not isinstance(bench_dataset, torch.utils.data.IterableDataset): - dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) - dataloader_params["drop_last"] = self.args.dataloader_drop_last - - return DataLoader(bench_dataset, **dataloader_params) - # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) - - def compute_loss( - self, model, inputs, return_outputs=False, num_items_in_batch=None - ): - # use one's weighted cross entropy loss calc - # if self.args.sample_packing: - # labels = inputs.pop("labels") - # outputs = model(**inputs) - # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) - # return (loss, outputs) if return_outputs else loss - if self.args.orpo_alpha: - return self.orpo_compute_loss( - model, - inputs, - return_outputs=return_outputs, - num_items_in_batch=num_items_in_batch, - ) - return super().compute_loss( - model, - inputs, - return_outputs=return_outputs, - num_items_in_batch=num_items_in_batch, - ) - - @staticmethod - def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): - concatenated_batch = {} - - max_length = max( - inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1] - ) - # Concatenate positive and negative inputs - concatenated_batch["input_ids"] = pad_to_length( - inputs["input_ids"], max_length, pad_token - ) - concatenated_batch["rejected_input_ids"] = pad_to_length( - inputs["rejected_input_ids"], max_length, pad_token - ) - concatenated_batch["labels"] = pad_to_length( - inputs["labels"], max_length, label_pad_token - ) - concatenated_batch["rejected_labels"] = pad_to_length( - inputs["rejected_labels"], max_length, label_pad_token - ) - concatenated_batch["attention_mask"] = pad_to_length( - inputs["attention_mask"], max_length, 0 - ) - concatenated_batch["rejected_attention_mask"] = pad_to_length( - inputs["rejected_attention_mask"], max_length, 0 - ) - concatenated_batch["prompt_attention_mask"] = pad_to_length( - inputs["prompt_attention_mask"], max_length, 0 - ).to(device=device) - - input_ids = torch.cat( - [concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]], - dim=0, - ).to(device=device) - attention_mask = torch.cat( - [ - concatenated_batch["attention_mask"], - concatenated_batch["rejected_attention_mask"], - ], - dim=0, - ).to(device=device) - labels = torch.cat( - [concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0 - ).to(device=device) - - return { - "input_ids": input_ids, - "labels": labels, - "attention_mask": attention_mask, - "prompt_attention_mask": concatenated_batch["prompt_attention_mask"], - } - - def orpo_compute_custom_loss(self, logits, labels): - logits = logits.contiguous() - loss = 0.0 - - if labels is not None: - # move labels to correct device to enable model parallelism - labels = labels.to(logits.device) - # Shift so that tokens < n predict n - shift_logits = logits[..., :-1, :].contiguous() - shift_labels = labels[..., 1:].contiguous() - - # Flatten the tokens - loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( - dim=-1 - ) - - return loss - - def orpo_compute_logps( - self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits - ): - # Get the shape of chosen_attention_mask[:, :-1] - chosen_shape = chosen_attention_mask[:, :-1].shape - - # Calculate the padding size - pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) - - # Pad prompt_attention_mask with zeros to match the desired shape - prompt_attention_mask_padded = torch.nn.functional.pad( - prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 - ) - - # Perform the subtraction operation - mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded - - per_token_logps = torch.gather( - logits[:, :-1, :].log_softmax(-1), - dim=2, - index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), - ).squeeze(2) - return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) - - def orpo_compute_loss( - self, - model, - inputs, - return_outputs=False, - num_items_in_batch=None, # pylint: disable=unused-argument - ): - concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( - inputs, - label_pad_token=-100, - pad_token=self.tokenizer.pad_token_id, - device=self.accelerator.device, - ) - - # Perform a single forward pass - outputs = model( - **{ - "input_ids": concat_inputs["input_ids"], - "attention_mask": concat_inputs["attention_mask"], - "labels": concat_inputs["labels"], - }, - output_hidden_states=True, - ) - - # Split the outputs for positive and negative examples - outputs_pos, outputs_neg = outputs.logits.chunk(2) - - # Calculate NLL loss - pos_loss = self.orpo_compute_custom_loss( - logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0] - ) - - # Calculate Log Probability - pos_prob = self.orpo_compute_logps( - prompt_attention_mask=concat_inputs["prompt_attention_mask"], - chosen_inputs=concat_inputs["input_ids"].chunk(2)[0], - chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0], - logits=outputs_pos, - ) - neg_prob = self.orpo_compute_logps( - prompt_attention_mask=concat_inputs["prompt_attention_mask"], - chosen_inputs=concat_inputs["input_ids"].chunk(2)[1], - chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1], - logits=outputs_neg, - ) - - # Calculate log odds - log_odds = (pos_prob - neg_prob) - ( - torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) - ) - sig_ratio = torch.nn.functional.sigmoid(log_odds) - ratio = torch.log(sig_ratio) - - # Calculate the Final Loss - loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( - dtype=torch.bfloat16 - ) - - metrics = {} - metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() - metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() - metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() - metrics["log_odds"] = torch.mean(log_odds).cpu().item() - self.store_metrics(metrics, train_eval="train") - - return (loss, outputs_pos) if return_outputs else loss - - @wraps(Trainer.push_to_hub) - def push_to_hub(self, *args, **kwargs) -> str: - """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - """ - kwargs = _sanitize_kwargs_for_ds_tagging( - dataset_tags=self.dataset_tags, kwargs=kwargs - ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) - - return super().push_to_hub(*args, **kwargs) - - @wraps(Trainer.create_accelerator_and_postprocess) - def create_accelerator_and_postprocess(self): - res = super().create_accelerator_and_postprocess() - - if self.is_fsdp_enabled: - if ( - "limit_all_gathers" in self.args.fsdp_config - and self.args.fsdp_config["limit_all_gathers"] - ): - self.accelerator.state.fsdp_plugin.limit_all_gathers = True - - return res - - def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: - """ - Log `logs` on the various objects watching training, including stored metrics. - - Args: - logs (`Dict[str, float]`): - The values to log. - start_time (`Optional[float]`): - The start of training. - """ - # logs either has 'loss' or 'eval_loss' - train_eval = "train" if "loss" in logs else "eval" - # Add averaged stored metrics to logs - for key, metrics in self._stored_metrics[train_eval].items(): - logs[key] = torch.tensor(metrics).mean().item() - del self._stored_metrics[train_eval] - - return super().log(logs, start_time) - - def store_metrics( - self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" - ) -> None: - for key, value in metrics.items(): - self._stored_metrics[train_eval][key].append(value) - - def _save_checkpoint(self, model, trial, **kwargs): - # make sure the checkpoint dir exists, since trainer is flakey - checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" - run_dir = self._get_output_dir(trial=trial) - output_dir = os.path.join(run_dir, checkpoint_folder) - os.makedirs(output_dir, exist_ok=True) - return super()._save_checkpoint(model, trial, **kwargs) - - -class AxolotlMambaTrainer(AxolotlTrainer): - """ - Mamba specific trainer to handle loss calculation - """ - - tag_names = ["axolotl", "mamba"] - - def compute_loss( - self, - model, - inputs, - return_outputs=False, # pylint: disable=unused-argument - num_items_in_batch=None, # pylint: disable=unused-argument - ): - input_ids = inputs.pop("input_ids") - lm_logits = model(input_ids).logits - - labels = input_ids.to(lm_logits.device) - shift_logits = lm_logits[:, :-1, :].contiguous() - labels = labels[:, 1:].contiguous() - - loss_fct = torch.nn.CrossEntropyLoss() - lm_loss = loss_fct( - shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) - ) - - return lm_loss - - -class ReLoRATrainer(AxolotlTrainer): - """ - Trainer subclass that uses the OneCycleLR scheduler - """ - - tag_names = ["axolotl", "relora"] - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.lr_scheduler = None - - def create_scheduler( - self, - num_training_steps: int, - optimizer: Optional[torch.optim.Optimizer] = None, - ): - optimizer = self.optimizer if optimizer is None else optimizer - lr_scheduler = super().create_scheduler(num_training_steps, optimizer) - - if self.args.relora_steps: - warmup_steps = ( - self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 - ) - anneal_steps = ( - self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 - ) - self.lr_scheduler = ReLoRAScheduler( - optimizer, - lr_scheduler, - self.args.relora_steps, - anneal_steps, - warmup_steps, - ) - else: - self.lr_scheduler = lr_scheduler - - return self.lr_scheduler - - -class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): - """ - Extend the base DPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "dpo"] - - def __init__(self, *args, dataset_tags=None, **kwargs): - super().__init__(*args, **kwargs) - self.dataset_tags = dataset_tags - self.optimizer = None - self.model_accepts_loss_kwargs = False - - def create_optimizer(self): - if self.args.loraplus_lr_ratio is None: - return super().create_optimizer() - - opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model - if self.optimizer is None: # pylint: disable=access-member-before-definition - optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( - self.args, - opt_model, - ) - - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - if loraplus_lr_ratio: - print("Using lora+") - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - loraplus_lr_ratio=loraplus_lr_ratio, - loraplus_lr_embedding=loraplus_lr_embedding, - **optimizer_kwargs, - ) - - if is_sagemaker_mp_enabled(): - self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init - self.optimizer - ) - - return self.optimizer - - @wraps(DPOTrainer.push_to_hub) - def push_to_hub(self, *args, **kwargs) -> str: - """ - Overwrite the `push_to_hub` method in order to force-add the tags when pushing the - model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. - """ - kwargs = _sanitize_kwargs_for_ds_tagging( - dataset_tags=self.dataset_tags, kwargs=kwargs - ) - kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) - - return super().push_to_hub(*args, **kwargs) - - @staticmethod - def tokenize_row( - features, - processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, - ) -> Dict: - res = DPOTrainer.tokenize_row( - features, - processing_class, - max_prompt_length, - max_completion_length, - add_special_tokens, - ) - # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen - if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: - for key in res.keys(): - res[key] = res[key][1:] - - if processing_class.bos_token and processing_class.bos_token_id is not None: - # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs - if res["chosen_input_ids"][0] == processing_class.bos_token_id: - res["chosen_input_ids"] = res["chosen_input_ids"][1:] - res["chosen_labels"] = res["chosen_labels"][1:] - res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] - if res["rejected_input_ids"][0] == processing_class.bos_token_id: - res["rejected_input_ids"] = res["rejected_input_ids"][1:] - res["rejected_labels"] = res["rejected_labels"][1:] - res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] - - return res - - def training_step( - self, - model: nn.Module, - inputs: Dict[str, Union[torch.Tensor, Any]], - num_items_in_batch=None, - ) -> torch.Tensor: - loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) - gc.collect() - torch.cuda.empty_cache() - return loss - - -class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): - """ - Extend the base ORPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "orpo"] - - -class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): - """ - Extend the base KTOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "kto"] - - -class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): - """ - Extend the base CPOTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "cpo"] - - -class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): - """ - Extend the base RewardTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "reward"] - - -class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): - """ - Extend the base trl.PRMTrainer for axolotl helpers - """ - - tag_names = ["axolotl", "prm"] - - class TrainerBuilderBase(abc.ABC): """ Base class for trainer builder @@ -1464,6 +296,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return callbacks def _get_trainer_cls(self): + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + if trainer_cls: + return trainer_cls if self.cfg.relora_steps: return ReLoRATrainer if self.cfg.model_config_type == "mamba": @@ -1862,13 +699,27 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): "accelerator_config" ] = self.cfg.accelerator_config + if self.cfg.kd_ce_alpha is not None: + training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha + if self.cfg.kd_alpha is not None: + training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha + if self.cfg.kd_temperature is not None: + training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature + if self.cfg.kd_zscore_base_temp is not None: + training_arguments_kwargs[ + "kd_zscore_base_temp" + ] = self.cfg.kd_zscore_base_temp + if self.cfg.kd_top_k_before_softmax is not None: + training_arguments_kwargs[ + "kd_top_k_before_softmax" + ] = self.cfg.kd_top_k_before_softmax + if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig elif self.cfg.process_reward_model: training_args_cls = AxolotlPRMConfig else: training_args_cls = AxolotlTrainingArguments - training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg **training_arguments_kwargs, ) @@ -1995,6 +846,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) kwargs.pop("padding", None) + elif self.cfg.kd_trainer: + from axolotl.integrations.kd.collator import ( + DataCollatorForKD, + KDBatchSamplerDataCollatorForSeq2Seq, + ) + + if self.cfg.sample_packing: + collator = KDBatchSamplerDataCollatorForSeq2Seq + else: + collator = DataCollatorForKD else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py new file mode 100644 index 000000000..44a6d54d7 --- /dev/null +++ b/src/axolotl/core/trainers/base.py @@ -0,0 +1,988 @@ +""" +module for customized trainers +""" + +from __future__ import annotations + +# pylint: disable=too-many-lines +import gc +import logging +import os +from collections import defaultdict +from functools import wraps +from typing import Any, Dict, Literal, Optional, Union + +import torch +from datasets import Dataset +from peft.optimizers import create_loraplus_optimizer +from torch import nn +from torch.optim.lr_scheduler import OneCycleLR +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler +from transformers import Trainer +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker +from transformers.utils import is_sagemaker_mp_enabled +from trl import ( + CPOTrainer, + DPOTrainer, + KTOTrainer, + ORPOTrainer, + PRMTrainer, + RewardTrainer, +) +from trl.trainer.utils import pad_to_length + +from axolotl.monkeypatch.relora import ReLoRAScheduler +from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths +from axolotl.utils.schedulers import ( + get_cosine_schedule_with_min_lr, + get_cosine_schedule_with_quadratic_warmup, + get_cosine_schedule_with_warmup_decay_constant, +) + +if is_sagemaker_mp_enabled(): + import smdistributed.modelparallel.torch as smp + +LOG = logging.getLogger("axolotl.core.trainer_builder") + + +def _sanitize_kwargs_for_tagging(tag_names, kwargs=None): + if isinstance(tag_names, str): + tag_names = [tag_names] + + if kwargs is not None: + if "tags" not in kwargs: + kwargs["tags"] = tag_names + elif "tags" in kwargs and isinstance(kwargs["tags"], list): + kwargs["tags"].extend(tag_names) + elif "tags" in kwargs and isinstance(kwargs["tags"], str): + tag_names.append(kwargs["tags"]) + kwargs["tags"] = tag_names + + return kwargs + + +def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None): + if isinstance(dataset_tags, str): + dataset_tags = [dataset_tags] + + if (dataset_tags is not None) and (kwargs is not None): + if "dataset_tags" not in kwargs: + kwargs["dataset_tags"] = dataset_tags + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list): + kwargs["dataset_tags"].extend(dataset_tags) + elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str): + dataset_tags.append(kwargs["dataset_tags"]) + kwargs["dataset_tags"] = dataset_tags + + return kwargs + + +class SchedulerMixin(Trainer): + """ + Mixin class for scheduler setup in CausalTrainer. + """ + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + + def create_scheduler( + self, num_training_steps: int, optimizer: torch.optim.Optimizer = None + ): + """ + Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or + passed as an argument. + + Args: + num_training_steps (int): The number of training steps to do. + optimizer (torch.optim.Optimizer): The training optimizer + """ + use_cosine_quadratic = ( + self.args.lr_scheduler_type == "cosine" + and self.args.lr_quadratic_warmup is True + ) + + use_cosine_min_lr = ( + self.args.lr_scheduler_type == "cosine" + and self.args.cosine_min_lr_ratio is not None + ) + + # fmt: off + if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition + # fmt: on + if self.args.alternate_lr_scheduler_type == "one_cycle": + num_warmup_steps = self.args.get_warmup_steps(num_training_steps) + pct_start = num_warmup_steps / num_training_steps + extra_lr_kwargs = {} + if "pct_start" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["pct_start"] = pct_start + if "anneal_strategy" not in self.args.lr_scheduler_kwargs: + extra_lr_kwargs["anneal_strategy"] = "cos" + + self.lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + **extra_lr_kwargs, + **self.args.lr_scheduler_kwargs, + ) + elif use_cosine_quadratic: + if use_cosine_min_lr: + LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.") + + self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + ) + elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + constant_lr_ratio=self.args.cosine_constant_lr_ratio, + ) + elif self.args.cosine_min_lr_ratio and use_cosine_min_lr: + assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0" + self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init + optimizer, + num_warmup_steps=self.args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + min_lr_ratio=self.args.cosine_min_lr_ratio, + ) + else: + return super().create_scheduler(num_training_steps, optimizer=optimizer) + else: + if use_cosine_quadratic: + LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).") + + if use_cosine_min_lr: + LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).") + + return self.lr_scheduler + + +class AxolotlTrainer(SchedulerMixin, Trainer): + """ + Extend the base Trainer for axolotl helpers + """ + + args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] + tag_names = ["axolotl"] + + def __init__( + self, + *_args, + bench_data_collator=None, + eval_data_collator=None, + dataset_tags=None, + **kwargs, + ): + self.bench_data_collator = bench_data_collator + self.eval_data_collator = eval_data_collator + self.dataset_tags = dataset_tags + self._signature_columns = None # workaround for pylint + super().__init__(*_args, **kwargs) + self.train_data_collator = self.data_collator + self._stored_metrics = defaultdict(lambda: defaultdict(list)) + if self.args.orpo_alpha: + self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") + + def _wrap_model(self, model, training=True, dataloader=None): + if self.args.torch_compile: + torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access + 256 + ) + model = torch.compile( + model, + backend=self.args.torch_compile_backend, + mode=self.args.torch_compile_mode, + ) + return super()._wrap_model(model, training=training, dataloader=dataloader) + + def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs): + decay_parameters = self.get_decay_parameter_names(opt_model) + params = { + "to_weight_decay": {}, # LayerNorm and bias + "embeddings": {}, # lm_head, embed_tokens, + "no_weight_decay": {}, + } + lr_groups_lookup = {} + lr_groups_learning_rates = {} + if self.args.lr_groups: + for lr_group in self.args.lr_groups: + group_name = lr_group["name"] + group_modules = lr_group["modules"] + for module in group_modules: + lr_groups_lookup[module] = group_name + lr_groups_learning_rates[group_name] = lr_group["lr"] + params[f"to_weight_decay_{group_name}"] = {} + + for name, param in opt_model.named_parameters(): + if not param.requires_grad: + continue + if name.endswith("modules_to_save.default.weight") or any( + embed_name in name for embed_name in ["embed_tokens", "lm_head"] + ): + params["embeddings"][name] = param + elif name in decay_parameters: + lr_group_modules = [ + group_modules + for group_modules in lr_groups_lookup + if group_modules in name + ] + if lr_groups_lookup and any(lr_group_modules): + lr_group_module = lr_group_modules[0] + group_name = lr_groups_lookup[lr_group_module] + params[f"to_weight_decay_{group_name}"][name] = param + else: + params["to_weight_decay"][name] = param + else: + params["no_weight_decay"][name] = param + optimizer_grouped_parameters = [] + if params["to_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["to_weight_decay"].values()), + "weight_decay": self.args.weight_decay, + "lr": optimizer_kwargs["lr"], + } + ) + if params["embeddings"]: + lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name + if self.args.embedding_lr_scale: + lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name + elif self.args.embedding_lr: + lr = self.args.embedding_lr # pylint: disable=invalid-name + optimizer_grouped_parameters.append( + { + "params": list(params["embeddings"].values()), + "weight_decay": 0.0, + "lr": lr, + } + ) + if params["no_weight_decay"]: + optimizer_grouped_parameters.append( + { + "params": list(params["no_weight_decay"].values()), + "weight_decay": 0.0, + "lr": optimizer_kwargs["lr"], + } + ) + for group_name, group_lr in lr_groups_learning_rates.items(): + if params[f"to_weight_decay_{group_name}"]: + optimizer_grouped_parameters.append( + { + "params": list( + params[f"to_weight_decay_{group_name}"].values() + ), + "weight_decay": self.args.weight_decay, + "lr": group_lr, + } + ) + + return optimizer_grouped_parameters + + def create_optimizer(self): + if ( + self.args.loraplus_lr_ratio is None + and self.args.embedding_lr_scale is None + and self.args.embedding_lr is None + and self.args.lr_groups is None + and self.args.alternate_optimizer + not in [ + "optimi_adamw", + "ao_adamw_8bit", + "ao_adamw_4bit", + "ao_adamw_fp8", + "adopt_adamw", + ] + ): + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + optimizer_grouped_parameters = self.create_optimizer_grouped_parameters( + opt_model, optimizer_kwargs + ) + + if self.args.loraplus_lr_ratio is not None: + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + loraplus_lr_embedding = getattr( + self.args, "loraplus_lr_embedding", 1e-6 + ) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, + ) + elif ( + self.args.embedding_lr_scale is not None + or self.args.embedding_lr is not None + or self.args.lr_groups is not None + ): + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "optimi_adamw": + from optimi import AdamW + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW( + optimizer_grouped_parameters, foreach=False, **optimizer_kwargs + ) + ) + elif self.args.alternate_optimizer == "ao_adamw_4bit": + from torchao.prototype.low_bit_optim import AdamW4bit + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "ao_adamw_8bit": + from torchao.prototype.low_bit_optim import AdamW8bit + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "ao_adamw_fp8": + from torchao.prototype.low_bit_optim import AdamWFp8 + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs) + ) + elif self.args.alternate_optimizer == "adopt_adamw": + from axolotl.utils.optimizers.adopt import ADOPT + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + ADOPT( + optimizer_grouped_parameters, + decouple=True, + **optimizer_kwargs, + ) + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + + def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: + if self.args.sample_packing and not self.args.pretraining: + if self.args.multipack_real_batches: + batch_size = self.args.per_device_train_batch_size + batch_max_len = self.args.max_seq_length + else: + batch_size = 1 + train_batch_size = ( + self.state.train_batch_size or self.args.per_device_train_batch_size + ) + batch_max_len = train_batch_size * self.args.max_seq_length + + if self.args.curriculum_sampling: + sampler = SequentialSampler(self.train_dataset) + else: + sampler = RandomSampler(self.train_dataset) + + return MultipackBatchSampler( + sampler, + lengths=get_dataset_lengths(self.train_dataset), + packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_max_len=batch_max_len, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, + drop_last=True, + ) + if self.args.curriculum_sampling: + return SequentialSampler(self.train_dataset) + return super()._get_train_sampler() + + def _get_eval_sampler( + self, eval_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.sample_packing and self.args.eval_sample_packing is not False: + if self.args.multipack_real_batches: + batch_size = self.args.per_device_eval_batch_size + batch_max_len = self.args.max_seq_length + else: + batch_size = 1 + batch_max_len = ( + self.args.per_device_eval_batch_size * self.args.max_seq_length + ) + return MultipackBatchSampler( + SequentialSampler(eval_dataset), + lengths=get_dataset_lengths(self.eval_dataset), + packing_efficiency_estimate=self.args.sample_packing_efficiency, + batch_max_len=batch_max_len, + batch_size=batch_size, + group_size=self.args.sample_packing_group_size, + bin_size=self.args.sample_packing_bin_size, + drop_last=True, + ) + return super()._get_eval_sampler(eval_dataset) + + def get_train_dataloader(self) -> DataLoader: + if self.args.sample_packing and not self.args.pretraining: + train_dataset = self.train_dataset + if "length" in train_dataset.features.keys(): + train_dataset = train_dataset.remove_columns(["length"]) + data_collator = self.data_collator + dataloader_params = { + "batch_size": self._train_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params[ + "prefetch_factor" + ] = self.args.dataloader_prefetch_factor + + sampler = self._get_train_sampler() + if isinstance(sampler, BatchSampler): + dataloader_params["batch_sampler"] = sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["worker_init_fn"] = seed_worker + + self.accelerator.even_batches = False + return self.accelerator.prepare_data_loader( + DataLoader(train_dataset, **dataloader_params) + ) + return super().get_train_dataloader() + + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: + if self.args.sample_packing and self.args.eval_sample_packing is False: + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.eval_data_collator + ) + if eval_dataset: + eval_dataset = eval_dataset.remove_columns(["length"]) + dataloader = super().get_eval_dataloader(eval_dataset) + self.data_collator = ( # pylint: disable=attribute-defined-outside-init + self.train_data_collator + ) + return dataloader + + if self.args.sample_packing and self.args.eval_sample_packing is not False: + eval_dataset = ( + eval_dataset if eval_dataset is not None else self.eval_dataset + ) + + eval_sampler = self._get_eval_sampler(eval_dataset) + eval_dataset = eval_dataset.remove_columns(["length"]) + data_collator = self.data_collator + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params[ + "prefetch_factor" + ] = self.args.dataloader_prefetch_factor + + if isinstance(eval_sampler, BatchSampler): + dataloader_params["batch_sampler"] = eval_sampler + del dataloader_params["batch_size"] + else: + dataloader_params["sampler"] = eval_sampler + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + self.accelerator.even_batches = False + return self.accelerator.prepare_data_loader( + DataLoader(eval_dataset, **dataloader_params) + ) + + return super().get_eval_dataloader(eval_dataset) + + def _get_bench_sampler( + self, bench_dataset: Dataset + ) -> Optional[torch.utils.data.Sampler]: + if self.args.world_size <= 1: + return SequentialSampler(bench_dataset) + return None + + def get_bench_dataloader( + self, + bench_dataset: Dataset, + ) -> DataLoader: + dataloader_params = { + "batch_size": self.args.eval_batch_size, + "collate_fn": self.bench_data_collator, + "num_workers": self.args.dataloader_num_workers, + "pin_memory": self.args.dataloader_pin_memory, + } + if self.args.dataloader_prefetch_factor: + dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor + + if not isinstance(bench_dataset, torch.utils.data.IterableDataset): + dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset) + dataloader_params["drop_last"] = self.args.dataloader_drop_last + + return DataLoader(bench_dataset, **dataloader_params) + # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) + + def compute_loss( + self, model, inputs, return_outputs=False, num_items_in_batch=None + ): + # use one's weighted cross entropy loss calc + # if self.args.sample_packing: + # labels = inputs.pop("labels") + # outputs = model(**inputs) + # loss = trainer_weighted_loss(outputs, labels, shift_labels=True) + # return (loss, outputs) if return_outputs else loss + if self.args.orpo_alpha: + return self.orpo_compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) + return super().compute_loss( + model, + inputs, + return_outputs=return_outputs, + num_items_in_batch=num_items_in_batch, + ) + + @staticmethod + def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None): + concatenated_batch = {} + + max_length = max( + inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1] + ) + # Concatenate positive and negative inputs + concatenated_batch["input_ids"] = pad_to_length( + inputs["input_ids"], max_length, pad_token + ) + concatenated_batch["rejected_input_ids"] = pad_to_length( + inputs["rejected_input_ids"], max_length, pad_token + ) + concatenated_batch["labels"] = pad_to_length( + inputs["labels"], max_length, label_pad_token + ) + concatenated_batch["rejected_labels"] = pad_to_length( + inputs["rejected_labels"], max_length, label_pad_token + ) + concatenated_batch["attention_mask"] = pad_to_length( + inputs["attention_mask"], max_length, 0 + ) + concatenated_batch["rejected_attention_mask"] = pad_to_length( + inputs["rejected_attention_mask"], max_length, 0 + ) + concatenated_batch["prompt_attention_mask"] = pad_to_length( + inputs["prompt_attention_mask"], max_length, 0 + ).to(device=device) + + input_ids = torch.cat( + [concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]], + dim=0, + ).to(device=device) + attention_mask = torch.cat( + [ + concatenated_batch["attention_mask"], + concatenated_batch["rejected_attention_mask"], + ], + dim=0, + ).to(device=device) + labels = torch.cat( + [concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0 + ).to(device=device) + + return { + "input_ids": input_ids, + "labels": labels, + "attention_mask": attention_mask, + "prompt_attention_mask": concatenated_batch["prompt_attention_mask"], + } + + def orpo_compute_custom_loss(self, logits, labels): + logits = logits.contiguous() + loss = 0.0 + + if labels is not None: + # move labels to correct device to enable model parallelism + labels = labels.to(logits.device) + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + + # Flatten the tokens + loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean( + dim=-1 + ) + + return loss + + def orpo_compute_logps( + self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits + ): + # Get the shape of chosen_attention_mask[:, :-1] + chosen_shape = chosen_attention_mask[:, :-1].shape + + # Calculate the padding size + pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1) + + # Pad prompt_attention_mask with zeros to match the desired shape + prompt_attention_mask_padded = torch.nn.functional.pad( + prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0 + ) + + # Perform the subtraction operation + mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded + + per_token_logps = torch.gather( + logits[:, :-1, :].log_softmax(-1), + dim=2, + index=(mask * chosen_inputs[:, 1:]).unsqueeze(2), + ).squeeze(2) + return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1) + + def orpo_compute_loss( + self, + model, + inputs, + return_outputs=False, + num_items_in_batch=None, # pylint: disable=unused-argument + ): + concat_inputs = AxolotlTrainer.orpo_concatenate_inputs( + inputs, + label_pad_token=-100, + pad_token=self.tokenizer.pad_token_id, + device=self.accelerator.device, + ) + + # Perform a single forward pass + outputs = model( + **{ + "input_ids": concat_inputs["input_ids"], + "attention_mask": concat_inputs["attention_mask"], + "labels": concat_inputs["labels"], + }, + output_hidden_states=True, + ) + + # Split the outputs for positive and negative examples + outputs_pos, outputs_neg = outputs.logits.chunk(2) + + # Calculate NLL loss + pos_loss = self.orpo_compute_custom_loss( + logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0] + ) + + # Calculate Log Probability + pos_prob = self.orpo_compute_logps( + prompt_attention_mask=concat_inputs["prompt_attention_mask"], + chosen_inputs=concat_inputs["input_ids"].chunk(2)[0], + chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0], + logits=outputs_pos, + ) + neg_prob = self.orpo_compute_logps( + prompt_attention_mask=concat_inputs["prompt_attention_mask"], + chosen_inputs=concat_inputs["input_ids"].chunk(2)[1], + chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1], + logits=outputs_neg, + ) + + # Calculate log odds + log_odds = (pos_prob - neg_prob) - ( + torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob)) + ) + sig_ratio = torch.nn.functional.sigmoid(log_odds) + ratio = torch.log(sig_ratio) + + # Calculate the Final Loss + loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to( + dtype=torch.bfloat16 + ) + + metrics = {} + metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item() + metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item() + metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item() + metrics["log_odds"] = torch.mean(log_odds).cpu().item() + self.store_metrics(metrics, train_eval="train") + + return (loss, outputs_pos) if return_outputs else loss + + @wraps(Trainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tags=self.dataset_tags, kwargs=kwargs + ) + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + + return super().push_to_hub(*args, **kwargs) + + @wraps(Trainer.create_accelerator_and_postprocess) + def create_accelerator_and_postprocess(self): + res = super().create_accelerator_and_postprocess() + + if self.is_fsdp_enabled: + if ( + "limit_all_gathers" in self.args.fsdp_config + and self.args.fsdp_config["limit_all_gathers"] + ): + self.accelerator.state.fsdp_plugin.limit_all_gathers = True + + return res + + def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None: + """ + Log `logs` on the various objects watching training, including stored metrics. + + Args: + logs (`Dict[str, float]`): + The values to log. + start_time (`Optional[float]`): + The start of training. + """ + # logs either has 'loss' or 'eval_loss' + train_eval = "train" if "loss" in logs else "eval" + # Add averaged stored metrics to logs + for key, metrics in self._stored_metrics[train_eval].items(): + logs[key] = torch.tensor(metrics).mean().item() + del self._stored_metrics[train_eval] + + return super().log(logs, start_time) + + def store_metrics( + self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" + ) -> None: + for key, value in metrics.items(): + self._stored_metrics[train_eval][key].append(value) + + def _save_checkpoint(self, model, trial, **kwargs): + # make sure the checkpoint dir exists, since trainer is flakey + checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}" + run_dir = self._get_output_dir(trial=trial) + output_dir = os.path.join(run_dir, checkpoint_folder) + os.makedirs(output_dir, exist_ok=True) + return super()._save_checkpoint(model, trial, **kwargs) + + +class AxolotlMambaTrainer(AxolotlTrainer): + """ + Mamba specific trainer to handle loss calculation + """ + + tag_names = ["axolotl", "mamba"] + + def compute_loss( + self, + model, + inputs, + return_outputs=False, # pylint: disable=unused-argument + num_items_in_batch=None, # pylint: disable=unused-argument + ): + input_ids = inputs.pop("input_ids") + lm_logits = model(input_ids).logits + + labels = input_ids.to(lm_logits.device) + shift_logits = lm_logits[:, :-1, :].contiguous() + labels = labels[:, 1:].contiguous() + + loss_fct = torch.nn.CrossEntropyLoss() + lm_loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1) + ) + + return lm_loss + + +class ReLoRATrainer(AxolotlTrainer): + """ + Trainer subclass that uses the OneCycleLR scheduler + """ + + tag_names = ["axolotl", "relora"] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.lr_scheduler = None + + def create_scheduler( + self, + num_training_steps: int, + optimizer: Optional[torch.optim.Optimizer] = None, + ): + optimizer = self.optimizer if optimizer is None else optimizer + lr_scheduler = super().create_scheduler(num_training_steps, optimizer) + + if self.args.relora_steps: + warmup_steps = ( + self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10 + ) + anneal_steps = ( + self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1 + ) + self.lr_scheduler = ReLoRAScheduler( + optimizer, + lr_scheduler, + self.args.relora_steps, + anneal_steps, + warmup_steps, + ) + else: + self.lr_scheduler = lr_scheduler + + return self.lr_scheduler + + +class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): + """ + Extend the base DPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "dpo"] + + def __init__(self, *args, dataset_tags=None, **kwargs): + super().__init__(*args, **kwargs) + self.dataset_tags = dataset_tags + self.optimizer = None + self.model_accepts_loss_kwargs = False + + def create_optimizer(self): + if self.args.loraplus_lr_ratio is None: + return super().create_optimizer() + + opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model + if self.optimizer is None: # pylint: disable=access-member-before-definition + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( + self.args, + opt_model, + ) + + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + if loraplus_lr_ratio: + print("Using lora+") + loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + loraplus_lr_ratio=loraplus_lr_ratio, + loraplus_lr_embedding=loraplus_lr_embedding, + **optimizer_kwargs, + ) + + if is_sagemaker_mp_enabled(): + self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init + self.optimizer + ) + + return self.optimizer + + @wraps(DPOTrainer.push_to_hub) + def push_to_hub(self, *args, **kwargs) -> str: + """ + Overwrite the `push_to_hub` method in order to force-add the tags when pushing the + model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. + """ + kwargs = _sanitize_kwargs_for_ds_tagging( + dataset_tags=self.dataset_tags, kwargs=kwargs + ) + kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs) + + return super().push_to_hub(*args, **kwargs) + + @staticmethod + def tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) -> Dict: + res = DPOTrainer.tokenize_row( + features, + processing_class, + max_prompt_length, + max_completion_length, + add_special_tokens, + ) + # fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen + if processing_class.bos_token is None and res["prompt_input_ids"][0] is None: + for key in res.keys(): + res[key] = res[key][1:] + + if processing_class.bos_token and processing_class.bos_token_id is not None: + # dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs + if res["chosen_input_ids"][0] == processing_class.bos_token_id: + res["chosen_input_ids"] = res["chosen_input_ids"][1:] + res["chosen_labels"] = res["chosen_labels"][1:] + res["chosen_attention_mask"] = res["chosen_attention_mask"][1:] + if res["rejected_input_ids"][0] == processing_class.bos_token_id: + res["rejected_input_ids"] = res["rejected_input_ids"][1:] + res["rejected_labels"] = res["rejected_labels"][1:] + res["rejected_attention_mask"] = res["rejected_attention_mask"][1:] + + return res + + def training_step( + self, + model: nn.Module, + inputs: Dict[str, Union[torch.Tensor, Any]], + num_items_in_batch=None, + ) -> torch.Tensor: + loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch) + gc.collect() + torch.cuda.empty_cache() + return loss + + +class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): + """ + Extend the base ORPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "orpo"] + + +class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): + """ + Extend the base KTOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "kto"] + + +class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): + """ + Extend the base CPOTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "cpo"] + + +class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): + """ + Extend the base RewardTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "reward"] + + +class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): + """ + Extend the base trl.PRMTrainer for axolotl helpers + """ + + tag_names = ["axolotl", "prm"] diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py new file mode 100644 index 000000000..9eae52162 --- /dev/null +++ b/src/axolotl/core/training_args.py @@ -0,0 +1,264 @@ +""" +extra axolotl specific training args +""" +from dataclasses import dataclass, field +from typing import Optional + +from transformers import TrainingArguments +from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig + + +@dataclass +class AxolotlTrainingMixins: + """ + Mixin class for the Axolotl training args. + """ + + # pylint: disable=duplicate-code + model_type: Optional[str] = field( + default=None, metadata={"help": "HF model configuration model_type."} + ) + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + multipack_real_batches: bool = field( + default=False, + metadata={"help": "Use real batches for efficient training."}, + ) + eval_sample_packing: Optional[bool] = field( + default=None, + metadata={"help": "Use sample packing for efficient evals."}, + ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) + sample_packing_bin_size: int = field( + default=200, + metadata={ + "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." + }, + ) + sample_packing_group_size: int = field( + default=100000, + metadata={ + "help": "The number of samples to group together for packing. Increase for better packing." + }, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) + relora_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for ReLoRA"}, + ) + relora_warmup_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_anneal_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_prune_ratio: Optional[float] = field( + default=0.9, + metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, + ) + bench_split: Optional[str] = field( + default="eval", metadata={"help": "The benchmark split to run on"} + ) + bench_dataset: Optional[str] = field( + default="pharaouk/dharma-1/dharma_1_mini.json", + metadata={ + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" + }, + ) + do_bench_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Benchmark evaluation."} + ) + do_causal_lm_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Causal LM evaluation."} + ) + max_bench_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." + }, + ) + bench_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for bench."} + ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={"help": "prefetch_factor argument to the dataloader"}, + ) + cosine_min_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, + ) + cosine_constant_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} + ) + loraplus_lr_embedding: Optional[float] = field( + default=1e-6, + metadata={"help": "loraplus learning rate for lora embedding layers."}, + ) + embedding_lr_scale: Optional[float] = field( + default=None, + metadata={"help": "Scale the learning rate for the embedding layers."}, + ) + lr_groups: Optional[list[dict]] = field( + default=None, + metadata={"help": "Specify learning rate groups for with different LRs."}, + ) + embedding_lr: Optional[float] = field( + default=None, + metadata={"help": "absolute learning rate for the embedding layers."}, + ) + qlora: bool = field( + default=False, + metadata={"help": "whether this is a qlora training"}, + ) + orpo_alpha: Optional[float] = field( + default=None, + ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) + alternate_optimizer: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate optimizer to the HF trainer" + }, + ) + alternate_lr_scheduler_type: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate lr scheduler to the HF trainer" + }, + ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) + + kd_ce_alpha: Optional[float] = field( + default=None, + metadata={ + "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" + }, + ) + + kd_alpha: Optional[float] = field( + default=1.0, + metadata={"help": "The alpha scaling parameter for KD loss"}, + ) + + kd_temperature: Optional[float] = field( + default=1.0, + metadata={ + "help": "the temperature parameter for KL divergence loss when using KD" + }, + ) + + kd_zscore_base_temp: Optional[float] = field( + default=None, + metadata={ + "help": "the base temperature parameter for KL divergence with z-score when using KD" + }, + ) + + kd_top_k_before_softmax: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether to apply top_k_before_softmax to the logits when using KD" + }, + ) + + +@dataclass +class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): + """ + Training arguments for Causal trainer + + This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value + so it can't be used as a mixin. + """ + + +@dataclass +class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): + """ + DPO config for DPO training + """ + + +@dataclass +class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): + """ + ORPO config for ORPO training + """ + + +@dataclass +class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig): + """ + KTO config for KTO training + """ + + +@dataclass +class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig): + """ + CPO config for CPO training + """ + + simpo_gamma: Optional[float] = field( + default=None, + metadata={"help": "simpo gamma parameter"}, + ) + + +@dataclass +class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig): + """ + Reward config for Reward training + """ + + +@dataclass +class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig): + """ + PRM config for PRM training + """ diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index e4531930f..143928019 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -2,7 +2,7 @@ import logging import os -from typing import List, Optional +from typing import List, Optional, Union import torch from datasets import Dataset, IterableDataset @@ -51,7 +51,17 @@ class TokenizedPromptDataset(Dataset): map_kwargs = {} if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True - map_kwargs["batch_size"] = 100 + map_kwargs["batch_size"] = 1_000 + + if ( + hasattr(self.prompt_tokenizer, "filter_rows") + and self.prompt_tokenizer.filter_rows + ): + dataset = dataset.filter( + self.prompt_tokenizer.filter_rows, + num_proc=num_proc, + desc="Strategy Filtering Rows", + ) return dataset.map( self.prompt_tokenizer.tokenize_prompt, @@ -63,6 +73,24 @@ class TokenizedPromptDataset(Dataset): ) +def wrap_dataset_for_tokenized_prompt( + prompt_tokenizer: PromptTokenizingStrategy, + dataset: Union[Dataset, IterableDataset], + **kwargs, +): + if isinstance(dataset, IterableDataset): + map_kwargs = {} + if prompt_tokenizer.supports_batched: + map_kwargs["batched"] = True + features = dataset.features.keys() + return dataset.map( + prompt_tokenizer.tokenize_prompt, + remove_columns=features, + **map_kwargs, + ) + return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs) + + # TODO this isn't the best since it can't interleave datasets class ConstantLengthDataset(IterableDataset): """ diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index a271c59d1..211d5e51b 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -111,6 +111,17 @@ class BasePlugin: None """ + def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): + """ + Returns a custom class for the trainer. + + Parameters: + cfg (dict): The global axolotl configuration. + + Returns: + class: The class for the trainer. + """ + def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument """ Creates and returns an optimizer for training. @@ -212,7 +223,17 @@ def load_plugin(plugin_name: str) -> BasePlugin: module_name, class_name = plugin_name.rsplit(".", 1) # import the module - module = importlib.import_module(module_name) + try: + module = importlib.import_module(module_name) + except ModuleNotFoundError as orig_exc: + try: + if not module_name.startswith("axolotl.integrations."): + module = importlib.import_module("axolotl.integrations." + module_name) + else: + raise orig_exc + except ModuleNotFoundError as exc: + raise orig_exc from exc + # instantiate the class plugin_class = getattr(module, class_name) # create an instance of the class @@ -272,8 +293,10 @@ class PluginManager: ImportError: If the plugin module cannot be imported. """ try: + logging.info(f"Attempting to load plugin: {plugin_name}") plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin + logging.info(f"Plugin loaded successfully: {plugin_name}") except ImportError: logging.error(f"Failed to load plugin: {plugin_name}") @@ -346,6 +369,22 @@ class PluginManager: for plugin in self.plugins.values(): plugin.post_lora_load(cfg, model) + def get_trainer_cls(self, cfg): + """ + Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + object: The trainer class, or None if none was found. + """ + for plugin in self.plugins.values(): + trainer_cls = plugin.get_trainer_cls(cfg) + if trainer_cls is not None: + return trainer_cls + return None + def create_optimizer(self, cfg, trainer): """ Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer. diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py new file mode 100644 index 000000000..8a6e3eda1 --- /dev/null +++ b/src/axolotl/integrations/kd/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Plugin init to add KD support to Axolotl. +""" +from axolotl.integrations.base import BasePlugin + +from .args import KDArgs # pylint: disable=unused-import. # noqa: F401 + + +class KDPlugin(BasePlugin): + """ + Plugin for KD support in Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.kd.KDArgs" + + def get_trainer_cls(self, cfg): + if cfg.kd_trainer: + from .trainer import AxolotlKDTrainer + + return AxolotlKDTrainer + return None diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py new file mode 100644 index 000000000..a88a0dc48 --- /dev/null +++ b/src/axolotl/integrations/kd/args.py @@ -0,0 +1,37 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Plugin args for KD support. +""" +from typing import Optional + +from pydantic import BaseModel + + +class KDArgs(BaseModel): + """ + Input args for knowledge distillation. + """ + + kd_trainer: Optional[bool] = None # whether to use KD trainer + kd_ce_alpha: Optional[ + float + ] = None # loss coefficient for cross-entropy loss during KD + kd_alpha: Optional[float] = None # loss coefficient for KD loss + kd_temperature: Optional[float] = None # temperature for sampling during KD + kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling + kd_top_k_before_softmax: Optional[ + bool + ] = None # whether to sample top k before softmax during KD diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py new file mode 100644 index 000000000..699728e9f --- /dev/null +++ b/src/axolotl/integrations/kd/chat_template.py @@ -0,0 +1,201 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Chat template prompt strategy loader with KD support +""" +from typing import Any, Dict + +import torch + +from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader + + +class ChatTemplateStrategyWithKD(ChatTemplateStrategy): + """ + Handle fields for logprob KD + """ + + def __init__( + self, + prompter, + tokenizer, + train_on_inputs, + sequence_len, + roles_to_train=None, + train_on_eos=None, + logprobs_field="logprobs", + gen_temperature=1.0, + kd_temperature=1.0, + ): + self.logprobs_field = logprobs_field + self.gen_temperature = gen_temperature + self.kd_temperature = kd_temperature + + super().__init__( + prompter, + tokenizer, + train_on_inputs, + sequence_len, + roles_to_train=roles_to_train, + train_on_eos=train_on_eos, + ) + + @property + def supports_batched(self) -> bool: + # batching doesn't work well for logprob data + return False + + def transform_logprobs(self, sample): + """ + Transform logprobs to target format for KD training + """ + + logprobs = sample.pop(self.logprobs_field) + target_seq_len = len(logprobs) + input_seq_len = len(sample["input_ids"]) + input_padding_len = input_seq_len - target_seq_len + # get non-zero top-k (prune None logprobs from vllm data step) + top_k_vals = [ + len(logprobs[i]) + for i in range(len(logprobs)) + if logprobs[i] is not None and len(logprobs[i]) + ] + max_top_k = max(set(top_k_vals), key=top_k_vals.count) + min_top_k = min(set(top_k_vals), key=top_k_vals.count) + top_k = min(max_top_k, min_top_k) + if top_k == 0: + raise ValueError("No non-zero top-k logprobs found.") + + target_logprobs = [] + target_token_ids = [] + target_mask = [] + + if input_padding_len < 0: + # logprobs is longer than target_seq_len, + # so we need to slice from the left/beginning of logprobs + logprobs = logprobs[:-input_seq_len] + input_padding_len = 0 + # target_seq_len = input_seq_len + + # truncate the second dimension of the logprobs to top_k + logprobs = [row[:top_k] for row in logprobs] + + # fill with -inf for padding_len tokens for top_k tokens + # extend target_logprobs with a padding_len x top_k 2D list filled with -inf + + # for causal models, if we start the range at 1, then we don't need to shift in the trainer + # otherwise, we need to shift in the trainer + shift = 0 + for _ in range(shift, input_padding_len): + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) + + for position in range(input_padding_len, input_seq_len): + if sample["labels"][position] == -100: + target_mask.append([0] * top_k) + else: + target_mask.append([1] * top_k) + + for _, token_pos_logprobs in enumerate(logprobs): + # Initialize collections for logprobs and token_ids + position_logprobs = [] + position_token_ids = [] + + # Process each token probability entry + for entry in token_pos_logprobs: + # Extract logprob value + logprob = entry["logprob"] + + # Parse token_id from the "token_id:###" format + token_id = int(entry["token"].split(":")[1]) + + # Append to our collections + position_logprobs.append(logprob) + position_token_ids.append(token_id) + + # Convert to a tensor for easier manipulation + position_logprobs_tensor = torch.tensor( + position_logprobs, dtype=torch.float + ) + + # Now we have distribution at T1 in log form, i.e. log p_{T1}(k). + # Next, re-scale to T2 = self.kd_temperature via exponent-based trick + # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z + # + # Convert from log to probability + teacher_probs_t1 = position_logprobs_tensor.exp() + if self.kd_temperature != self.gen_temperature: + # Exponentiate by factor (T1 / T2) + exponent = self.gen_temperature / self.kd_temperature + teacher_probs_t2 = teacher_probs_t1**exponent + else: + teacher_probs_t2 = teacher_probs_t1 + # Re-normalize + teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( + dim=0, keepdim=True + ) + # Convert back to log + position_logprobs_tensor = torch.log(teacher_probs_t2) + + # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor + position_logprobs_scaled = position_logprobs_tensor.tolist() + + target_logprobs.append(position_logprobs_scaled) + target_token_ids.append(position_token_ids) + + if shift == 1: + # since we started at index 1 for causal, we need one more padding token + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) + + # Update sample with transformed logprobs + sample["target_logprobs"] = target_logprobs + sample["target_token_ids"] = target_token_ids + sample["target_mask"] = target_mask + + return sample + + def _tokenize_single_prompt(self, prompt): + logprobs = prompt.pop(self.logprobs_field) + tokenized_prompt = super()._tokenize_single_prompt(prompt) + tokenized_prompt[self.logprobs_field] = logprobs + tokenized_prompt = self.transform_logprobs(tokenized_prompt) + + return tokenized_prompt + + +class KDStrategyLoader(StrategyLoader): + """ + Load ChatTemplateStrategy with KD support using StrategyLoader. + """ + + def _get_strategy_cls(self): + return ChatTemplateStrategyWithKD + + def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): + strategy_params = super()._get_strategy_params(cfg, ds_cfg) + if logprobs_field := ds_cfg.get("logprobs_field"): + strategy_params["logprobs_field"] = logprobs_field + if gen_temperature := ds_cfg.get("temperature"): + strategy_params["gen_temperature"] = gen_temperature + if kd_temperature := cfg.get("kd_temperature"): + strategy_params["kd_temperature"] = kd_temperature + + return strategy_params + + +load = KDStrategyLoader() diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py new file mode 100644 index 000000000..de63869c7 --- /dev/null +++ b/src/axolotl/integrations/kd/collator.py @@ -0,0 +1,255 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +DataCollator for axolotl to handle KD fields without using -inf for padding, +and with a teacher_mask to identify padded positions. +""" + +from dataclasses import dataclass +from typing import Any, Optional, Union + +import numpy as np +import torch +from transformers import PreTrainedTokenizerBase +from transformers.utils import PaddingStrategy + +from axolotl.utils.collators.batching import DataCollatorForSeq2Seq + + +@dataclass +class DataCollatorForKD(DataCollatorForSeq2Seq): + """ + Data collator for KD, including handling KD-specific fields. + + This version avoids using -inf and instead uses a large negative value for padding + target_logprobs. It also creates a teacher_mask to indicate which entries are valid. + """ + + # pylint: disable=duplicate-code + tokenizer: PreTrainedTokenizerBase + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + position_pad_token_id: int = 0 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + if return_tensors is None: + return_tensors = self.return_tensors + + padding_side = self.tokenizer.padding_side + + # Pad labels and position_ids first + for feature_name, pad_token_id in [ + ("labels", self.label_pad_token_id), + ("position_ids", self.position_pad_token_id), + ]: + if feature_name in features[0]: + feat = [f[feature_name] for f in features] + max_len = max(len(x) for x in feat) + if self.pad_to_multiple_of is not None: + max_len = ( + (max_len + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + ) * self.pad_to_multiple_of + + for f in features: # pylint: disable=invalid-name + remainder = [pad_token_id] * (max_len - len(f[feature_name])) + if isinstance(f[feature_name], list): + f[feature_name] = ( + f[feature_name] + remainder + if padding_side == "right" + else remainder + f[feature_name] + ) + else: + # If they are numpy arrays + if padding_side == "right": + f[feature_name] = np.concatenate( + [f[feature_name], remainder] + ).astype(np.int64) + else: + f[feature_name] = np.concatenate( + [remainder, f[feature_name]] + ).astype(np.int64) + + # Handle target_logprobs and target_token_ids manually + target_logprobs_list = [] + target_token_ids_list = [] + target_mask_list = [] + has_teacher_data = ("target_logprobs" in features[0]) and ( + "target_token_ids" in features[0] + ) + + if has_teacher_data: + # Extract and remove from features + for f in features: # pylint: disable=invalid-name + target_logprobs_list.append(f.pop("target_logprobs")) + target_token_ids_list.append(f.pop("target_token_ids")) + target_mask_list.append(f.pop("target_mask")) + + # Determine max lengths + max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list) + max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq) + + padded_target_logprobs = [] + padded_target_token_ids = [] + padded_teacher_mask_list = [] + + for t_logprobs, t_ids, t_mask in zip( + target_logprobs_list, target_token_ids_list, target_mask_list + ): + t_logprobs_padded = [] + t_ids_padded = [] + t_mask_padded = [] + + for lp, ids, mask in zip( # pylint: disable=invalid-name + t_logprobs, t_ids, t_mask + ): + lp_len = len(lp) + if lp_len < max_k: + # Use -1e9 for padding logprobs and 0 for token_ids + pad_len = max_k - lp_len + lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name + ids = ids + [0] * pad_len + mask = mask + [0] * pad_len + else: + lp = lp[:max_k] # pylint: disable=invalid-name + ids = ids[:max_k] + mask = mask[:max_k] + + t_logprobs_padded.append(lp) + t_ids_padded.append(ids) + t_mask_padded.append(mask) + + seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded) + if seq_len_diff > 0: + # Pad sequences fully if needed + t_logprobs_padded.extend( + [[-1e9] * max_k for _ in range(seq_len_diff)] + ) + t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)]) + t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)]) + + padded_target_logprobs.append(t_logprobs_padded) + padded_target_token_ids.append(t_ids_padded) + padded_teacher_mask_list.append(t_mask_padded) + + # Convert to tensors + padded_target_logprobs = torch.tensor( + padded_target_logprobs, dtype=torch.float + ) + padded_target_token_ids = torch.tensor( + padded_target_token_ids, dtype=torch.long + ) + padded_teacher_mask_list = torch.tensor( + padded_teacher_mask_list, dtype=torch.int + ) + + # Pad using tokenizer for regular fields + features = self.tokenizer.pad( + features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=return_tensors, + ) + + # Add back teacher data if present + if has_teacher_data: + features["target_logprobs"] = padded_target_logprobs + features["target_token_ids"] = padded_target_token_ids + features["target_mask"] = padded_teacher_mask_list + + # Prepare decoder_input_ids if the model supports it + if ( + "labels" in features + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels( + labels=features["labels"] + ) + features["decoder_input_ids"] = decoder_input_ids + + return features + + +class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): + """ + Collator for multipack (batch of sub-batches) specifically for KD. + Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item. + """ + + def __call__(self, features, return_tensors=None): + """ + Expects that `features` could be either: + - a single list of dicts, OR + - a list of lists of dicts (the "sub-batches" to be packed). + """ + # 1) If we are *not* dealing with multiple sequences per batch element, + # just pass straight to parent. + if not isinstance(features[0], list): + return super().__call__(features, return_tensors=return_tensors) + + # 2) Otherwise, we *are* dealing with multiple sequences in each batch item. + # We want to produce a single "merged" feature dict for each sub-batch. + out_features = [{} for _ in features] + + for i, sub_features in enumerate(features): + # sub_features is a list of dicts, each dict = one sequence’s features + # We'll merge them into out_features[i]. + # + # NOTE: You can customize how you combine fields as needed (e.g. summation + # or offset for attention_mask). Below is a straightforward concatenation/extension. + + for field_name in sub_features[0].keys(): + # Some fields you might want to skip or treat specially: + if field_name == "length": + continue + + # If it’s a KD field that’s a list-of-lists (e.g. target_logprobs), + # you typically just want to flatten them by extending. + if field_name in ["target_logprobs", "target_token_ids", "target_mask"]: + combined = [] + for feat in sub_features: + combined.extend(feat[field_name]) + out_features[i][field_name] = combined + + elif field_name == "attention_mask": + # Here we apply the (j+1) factor to differentiate each sub-sample + # within this merged batch item. + arrays = [] + for j, feat in enumerate(sub_features): + if field_name in feat: + arrays.append((j + 1) * np.array(feat[field_name])) + out_features[i][field_name] = np.concatenate(arrays) + else: + # By default, just concatenate them if they are arrays + # or extend them if they are lists. + # For example, input_ids or labels are often arrays. + arrays = [] + for feat in sub_features: + if field_name in feat: + arr = np.array(feat[field_name]) + arrays.append(arr) + out_features[i][field_name] = np.concatenate(arrays) + + # 3) Now call the parent collator, which will do: + # - padding of labels/position_ids + # - KD-specific padding for target_logprobs, target_token_ids, etc. + # - final conversion to return_tensors + return super().__call__(out_features, return_tensors=return_tensors) diff --git a/src/axolotl/integrations/kd/kernels/__init__.py b/src/axolotl/integrations/kd/kernels/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/kd/topk_logprob/LICENSE.md b/src/axolotl/integrations/kd/topk_logprob/LICENSE.md new file mode 100644 index 000000000..435d36d75 --- /dev/null +++ b/src/axolotl/integrations/kd/topk_logprob/LICENSE.md @@ -0,0 +1,58 @@ +### AXOLOTL COMMUNITY LICENSE AGREEMENT + +This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and +any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms +and conditions set forth in this Agreement. + +1. Definitions + 1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement. + 1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl, + which may be licensed separately by their respective authors and/or licensors. + 1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at + https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which + permits Plugin Integrations to integrate with the Axolotl service. +2. Grant of License + 2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge, + publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions: + - Licensee must comply with all the terms and conditions of this Agreement. + - Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial + portions of the Software. + 2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3. +3. Restrictions + 3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for + free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such + third parties to fine-tune artificial intelligence models. + 3.2 Licensee shall not: + - Use the Software for any illegal or unauthorized purpose. + - Reverse engineer, decompile, or disassemble the Software. + - Remove or modify any copyright, trademark, or other proprietary notices contained in the Software. + - Use the Software in a way that could damage, disable, overburden, or impair the functionality of the + Software or interfere with any third-party use of the Software. + 3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement. +4. Intellectual Property Rights + 4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee + acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to + Licensee. +5. Disclaimer of Warranty + 5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED + TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL + THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF + CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER + DEALINGS IN THE SOFTWARE. +6. Termination + 6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and + conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any + copies in its possession. +7. Governing Law + 7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California, + without regards to conflicts of laws provisions thereof. +8. Entire Agreement + 8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter + hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning + the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and + Licensee’s continued use of the Software after any such updates shall constitute acceptance of updated terms + on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any + material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be + bound by the terms and conditions of this Agreement. + +This Agreement was last updated on August 23, 2024. diff --git a/src/axolotl/integrations/kd/topk_logprob/__init__.py b/src/axolotl/integrations/kd/topk_logprob/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py new file mode 100644 index 000000000..ab9a54d33 --- /dev/null +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -0,0 +1,235 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Axolotl Community License Agreement (the "License"); +# you may not use this file except in compliance with the License. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +""" +loss for top_k KL divergence +""" +import torch + + +def zscore_standardize( + logits: torch.Tensor, + mask: torch.Tensor = None, + base_temperature: float = 1.0, + eps: float = 1e-9, +): + """ + Z-score standardize along the last dimension of `logits`. + i.e., for each [B, seq_len] row, across K entries: + z = (logits - mean) / std, + then scale by 1 / base_temperature if desired. + + mask can be broadcastable or None. If None, we standardize all elements. + """ + if mask is None: + # shape: [B, seq_len, K] + # Mean and std over dim=-1 + mean = logits.mean(dim=-1, keepdim=True) + var = logits.var(dim=-1, unbiased=False, keepdim=True) + else: + # If you have to exclude some tokens, multiply by mask, etc. + float_mask = mask.to(logits.dtype) + count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0) + mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count + var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count + + std = torch.sqrt(var.clamp_min(eps)) + z = (logits - mean) / std + + # Scale by 1 / base_temperature + z = z / base_temperature + return z + + +@torch.jit.script +def loss( + student_logits: torch.Tensor, + target_token_ids: torch.Tensor, + target_logprobs: torch.Tensor, + target_mask: torch.Tensor, + num_items_in_batch: int = -1, # Use -1 to indicate "None" + kd_temperature: float = 1.0, + top_k_before_softmax: int = 0, +) -> torch.Tensor: + """ + A KD loss function that is TorchScript-friendly. + + Arguments: + student_logits (torch.Tensor): The logits of the student model. + Shape: [B, student_seq_len, vocab_size] + target_token_ids (torch.Tensor): The top-k teacher/target token IDs + Shape: [B, teacher_seq_len, top_k] + target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized. + Shape: [B, teacher_seq_len, top_k] + target_mask (torch.Tensor): The mask for valid tokens. + Shape: [B, teacher_seq_len, top_k] + num_items_in_batch (int, optional): The number of items in the batch. + kd_temperature (float, optional): The temperature for KD. + Default: 1.0 + top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits + Default: 0 + """ + + target_logprobs = target_logprobs.float() + + # Determine the teacher sequence length + # target_token_ids shape: [B, teacher_seq_len, K] + # student_logits shape: [B, student_seq_len, vocab_size] + teacher_seq_len = target_token_ids.shape[1] + + if top_k_before_softmax: + # Slice student logits to match teacher-provided sequence length + student_logits_for_kd = student_logits[ + :, :teacher_seq_len, : + ] # [B, teacher_seq_len, vocab_size] + + # Gather student logits for teacher's top-K tokens + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) # [B, teacher_seq_len, K] + + student_logits_topk = student_logits_topk.float() + + # Apply KD temperature to student’s logits + if kd_temperature != 1.0: + student_logits_topk = student_logits_topk / kd_temperature + + # Convert student top-k logits to logprobs + student_logprobs_topk = student_logits_topk - torch.logsumexp( + student_logits_topk, dim=-1, keepdim=True + ) # [B, teacher_seq_len, K] + else: + # Slice student logits to match teacher-provided sequence length + student_logits_for_kd = ( + student_logits[:, :teacher_seq_len, :] / kd_temperature + ) # [B, teacher_seq_len, vocab_size] + + # keep in full precision for numerical stability of loss + student_logits_for_kd = student_logits_for_kd.float() + + # Gather student logits for teacher's top-K tokens + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) # [B, teacher_seq_len, K] + + # Compute logsumexp across full vocabulary + student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) + + # Convert just the top-k logits to logprobs + student_logprobs_topk = student_logits_topk - student_lse + + # Convert teacher_mask to boolean for indexing + # In TorchScript, .bool() is sometimes unsupported, so we do: + valid_mask = target_mask.to(torch.bool) + + # Prune tensors to only keep valid tokens + student_logprobs_topk = student_logprobs_topk[valid_mask] + target_logprobs = target_logprobs[valid_mask] + + # Convert teacher logprobs to probabilities + teacher_probs = target_logprobs.exp() + + # Compute forward KL + kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) + kd_loss = kd_loss_per_token.sum() + + # Multiply by T^2 (classical KD scaling) + if kd_temperature != 1.0: + kd_loss = kd_loss * (kd_temperature**2) + + # Normalize by number of items (if provided) or by valid tokens + if num_items_in_batch > 0: + kd_loss = kd_loss / float(num_items_in_batch) + else: + # Fall back to average over valid tokens + kd_loss = kd_loss / float(kd_loss_per_token.size(0)) + + return kd_loss + + +def topk_kd_loss_with_zscore( + student_logits: torch.Tensor, # [B, seq_len, vocab_size] + target_token_ids: torch.Tensor, # [B, seq_len, K] + target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space + target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] + kd_temperature: float = 1.0, # classic KD temperature + zscore_base_temp: float = 1.0, # from the paper + num_items_in_batch: int = -1, +): + """ + A variant of top_k KL divergence with Z-score scaling + from "Logit Standardization in Knowledge Distillation". + """ + + target_logprobs = target_logprobs.float() + + B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name + # 1) Gather the student's top-k logits to match teacher + student_logits_for_kd = student_logits[ + :, :teacher_seq_len, : + ] # [B, seq_len, vocab] + student_topk_logits = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) # [B, seq_len, K] + + student_topk_logits = student_topk_logits.float() + + # 2) If you want to keep the "classical" T scaling, apply it first + if kd_temperature != 1.0: + student_topk_logits = student_topk_logits / kd_temperature + + # 3) Convert teacher logprobs -> treat them as “logits” for z-score + # (They differ by +some_constant from real logits, but in z-score + # that constant is subtracted out anyway.) + teacher_logits_for_zscore = target_logprobs # rename variable for clarity + + # 4) Z-score teacher and student + # If target_mask is 2D, expand to 3D for the K dimension + if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len): + target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K) + + teacher_z = zscore_standardize( + teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp + ) + student_z = zscore_standardize( + student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp + ) + + # 5) Convert to log-probs for KL + teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True) + student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True) + + # 6) Restrict to valid tokens if needed + valid_mask = target_mask.bool() # shape [B, seq_len, K] + teacher_probs_z = teacher_logprobs_z.exp() + teacher_probs_z = teacher_probs_z[valid_mask] + teacher_logprobs_z = teacher_logprobs_z[valid_mask] + student_logprobs_z = student_logprobs_z[valid_mask] + + # 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] ) + kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z) + kd_loss = kd_loss_per_token.sum() + + # 8) If using classical KD scaling by T^2 + if kd_temperature != 1.0: + kd_loss = kd_loss * (kd_temperature**2) + + # Optionally scale by zscore_base_temp**2 if you want (paper might differ). + # kd_loss = kd_loss * (zscore_base_temp**2) + + # 9) Normalize + if num_items_in_batch is not None and num_items_in_batch > 0: + kd_loss = kd_loss / float(num_items_in_batch) + else: + kd_loss = kd_loss / float(kd_loss_per_token.size(0)) + + return kd_loss diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py new file mode 100644 index 000000000..f99f2ca28 --- /dev/null +++ b/src/axolotl/integrations/kd/trainer.py @@ -0,0 +1,113 @@ +# Copyright 2024 Axolotl AI. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +KD trainer +""" + +from axolotl.core.trainers.base import AxolotlTrainer + +from .topk_logprob.forward_kl import loss as topk_kd_loss +from .topk_logprob.forward_kl import topk_kd_loss_with_zscore + + +class AxolotlKDTrainer(AxolotlTrainer): + """ + Custom trainer subclass for Knowledge Distillation (KD) + """ + + def _set_signature_columns_if_needed(self): + super()._set_signature_columns_if_needed() + columns_to_add = [] + if self._signature_columns: + if "target_logprobs" not in self._signature_columns: + columns_to_add.append("target_logprobs") + if "target_token_ids" not in self._signature_columns: + columns_to_add.append("target_token_ids") + if "target_mask" not in self._signature_columns: + columns_to_add.append("target_mask") + if columns_to_add: + self._signature_columns += columns_to_add + + def compute_loss( + self, + model, + inputs, + return_outputs=False, + num_items_in_batch=None, + ): + """ + How the loss is computed by Trainer. By default, all models return the loss in the first element. + + Subclass and override for custom behavior. + """ + + target_logprobs = inputs.pop("target_logprobs") + target_token_ids = inputs.pop("target_token_ids") + target_mask = inputs.pop("target_mask") + + seq_len = target_token_ids.shape[1] + + if self.model_accepts_loss_kwargs: + loss_kwargs = {} + if num_items_in_batch is not None: + loss_kwargs["num_items_in_batch"] = num_items_in_batch + inputs = {**inputs, **loss_kwargs} + outputs = model(**inputs) + + # FIXME: account for tokenizer.padding_side + student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() + + shift_logits = student_logits.contiguous() + target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() + target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() + target_mask_for_loss = target_mask[..., 1:, :].contiguous() + + if self.args.kd_zscore_base_temp: + loss_kd = topk_kd_loss_with_zscore( + shift_logits, + target_token_ids_for_loss, + target_logprobs_for_loss, + target_mask_for_loss, + kd_temperature=self.args.kd_temperature, + zscore_base_temp=self.args.kd_zscore_base_temp, + num_items_in_batch=num_items_in_batch, + ) + else: + loss_kd = topk_kd_loss( + shift_logits, + target_token_ids_for_loss, + target_logprobs_for_loss, + target_mask_for_loss, + num_items_in_batch=num_items_in_batch, + kd_temperature=self.args.kd_temperature, + top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, + ) + + if self.args.kd_ce_alpha > 0: + kd_alpha = self.args.kd_alpha + loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd + else: + loss = loss_kd + # Save past state if it exists + # TODO: this needs to be fixed and made cleaner later. + if self.args.past_index >= 0: + self._past = outputs[ # pylint: disable=attribute-defined-outside-init + self.args.past_index + ] + + if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: + loss *= self.accelerator.num_processes + + return (loss, outputs) if return_outputs else loss diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 74da20c5e..645a9329c 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -16,10 +16,21 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): return messages_load(tokenizer, cfg, ds_cfg, processor=processor) load_fn = "load" + package = "axolotl.prompt_strategies" if strategy.split(".")[-1].startswith("load_"): load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) - mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies") + elif len(strategy.split(".")) > 1: + try: + importlib.import_module( + "." + strategy.split(".")[-1], + ".".join(strategy.split(".")[:-1]), + ) + package = ".".join(strategy.split(".")[:-1]) + strategy = strategy.split(".")[-1] + except ModuleNotFoundError: + pass + mod = importlib.import_module(f".{strategy}", package) func = getattr(mod, load_fn) load_kwargs = {} if strategy == "user_defined": diff --git a/src/axolotl/prompt_strategies/base.py b/src/axolotl/prompt_strategies/base.py index fce2aba14..cddb3d0e1 100644 --- a/src/axolotl/prompt_strategies/base.py +++ b/src/axolotl/prompt_strategies/base.py @@ -10,6 +10,8 @@ LOG = logging.getLogger("axolotl") def load(strategy, cfg, module_base=None, **kwargs): try: + if len(strategy.split(".")) == 1: + strategy = strategy + ".default" load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) mod = importlib.import_module(f".{strategy}", module_base) diff --git a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py index 4f60842c5..c6b0fe2cf 100644 --- a/src/axolotl/prompt_strategies/bradley_terry/chat_template.py +++ b/src/axolotl/prompt_strategies/bradley_terry/chat_template.py @@ -21,7 +21,11 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): Bradley-Terry reward model pairwise chat template prompt strategy. """ - def tokenize_prompt(self, prompt): + @property + def supports_batched(self) -> bool: + return False + + def _tokenize_single_prompt(self, prompt): """ :param prompt: the actual row of data from the underlying dataset @@ -39,7 +43,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): ) prompt[self.messages].append({"role": "user", "content": prompt["input"]}) prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]}) - chosen_tokenized = super().tokenize_prompt(prompt) + chosen_tokenized = super()._tokenize_single_prompt(prompt) if len(chosen_tokenized["input_ids"]) > max_length: LOG.warning( @@ -62,7 +66,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy): prompt[self.messages].append( {"role": "assistant", "content": prompt["rejected"]} ) - rejected_tokenized = super().tokenize_prompt(prompt) + rejected_tokenized = super()._tokenize_single_prompt(prompt) if len(rejected_tokenized["input_ids"]) > max_length: LOG.warning( diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index 5b12130d7..bb87ee45b 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -3,6 +3,7 @@ HF Chat Templates prompt strategy """ import logging +from collections import defaultdict from typing import Any, Dict, List, Optional from transformers import ProcessorMixin @@ -193,7 +194,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def __init__( self, - prompter, + prompter: ChatTemplatePrompter, tokenizer, train_on_inputs, sequence_len, @@ -220,22 +221,61 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def messages(self, messages): self._messages = messages - def tokenize_prompt(self, prompt): + @property + def supports_batched(self) -> bool: + # Let calling code know we can handle lists of examples + return True + + def is_prompt_batched(self, prompt: dict[str, Any]) -> bool: + try: + return all(isinstance(v, list) for v in prompt.values()) and all( + isinstance(v, list) for v in prompt[self.messages] + ) + except KeyError: + return False + + def tokenize_prompt(self, prompt: dict[str, Any]): + """ + Public method that can handle either a single prompt or a batch of prompts. + """ + + if not self.is_prompt_batched(prompt) or not self.supports_batched: + return self._tokenize_single_prompt(prompt) + + res = defaultdict(lambda: []) + feature_names = list(prompt.keys()) + + # Process each prompt individually + for row in zip(*prompt.values()): + tokenized_prompt = self._tokenize_single_prompt( + dict(zip(feature_names, row)) + ) + for key, val in tokenized_prompt.items(): + for i in range(0, len(val), self.sequence_len): + res[key].append(val[i : i + self.sequence_len]) + + # If there are no examples left, return an empty dictionary + if not res: + return {} + + return dict(res) + + def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]: # Old simple legacy behavior that works reliably. if ( not self.roles_to_train and not self.train_on_eos - and not self.prompter.message_field_training - and not self.prompter.message_field_training_detail + and not self.prompter.message_field_training # type: ignore + and not self.prompter.message_field_training_detail # type: ignore ): turns = self.get_conversation_thread(prompt) images = self.get_images(prompt) - prompt_ids = self.prompter.build_prompt( + prompt_ids = self.prompter.build_prompt( # type: ignore turns[:-1], add_generation_prompt=True, images=images, ) - tokenized_res = self.prompter.build_prompt(turns, images=images) + tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore tokenized_prompt = {} if isinstance(tokenized_res, list): input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] @@ -256,7 +296,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return tokenized_prompt turns = self.get_conversation_thread(prompt) - input_ids = self.prompter.build_prompt(turns) + input_ids = self.prompter.build_prompt(turns) # type: ignore labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 @@ -286,7 +326,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if should_train and turn_start_idx != -1 and turn_end_idx != -1: if train_detail: - token_offsets = self.prompter.get_offsets_for_train_detail( + token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore content, train_detail ) LOG.debug(f"Token offsets: {token_offsets}") @@ -459,43 +499,62 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return prompt.get(self.images, None) -def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None): - # pylint: disable=duplicate-code - ds_cfg = ds_cfg or {} - chat_template_string = get_chat_template_from_config( - cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer - ) - LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---") +class StrategyLoader: + """ + Load chat template strategy based on configuration. + """ - prompter_params = { - "tokenizer": tokenizer, - "chat_template": chat_template_string, - "message_field_role": ds_cfg.get("message_field_role", "role"), - "message_field_content": ds_cfg.get("message_field_content", "content"), - "message_field_training": ds_cfg.get("message_field_training", None), - "message_field_training_detail": ds_cfg.get( - "message_field_training_detail", - None, - ), - "roles": ds_cfg.get("roles"), - "drop_system_message": ds_cfg.get("drop_system_message", False), - # we need to add one for detecting sequences with exceeding the `sequence_len` limit. - "max_length": cfg.sequence_len + 1, - "processor": processor, - } + def _get_strategy_cls(self): + return ChatTemplateStrategy - strategy_params = { - "train_on_inputs": cfg.train_on_inputs, - "sequence_len": cfg.sequence_len, - "roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]), - "train_on_eos": ds_cfg.get("train_on_eos", "turn"), - } + def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]): + return { + "train_on_inputs": cfg.train_on_inputs, + "sequence_len": cfg.sequence_len, + "roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]), + "train_on_eos": ds_cfg.get("train_on_eos", "turn"), + } - strategy = ChatTemplateStrategy( - ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params - ) + def __call__( + self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None + ): + # pylint: disable=duplicate-code + ds_cfg = ds_cfg or {} + chat_template_string = get_chat_template_from_config( + cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer + ) + LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---") - if "field_messages" in ds_cfg and hasattr(strategy, "messages"): - strategy.messages = ds_cfg["field_messages"] + prompter_params = { + "tokenizer": tokenizer, + "chat_template": chat_template_string, + "message_field_role": ds_cfg.get("message_field_role", "role"), + "message_field_content": ds_cfg.get("message_field_content", "content"), + "message_field_training": ds_cfg.get("message_field_training", None), + "message_field_training_detail": ds_cfg.get( + "message_field_training_detail", + None, + ), + "roles": ds_cfg.get("roles"), + "drop_system_message": ds_cfg.get("drop_system_message", False), + # we need to add one for detecting sequences with exceeding the `sequence_len` limit. + "max_length": cfg.sequence_len + 1, + "processor": processor, + } - return strategy + strategy_params = self._get_strategy_params(cfg, ds_cfg) + strategy_cls = self._get_strategy_cls() + + strategy = strategy_cls( + ChatTemplatePrompter(**prompter_params), + tokenizer=tokenizer, + **strategy_params, + ) + + if "field_messages" in ds_cfg and hasattr(strategy, "messages"): + strategy.messages = ds_cfg["field_messages"] + + return strategy + + +load = StrategyLoader() diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index 585696e29..5043a501e 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -3,22 +3,41 @@ DPO strategies for chatml """ -def argilla( +def default( cfg, **kwargs, ): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): + if "prompt" in sample.keys(): + prompt_key = "prompt" + elif "input" in sample.keys(): + prompt_key = "input" + elif "question" in sample.keys(): + prompt_key = "question" + else: + prompt_key = "instruction" + + if "chosen" in sample.keys(): + chosen_key = "chosen" + else: + chosen_key = "chosen_response" + + if "rejected" in sample.keys(): + rejected_key = "rejected" + else: + rejected_key = "rejected_response" + if "system" in sample and sample["system"]: sample["prompt"] = ( f"<|im_start|>system\n{sample['system']}<|im_end|>\n" - f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" + f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" ) else: sample[ "prompt" - ] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n" - sample["chosen"] = f"{sample['chosen_response']}<|im_end|>" - sample["rejected"] = f"{sample['rejected_response']}<|im_end|>" + ] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample[chosen_key]}<|im_end|>" + sample["rejected"] = f"{sample[rejected_key]}<|im_end|>" return sample return transform_fn diff --git a/src/axolotl/prompt_strategies/dpo/llama3.py b/src/axolotl/prompt_strategies/dpo/llama3.py index cb394cc22..d10aa223b 100644 --- a/src/axolotl/prompt_strategies/dpo/llama3.py +++ b/src/axolotl/prompt_strategies/dpo/llama3.py @@ -3,22 +3,42 @@ DPO strategies for llama-3 chat template """ -def argilla( +def default( cfg, **kwargs, ): # pylint: disable=possibly-unused-variable,unused-argument def transform_fn(sample): + # pylint: disable=duplicate-code + if "prompt" in sample.keys(): + prompt_key = "prompt" + elif "input" in sample.keys(): + prompt_key = "input" + elif "question" in sample.keys(): + prompt_key = "question" + else: + prompt_key = "instruction" + + if "chosen" in sample.keys(): + chosen_key = "chosen" + else: + chosen_key = "chosen_response" + + if "rejected" in sample.keys(): + rejected_key = "rejected" + else: + rejected_key = "rejected_response" + if "system" in sample and sample["system"]: sample["prompt"] = ( f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>" - f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" ) else: sample[ "prompt" - ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" - sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>" - sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>" + ] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>" + sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>" return sample return transform_fn diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index bd6e3f9dc..c29fd05a4 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -2,7 +2,7 @@ import abc import logging -from typing import Dict, List, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from transformers import BatchEncoding, PreTrainedTokenizer @@ -34,6 +34,8 @@ class PromptTokenizingStrategy(abc.ABC): Abstract class for tokenizing strategies """ + filter_rows: Optional[Callable] = None + def __init__( self, prompter: Prompter, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f05b259b7..028b7ea18 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -163,6 +163,7 @@ class SFTDataset(BaseModel): type: Optional[Union[str, UserDefinedPrompterType]] = None input_transform: Optional[str] = None shards: Optional[int] = None + preprocess_shards: Optional[int] = None conversation: Optional[str] = None # Do not make this too strict or it will break the validator to choose different dataset class chat_template: Optional[ @@ -185,6 +186,8 @@ class SFTDataset(BaseModel): message_field_content: Optional[str] = None message_field_training: Optional[str] = None message_field_training_detail: Optional[str] = None + logprobs_field: Optional[str] = None + temperature: Optional[float] = None roles_to_train: Optional[List[str]] = None train_on_eos: Optional[str] = None roles: Optional[Dict[str, List[str]]] = None @@ -861,6 +864,7 @@ class AxolotlInputConfig( # INTERNALS - document for now, generally not set externally is_preprocess: Optional[bool] = None + preprocess_iterable: Optional[bool] = None total_num_tokens: Optional[int] = None total_supervised_tokens: Optional[int] = None diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index ba5d0c54d..79bbb2972 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -3,11 +3,12 @@ import functools import logging from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Optional, Tuple, Union from datasets import ( Dataset, DatasetDict, + IterableDataset, Sequence, Value, concatenate_datasets, @@ -17,7 +18,7 @@ from datasets import ( from transformers import PreTrainedTokenizerBase from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH -from axolotl.datasets import TokenizedPromptDataset +from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt from axolotl.prompt_strategies import load from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load from axolotl.prompt_tokenizers import ( @@ -59,7 +60,7 @@ LOG = logging.getLogger("axolotl") @retry_on_request_exceptions(max_retries=3, delay=5) -def prepare_dataset(cfg, tokenizer, processor=None): +def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_local_main_process()): @@ -70,6 +71,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): DEFAULT_DATASET_PREPARED_PATH, split="train", processor=processor, + preprocess_iterable=preprocess_iterable, ) _, eval_dataset, _ = load_prepare_datasets( tokenizer, @@ -77,6 +79,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): DEFAULT_DATASET_PREPARED_PATH, split="test", processor=processor, + preprocess_iterable=preprocess_iterable, ) else: train_dataset, eval_dataset, prompters = load_prepare_datasets( @@ -84,6 +87,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): cfg, DEFAULT_DATASET_PREPARED_PATH, processor=processor, + preprocess_iterable=preprocess_iterable, ) else: # Load streaming dataset if pretraining_dataset is given @@ -139,6 +143,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): DEFAULT_DATASET_PREPARED_PATH, split="test", processor=processor, + preprocess_iterable=preprocess_iterable, ) if cfg.dataset_exact_deduplication: @@ -170,6 +175,7 @@ def load_tokenized_prepared_datasets( default_dataset_prepared_path, split="train", processor=None, + preprocess_iterable: Optional[bool] = None, ) -> Tuple[DatasetDict, List[Prompter]]: cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = cfg.tokenizer_config @@ -184,10 +190,11 @@ def load_tokenized_prepared_datasets( + "@" + str(cfg.group_by_length) + "@" + + str(cfg.kd_temperature or 1.0) + "|".join( sorted( [ - f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}" + f"{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}" for d in cfg_datasets ] ) @@ -262,13 +269,25 @@ def load_tokenized_prepared_datasets( # at the same time for a given dataset for name in dataset.name: yield DictDefault({**dataset, "name": name}) + elif dataset.preprocess_shards and not dataset.shards: + for shard in range(dataset.preprocess_shards): + yield DictDefault( + { + **dataset, + "shards": dataset.preprocess_shards, + "shards_idx": shard, + } + ) else: yield dataset + streaming_ds = False + if preprocess_iterable: + streaming_ds = True # pylint: disable=invalid-name for config_dataset in for_d_in_datasets(cfg_datasets): ds: Union[Dataset, DatasetDict] = load_dataset_w_config( - config_dataset, use_auth_token + config_dataset, use_auth_token, streaming=streaming_ds ) d_base_type = d_prompt_style = None @@ -325,7 +344,21 @@ def load_tokenized_prepared_datasets( if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") - dataset.save_to_disk(str(prepared_ds_path)) + if isinstance(dataset, IterableDataset): + + def gen_from_iter_ds(_ds, _=None): + yield from _ds + + ds_from_iter = Dataset.from_generator( + functools.partial(gen_from_iter_ds, dataset), + features=dataset.features, + num_proc=cfg.dataset_processes, + split=split, + gen_kwargs={"_": list(range(cfg.dataset_processes))}, + ) + ds_from_iter.save_to_disk(str(prepared_ds_path)) + else: + dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: LOG.info( f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." @@ -345,6 +378,7 @@ def load_prepare_datasets( default_dataset_prepared_path, split="train", processor=None, + preprocess_iterable: Optional[bool] = False, ) -> Tuple[Dataset, Dataset, List[Prompter]]: dataset, prompters = load_tokenized_prepared_datasets( tokenizer, @@ -352,6 +386,7 @@ def load_prepare_datasets( default_dataset_prepared_path, split=split, processor=processor, + preprocess_iterable=preprocess_iterable, ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: @@ -451,7 +486,7 @@ def get_dataset_wrapper( "user_defined", tokenizer, cfg, config_dataset.type.to_dict() ) dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( + dataset_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -464,7 +499,7 @@ def get_dataset_wrapper( config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset ): dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( + dataset_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -487,7 +522,7 @@ def get_dataset_wrapper( dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs) else: dataset_prompter = UnsupportedPrompter() - dataset_wrapper = TokenizedPromptDataset( + dataset_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -500,7 +535,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -514,7 +549,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -528,7 +563,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -542,7 +577,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -556,7 +591,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -570,7 +605,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -584,7 +619,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, @@ -598,7 +633,7 @@ def get_dataset_wrapper( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset( + ds_wrapper = wrap_dataset_for_tokenized_prompt( ds_strategy, dataset, **ds_kwargs, diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index e4f31a184..013d7a895 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault): return ds_type -def load_dataset_w_config(config_dataset, auth_token): +def load_dataset_w_config( + config_dataset, auth_token, streaming=False +) -> Union[Dataset, DatasetDict]: # pylint: disable=invalid-name ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name ds_from_hub = False @@ -124,7 +126,7 @@ def load_dataset_w_config(config_dataset, auth_token): ds_type, name=config_dataset.name, data_files=config_dataset.data_files, - streaming=False, + streaming=streaming, **load_ds_kwargs, ) else: @@ -157,7 +159,7 @@ def load_dataset_w_config(config_dataset, auth_token): ds = load_dataset( config_dataset.path, name=config_dataset.name, - streaming=False, + streaming=streaming, data_files=config_dataset.data_files, token=auth_token, revision=config_dataset.revision, @@ -176,7 +178,7 @@ def load_dataset_w_config(config_dataset, auth_token): ds_type, name=config_dataset.name, data_files=config_dataset.path, - streaming=False, + streaming=streaming, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, **load_ds_kwargs, @@ -187,7 +189,7 @@ def load_dataset_w_config(config_dataset, auth_token): ds_type, name=config_dataset.name, data_files=config_dataset.path, - streaming=False, + streaming=streaming, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, **load_ds_kwargs, @@ -217,7 +219,7 @@ def load_dataset_w_config(config_dataset, auth_token): "json", name=config_dataset.name, data_files=fp, - streaming=False, + streaming=streaming, **load_ds_kwargs, ) if not ds: diff --git a/src/axolotl/utils/tokenization.py b/src/axolotl/utils/tokenization.py index 139d50110..e0b21a9f0 100644 --- a/src/axolotl/utils/tokenization.py +++ b/src/axolotl/utils/tokenization.py @@ -26,6 +26,7 @@ def check_example_labels(example, tokenizer, text_only=False): # Get the input_ids, labels, and attention_mask from the dataset input_ids = example["input_ids"] labels = example["labels"] + target_mask = example.pop("target_mask", None) # You can compare the input_ids and labels element-wise # Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0 @@ -42,6 +43,13 @@ def check_example_labels(example, tokenizer, text_only=False): delimiter = "" if text_only else " " LOG.info(delimiter.join(colored_tokens)) LOG.info("\n\n\n") + target_labels_count = sum(label_id != -100 for label_id in labels) + total_len = len(input_ids) + LOG.info(f"Total input len: {total_len}") + LOG.info(f"Count of labels: {target_labels_count}") + if target_mask: + target_mask_positions = sum(m[0] for m in target_mask) + LOG.info(f"Number of positions in target_mask: {target_mask_positions}") return " ".join(colored_tokens) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index caa74fccc..32e9bdfb4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -11,7 +11,7 @@ import numpy as np import torch import torch.cuda from accelerate.logging import get_logger -from datasets import disable_caching, enable_caching +from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available @@ -95,9 +95,41 @@ def disable_datasets_caching(): def add_position_ids(sample): - sample_len = len(sample["input_ids"]) - sample["position_ids"] = torch.arange(len(sample["input_ids"])) - sample["length"] = sample_len + """ + Handle both single-example and batched data. + - single example: sample['input_ids'] is a list[int] + - batched data: sample['input_ids'] is a list[list[int]] + """ + # Return sample unchanged if "input_ids" is not present, or is empty + if "input_ids" not in sample or not sample["input_ids"]: + return sample + + input_ids = sample["input_ids"] + + # If first element is an int, it’s a single example + # If first element is a list, it’s a batch + if isinstance(input_ids[0], int): + # ---- SINGLE EXAMPLE ---- + seq_len = len(input_ids) + # Position IDs for a single example + # As a list + sample["position_ids"] = list(range(seq_len)) + sample["length"] = seq_len + + else: + # ---- BATCHED EXAMPLES ---- + # input_ids is a list of lists + position_ids_batch = [] + lengths_batch = [] + for seq in input_ids: + seq_len = len(seq) + position_ids_batch.append(list(range(seq_len))) + lengths_batch.append(seq_len) + + # Now store them back + sample["position_ids"] = position_ids_batch + sample["length"] = lengths_batch + return sample @@ -172,10 +204,31 @@ def add_length(sample): def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): - return ( - len(sample["input_ids"]) <= sequence_len - and len(sample["input_ids"]) >= min_sequence_len - ) + """ + Drop samples whose sequence length is either too long (> sequence_len) + or too short (< min_sequence_len). + + Works for both single-example (list[int]) or batched (list[list[int]]). + """ + input_ids = sample["input_ids"] + + # Edge case: if input_ids is empty + if not input_ids: + # Decide if you want to drop or keep empty. Let's drop. + return False + + # Check if single example or batched by looking at the first element + if isinstance(input_ids[0], int): + # Single example (input_ids is a list of int) + length = len(input_ids) + return min_sequence_len <= length <= sequence_len + + # Batched (input_ids is a list of lists) + results = [] + for seq in input_ids: + length = len(seq) + results.append(min_sequence_len <= length <= sequence_len) + return results def process_datasets_for_packing(cfg, train_dataset, eval_dataset): @@ -185,10 +238,13 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): min_sequence_len=cfg.min_sample_len or 2, ) - min_input_len = np.min(get_dataset_lengths(train_dataset)) - LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) - max_input_len = np.max(get_dataset_lengths(train_dataset)) - LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) + try: + min_input_len = np.min(get_dataset_lengths(train_dataset)) + LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) + max_input_len = np.max(get_dataset_lengths(train_dataset)) + LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) + except AttributeError: + pass if cfg.model_config_type == "mamba": LOG.info("dropping attention_mask column") @@ -203,60 +259,106 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") - prior_len = len(train_dataset) + filter_map_kwargs = {} + if not isinstance(train_dataset, IterableDataset): + filter_map_kwargs["num_proc"] = cfg.dataset_processes + filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess + + try: + prior_len = len(train_dataset) + except TypeError: + # handle iterable datasets case + prior_len = None + drop_long_kwargs = {} + if filter_map_kwargs: + drop_long_kwargs["desc"] = "Dropping Long Sequences" train_dataset = train_dataset.filter( drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, ) - dropped = prior_len - len(train_dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from train dataset") + if prior_len: + dropped = prior_len - len(train_dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from train dataset") if eval_dataset: - prior_len = len(eval_dataset) + try: + prior_len = len(eval_dataset) + except TypeError: + # handle iterable datasets case + prior_len = None eval_dataset = eval_dataset.filter( drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", + **filter_map_kwargs, + **drop_long_kwargs, ) - dropped = prior_len - len(eval_dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from eval dataset") + if prior_len: + dropped = prior_len - len(eval_dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from eval dataset") - # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): - return np.sum(np.array(sample["labels"]) != -100) > 0 + """ + Drop samples if all labels are -100 (i.e., zero trainable tokens). + Works for both single-example or batched input. + """ + labels = sample["labels"] + if not labels: + return True - prior_len = len(train_dataset) + # Check if single example or batch + # If first element is an int, we assume a single example + # If it's a list, we assume we're dealing with a batch + if isinstance(labels[0], int): + # Single example: return a single bool + return np.any(labels != -100) + + # Batched: 'labels' is a list of lists + # Return a list of booleans, one per sub-list + results = [np.any(row_labels != -100) for row_labels in labels] + return results + + try: + prior_len = len(train_dataset) + except TypeError: + # handle iterable datasets case + prior_len = None + drop_long_kwargs = {} + if filter_map_kwargs: + drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens" train_dataset = train_dataset.filter( drop_no_trainable_tokens, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Drop Samples with Zero Trainable Tokens", + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, ) - dropped = prior_len - len(train_dataset) - if dropped: - LOG.warning( - f"Dropped {dropped} samples with no trainable tokens from train dataset" - ) - - if eval_dataset: - prior_len = len(eval_dataset) - eval_dataset = eval_dataset.filter( - drop_no_trainable_tokens, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Drop Samples with Zero Trainable Tokens", - ) - dropped = prior_len - len(eval_dataset) + if prior_len: + dropped = prior_len - len(train_dataset) if dropped: LOG.warning( - f"Dropped {dropped} samples with no trainable tokens from eval dataset" + f"Dropped {dropped} samples with no trainable tokens from train dataset" ) + if eval_dataset: + try: + prior_len = len(eval_dataset) + except TypeError: + # handle iterable datasets case + prior_len = None + eval_dataset = eval_dataset.filter( + drop_no_trainable_tokens, + **filter_map_kwargs, + **drop_long_kwargs, + ) + if prior_len: + dropped = prior_len - len(eval_dataset) + if dropped: + LOG.warning( + f"Dropped {dropped} samples with no trainable tokens from eval dataset" + ) + if cfg.group_by_length: train_dataset = train_dataset.map( add_length, @@ -291,19 +393,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Add position_id column (PoSE)", ) elif cfg.sample_packing: + drop_long_kwargs = {} + if filter_map_kwargs: + drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" train_dataset = train_dataset.map( add_position_ids, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (Sample Packing)", + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, ) if cfg.eval_sample_packing is not False: if eval_dataset: eval_dataset = eval_dataset.map( add_position_ids, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (Sample Packing)", + **filter_map_kwargs, + **drop_long_kwargs, ) return train_dataset, eval_dataset @@ -337,7 +441,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): and not cfg.reward_model ): total_num_tokens = np.sum( - train_dataset.data.column("input_ids") + train_dataset.select_columns("input_ids") .to_pandas() .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .values diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py new file mode 100644 index 000000000..e885d3b48 --- /dev/null +++ b/tests/e2e/integrations/test_kd.py @@ -0,0 +1,121 @@ +""" +e2e tests for kd trainer support in Axolotl +""" +from pathlib import Path + +import pytest +from e2e.utils import check_tensorboard, require_torch_2_5_1 + +from axolotl.cli.args import TrainerCliArgs +from axolotl.common.datasets import load_datasets +from axolotl.train import train +from axolotl.utils.config import normalize_config, prepare_plugins +from axolotl.utils.dict import DictDefault + + +@pytest.fixture(name="kd_min_cfg") +def min_cfg(temp_dir): + return { + "base_model": "osllmai-community/Llama-3.2-1B", + "tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer", + "plugins": [ + "axolotl.integrations.kd.KDPlugin", + "axolotl.integrations.liger.LigerPlugin", + ], + "liger_rms_norm": True, + "liger_glu_activation": True, + "torch_compile": True, + "chat_template": "llama3", + "kd_trainer": True, + "kd_ce_alpha": 0.1, + "kd_alpha": 0.9, + "kd_temperature": 1.0, + "dataloader_prefetch_factor": 8, + "dataloader_num_workers": 4, + "dataloader_pin_memory": True, + "datasets": [ + { + "path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", + "type": "axolotl.integrations.kd.chat_template", + "field_messages": "messages_combined", + "split": "train", + "logprobs_field": "llm_text_generation_vllm_logprobs", + "temperature": 1.0, + "preprocess_shards": 2, + }, + ], + "val_set_size": 0.0, + "sequence_len": 2048, + "sample_packing": True, + "pad_to_sequence_len": True, + "gradient_accumulation_steps": 2, + "micro_batch_size": 1, + "num_epochs": 1, + "optimizer": "adamw_8bit", + "lr_scheduler": "cosine", + "learning_rate": 0.00001, + "bf16": "auto", + "gradient_checkpointing": True, + "flash_attention": True, + "special_tokens": { + "pad_token": "<|end_of_text|>", + "eos_token": "<|eot_id|>", + }, + "max_steps": 5, + "output_dir": temp_dir, + "save_safetensors": True, + "use_tensorboard": True, + } + + +class TestKnowledgeDistillation: + """ + Test case for Knowledge Distillation + """ + + # While this will run on torch 2.4.x without torch_compile enabled + # the VRAM requirement is higher than what is available in CI + @require_torch_2_5_1 + def test_llama_kd(self, temp_dir, kd_min_cfg): + cfg = DictDefault(kd_min_cfg) + # pylint: disable=duplicate-code + prepare_plugins(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "model.safetensors").exists() + check_tensorboard( + temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" + ) + + @pytest.mark.parametrize( + "load_in_8bit", + [True, False], + ) + def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit): + cfg = DictDefault( + { + "load_in_8bit": load_in_8bit, + "torch_compile": False, + "adapter": "lora", + "peft_use_dora": True, + "lora_target_linear": True, + "lora_r": 16, + "lora_alpha": 32, + "lora_dropout": 0.0, + } + | kd_min_cfg + ) + # pylint: disable=duplicate-code + prepare_plugins(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.safetensors").exists() + check_tensorboard( + temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high" + ) diff --git a/tests/e2e/integrations/test_liger.py b/tests/e2e/integrations/test_liger.py index cf673cab2..226ed46f8 100644 --- a/tests/e2e/integrations/test_liger.py +++ b/tests/e2e/integrations/test_liger.py @@ -55,6 +55,7 @@ class LigerIntegrationTestCase: "max_steps": 5, } ) + # pylint: disable=duplicate-code prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs() @@ -100,6 +101,7 @@ class LigerIntegrationTestCase: "max_steps": 5, } ) + # pylint: disable=duplicate-code prepare_plugins(cfg) normalize_config(cfg) cli_args = TrainerCliArgs()