Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
8fc4c420a4 Add kd coefficient scheduler 2025-03-18 09:01:58 -04:00
21 changed files with 146 additions and 167 deletions

View File

@@ -55,7 +55,6 @@ Features:
### Installation
```bash
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs

View File

@@ -32,9 +32,8 @@ website:
contents:
- docs/getting-started.qmd
- docs/installation.qmd
- docs/inference.qmd
- docs/cli.qmd
- docs/config.qmd
- docs/inference.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
@@ -75,6 +74,10 @@ website:
- docs/debugging.qmd
- docs/nccl.qmd
- section: "Reference"
contents:
- docs/config.qmd
format:
html:
theme: darkly

View File

@@ -1,5 +1,5 @@
---
title: Config Reference
title: Config options
description: A complete list of all configuration options.
---
@@ -30,8 +30,6 @@ tokenizer_legacy:
# Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# (Internal use only)
# Used to identify which the model is based on
@@ -207,46 +205,10 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
# use RL training: 'dpo', 'ipo', 'kto'
rl:
rl_beta: # Optional[float]. The beta parameter for the RL training.
# dpo
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
# orpo
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
# kto
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
# simpo
cpo_alpha: 1.0 # Weight of the BC regularizer
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# reward modelling: `True` or `False`
reward_model:
@@ -270,7 +232,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub
push_dataset_to_hub: # Optional[str] repo_org/repo_name
push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set

View File

@@ -27,16 +27,6 @@ description: Frequently asked questions
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
**Q: How to call Axolotl via custom python scripts?**
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -36,9 +36,7 @@ The YAML configuration file controls everything about your training. Here's what
```yaml
base_model: NousResearch/Llama-3.2-1B
load_in_8bit: true
adapter: lora
# hub_model_id: username/custom_model_name
datasets:
- path: teknium/GPT4-LLM-Cleaned
@@ -46,15 +44,11 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
```
::: {.callout-tip}
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
- To perform Full finetuning, remove these two lines.
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
:::
See our [Config options](config.qmd) for more details.
### Training {#sec-training}
@@ -62,7 +56,7 @@ See our [Config options](config.qmd) for more details.
When you run `axolotl train`, Axolotl:
1. Downloads the base model
2. (If specified) applies QLoRA/LoRA adapter layers
2. (If specified) applies LoRA adapter layers
3. Loads and processes the dataset
4. Runs the training loop
5. Saves the trained model and / or LoRA weights
@@ -75,8 +69,6 @@ Let's modify the example for your own data:
```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1
load_in_8bit: true
adapter: lora
# Training settings
@@ -112,6 +104,8 @@ format):
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training:
```bash

View File

@@ -1,5 +1,5 @@
---
title: "Inference and Merging"
title: "Inference"
format:
html:
toc: true
@@ -9,14 +9,10 @@ execute:
enabled: false
---
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
## Quick Start {#sec-quickstart}
::: {.callout-tip}
Use the same config used for training on inference/merging.
:::
### Basic Inference {#sec-basic}
::: {.panel-tabset}

View File

@@ -22,7 +22,6 @@ This guide covers all the ways you can install and set up Axolotl for your envir
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
@@ -38,7 +37,7 @@ For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install -U packaging setuptools wheel ninja
pip3 install packaging ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
@@ -108,7 +107,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install packaging
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:

View File

@@ -66,10 +66,6 @@ logic to be compatible with more of them.
</details>
::: {.callout-tip}
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
:::
## Usage
These optimizations can be enabled in your Axolotl config YAML file. The

View File

@@ -41,10 +41,6 @@ Bradley-Terry chat templates expect single-turn conversations in the following f
### Process Reward Models (PRM)
::: {.callout-tip}
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
:::
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml
base_model: Qwen/Qwen2.5-3B

View File

@@ -298,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### IPO
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
As IPO is just DPO with a different loss function, all supported options for DPO works here.
```yaml
rl: ipo
@@ -344,9 +344,8 @@ ORPO supports the following types with the following dataset format:
```yaml
rl: kto
rl_beta: 0.1 # default
kto_desirable_weight: 1.0 # default
kto_undesirable_weight: 1.0 # default
rl_beta: 0.5
kto_desirable_weight: 0.2
remove_unused_columns: false
@@ -498,10 +497,6 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
@@ -545,19 +540,6 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
```yaml
rl: simpo
rl_beta: 0.1 # default in CPOTrainer
cpo_alpha: 1.0 # default in CPOTrainer
simpo_gamma: 0.5 # default in CPOTrainer
```
This method uses the same dataset format as [DPO](#dpo).
### Using local dataset files
```yaml

View File

@@ -55,7 +55,7 @@ tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
use_reentrant: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -751,8 +751,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_ce_alpha_end is not None:
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_alpha_end is not None:
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:

View File

@@ -34,3 +34,12 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer
return None
def add_callbacks_post_trainer(self, cfg, trainer):
callbacks = []
if cfg.kd_trainer:
from .callbacks import KDAlphaSchedulerCallback
callbacks.append(KDAlphaSchedulerCallback())
return callbacks

View File

