Compare commits
1 Commits
fix_kto
...
kd-logprob
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8fc4c420a4 |
@@ -55,7 +55,6 @@ Features:
|
|||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
# Download example axolotl configs, deepspeed configs
|
||||||
|
|||||||
@@ -32,9 +32,8 @@ website:
|
|||||||
contents:
|
contents:
|
||||||
- docs/getting-started.qmd
|
- docs/getting-started.qmd
|
||||||
- docs/installation.qmd
|
- docs/installation.qmd
|
||||||
- docs/inference.qmd
|
|
||||||
- docs/cli.qmd
|
- docs/cli.qmd
|
||||||
- docs/config.qmd
|
- docs/inference.qmd
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -75,6 +74,10 @@ website:
|
|||||||
- docs/debugging.qmd
|
- docs/debugging.qmd
|
||||||
- docs/nccl.qmd
|
- docs/nccl.qmd
|
||||||
|
|
||||||
|
- section: "Reference"
|
||||||
|
contents:
|
||||||
|
- docs/config.qmd
|
||||||
|
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
theme: darkly
|
theme: darkly
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Config Reference
|
title: Config options
|
||||||
description: A complete list of all configuration options.
|
description: A complete list of all configuration options.
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -30,8 +30,6 @@ tokenizer_legacy:
|
|||||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||||
# This is reported to improve training speed on some models
|
# This is reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
|
||||||
shrink_embeddings:
|
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -207,46 +205,10 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
# use RL training: 'dpo', 'ipo', 'kto'
|
||||||
rl:
|
rl:
|
||||||
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
# whether to perform weighting if doing DPO training. Boolean.
|
||||||
|
dpo_use_weighting:
|
||||||
# dpo
|
|
||||||
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
|
||||||
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
|
||||||
|
|
||||||
# orpo
|
|
||||||
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
|
||||||
|
|
||||||
# kto
|
|
||||||
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
|
||||||
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
|
||||||
|
|
||||||
# simpo
|
|
||||||
cpo_alpha: 1.0 # Weight of the BC regularizer
|
|
||||||
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
|
||||||
|
|
||||||
# grpo
|
|
||||||
trl:
|
|
||||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
|
||||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
|
||||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
|
||||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
|
||||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
|
||||||
|
|
||||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
|
||||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
|
||||||
|
|
||||||
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
|
||||||
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
|
||||||
|
|
||||||
num_generations: # Optional[int]. Number of generations to sample.
|
|
||||||
log_completions: # Optional[bool]. Whether to log completions.
|
|
||||||
|
|
||||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
|
||||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
|
||||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
|
||||||
|
|
||||||
|
|
||||||
# reward modelling: `True` or `False`
|
# reward modelling: `True` or `False`
|
||||||
reward_model:
|
reward_model:
|
||||||
@@ -270,7 +232,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
|
|||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
# Push prepared dataset to hub
|
# Push prepared dataset to hub
|
||||||
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
push_dataset_to_hub: # repo path
|
||||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||||
# if not set.
|
# if not set.
|
||||||
dataset_processes: # defaults to os.cpu_count() if not set
|
dataset_processes: # defaults to os.cpu_count() if not set
|
||||||
|
|||||||
10
docs/faq.qmd
10
docs/faq.qmd
@@ -27,16 +27,6 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||||
|
|
||||||
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
|
|
||||||
|
|
||||||
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
|
|
||||||
|
|
||||||
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
|
|
||||||
|
|
||||||
**Q: How to call Axolotl via custom python scripts?**
|
|
||||||
|
|
||||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -36,9 +36,7 @@ The YAML configuration file controls everything about your training. Here's what
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Llama-3.2-1B
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
|
# hub_model_id: username/custom_model_name
|
||||||
load_in_8bit: true
|
|
||||||
adapter: lora
|
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
@@ -46,15 +44,11 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
|
adapter: lora
|
||||||
|
lora_model_dir:
|
||||||
```
|
```
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
|
|
||||||
|
|
||||||
- To perform Full finetuning, remove these two lines.
|
|
||||||
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
|
||||||
:::
|
|
||||||
|
|
||||||
See our [Config options](config.qmd) for more details.
|
See our [Config options](config.qmd) for more details.
|
||||||
|
|
||||||
### Training {#sec-training}
|
### Training {#sec-training}
|
||||||
@@ -62,7 +56,7 @@ See our [Config options](config.qmd) for more details.
|
|||||||
When you run `axolotl train`, Axolotl:
|
When you run `axolotl train`, Axolotl:
|
||||||
|
|
||||||
1. Downloads the base model
|
1. Downloads the base model
|
||||||
2. (If specified) applies QLoRA/LoRA adapter layers
|
2. (If specified) applies LoRA adapter layers
|
||||||
3. Loads and processes the dataset
|
3. Loads and processes the dataset
|
||||||
4. Runs the training loop
|
4. Runs the training loop
|
||||||
5. Saves the trained model and / or LoRA weights
|
5. Saves the trained model and / or LoRA weights
|
||||||
@@ -75,8 +69,6 @@ Let's modify the example for your own data:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
@@ -112,6 +104,8 @@ format):
|
|||||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
||||||
|
|
||||||
3. Run the training:
|
3. Run the training:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "Inference and Merging"
|
title: "Inference"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
@@ -9,14 +9,10 @@ execute:
|
|||||||
enabled: false
|
enabled: false
|
||||||
---
|
---
|
||||||
|
|
||||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
|
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
||||||
|
|
||||||
## Quick Start {#sec-quickstart}
|
## Quick Start {#sec-quickstart}
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Use the same config used for training on inference/merging.
|
|
||||||
:::
|
|
||||||
|
|
||||||
### Basic Inference {#sec-basic}
|
### Basic Inference {#sec-basic}
|
||||||
|
|
||||||
::: {.panel-tabset}
|
::: {.panel-tabset}
|
||||||
|
|||||||
@@ -22,7 +22,6 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### PyPI Installation (Recommended) {#sec-pypi}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -38,7 +37,7 @@ For the latest features between releases:
|
|||||||
```{.bash}
|
```{.bash}
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
pip3 install packaging ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -108,7 +107,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
|||||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||||
3. Install Axolotl:
|
3. Install Axolotl:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install -U packaging setuptools wheel ninja
|
pip3 install packaging
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
4. (Optional) Login to Hugging Face:
|
4. (Optional) Login to Hugging Face:
|
||||||
|
|||||||
@@ -66,10 +66,6 @@ logic to be compatible with more of them.
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
|
||||||
:::
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||||
|
|||||||
@@ -41,10 +41,6 @@ Bradley-Terry chat templates expect single-turn conversations in the following f
|
|||||||
|
|
||||||
### Process Reward Models (PRM)
|
### Process Reward Models (PRM)
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
|
|
||||||
:::
|
|
||||||
|
|
||||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||||
```yaml
|
```yaml
|
||||||
base_model: Qwen/Qwen2.5-3B
|
base_model: Qwen/Qwen2.5-3B
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### IPO
|
### IPO
|
||||||
|
|
||||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: ipo
|
rl: ipo
|
||||||
@@ -344,9 +344,8 @@ ORPO supports the following types with the following dataset format:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: kto
|
rl: kto
|
||||||
rl_beta: 0.1 # default
|
rl_beta: 0.5
|
||||||
kto_desirable_weight: 1.0 # default
|
kto_desirable_weight: 0.2
|
||||||
kto_undesirable_weight: 1.0 # default
|
|
||||||
|
|
||||||
remove_unused_columns: false
|
remove_unused_columns: false
|
||||||
|
|
||||||
@@ -498,10 +497,6 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### GRPO
|
### GRPO
|
||||||
|
|
||||||
::: {.callout-tip}
|
|
||||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
|
||||||
:::
|
|
||||||
|
|
||||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||||
|
|
||||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||||
@@ -545,19 +540,6 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
|
|||||||
|
|
||||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||||
|
|
||||||
### SimPO
|
|
||||||
|
|
||||||
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
rl: simpo
|
|
||||||
rl_beta: 0.1 # default in CPOTrainer
|
|
||||||
cpo_alpha: 1.0 # default in CPOTrainer
|
|
||||||
simpo_gamma: 0.5 # default in CPOTrainer
|
|
||||||
```
|
|
||||||
|
|
||||||
This method uses the same dataset format as [DPO](#dpo).
|
|
||||||
|
|
||||||
### Using local dataset files
|
### Using local dataset files
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ tf32: true
|
|||||||
|
|
||||||
gradient_checkpointing: true
|
gradient_checkpointing: true
|
||||||
gradient_checkpointing_kwargs:
|
gradient_checkpointing_kwargs:
|
||||||
use_reentrant: false
|
use_reentrant: true
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|||||||
@@ -751,8 +751,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
if self.cfg.kd_ce_alpha is not None:
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||||
|
if self.cfg.kd_ce_alpha_end is not None:
|
||||||
|
training_arguments_kwargs["kd_ce_alpha_end"] = self.cfg.kd_ce_alpha_end
|
||||||
if self.cfg.kd_alpha is not None:
|
if self.cfg.kd_alpha is not None:
|
||||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
||||||
|
if self.cfg.kd_alpha_end is not None:
|
||||||
|
training_arguments_kwargs["kd_alpha_end"] = self.cfg.kd_alpha_end
|
||||||
if self.cfg.kd_temperature is not None:
|
if self.cfg.kd_temperature is not None:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||||
if self.cfg.kd_zscore_base_temp is not None:
|
if self.cfg.kd_zscore_base_temp is not None:
|
||||||
|
|||||||
@@ -34,3 +34,12 @@ class KDPlugin(BasePlugin):
|
|||||||
|
|
||||||
return AxolotlKDTrainer
|
return AxolotlKDTrainer
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||||
|
callbacks = []
|
||||||
|
if cfg.kd_trainer:
|
||||||
|
from .callbacks import KDAlphaSchedulerCallback
|
||||||
|
|
||||||
|
callbacks.append(KDAlphaSchedulerCallback())
|
||||||
|
|
||||||
|
return callbacks
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ class KDArgs(BaseModel):
|
|||||||
float
|
float
|
||||||
] = None # loss coefficient for cross-entropy loss during KD
|
] = None # loss coefficient for cross-entropy loss during KD
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
|
kd_ce_alpha_end: Optional[float] = None # end value for kd_ce_alpha
|
||||||
|
kd_alpha_end: Optional[float] = None # end value for kd_alpha
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||||
kd_top_k_before_softmax: Optional[
|
kd_top_k_before_softmax: Optional[
|
||||||
|
|||||||
28
src/axolotl/integrations/kd/callbacks.py
Normal file
28
src/axolotl/integrations/kd/callbacks.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
|
class KDAlphaSchedulerCallback(TrainerCallback):
|
||||||
|
"""Callback to for scheduling KD alpha during training."""
|
||||||
|
|
||||||
|
def on_epoch_begin(
|
||||||
|
self, args, state, control, **kwargs # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if int(state.epoch) == 0:
|
||||||
|
state.kd_alpha = args.kd_alpha
|
||||||
|
state.kd_ce_alpha = args.kd_ce_alpha
|
||||||
|
elif int(state.epoch) == state.num_train_epochs - 1:
|
||||||
|
if args.kd_alpha_end is not None:
|
||||||
|
control.kd_alpha = args.kd_alpha_end
|
||||||
|
if args.kd_ce_alpha_end is not None:
|
||||||
|
control.kd_ce_alpha = args.kd_ce_alpha_end
|
||||||
|
else:
|
||||||
|
epoch_steps = state.num_train_epochs - 1
|
||||||
|
scale = int(state.epoch) / epoch_steps
|
||||||
|
if args.kd_alpha_end is not None:
|
||||||
|
control.kd_alpha = (
|
||||||
|
args.kd_alpha + (args.kd_alpha_end - args.kd_alpha) * scale
|
||||||
|
)
|
||||||
|
if args.kd_ce_alpha_end is not None:
|
||||||
|
control.kd_ce_alpha = (
|
||||||
|
args.kd_ce_alpha + (args.kd_ce_alpha_end - args.kd_ce_alpha) * scale
|
||||||
|
)
|
||||||
@@ -62,10 +62,16 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
Transform logprobs to target format for KD training
|
Transform logprobs to target format for KD training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
logprobs = sample.pop(self.logprobs_field)
|
if "target_logprobs" in sample.keys() and "target_token_ids" in sample.keys():
|
||||||
|
logprobs = sample.pop("target_logprobs")
|
||||||
|
token_ids = sample.pop("target_token_ids")
|
||||||
|
else:
|
||||||
|
logprobs = sample.pop(self.logprobs_field)
|
||||||
|
token_ids = [None] * len(logprobs)
|
||||||
|
|
||||||
target_seq_len = len(logprobs)
|
target_seq_len = len(logprobs)
|
||||||
input_seq_len = len(sample["input_ids"])
|
input_seq_len = len(sample["input_ids"])
|
||||||
input_padding_len = input_seq_len - target_seq_len
|
target_padding_len = input_seq_len - target_seq_len
|
||||||
# get non-zero top-k (prune None logprobs from vllm data step)
|
# get non-zero top-k (prune None logprobs from vllm data step)
|
||||||
top_k_vals = [
|
top_k_vals = [
|
||||||
len(logprobs[i])
|
len(logprobs[i])
|
||||||
@@ -82,11 +88,11 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
target_token_ids = []
|
target_token_ids = []
|
||||||
target_mask = []
|
target_mask = []
|
||||||
|
|
||||||
if input_padding_len < 0:
|
if target_padding_len < 0:
|
||||||
# logprobs is longer than target_seq_len,
|
# logprobs is longer than target_seq_len,
|
||||||
# so we need to slice from the left/beginning of logprobs
|
# so we need to slice from the left/beginning of logprobs
|
||||||
logprobs = logprobs[:-input_seq_len]
|
logprobs = logprobs[:-input_seq_len]
|
||||||
input_padding_len = 0
|
target_padding_len = 0
|
||||||
# target_seq_len = input_seq_len
|
# target_seq_len = input_seq_len
|
||||||
|
|
||||||
# truncate the second dimension of the logprobs to top_k
|
# truncate the second dimension of the logprobs to top_k
|
||||||
@@ -98,33 +104,37 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
|
||||||
# otherwise, we need to shift in the trainer
|
# otherwise, we need to shift in the trainer
|
||||||
shift = 0
|
shift = 0
|
||||||
for _ in range(shift, input_padding_len):
|
for _ in range(shift, target_padding_len):
|
||||||
target_logprobs.append([-float("inf")] * top_k)
|
target_logprobs.append([-float("inf")] * top_k)
|
||||||
target_token_ids.append(list(range(top_k)))
|
target_token_ids.append(list(range(top_k)))
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
|
|
||||||
for position in range(input_padding_len, input_seq_len):
|
for position in range(target_padding_len, input_seq_len):
|
||||||
if sample["labels"][position] == -100:
|
if sample["labels"][position] == -100:
|
||||||
target_mask.append([0] * top_k)
|
target_mask.append([0] * top_k)
|
||||||
else:
|
else:
|
||||||
target_mask.append([1] * top_k)
|
target_mask.append([1] * top_k)
|
||||||
|
|
||||||
for _, token_pos_logprobs in enumerate(logprobs):
|
for token_pos_logprobs, token_pos_token_ids in zip(logprobs, token_ids):
|
||||||
# Initialize collections for logprobs and token_ids
|
# Initialize collections for logprobs and token_ids
|
||||||
position_logprobs = []
|
position_logprobs = []
|
||||||
position_token_ids = []
|
position_token_ids = []
|
||||||
|
|
||||||
# Process each token probability entry
|
# Process each token probability entry
|
||||||
for entry in token_pos_logprobs:
|
if token_pos_token_ids is None:
|
||||||
# Extract logprob value
|
for entry in token_pos_logprobs:
|
||||||
logprob = entry["logprob"]
|
# Extract logprob value
|
||||||
|
logprob = entry["logprob"]
|
||||||
|
|
||||||
# Parse token_id from the "token_id:###" format
|
# Parse token_id from the "token_id:###" format
|
||||||
token_id = int(entry["token"].split(":")[1])
|
token_id = int(entry["token"].split(":")[1])
|
||||||
|
|
||||||
# Append to our collections
|
# Append to our collections
|
||||||
position_logprobs.append(logprob)
|
position_logprobs.append(logprob)
|
||||||
position_token_ids.append(token_id)
|
position_token_ids.append(token_id)
|
||||||
|
else:
|
||||||
|
position_logprobs = token_pos_logprobs
|
||||||
|
position_token_ids = token_pos_token_ids
|
||||||
|
|
||||||
# Convert to a tensor for easier manipulation
|
# Convert to a tensor for easier manipulation
|
||||||
position_logprobs_tensor = torch.tensor(
|
position_logprobs_tensor = torch.tensor(
|
||||||
@@ -143,6 +153,7 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
teacher_probs_t2 = teacher_probs_t1**exponent
|
teacher_probs_t2 = teacher_probs_t1**exponent
|
||||||
else:
|
else:
|
||||||
teacher_probs_t2 = teacher_probs_t1
|
teacher_probs_t2 = teacher_probs_t1
|
||||||
|
|
||||||
# Re-normalize
|
# Re-normalize
|
||||||
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
|
||||||
dim=0, keepdim=True
|
dim=0, keepdim=True
|
||||||
|
|||||||
@@ -16,17 +16,35 @@
|
|||||||
KD trainer
|
KD trainer
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from transformers import TrainerControl
|
||||||
|
|
||||||
from axolotl.core.trainers.base import AxolotlTrainer
|
from axolotl.core.trainers.base import AxolotlTrainer
|
||||||
|
|
||||||
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
from .topk_logprob.forward_kl import loss as topk_kd_loss
|
||||||
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlKDTrainerControl(TrainerControl):
|
||||||
|
kd_alpha: float = 1.0
|
||||||
|
kd_ce_alpha: float = 0.0
|
||||||
|
|
||||||
|
def state(self) -> dict:
|
||||||
|
state_val = super().state()
|
||||||
|
state_val["args"]["kd_alpha"] = self.kd_alpha
|
||||||
|
state_val["args"]["kd_ce_alpha"] = self.kd_ce_alpha
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKDTrainer(AxolotlTrainer):
|
class AxolotlKDTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Custom trainer subclass for Knowledge Distillation (KD)
|
Custom trainer subclass for Knowledge Distillation (KD)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.kd_alpha = self.args.kd_alpha
|
||||||
|
self.kd_ce_alpha = self.args.kd_ce_alpha
|
||||||
|
self.control = AxolotlKDTrainerControl()
|
||||||
|
|
||||||
def _set_signature_columns_if_needed(self):
|
def _set_signature_columns_if_needed(self):
|
||||||
super()._set_signature_columns_if_needed()
|
super()._set_signature_columns_if_needed()
|
||||||
columns_to_add = []
|
columns_to_add = []
|
||||||
@@ -95,9 +113,8 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.kd_ce_alpha > 0:
|
if self.kd_ce_alpha > 0:
|
||||||
kd_alpha = self.args.kd_alpha
|
loss = self.kd_ce_alpha * outputs["loss"] + self.kd_alpha * loss_kd
|
||||||
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
|
|
||||||
else:
|
else:
|
||||||
loss = loss_kd
|
loss = loss_kd
|
||||||
# Save past state if it exists
|
# Save past state if it exists
|
||||||
|
|||||||
@@ -813,6 +813,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
)
|
)
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
|
# TODO if using deepspeed and it's a file, save deepspeed config too
|
||||||
|
if args.deepspeed and os.path.isfile(args.deepspeed):
|
||||||
|
LOG.info(f"DeepSpeed config has been saved to the WandB run.")
|
||||||
|
artifact = wandb.Artifact(
|
||||||
|
f"deepspeed-{wandb.run.id}", type="deepspeed-config"
|
||||||
|
)
|
||||||
|
artifact.add_file(args.deepspeed)
|
||||||
|
wandb.log_artifact(artifact)
|
||||||
|
wandb.save(args.deepspeed)
|
||||||
return control
|
return control
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -173,10 +173,16 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
]
|
]
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
else:
|
else:
|
||||||
arrays = [
|
try:
|
||||||
np.array(item[feature]) for item in features_ if feature in item
|
arrays = [
|
||||||
]
|
np.array(item[feature])
|
||||||
out_features[i][feature] = np.concatenate(arrays)
|
for item in features_
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
if arrays[0].dtype != "object":
|
||||||
|
out_features[i][feature] = np.concatenate(arrays)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
return super().__call__(out_features, return_tensors=return_tensors)
|
return super().__call__(out_features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Module with Pydantic models for configuration."""
|
"""Module with Pydantic models for configuration."""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -1679,30 +1678,6 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_rl_config_gradient_checkpointing(cls, data):
|
|
||||||
# TODO: SalmanMohammadi
|
|
||||||
# Distributed RL with QLoRA + gradient checkpointing
|
|
||||||
# and use_reentrant = True is broken upstream in TRL
|
|
||||||
# pylint: disable=too-many-boolean-expressions
|
|
||||||
if (
|
|
||||||
data.get("rl")
|
|
||||||
and data.get("gradient_checkpointing")
|
|
||||||
and data.get("gradient_checkpointing_kwargs")
|
|
||||||
and data.get("gradient_checkpointing_kwargs").get("use_reentrant")
|
|
||||||
and data.get("load_in_4bit")
|
|
||||||
and data.get("adapter") == "qlora"
|
|
||||||
and data.get("capabilities")
|
|
||||||
and data.get("capabilities").get("n_gpu", 1) > 1
|
|
||||||
):
|
|
||||||
raise ValueError(
|
|
||||||
"The `use_reentrant: True` implementation of gradient checkpointing "
|
|
||||||
"is not supported for distributed RL training with QLoRA. Please set "
|
|
||||||
"`use_reentrant: False` in `gradient_checkpointing_kwargs`."
|
|
||||||
)
|
|
||||||
return data
|
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_kto_config(cls, data):
|
def check_kto_config(cls, data):
|
||||||
@@ -1713,6 +1688,15 @@ class AxolotlInputConfig(
|
|||||||
if data.get("remove_unused_columns") is not False:
|
if data.get("remove_unused_columns") is not False:
|
||||||
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
raise ValueError("Set `remove_unused_columns: False` when using kto")
|
||||||
|
|
||||||
|
if data.get("gradient_checkpointing") and not (
|
||||||
|
data.get("gradient_checkpointing_kwargs")
|
||||||
|
and isinstance(data.get("gradient_checkpointing_kwargs"), dict)
|
||||||
|
and data["gradient_checkpointing_kwargs"].get("use_reentrant")
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Set `gradient_checkpointing_kwargs: {use_reentrant: true}` for when kto is enabled"
|
||||||
|
)
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
@@ -1843,14 +1827,6 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data["torch_compile"] = False
|
data["torch_compile"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@model_validator(mode="before")
|
|
||||||
@classmethod
|
|
||||||
def check_beta_and_trl_beta_match(cls, data):
|
|
||||||
if data.get("beta") and data.get("trl", {}).get("beta"):
|
|
||||||
if data["beta"] != data["trl"]["beta"]:
|
|
||||||
raise ValueError("beta and trl.beta must match or one must be removed")
|
|
||||||
return data
|
|
||||||
|
|
||||||
|
|
||||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -25,8 +25,8 @@ def fixture_cfg():
|
|||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"sequence_len": 2048,
|
"sequence_len": 2048,
|
||||||
"rl": True,
|
"rl": True,
|
||||||
"adam_beta1": 0.998,
|
"adam_beta1": 0.91,
|
||||||
"adam_beta2": 0.9,
|
"adam_beta2": 0.998,
|
||||||
"adam_epsilon": 0.00001,
|
"adam_epsilon": 0.00001,
|
||||||
"dataloader_num_workers": 1,
|
"dataloader_num_workers": 1,
|
||||||
"dataloader_pin_memory": True,
|
"dataloader_pin_memory": True,
|
||||||
@@ -60,8 +60,8 @@ class TestHFRLTrainerBuilder:
|
|||||||
def test_build_training_arguments(self, cfg, model, tokenizer):
|
def test_build_training_arguments(self, cfg, model, tokenizer):
|
||||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||||
training_arguments = builder.build_training_arguments(100)
|
training_arguments = builder.build_training_arguments(100)
|
||||||
assert training_arguments.adam_beta1 == 0.998
|
assert training_arguments.adam_beta1 == 0.91
|
||||||
assert training_arguments.adam_beta2 == 0.9
|
assert training_arguments.adam_beta2 == 0.998
|
||||||
assert training_arguments.adam_epsilon == 0.00001
|
assert training_arguments.adam_epsilon == 0.00001
|
||||||
assert training_arguments.dataloader_num_workers == 1
|
assert training_arguments.dataloader_num_workers == 1
|
||||||
assert training_arguments.dataloader_pin_memory is True
|
assert training_arguments.dataloader_pin_memory is True
|
||||||
|
|||||||
Reference in New Issue
Block a user