Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92c217677c |
@@ -55,7 +55,6 @@ Features:
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
|
||||
@@ -32,9 +32,8 @@ website:
|
||||
contents:
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/inference.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/config.qmd
|
||||
- docs/inference.qmd
|
||||
|
||||
- section: "Dataset Formats"
|
||||
contents: docs/dataset-formats/*
|
||||
@@ -75,6 +74,10 @@ website:
|
||||
- docs/debugging.qmd
|
||||
- docs/nccl.qmd
|
||||
|
||||
- section: "Reference"
|
||||
contents:
|
||||
- docs/config.qmd
|
||||
|
||||
format:
|
||||
html:
|
||||
theme: darkly
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Config Reference
|
||||
title: Config options
|
||||
description: A complete list of all configuration options.
|
||||
---
|
||||
|
||||
@@ -30,8 +30,6 @@ tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
@@ -207,46 +205,10 @@ test_datasets:
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
||||
# use RL training: 'dpo', 'ipo', 'kto'
|
||||
rl:
|
||||
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
||||
|
||||
# dpo
|
||||
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
||||
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
||||
|
||||
# orpo
|
||||
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
||||
|
||||
# kto
|
||||
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
||||
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
||||
|
||||
# simpo
|
||||
cpo_alpha: 1.0 # Weight of the BC regularizer
|
||||
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||
|
||||
# grpo
|
||||
trl:
|
||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
||||
|
||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||
|
||||
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
||||
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
||||
|
||||
num_generations: # Optional[int]. Number of generations to sample.
|
||||
log_completions: # Optional[bool]. Whether to log completions.
|
||||
|
||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||
|
||||
# whether to perform weighting if doing DPO training. Boolean.
|
||||
dpo_use_weighting:
|
||||
|
||||
# reward modelling: `True` or `False`
|
||||
reward_model:
|
||||
@@ -270,7 +232,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
||||
push_dataset_to_hub: # repo path
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
dataset_processes: # defaults to os.cpu_count() if not set
|
||||
|
||||
10
docs/faq.qmd
10
docs/faq.qmd
@@ -27,16 +27,6 @@ description: Frequently asked questions
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
|
||||
|
||||
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
|
||||
|
||||
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
|
||||
|
||||
**Q: How to call Axolotl via custom python scripts?**
|
||||
|
||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
|
||||
@@ -36,9 +36,7 @@ The YAML configuration file controls everything about your training. Here's what
|
||||
|
||||
```yaml
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
|
||||
load_in_8bit: true
|
||||
adapter: lora
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
@@ -46,15 +44,11 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
|
||||
|
||||
- To perform Full finetuning, remove these two lines.
|
||||
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||
:::
|
||||
|
||||
See our [Config options](config.qmd) for more details.
|
||||
|
||||
### Training {#sec-training}
|
||||
@@ -62,7 +56,7 @@ See our [Config options](config.qmd) for more details.
|
||||
When you run `axolotl train`, Axolotl:
|
||||
|
||||
1. Downloads the base model
|
||||
2. (If specified) applies QLoRA/LoRA adapter layers
|
||||
2. (If specified) applies LoRA adapter layers
|
||||
3. Loads and processes the dataset
|
||||
4. Runs the training loop
|
||||
5. Saves the trained model and / or LoRA weights
|
||||
@@ -75,8 +69,6 @@ Let's modify the example for your own data:
|
||||
|
||||
```yaml
|
||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||
|
||||
load_in_8bit: true
|
||||
adapter: lora
|
||||
|
||||
# Training settings
|
||||
@@ -112,6 +104,8 @@ format):
|
||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||
```
|
||||
|
||||
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
||||
|
||||
3. Run the training:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Inference and Merging"
|
||||
title: "Inference"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
@@ -9,14 +9,10 @@ execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
|
||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
||||
|
||||
## Quick Start {#sec-quickstart}
|
||||
|
||||
::: {.callout-tip}
|
||||
Use the same config used for training on inference/merging.
|
||||
:::
|
||||
|
||||
### Basic Inference {#sec-basic}
|
||||
|
||||
::: {.panel-tabset}
|
||||
|
||||
@@ -22,7 +22,6 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
@@ -38,7 +37,7 @@ For the latest features between releases:
|
||||
```{.bash}
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install packaging ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
@@ -108,7 +107,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install packaging
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
|
||||
@@ -66,10 +66,6 @@ logic to be compatible with more of them.
|
||||
|
||||
</details>
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
||||
:::
|
||||
|
||||
## Usage
|
||||
|
||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||
|
||||
@@ -41,10 +41,6 @@ Bradley-Terry chat templates expect single-turn conversations in the following f
|
||||
|
||||
### Process Reward Models (PRM)
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
|
||||
:::
|
||||
|
||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-3B
|
||||
|
||||
@@ -298,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
|
||||
### IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
|
||||
```yaml
|
||||
rl: ipo
|
||||
@@ -344,9 +344,8 @@ ORPO supports the following types with the following dataset format:
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
rl_beta: 0.1 # default
|
||||
kto_desirable_weight: 1.0 # default
|
||||
kto_undesirable_weight: 1.0 # default
|
||||
rl_beta: 0.5
|
||||
kto_desirable_weight: 0.2
|
||||
|
||||
remove_unused_columns: false
|
||||
|
||||
@@ -498,10 +497,6 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
|
||||
### GRPO
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||
:::
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
@@ -545,19 +540,6 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
|
||||
|
||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||
|
||||
### SimPO
|
||||
|
||||
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||
|
||||
```yaml
|
||||
rl: simpo
|
||||
rl_beta: 0.1 # default in CPOTrainer
|
||||
cpo_alpha: 1.0 # default in CPOTrainer
|
||||
simpo_gamma: 0.5 # default in CPOTrainer
|
||||
```
|
||||
|
||||
This method uses the same dataset format as [DPO](#dpo).
|
||||
|
||||
### Using local dataset files
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -55,7 +55,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: false
|
||||
use_reentrant: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
@@ -40,7 +40,6 @@ from trl.trainer.utils import RewardDataCollatorWithPadding
|
||||
|
||||
from axolotl.core.trainers.base import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlMambaTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
AxolotlPRMTrainer,
|
||||
@@ -51,6 +50,7 @@ from axolotl.core.trainers.base import (
|
||||
from axolotl.core.trainers.dpo import DPOStrategy
|
||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
||||
from axolotl.core.training_args import (
|
||||
AxolotlCPOConfig,
|
||||
AxolotlKTOConfig,
|
||||
|
||||
@@ -20,9 +20,10 @@ from torch.utils.data import BatchSampler, DataLoader, RandomSampler, Sequential
|
||||
from transformers import Trainer
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
|
||||
from transformers.utils import is_sagemaker_mp_enabled
|
||||
from trl import CPOTrainer, KTOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl import CPOTrainer, ORPOTrainer, PRMTrainer, RewardTrainer
|
||||
from trl.trainer.utils import pad_to_length
|
||||
|
||||
from axolotl.core.trainers.kto import AxolotlKTOTrainer
|
||||
from axolotl.integrations.base import BaseOptimizerFactory
|
||||
from axolotl.monkeypatch.relora import ReLoRAScheduler
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
@@ -874,14 +875,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
Extend the base CPOTrainer for axolotl helpers
|
||||
|
||||
7
src/axolotl/core/trainers/kto/__init__.py
Normal file
7
src/axolotl/core/trainers/kto/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""
|
||||
KTO package initialization.
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.kto.trainer import AxolotlKTOTrainer
|
||||
|
||||
__all__ = ["AxolotlKTOTrainer"]
|
||||
512
src/axolotl/core/trainers/kto/trainer.py
Normal file
512
src/axolotl/core/trainers/kto/trainer.py
Normal file
@@ -0,0 +1,512 @@
|
||||
"""
|
||||
KTO trainer implementation for Axolotl.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import warnings
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, Callable, Literal, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from datasets import Dataset
|
||||
from torch.utils.data import DataLoader, SequentialSampler
|
||||
from transformers import (
|
||||
AutoModelForCausalLM,
|
||||
BaseImageProcessor,
|
||||
DataCollator,
|
||||
FeatureExtractionMixin,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
ProcessorMixin,
|
||||
Trainer,
|
||||
TrainerCallback,
|
||||
TrainingArguments,
|
||||
)
|
||||
from transformers.trainer_utils import EvalLoopOutput
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer.kto_config import KTOConfig
|
||||
from trl.trainer.utils import KTODataCollatorWithPadding, pad_to_length
|
||||
|
||||
from axolotl.core.trainers.base import SchedulerMixin
|
||||
|
||||
# Check if PEFT is available
|
||||
try:
|
||||
from peft import PeftModel, get_peft_model, prepare_model_for_kbit_training, peft_module_casting_to_bf16
|
||||
is_peft_available = True
|
||||
except ImportError:
|
||||
is_peft_available = False
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainers.kto")
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, Trainer):
|
||||
"""
|
||||
Extend the base KTOTrainer for axolotl helpers
|
||||
"""
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Union[PreTrainedModel, nn.Module, str] = None,
|
||||
args: KTOConfig = None,
|
||||
train_dataset: Optional[Dataset] = None,
|
||||
eval_dataset: Optional[Union[Dataset, dict[str, Dataset]]] = None,
|
||||
processing_class: Optional[
|
||||
Union[PreTrainedTokenizerBase, BaseImageProcessor, FeatureExtractionMixin, ProcessorMixin]
|
||||
] = None,
|
||||
data_collator: Optional[DataCollator] = None,
|
||||
model_init: Optional[Callable[[], PreTrainedModel]] = None,
|
||||
callbacks: Optional[list[TrainerCallback]] = None,
|
||||
optimizers: tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR] = (None, None),
|
||||
preprocess_logits_for_metrics: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
peft_config: Optional[dict] = None,
|
||||
compute_metrics: Optional[Callable[[EvalLoopOutput], dict]] = None,
|
||||
dataset_tags=None,
|
||||
model_adapter_name: Optional[str] = None,
|
||||
ref_adapter_name: Optional[str] = None,
|
||||
):
|
||||
self.dataset_tags = dataset_tags
|
||||
self._tag_names = ["trl", "kto"]
|
||||
if hasattr(self, "tag_names"):
|
||||
self._tag_names.extend(self.tag_names)
|
||||
|
||||
if type(args) is TrainingArguments:
|
||||
raise ValueError("Please use `KTOConfig` instead TrainingArguments.")
|
||||
|
||||
if args.model_init_kwargs is None:
|
||||
model_init_kwargs = {}
|
||||
elif not isinstance(model, str):
|
||||
raise ValueError("You passed model_kwargs to the KTOTrainer. But your model is already instantiated.")
|
||||
else:
|
||||
model_init_kwargs = args.model_init_kwargs
|
||||
torch_dtype = model_init_kwargs.get("torch_dtype")
|
||||
if torch_dtype is not None:
|
||||
# Convert to `torch.dtype` if an str is passed
|
||||
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
||||
)
|
||||
model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
|
||||
if args.ref_model_init_kwargs is None:
|
||||
ref_model_init_kwargs = {}
|
||||
elif not isinstance(ref_model, str):
|
||||
raise ValueError(
|
||||
"You passed ref_model_kwargs to the KTOTrainer. But your ref_model is already instantiated."
|
||||
)
|
||||
else:
|
||||
ref_model_init_kwargs = args.ref_model_init_kwargs
|
||||
torch_dtype = ref_model_init_kwargs.get("torch_dtype")
|
||||
if torch_dtype is not None:
|
||||
# Convert to `torch.dtype` if an str is passed
|
||||
if isinstance(torch_dtype, str) and torch_dtype != "auto":
|
||||
torch_dtype = getattr(torch, torch_dtype)
|
||||
if torch_dtype != "auto" and not isinstance(torch_dtype, torch.dtype):
|
||||
raise ValueError(
|
||||
f"Invalid `torch_dtype` passed to the KTOConfig. Expected a string with either `torch.dtype` or 'auto', but got {torch_dtype}."
|
||||
)
|
||||
ref_model_init_kwargs["torch_dtype"] = torch_dtype
|
||||
|
||||
if isinstance(model, str):
|
||||
model = AutoModelForCausalLM.from_pretrained(model, **model_init_kwargs)
|
||||
|
||||
if isinstance(ref_model, str):
|
||||
ref_model = AutoModelForCausalLM.from_pretrained(ref_model, **ref_model_init_kwargs)
|
||||
|
||||
# Initialize this variable to False. This helps tracking the case when `peft_module_casting_to_bf16`
|
||||
# has been called in order to properly call autocast if needed.
|
||||
self._peft_has_been_casted_to_bf16 = False
|
||||
|
||||
if not is_peft_available() and peft_config is not None:
|
||||
raise ValueError(
|
||||
"PEFT is not installed and you passed a `peft_config` in the trainer's kwargs, please install it with `pip install peft` to use the PEFT models"
|
||||
)
|
||||
elif is_peft_available() and peft_config is not None:
|
||||
# if model is a peft model and we have a peft_config, we merge and unload it first
|
||||
if isinstance(model, PeftModel):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False):
|
||||
_support_gc_kwargs = hasattr(
|
||||
args, "gradient_checkpointing_kwargs"
|
||||
) and "gradient_checkpointing_kwargs" in list(
|
||||
inspect.signature(prepare_model_for_kbit_training).parameters
|
||||
)
|
||||
|
||||
prepare_model_kwargs = {"use_gradient_checkpointing": args.gradient_checkpointing}
|
||||
|
||||
if _support_gc_kwargs:
|
||||
prepare_model_kwargs["gradient_checkpointing_kwargs"] = args.gradient_checkpointing_kwargs
|
||||
|
||||
model = prepare_model_for_kbit_training(model, **prepare_model_kwargs)
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
# get peft model with the given config
|
||||
model = get_peft_model(model, peft_config)
|
||||
if args.bf16 and getattr(model, "is_loaded_in_4bit", False):
|
||||
peft_module_casting_to_bf16(model)
|
||||
# If args.bf16 we need to explicitly call `generate` with torch amp autocast context manager
|
||||
self._peft_has_been_casted_to_bf16 = True
|
||||
|
||||
# For models that use gradient_checkpointing, we need to attach a hook that enables input
|
||||
# to explicitly have `requires_grad=True`, otherwise training will either silently
|
||||
# fail or completely fail.
|
||||
elif getattr(args, "gradient_checkpointing", False):
|
||||
# For backward compatibility with older versions of transformers
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
else:
|
||||
|
||||
def make_inputs_require_grad(module, input, output):
|
||||
output.requires_grad_(True)
|
||||
|
||||
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
|
||||
|
||||
if args.generate_during_eval and not (is_wandb_available() or is_comet_available()):
|
||||
raise ValueError(
|
||||
"`generate_during_eval=True` requires Weights and Biases or Comet to be installed."
|
||||
" Please install `wandb` or `comet-ml` to resolve."
|
||||
)
|
||||
|
||||
if model is not None:
|
||||
self.is_encoder_decoder = model.config.is_encoder_decoder
|
||||
elif args.is_encoder_decoder is None:
|
||||
raise ValueError("When no model is provided, you need to pass the parameter is_encoder_decoder.")
|
||||
else:
|
||||
self.is_encoder_decoder = args.is_encoder_decoder
|
||||
|
||||
self.is_peft_model = is_peft_available() and isinstance(model, PeftModel)
|
||||
self.model_adapter_name = model_adapter_name
|
||||
self.ref_adapter_name = ref_adapter_name
|
||||
|
||||
if ref_model:
|
||||
self.ref_model = ref_model
|
||||
elif self.is_peft_model or args.precompute_ref_log_probs:
|
||||
# The `model` with adapters turned off will be used as the reference model
|
||||
self.ref_model = None
|
||||
else:
|
||||
self.ref_model = create_reference_model(model)
|
||||
|
||||
if processing_class is None:
|
||||
raise ValueError(
|
||||
"max_length or a processing_class must be specified when using the default DPODataCollatorWithPadding"
|
||||
)
|
||||
if args.max_length is None:
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding, you should set `max_length` in the KTOTrainer's init"
|
||||
" it will be set to `512` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_length = 512
|
||||
if args.max_length is not None:
|
||||
max_length = args.max_length
|
||||
|
||||
if args.max_prompt_length is None:
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding, you should set `max_prompt_length` in the KTOTrainer's init"
|
||||
" it will be set to `128` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_prompt_length = 128
|
||||
if args.max_prompt_length is not None:
|
||||
max_prompt_length = args.max_prompt_length
|
||||
|
||||
max_completion_length = None
|
||||
if args.max_completion_length is None and self.is_encoder_decoder:
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding with an encoder decoder architecture, you should set `max_completion_length` in the KTOTrainer's init"
|
||||
" it will be set to `128` by default, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
max_completion_length = 128
|
||||
if args.max_completion_length is not None and self.is_encoder_decoder:
|
||||
max_completion_length = args.max_completion_length
|
||||
|
||||
if data_collator is None:
|
||||
data_collator = DPODataCollatorWithPadding(
|
||||
pad_token_id=processing_class.pad_token_id,
|
||||
label_pad_token_id=args.label_pad_token_id,
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
)
|
||||
|
||||
if args.remove_unused_columns:
|
||||
args.remove_unused_columns = False
|
||||
# warn users
|
||||
warnings.warn(
|
||||
"When using DPODataCollatorWithPadding, you should set `remove_unused_columns=False` in your KTOConfig"
|
||||
" we have set it for you, but you should do it yourself in the future.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
self.use_dpo_data_collator = True
|
||||
else:
|
||||
self.use_dpo_data_collator = False
|
||||
|
||||
# Disable dropout in the model and reference model
|
||||
if args.disable_dropout:
|
||||
disable_dropout_in_model(model)
|
||||
if self.ref_model is not None:
|
||||
disable_dropout_in_model(self.ref_model)
|
||||
|
||||
self.loss_type = args.loss_type
|
||||
self.max_length = max_length
|
||||
self.generate_during_eval = args.generate_during_eval
|
||||
self.label_pad_token_id = args.label_pad_token_id
|
||||
self.padding_value = args.padding_value if args.padding_value is not None else processing_class.pad_token_id
|
||||
self.max_prompt_length = max_prompt_length
|
||||
self.truncation_mode = args.truncation_mode
|
||||
self.max_completion_length = max_completion_length
|
||||
self.processing_class = processing_class
|
||||
self.precompute_ref_log_probs = args.precompute_ref_log_probs
|
||||
|
||||
# Not all losses require a KL calculation
|
||||
self.calculate_KL = True
|
||||
if self.loss_type in ["apo_zero_unpaired"]:
|
||||
self.calculate_KL = False
|
||||
|
||||
# Since ref_logs are precomputed on the first call to get_train/eval_dataloader
|
||||
# keep track of first called to avoid computation of future calls
|
||||
self._precomputed_train_ref_log_probs = False
|
||||
self._precomputed_eval_ref_log_probs = False
|
||||
|
||||
# metric
|
||||
self._stored_metrics = defaultdict(lambda: defaultdict(list))
|
||||
|
||||
# KTO parameter
|
||||
self.beta = args.beta
|
||||
self.desirable_weight = args.desirable_weight
|
||||
self.undesirable_weight = args.undesirable_weight
|
||||
self.aux_loss_enabled = getattr(model.config, "output_router_logits", False)
|
||||
self.aux_loss_coef = getattr(model.config, "router_aux_loss_coef", 0.0)
|
||||
if self.aux_loss_enabled and self.aux_loss_coef == 0.0:
|
||||
warnings.warn(
|
||||
"You set `output_router_logits` to `True` in the model config, but `router_aux_loss_coef` is set to "
|
||||
"`0.0`, meaning the auxiliary loss will not be used. Either set `router_aux_loss_coef` to a value "
|
||||
"greater than `0.0`, or set `output_router_logits` to `False` if you don't want to use the auxiliary "
|
||||
"loss.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
# The trainer estimates the number of FLOPs (floating-point operations) using the number of elements in the
|
||||
# input tensor associated with the key "input_ids". However, in KTO, the sampled data does not include the
|
||||
# "input_ids" key. Instead, the available keys are "prompt_input_ids" and "completion_input_ids". As a result,
|
||||
# the trainer issues the warning: "Could not estimate the number of tokens of the input, floating-point
|
||||
# operations will not be computed." To suppress this warning, we set the "estimate_tokens" key in the model's
|
||||
# "warnings_issued" dictionary to True. This acts as a flag to indicate that the warning has already been
|
||||
# issued.
|
||||
model.warnings_issued["estimate_tokens"] = True
|
||||
|
||||
# Compute that only on the main process for faster data processing.
|
||||
# see: https://github.com/huggingface/trl/pull/1255
|
||||
with PartialState().local_main_process_first():
|
||||
# Extract the prompt if needed
|
||||
train_dataset = train_dataset.map(
|
||||
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from train dataset"
|
||||
)
|
||||
# Unpair the dataset if needed
|
||||
train_dataset = maybe_unpair_preference_dataset(
|
||||
train_dataset, args.dataset_num_proc, desc="Unpairing train dataset"
|
||||
)
|
||||
# Apply the chat template if needed
|
||||
train_dataset = train_dataset.map(
|
||||
maybe_apply_chat_template,
|
||||
fn_kwargs={"tokenizer": processing_class},
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Applying chat template to train dataset",
|
||||
)
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
maybe_extract_prompt, num_proc=args.dataset_num_proc, desc="Extracting prompt from eval dataset"
|
||||
)
|
||||
eval_dataset = maybe_unpair_preference_dataset(
|
||||
eval_dataset, args.dataset_num_proc, desc="Unpairing eval dataset"
|
||||
)
|
||||
eval_dataset = eval_dataset.map(
|
||||
maybe_apply_chat_template,
|
||||
fn_kwargs={"tokenizer": processing_class},
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Applying chat template to eval dataset",
|
||||
)
|
||||
|
||||
# Tokenize and prepare the training datasets
|
||||
train_dataset = train_dataset.map(
|
||||
_tokenize,
|
||||
batched=True,
|
||||
fn_kwargs={"tokenizer": self.processing_class},
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Tokenizing train dataset",
|
||||
)
|
||||
|
||||
fn_kwargs = {
|
||||
"prefix": "",
|
||||
"is_encoder_decoder": self.is_encoder_decoder,
|
||||
"tokenizer": self.processing_class,
|
||||
"max_length": self.max_length,
|
||||
"truncation_mode": self.truncation_mode,
|
||||
"label_pad_token_id": self.label_pad_token_id,
|
||||
"max_prompt_length": self.max_prompt_length,
|
||||
"max_completion_length": self.max_completion_length,
|
||||
}
|
||||
|
||||
train_dataset = train_dataset.map(
|
||||
_process_tokens,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Processing tokenized train dataset",
|
||||
)
|
||||
|
||||
# Tokenize and prepare the eval datasets
|
||||
if eval_dataset is not None:
|
||||
eval_dataset = eval_dataset.map(
|
||||
_tokenize,
|
||||
fn_kwargs={"tokenizer": self.processing_class},
|
||||
batched=True,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Tokenizing eval dataset",
|
||||
)
|
||||
|
||||
eval_dataset = eval_dataset.map(
|
||||
_process_tokens,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Processing tokenized eval dataset",
|
||||
)
|
||||
|
||||
# Get KL datasets if needed
|
||||
if self.calculate_KL:
|
||||
if args.per_device_train_batch_size <= 1:
|
||||
raise ValueError(
|
||||
"Actual (not effective) batch size must be > 1. KTO will not work properly because the KL term will be equivalent to the implied reward."
|
||||
)
|
||||
|
||||
# create pairs for estimating the KL term by flipping the matched pairs in each batch of size total_batch_size
|
||||
# i.e., (x_1, y_1), ..., (x_n, y_n) --> (x_1, y_n), ..., (x_n, y_1) = (x'_1, y'_1), ..., (x'_n, y'_n)
|
||||
train_kl_dataset = train_dataset.map(
|
||||
_get_kl_dataset,
|
||||
batched=True,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Extracting KL train dataset",
|
||||
)
|
||||
|
||||
fn_kwargs["prefix"] = "KL_"
|
||||
train_kl_dataset = train_kl_dataset.map(
|
||||
_process_tokens,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=[c for c in train_kl_dataset.column_names if c in train_dataset.column_names],
|
||||
desc="Processing tokenized train KL dataset",
|
||||
)
|
||||
|
||||
# merge the datasets
|
||||
train_dataset = concatenate_datasets([train_dataset, train_kl_dataset], axis=1)
|
||||
|
||||
if eval_dataset is not None:
|
||||
# Get KL dataset
|
||||
eval_kl_dataset = eval_dataset.map(
|
||||
_get_kl_dataset,
|
||||
batched=True,
|
||||
batch_size=args.per_device_train_batch_size,
|
||||
num_proc=args.dataset_num_proc,
|
||||
desc="Extracting eval KL dataset",
|
||||
)
|
||||
|
||||
eval_kl_dataset = eval_kl_dataset.map(
|
||||
_process_tokens,
|
||||
fn_kwargs=fn_kwargs,
|
||||
num_proc=args.dataset_num_proc,
|
||||
remove_columns=[c for c in eval_kl_dataset.column_names if c in eval_dataset.column_names],
|
||||
desc="Processing tokenized eval KL dataset",
|
||||
)
|
||||
|
||||
# merge the datasets
|
||||
eval_dataset = concatenate_datasets([eval_dataset, eval_kl_dataset], axis=1)
|
||||
|
||||
# calculate dataset desirability balance
|
||||
num_desirable = max(sum(train_dataset["label"]), 1)
|
||||
num_undesirable = max(len(train_dataset["label"]) - num_desirable, 1) # "label" is binary
|
||||
|
||||
if num_desirable != num_undesirable:
|
||||
# The lower and upper bounds come from Eq. (8) of https://huggingface.co/papers/2402.01306
|
||||
des_weight_lower_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1, 2)
|
||||
des_weight_upper_bound = round((num_undesirable * self.undesirable_weight / num_desirable) * 1.33, 2)
|
||||
und_weight_lower_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1.33, 2)
|
||||
und_weight_upper_bound = round((num_desirable * self.desirable_weight / num_undesirable) / 1, 2)
|
||||
|
||||
des_weight_in_range = des_weight_lower_bound <= self.desirable_weight <= des_weight_upper_bound
|
||||
und_weight_in_range = und_weight_lower_bound <= self.undesirable_weight <= und_weight_upper_bound
|
||||
|
||||
if not (des_weight_in_range or und_weight_in_range):
|
||||
warnings.warn(
|
||||
"You have different amounts of desirable/positive and undesirable/negative examples but the "
|
||||
"weights on the desirable and undesirable losses don't seem to be in an ideal range. Based "
|
||||
f"on your data, we recommend EITHER "
|
||||
f"desirable_weight in [{des_weight_lower_bound}, {des_weight_upper_bound}] or "
|
||||
f"undesirable_weight in [{und_weight_lower_bound}, {und_weight_upper_bound}] (but NOT BOTH). "
|
||||
"See the documentation on how to optimally set these weights.",
|
||||
UserWarning,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
args=args,
|
||||
data_collator=data_collator,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
processing_class=processing_class,
|
||||
model_init=model_init,
|
||||
compute_metrics=compute_metrics,
|
||||
callbacks=callbacks,
|
||||
optimizers=optimizers,
|
||||
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
|
||||
)
|
||||
|
||||
# Gradient accumulation requires scaled loss. Normally, loss scaling in the parent class depends on whether the
|
||||
# model accepts loss-related kwargs. Since we compute our own loss, this check is irrelevant. We set
|
||||
# self.model_accepts_loss_kwargs to False to enable scaling.
|
||||
self.model_accepts_loss_kwargs = False
|
||||
|
||||
# Add tags for models that have been loaded with the correct transformers version
|
||||
if hasattr(self.model, "add_model_tags"):
|
||||
self.model.add_model_tags(self._tag_names)
|
||||
|
||||
if not hasattr(self, "accelerator"):
|
||||
raise AttributeError(
|
||||
"Your `Trainer` does not have an `accelerator` object. Consider upgrading `transformers`."
|
||||
)
|
||||
|
||||
# Deepspeed Zero-3 does not support precompute_ref_log_probs
|
||||
if self.is_deepspeed_enabled:
|
||||
if self.accelerator.state.deepspeed_plugin.zero_stage == 3 and self.precompute_ref_log_probs:
|
||||
raise ValueError(
|
||||
"You cannot use `precompute_ref_log_probs=True` with Deepspeed ZeRO-3. Please set `precompute_ref_log_probs=False`."
|
||||
)
|
||||
|
||||
if self.ref_model is None:
|
||||
if not (self.is_peft_model or self.precompute_ref_log_probs):
|
||||
raise ValueError(
|
||||
"No reference model and model is not a Peft model. Try setting `precompute_ref_log_probs=True`"
|
||||
)
|
||||
else:
|
||||
if self.is_deepspeed_enabled:
|
||||
self.ref_model = self._prepare_deepspeed(self.ref_model)
|
||||
else:
|
||||
self.ref_model = self.accelerator.prepare_model(self.ref_model, evaluation_mode=True)
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Module with Pydantic models for configuration."""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
import logging
|
||||
@@ -1679,30 +1678,6 @@ class AxolotlInputConfig(
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_rl_config_gradient_checkpointing(cls, data):
|
||||
# TODO: SalmanMohammadi
|
||||
# Distributed RL with QLoRA + gradient checkpointing
|
||||
# and use_reentrant = True is broken upstream in TRL
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
data.get("rl")
|
||||
and data.get("gradient_checkpointing")
|
||||
and data.get("gradient_checkpointing_kwargs")
|
||||
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
|
||||
and data.get("load_in_4bit")
|
||||
and data.get("adapter") == "qlora"
|
||||
and data.get("capabilities")
|
||||
and data.get("capabilities").get("n_gpu", 1) > 1
|
||||
):
|
||||
raise ValueError(
|
||||
"The `use_reentrant: True` implementation of gradient checkpointing "
|
||||
"is not supported for distributed RL training with QLoRA. Please set "
|
||||
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_kto_config(cls, data):
|
||||
@@ -1713,6 +1688,15 @@ class AxolotlInputConfig(
|
||||
if data.get("remove_unused_columns") is not False:
|
||||
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
||||
|
||||
if data.get("gradient_checkpointing") and not (
|
||||
data.get("gradient_checkpointing_kwargs")
|
||||
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
|
||||
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
|
||||
):
|
||||
raise ValueError(
|
||||
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@@ -1843,14 +1827,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
||||
data["torch_compile"] = False
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_beta_and_trl_beta_match(cls, data):
|
||||
if data.get("beta") and data.get("trl", {}).get("beta"):
|
||||
if data["beta"] != data["trl"]["beta"]:
|
||||
raise ValueError("beta and trl.beta must match or one must be removed")
|
||||
return data
|
||||
|
||||
|
||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||
"""
|
||||
|
||||
@@ -121,6 +121,7 @@ def drop_long_rl_seq(
|
||||
|
||||
|
||||
def load_prepare_preference_datasets(cfg):
|
||||
import pdb; pdb.set_trace()
|
||||
def load_split(dataset_cfgs, _cfg):
|
||||
split_datasets: List[Any] = []
|
||||
use_auth_token = _cfg.hf_use_auth_token
|
||||
|
||||
@@ -24,6 +24,7 @@ from peft import (
|
||||
PeftModelForCausalLM,
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from torch import nn
|
||||
from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
@@ -1359,7 +1360,7 @@ def load_llama_adapter(model, cfg):
|
||||
|
||||
|
||||
def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear)
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
|
||||
Reference in New Issue
Block a user