@@ -30,6 +30,8 @@ class KDArgs(BaseModel):
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
kd_alpha_end: Optional[float] = None # end value for kd_alpha
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[

View File

@@ -0,0 +1,28 @@
from transformers import TrainerCallback
class KDAlphaSchedulerCallback(TrainerCallback):
"""Callback to for scheduling KD alpha during training."""
def on_epoch_begin(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if int(state.epoch) == 0:
state.kd_alpha = args.kd_alpha
state.kd_ce_alpha = args.kd_ce_alpha
elif int(state.epoch) == state.num_train_epochs - 1:
if args.kd_alpha_end is not None:
control.kd_alpha = args.kd_alpha_end
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = args.kd_ce_alpha_end
else:
epoch_steps = state.num_train_epochs - 1
scale = int(state.epoch) / epoch_steps
if args.kd_alpha_end is not None:
control.kd_alpha = (
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
)
if args.kd_ce_alpha_end is not None:
control.kd_ce_alpha = (
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
)

View File

@@ -62,10 +62,16 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field)
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
logprobs = sample.pop("target_logprobs")
token_ids = sample.pop("target_token_ids")
else:
logprobs = sample.pop(self.logprobs_field)
token_ids = [None] * len(logprobs)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
target_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
@@ -82,11 +88,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_token_ids = []
target_mask = []
if input_padding_len < 0:
if target_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
target_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
@@ -98,33 +104,37 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
for _ in range(shift, target_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
for position in range(target_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
if token_pos_token_ids is None:
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
else:
position_logprobs = token_pos_logprobs
position_token_ids = token_pos_token_ids
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
@@ -143,6 +153,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True

View File

@@ -16,17 +16,35 @@
KD trainer
"""
from transformers import TrainerControl
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainerControl(TrainerControl):
kd_alpha: float = 1.0
kd_ce_alpha: float = 0.0
def state(self) -> dict:
state_val = super().state()
state_val["args"]["kd_alpha"] = self.kd_alpha
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.kd_alpha = self.args.kd_alpha
self.kd_ce_alpha = self.args.kd_ce_alpha
self.control = AxolotlKDTrainerControl()
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
@@ -95,9 +113,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
if self.kd_ce_alpha > 0:
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists

View File

@@ -813,6 +813,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
)
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
# TODO if using deepspeed and it's a file, save deepspeed config too
if args.deepspeed and os.path.isfile(args.deepspeed):
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
artifact = wandb.Artifact(
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
)
artifact.add_file(args.deepspeed)
wandb.log_artifact(artifact)
wandb.save(args.deepspeed)
return control

View File

@@ -173,10 +173,16 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
]
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
try:
arrays = [
np.array(item[feature])
for item in features_
if feature in item
]
if arrays[0].dtype != "object":
out_features[i][feature] = np.concatenate(arrays)
except ValueError:
pass
return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -1,5 +1,4 @@
"""Module with Pydantic models for configuration."""
# pylint: disable=too-many-lines
import logging
@@ -1679,30 +1678,6 @@ class AxolotlInputConfig(
return data
@model_validator(mode="before")
@classmethod
def check_rl_config_gradient_checkpointing(cls, data):
# TODO: SalmanMohammadi
# Distributed RL with QLoRA + gradient checkpointing
# and use_reentrant = True is broken upstream in TRL
# pylint: disable=too-many-boolean-expressions
if (
data.get("rl")
and data.get("gradient_checkpointing")
and data.get("gradient_checkpointing_kwargs")
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
and data.get("load_in_4bit")
and data.get("adapter") == "qlora"
and data.get("capabilities")
and data.get("capabilities").get("n_gpu", 1) > 1
):
raise ValueError(
"The `use_reentrant: True` implementation of gradient checkpointing "
"is not supported for distributed RL training with QLoRA. Please set "
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
)
return data
@model_validator(mode="before")
@classmethod
def check_kto_config(cls, data):
@@ -1713,6 +1688,15 @@ class AxolotlInputConfig(
if data.get("remove_unused_columns") is not False:
raise ValueError("Set `remove_unused_columns: False` when using kto")
if data.get("gradient_checkpointing") and not (
data.get("gradient_checkpointing_kwargs")
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
):
raise ValueError(
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
)
return data
@@ -1843,14 +1827,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
data["torch_compile"] = False
return data
@model_validator(mode="before")
@classmethod
def check_beta_and_trl_beta_match(cls, data):
if data.get("beta") and data.get("trl", {}).get("beta"):
if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed")
return data
def handle_legacy_message_fields_logic(data: dict) -> dict:
"""

View File

@@ -25,8 +25,8 @@ def fixture_cfg():
"optimizer": "adamw_torch_fused",
"sequence_len": 2048,
"rl": True,
"adam_beta1": 0.998,
"adam_beta2": 0.9,
"adam_beta1": 0.91,
"adam_beta2": 0.998,
"adam_epsilon": 0.00001,
"dataloader_num_workers": 1,
"dataloader_pin_memory": True,
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
def test_build_training_arguments(self, cfg, model, tokenizer):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
training_arguments = builder.build_training_arguments(100)
assert training_arguments.adam_beta1 == 0.998
assert training_arguments.adam_beta2 == 0.9
assert training_arguments.adam_beta1 == 0.91
assert training_arguments.adam_beta2 == 0.998
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True