Compare commits

...

17 Commits

Author SHA1 Message Date
Wing Lian
f24efd77a1 lint docs 2025-02-12 10:04:01 -05:00
Sung Ching Liu
44f64ab627 Update faq.qmd (#2319)
* Update faq.qmd

Added Q&A for being stuck on saving preprocessed datasets

* Update faq.qmd

added details on preprocessing on cpu

* Update faq.qmd

* Update faq.qmd
2025-02-11 13:18:31 -05:00
NanoCode012
826f1b1494 feat(doc): Add multi-node torchrun info (#2304) 2025-02-08 06:02:02 -05:00
NanoCode012
526e5ee8b8 fix(config): missing config not being documented and fix model_ override (#2317)
* fix(config): missing config not being documented and fix model_ space override

* fix: delete redundant field
2025-02-08 06:01:48 -05:00
NanoCode012
fd8cb32547 chore: remove redundant py310 from tests (#2316) 2025-02-07 21:34:16 -05:00
NanoCode012
e48e2df4dd feat: update FA to 2.7.4.post1 which includes torch2.6 binary (#2315) 2025-02-07 21:34:01 -05:00
Wing Lian
b7616022ab bump transformers to 4.48.3 (#2318) 2025-02-07 21:33:44 -05:00
Wing Lian
1faf1a5c5a batch add of spectrum snr results (#2320) 2025-02-07 21:33:14 -05:00
NanoCode012
5bbad5ef93 feat: add torch2.6 to ci (#2311) 2025-02-07 07:28:54 -05:00
Wing Lian
a971eb4ce6 Torch 2.6 support for base docker image (#2312) 2025-02-05 09:24:02 -05:00
NanoCode012
a620d481e2 fix: drop long seq even if not sample packing (#2211)
* fix: drop long seq even if not sample packing

* fix: logging import

* fix: cfg passed being none

* fix: try to fix logging

* fix: refactor call to not use accelerate log

* fix: try to fix circular import issue

* fix: don't drop when skip prepare

* chore: remove duplicate line

* fix: update warning to mention that sequences will be trimmed

* fix: do not drop seq if input_ids don't exist

* fix: increase RM unittest sequence length to reduce trim warnings

* fix: solve conflicts

* fix: default min_seq_len in case of None
2025-02-04 09:43:35 -05:00
Wing Lian
158330ab60 [feature] sweeps (#2171) 2025-02-01 21:11:18 -05:00
Wing Lian
80e1468b8d better handling of multipack dataset length (#2296) 2025-02-01 21:10:34 -05:00
Wing Lian
a20f17689b set MODAL_IMAGE_BUILDER_VERSION=2024.10 to 2024.10 to test latest builder (#2302)
* set MODAL_IMAGE_BUILDER_VERSION=2024.10 to 2024.10 to test latest builder

* chore: lint

* remove fastapi and pydantic extras
2025-01-31 20:19:20 -05:00
Wing Lian
78ce268848 KD Trainer w logprobs (#2303)
* refactor trainer to prevent circular dependencies later

fix loader default
KD dataset loading and KD with logprobs
filter bad rows
make batch smaller
handle padding/collation for KD datasets
make it work
flipped the slice
cross entropy loss coefficient during KD
make sure to multiply against the correct loss
chore: lint
triton wip
no where support
v2 trial
no torch.exp inside triton kernel
no log etc
no torch.tensor
v3
fix kwarg
don't use triton for now
better rescaling for temperatures
hash for temperature too
use kd_alpha in the correct loss method
fix kd loss so it's causal (fixes repeating tokens)
var naming and add todo
chore: lint
refactor so we can easily add new loss functions
add license block
remove references to triton kd for now
handle token/logprob shifting
support for custom trainer classes from plugins
refactor kd chat template loader
move more things to kd plugin
remove moved class from import
make plugin setup concise
increase logging around loading plugins
add copyrights
remove duplicate code
more info on preprocess for kd and fix import
be a bit pickier about loading dynamic prompt strategies
kd sample packing
make loss torch script compat
support streaming for processing sft datasts?
improve iterable support
ensure that batch vs single is done properly
tweak check for batched prompt data
reward can use same batch check
fix reward trainer calls for tokenization
improve check for batched
reward model doesn't work well with batched
add kd trainer e2e test
linting
rename test files so it gets picked up
make the kd e2e fit in vram for ci and add lora version
set lora_dropout explicitly
lower lr
make sure to set tokenizer from l3 70b and save safetensors
make sure to use the correct tokenizer
fix adapter model check
make sure to use tensorboard to capture loss for checks
chore: lint
chore: lint
improve logprob masking and shift in trainer
more fixes
try tests for kd on l40s
don't shift student logits for kd
no batching for kd chat templates
make sure to truncate logprobs if there are more than top_k
change up logic so we always truncate to top_k
use iter instead of tuple
fix finding the top-k rather than assuming first position has the correct val
apply z-score scaling to kd
kd loss needs to be calculated in full precision
Always re-normalize teacher distribution
various fixes

* support for configurable top-k/softmax ordering

* add attribute check for filter rows and lint

* fix logic

* handle none case for conversion to int

* fix student logit off by one

* set kd_temp to 1.0 for test loss

* address PR feedback
2025-01-31 20:18:52 -05:00
NanoCode012
d425d5d3c3 fix: add warning for invalid eval_steps or save_steps (#2298) 2025-01-31 08:58:25 -05:00
Wing Lian
cf17649ef3 Misc fixes 20250130 (#2301)
* misc fixes for garbage collection and L40S w NCCL P2P

* patch bnb fix for triton check

* chore: lint

* change up import

* try patching differently

* remove patch for bnb fix for now

* more verbose checks and tweak train loss threshold
2025-01-31 08:58:04 -05:00
72 changed files with 14378 additions and 1497 deletions

View File

@@ -22,12 +22,6 @@ jobs:
fail-fast: false
matrix:
include:
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.10"
pytorch: 2.4.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
@@ -40,6 +34,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "124"
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -19,7 +19,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
python-version: '3.11'
- name: install dependencies
run: |
python3 -m pip install jupyter

View File

@@ -19,6 +19,6 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1

View File

@@ -26,6 +26,11 @@ jobs:
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -34,6 +34,13 @@ jobs:
axolotl_extras:
num_gpus: 2
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
num_gpus: 2
nightly_build: "true"
runs-on: [self-hosted, modal]
timeout-minutes: 120
steps:
@@ -42,7 +49,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip

View File

@@ -22,6 +22,11 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -36,7 +36,7 @@ jobs:
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install dependencies
run: |

View File

@@ -12,7 +12,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
@@ -25,13 +25,8 @@ jobs:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -112,13 +107,20 @@ jobs:
num_gpus: 1
axolotl_extras:
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
nightly_build: "true"
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip

View File

@@ -35,7 +35,7 @@ jobs:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
cache: 'pip' # caching pip dependencies
- uses: pre-commit/action@v3.0.1
env:
@@ -48,13 +48,8 @@ jobs:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.10", "3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
exclude:
- python_version: "3.10"
pytorch_version: "2.4.1"
- python_version: "3.10"
pytorch_version: "2.5.1"
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -127,7 +122,7 @@ jobs:
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
pytorch_version: ["2.4.1", "2.5.1", "2.6.0"]
timeout-minutes: 20
steps:
@@ -207,7 +202,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
steps:
@@ -216,7 +211,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
@@ -228,6 +223,7 @@ jobs:
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
@@ -247,7 +243,13 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
num_gpus: 1
axolotl_extras:
steps:
@@ -256,7 +258,7 @@ jobs:
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
python-version: "3.11"
- name: Install Modal
run: |
python -m pip install --upgrade pip
@@ -268,6 +270,7 @@ jobs:
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "MODAL_IMAGE_BUILDER_VERSION=2024.10" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |

View File

@@ -51,7 +51,7 @@ Features:
**Requirements**:
- NVIDIA GPU (Ampere or newer for `bf16` and Flash Attention) or AMD GPU
- Python 3.10
- Python 3.11
- PyTorch ≥2.4.1
### Installation

View File

@@ -38,16 +38,12 @@ temp_dir = tempfile.mkdtemp()
with open(pathlib.Path(temp_dir) / "Dockerfile", "w", encoding="utf-8") as f:
f.write(dockerfile_contents)
cicd_image = (
Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
)
.env(df_args)
.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
)
cicd_image = Image.from_dockerfile(
pathlib.Path(temp_dir) / "Dockerfile",
context_mount=None,
force_build=True,
gpu="A10G",
).env(df_args)
app = App("Axolotl CI/CD", secrets=[])
@@ -59,7 +55,7 @@ VOLUME_CONFIG = {
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
GPU_CONFIG = modal.gpu.L40S(count=N_GPUS)
def run_cmd(cmd: str, run_folder: str):

View File

@@ -46,6 +46,10 @@ overrides_of_model_config:
type: # linear | dynamic
factor: # float
# optional overrides the base model loading from_pretrained
overrides_of_model_kwargs:
# use_cache: False
# optional overrides to the bnb 4bit quantization configuration
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
bnb_config_kwargs:

View File

@@ -23,4 +23,4 @@ Here's a simple example of a stepwise supervised dataset entry:
],
"labels": [true, false]
}
```
```

View File

@@ -19,3 +19,7 @@ description: Frequently asked questions
**Q: AttributeError: 'DummyOptim' object has no attribute 'step'**
> A: You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
**Q: The codes is stuck on saving preprocessed datasets.**
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.

View File

@@ -3,6 +3,18 @@ title: Multi Node
description: How to use Axolotl on multiple machines
---
The below are three ways to train multi-node in Axolotl.
::: {.callout-important}
Each machine needs a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
Make sure the main machine is reachable by other machines.
:::
# Accelerate
You will need to create a configuration for accelerate, either by using `accelerate config` and follow the instructions or you can use one of the preset below:
~/.cache/huggingface/accelerate/default_config.yaml
@@ -26,7 +38,7 @@ tpu_use_sudo: false
use_cpu: false
```
Configure your model to use FSDP with for example:
Configure your model to use FSDP in the Axolotl yaml. For example:
```yaml
fsdp:
- full_shard
@@ -37,12 +49,40 @@ fsdp_config:
fsdp_transformer_layer_cls_to_wrap: LlamaDecoderLayer
```
## Machine configuration
On each machine you need a copy of Axolotl, we suggest using the same commit to ensure compatibility.
You will also need to have the same configuration file for your model on each machine.
On the main machine only, make sure the port you set as `main_process_port` is open in TCP and reachable by other machines.
All you have to do now is launch using accelerate as you would usually do on each machine and voila, the processes will start once you have launched accelerate on every machine.
# Raytrain
Please see ray train doc [here](ray-integration.qmd).
# Torchrun
If you are using Infiniband, we recommend torchrun to utilize the full bandwidth.
Set the following env (change buffersize/socketname depending on your system):
```yaml
export NCCL_IB_DISABLE=0
export NCCL_SOCKET_IFNAME="eth0,en,eth,em,bond"
export NCCL_BUFFSIZE=2097152
```
Run the following on each node:
```bash
torchrun --nnodes $num_nodes --nproc_per_node $gpu_per_node --rdzv_id $rdzv_id --rdzv_backend c10d --rdzv_endpoint "$head_node_ip:$head_node_port" -m axolotl.cli.train config.yaml
```
Please make sure to substitute the placeholder variables.
- `num_nodes`: Number of nodes (containing GPUs)
- `gpu_per_node`: Number of gpus per node
- `head_node_ip`: IP of the head node (make sure other machines can connect to this)
- `head_node_port`: Port of the head node (make sure other machines can connect to this. Default 29400)
- `rdzv_id`: A unique job ID that is used by the job across nodes.
::: {.callout-note}
You need to call `axolotl.cli.train` instead of `axolotl train` as the latter calls accelerate under the hood
:::
More info on the available configs can be found on the Pytorch docs [here](https://pytorch.org/docs/stable/elastic/run.html)

View File

@@ -29,7 +29,7 @@ datasets:
type: chatml.intel
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: chatml.argilla
type: chatml
```
#### IPO

View File

@@ -1,10 +1,10 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.1
bitsandbytes==0.45.2
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
flash-attn==2.7.4.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.2
@@ -13,7 +13,7 @@ liger-kernel==0.5.2
packaging==23.2
peft==0.14.0
transformers==4.48.1
transformers==4.48.3
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0

View File

@@ -71,12 +71,15 @@ def parse_requirements():
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 5):
if (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers==0.0.29.post2")
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
_install_requires.append("xformers==0.0.28.post2")
else:
_install_requires.append("xformers==0.0.28.post3")
_install_requires.append("xformers==0.0.29")
_install_requires.pop(_install_requires.index(autoawq_version))
elif (major, minor) >= (2, 4):
if patch == 0:

View File

@@ -13,6 +13,12 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@dataclass

View File

@@ -1,10 +1,17 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import logging
import random
import subprocess # nosec B404
import tempfile
from copy import deepcopy
from itertools import product
from pathlib import Path
from typing import Optional
import click
import yaml
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
@@ -20,6 +27,76 @@ from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
def generate_sweep_configs(base_config, sweeps_config):
"""
Recursively generates all possible configurations by applying sweeps to the base config.
Args:
base_config (dict): The original configuration dictionary
sweeps_config (dict): Dictionary where keys are parameters and values are either:
- lists of values to sweep independently
- or for paired values, a list of dicts under the '_' key
Returns:
list: List of all possible configuration dictionaries
Example:
sweeps_config = {
'learning_rate': [0.1, 0.01],
'_': [
{'load_in_8bit': True, 'adapter': 'lora'},
{'load_in_4bit': True, 'adapter': 'qlora'}
]
}
"""
# Separate paired values from regular sweeps
paired_values = sweeps_config.get("_", [])
regular_sweeps = {k: v for k, v in sweeps_config.items() if k != "_"}
# Process regular sweeps
param_names = list(regular_sweeps.keys())
param_values = list(regular_sweeps.values())
# Generate combinations for regular sweeps
regular_combinations = list(product(*param_values)) if param_values else [()]
# Combine regular sweeps with paired values
all_combinations = []
for reg_combo in regular_combinations:
if paired_values:
for paired_set in paired_values:
new_config = {}
# new_config = deepcopy(base_config)
# Combine regular parameters with paired parameters
full_combo = {**dict(zip(param_names, reg_combo)), **paired_set}
for param_name, param_value in full_combo.items():
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
else:
# If no paired values, just use regular combinations
# new_config = deepcopy(base_config)
new_config = {}
for param_name, param_value in zip(param_names, reg_combo):
new_config[param_name] = param_value
print(new_config)
all_combinations.append(new_config)
# randomize the order of trials
random.seed(42)
random.shuffle(all_combinations)
# Generate a new config for each combination
result_configs = []
for combination in all_combinations:
new_config = deepcopy(base_config)
for param_name, param_value in combination.items():
new_config[param_name] = param_value
result_configs.append(new_config)
return result_configs
@click.group()
@click.version_option(version=axolotl.__version__, prog_name="axolotl")
def cli():
@@ -60,10 +137,21 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs) -> None:
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@click.option(
"--sweep",
type=click.Path(exists=True, path_type=str),
help="YAML config for sweeping hyperparameters",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs) -> None:
def train(
config: str,
accelerate: bool,
cloud: Optional[str] = None,
sweep: Optional[str] = None,
**kwargs,
) -> None:
"""
Train or fine-tune a model.
@@ -71,6 +159,7 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
cloud: Path to a cloud accelerator configuration file
sweep: Path to YAML config for sweeping hyperparameters.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
@@ -80,35 +169,66 @@ def train(config: str, accelerate: bool, cloud: Optional[str] = None, **kwargs)
if "use_ray" in kwargs and kwargs["use_ray"]:
accelerate = False
if sweep:
# load the sweep configuration yaml file
with open(sweep, "r", encoding="utf-8") as fin:
sweep_config: dict[str, list] = yaml.safe_load(fin)
with open(config, "r", encoding="utf-8") as fin:
base_config: dict[str, list] = yaml.safe_load(fin)
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
# generate all possible configurations
permutations = generate_sweep_configs(base_config, sweep_config)
def iter_configs():
for perm in permutations:
# open temp directory for temporary configurations
with tempfile.TemporaryDirectory() as temp_dir:
with open(
Path(temp_dir) / "config.yaml", "w", encoding="utf-8"
) as fout:
yaml.dump(perm, fout)
yield str(Path(temp_dir) / "config.yaml")
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
do_cli(config=config, **kwargs)
def iter_configs():
yield config
for cfg_file in iter_configs():
# handle errors from subprocess so we can continue rest of sweeps
try:
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
accelerate_args = []
if "main_process_port" in kwargs:
main_process_port = kwargs.pop("main_process_port", None)
accelerate_args.append("--main_process_port")
accelerate_args.append(str(main_process_port))
if "num_processes" in kwargs:
num_processes = kwargs.pop("num_processes", None)
accelerate_args.append("--num-processes")
accelerate_args.append(str(num_processes))
base_cmd = ["accelerate", "launch"]
base_cmd.extend(accelerate_args)
base_cmd.extend(["-m", "axolotl.cli.train"])
if cfg_file:
base_cmd.append(cfg_file)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
do_cli(config=cfg_file, **kwargs)
except subprocess.CalledProcessError as exc:
logging.error(f"Failed to train/fine-tune config '{cfg_file}': {exc}")
if not sweep:
raise exc
@cli.command()

View File

@@ -75,7 +75,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
def do_cli(
config: Union[Path, str] = Path("examples/"),
**kwargs,
) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_preprocess`.

View File

@@ -63,11 +63,17 @@ def load_datasets(
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,988 @@
"""
module for customized trainers
"""
from __future__ import annotations
# pylint: disable=too-many-lines
import gc
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Dict, Literal, Optional, Union
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import (
CPOTrainer,
DPOTrainer,
KTOTrainer,
ORPOTrainer,
PRMTrainer,
RewardTrainer,
)
from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger("axolotl.core.trainer_builder")
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if eval_dataset:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}
max_length = max(
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
)
# Concatenate positive and negative inputs
concatenated_batch["input_ids"] = pad_to_length(
inputs["input_ids"], max_length, pad_token
)
concatenated_batch["rejected_input_ids"] = pad_to_length(
inputs["rejected_input_ids"], max_length, pad_token
)
concatenated_batch["labels"] = pad_to_length(
inputs["labels"], max_length, label_pad_token
)
concatenated_batch["rejected_labels"] = pad_to_length(
inputs["rejected_labels"], max_length, label_pad_token
)
concatenated_batch["attention_mask"] = pad_to_length(
inputs["attention_mask"], max_length, 0
)
concatenated_batch["rejected_attention_mask"] = pad_to_length(
inputs["rejected_attention_mask"], max_length, 0
)
concatenated_batch["prompt_attention_mask"] = pad_to_length(
inputs["prompt_attention_mask"], max_length, 0
).to(device=device)
input_ids = torch.cat(
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
dim=0,
).to(device=device)
attention_mask = torch.cat(
[
concatenated_batch["attention_mask"],
concatenated_batch["rejected_attention_mask"],
],
dim=0,
).to(device=device)
labels = torch.cat(
[concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
).to(device=device)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
}
def orpo_compute_custom_loss(self, logits, labels):
logits = logits.contiguous()
loss = 0.0
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
dim=-1
)
return loss
def orpo_compute_logps(
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
):
# Get the shape of chosen_attention_mask[:, :-1]
chosen_shape = chosen_attention_mask[:, :-1].shape
# Calculate the padding size
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
# Pad prompt_attention_mask with zeros to match the desired shape
prompt_attention_mask_padded = torch.nn.functional.pad(
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
)
# Perform the subtraction operation
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
per_token_logps = torch.gather(
logits[:, :-1, :].log_softmax(-1),
dim=2,
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
pad_token=self.tokenizer.pad_token_id,
device=self.accelerator.device,
)
# Perform a single forward pass
outputs = model(
**{
"input_ids": concat_inputs["input_ids"],
"attention_mask": concat_inputs["attention_mask"],
"labels": concat_inputs["labels"],
},
output_hidden_states=True,
)
# Split the outputs for positive and negative examples
outputs_pos, outputs_neg = outputs.logits.chunk(2)
# Calculate NLL loss
pos_loss = self.orpo_compute_custom_loss(
logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
)
# Calculate Log Probability
pos_prob = self.orpo_compute_logps(
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
logits=outputs_pos,
)
neg_prob = self.orpo_compute_logps(
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
logits=outputs_neg,
)
# Calculate log odds
log_odds = (pos_prob - neg_prob) - (
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
)
sig_ratio = torch.nn.functional.sigmoid(log_odds)
ratio = torch.log(sig_ratio)
# Calculate the Final Loss
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
dtype=torch.bfloat16
)
metrics = {}
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
self.store_metrics(metrics, train_eval="train")
return (loss, outputs_pos) if return_outputs else loss
@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
res = super().create_accelerator_and_postprocess()
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
and self.args.fsdp_config["limit_all_gathers"]
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
return res
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
tag_names = ["axolotl", "prm"]

View File

@@ -0,0 +1,264 @@
"""
extra axolotl specific training args
"""
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
},
)
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
kd_top_k_before_softmax: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
ORPO config for ORPO training
"""
@dataclass
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""
KTO config for KTO training
"""
@dataclass
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
"""
CPO config for CPO training
"""
simpo_gamma: Optional[float] = field(
default=None,
metadata={"help": "simpo gamma parameter"},
)
@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""
@dataclass
class AxolotlPRMConfig(AxolotlTrainingMixins, PRMConfig):
"""
PRM config for PRM training
"""

View File

@@ -2,7 +2,7 @@
import logging
import os
from typing import List, Optional
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
@@ -51,7 +51,17 @@ class TokenizedPromptDataset(Dataset):
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
map_kwargs["batch_size"] = 1_000
if (
hasattr(self.prompt_tokenizer, "filter_rows")
and self.prompt_tokenizer.filter_rows
):
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=num_proc,
desc="Strategy Filtering Rows",
)
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
@@ -63,6 +73,24 @@ class TokenizedPromptDataset(Dataset):
)
def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Union[Dataset, IterableDataset],
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = dataset.features.keys()
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""

View File

@@ -111,6 +111,17 @@ class BasePlugin:
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Parameters:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
@@ -212,7 +223,17 @@ def load_plugin(plugin_name: str) -> BasePlugin:
module_name, class_name = plugin_name.rsplit(".", 1)
# import the module
module = importlib.import_module(module_name)
try:
module = importlib.import_module(module_name)
except ModuleNotFoundError as orig_exc:
try:
if not module_name.startswith("axolotl.integrations."):
module = importlib.import_module("axolotl.integrations." + module_name)
else:
raise orig_exc
except ModuleNotFoundError as exc:
raise orig_exc from exc
# instantiate the class
plugin_class = getattr(module, class_name)
# create an instance of the class
@@ -272,8 +293,10 @@ class PluginManager:
ImportError: If the plugin module cannot be imported.
"""
try:
logging.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
logging.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -346,6 +369,22 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def get_trainer_cls(self, cfg):
"""
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The trainer class, or None if none was found.
"""
for plugin in self.plugins.values():
trainer_cls = plugin.get_trainer_cls(cfg)
if trainer_cls is not None:
return trainer_cls
return None
def create_optimizer(self, cfg, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

View File

@@ -0,0 +1,36 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Plugin init to add KD support to Axolotl.
"""
from axolotl.integrations.base import BasePlugin
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
class KDPlugin(BasePlugin):
"""
Plugin for KD support in Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None

View File

@@ -0,0 +1,37 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Plugin args for KD support.
"""
from typing import Optional
from pydantic import BaseModel
class KDArgs(BaseModel):
"""
Input args for knowledge distillation.
"""
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[
bool
] = None # whether to sample top k before softmax during KD

View File

@@ -0,0 +1,201 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Chat template prompt strategy loader with KD support
"""
from typing import Any, Dict
import torch
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
Handle fields for logprob KD
"""
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
):
self.logprobs_field = logprobs_field
self.gen_temperature = gen_temperature
self.kd_temperature = kd_temperature
super().__init__(
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
)
@property
def supports_batched(self) -> bool:
# batching doesn't work well for logprob data
return False
def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = []
target_token_ids = []
target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float
)
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
if shift == 1:
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
class KDStrategyLoader(StrategyLoader):
"""
Load ChatTemplateStrategy with KD support using StrategyLoader.
"""
def _get_strategy_cls(self):
return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if gen_temperature := ds_cfg.get("temperature"):
strategy_params["gen_temperature"] = gen_temperature
if kd_temperature := cfg.get("kd_temperature"):
strategy_params["kd_temperature"] = kd_temperature
return strategy_params
load = KDStrategyLoader()

View File

@@ -0,0 +1,255 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DataCollator for axolotl to handle KD fields without using -inf for padding,
and with a teacher_mask to identify padded positions.
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
import torch
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
@dataclass
class DataCollatorForKD(DataCollatorForSeq2Seq):
"""
Data collator for KD, including handling KD-specific fields.
This version avoids using -inf and instead uses a large negative value for padding
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
"""
# pylint: disable=duplicate-code
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
padding_side = self.tokenizer.padding_side
# Pad labels and position_ids first
for feature_name, pad_token_id in [
("labels", self.label_pad_token_id),
("position_ids", self.position_pad_token_id),
]:
if feature_name in features[0]:
feat = [f[feature_name] for f in features]
max_len = max(len(x) for x in feat)
if self.pad_to_multiple_of is not None:
max_len = (
(max_len + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
) * self.pad_to_multiple_of
for f in features: # pylint: disable=invalid-name
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
if isinstance(f[feature_name], list):
f[feature_name] = (
f[feature_name] + remainder
if padding_side == "right"
else remainder + f[feature_name]
)
else:
# If they are numpy arrays
if padding_side == "right":
f[feature_name] = np.concatenate(
[f[feature_name], remainder]
).astype(np.int64)
else:
f[feature_name] = np.concatenate(
[remainder, f[feature_name]]
).astype(np.int64)
# Handle target_logprobs and target_token_ids manually
target_logprobs_list = []
target_token_ids_list = []
target_mask_list = []
has_teacher_data = ("target_logprobs" in features[0]) and (
"target_token_ids" in features[0]
)
if has_teacher_data:
# Extract and remove from features
for f in features: # pylint: disable=invalid-name
target_logprobs_list.append(f.pop("target_logprobs"))
target_token_ids_list.append(f.pop("target_token_ids"))
target_mask_list.append(f.pop("target_mask"))
# Determine max lengths
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
padded_target_logprobs = []
padded_target_token_ids = []
padded_teacher_mask_list = []
for t_logprobs, t_ids, t_mask in zip(
target_logprobs_list, target_token_ids_list, target_mask_list
):
t_logprobs_padded = []
t_ids_padded = []
t_mask_padded = []
for lp, ids, mask in zip( # pylint: disable=invalid-name
t_logprobs, t_ids, t_mask
):
lp_len = len(lp)
if lp_len < max_k:
# Use -1e9 for padding logprobs and 0 for token_ids
pad_len = max_k - lp_len
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
ids = ids + [0] * pad_len
mask = mask + [0] * pad_len
else:
lp = lp[:max_k] # pylint: disable=invalid-name
ids = ids[:max_k]
mask = mask[:max_k]
t_logprobs_padded.append(lp)
t_ids_padded.append(ids)
t_mask_padded.append(mask)
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
if seq_len_diff > 0:
# Pad sequences fully if needed
t_logprobs_padded.extend(
[[-1e9] * max_k for _ in range(seq_len_diff)]
)
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
padded_target_logprobs.append(t_logprobs_padded)
padded_target_token_ids.append(t_ids_padded)
padded_teacher_mask_list.append(t_mask_padded)
# Convert to tensors
padded_target_logprobs = torch.tensor(
padded_target_logprobs, dtype=torch.float
)
padded_target_token_ids = torch.tensor(
padded_target_token_ids, dtype=torch.long
)
padded_teacher_mask_list = torch.tensor(
padded_teacher_mask_list, dtype=torch.int
)
# Pad using tokenizer for regular fields
features = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# Add back teacher data if present
if has_teacher_data:
features["target_logprobs"] = padded_target_logprobs
features["target_token_ids"] = padded_target_token_ids
features["target_mask"] = padded_teacher_mask_list
# Prepare decoder_input_ids if the model supports it
if (
"labels" in features
and self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
labels=features["labels"]
)
features["decoder_input_ids"] = decoder_input_ids
return features
class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
"""
Collator for multipack (batch of sub-batches) specifically for KD.
Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item.
"""
def __call__(self, features, return_tensors=None):
"""
Expects that `features` could be either:
- a single list of dicts, OR
- a list of lists of dicts (the "sub-batches" to be packed).
"""
# 1) If we are *not* dealing with multiple sequences per batch element,
# just pass straight to parent.
if not isinstance(features[0], list):
return super().__call__(features, return_tensors=return_tensors)
# 2) Otherwise, we *are* dealing with multiple sequences in each batch item.
# We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features]
for i, sub_features in enumerate(features):
# sub_features is a list of dicts, each dict = one sequences features
# We'll merge them into out_features[i].
#
# NOTE: You can customize how you combine fields as needed (e.g. summation
# or offset for attention_mask). Below is a straightforward concatenation/extension.
for field_name in sub_features[0].keys():
# Some fields you might want to skip or treat specially:
if field_name == "length":
continue
# If its a KD field thats a list-of-lists (e.g. target_logprobs),
# you typically just want to flatten them by extending.
if field_name in ["target_logprobs", "target_token_ids", "target_mask"]:
combined = []
for feat in sub_features:
combined.extend(feat[field_name])
out_features[i][field_name] = combined
elif field_name == "attention_mask":
# Here we apply the (j+1) factor to differentiate each sub-sample
# within this merged batch item.
arrays = []
for j, feat in enumerate(sub_features):
if field_name in feat:
arrays.append((j + 1) * np.array(feat[field_name]))
out_features[i][field_name] = np.concatenate(arrays)
else:
# By default, just concatenate them if they are arrays
# or extend them if they are lists.
# For example, input_ids or labels are often arrays.
arrays = []
for feat in sub_features:
if field_name in feat:
arr = np.array(feat[field_name])
arrays.append(arr)
out_features[i][field_name] = np.concatenate(arrays)
# 3) Now call the parent collator, which will do:
# - padding of labels/position_ids
# - KD-specific padding for target_logprobs, target_token_ids, etc.
# - final conversion to return_tensors
return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -0,0 +1,58 @@
### AXOLOTL COMMUNITY LICENSE AGREEMENT
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
and conditions set forth in this Agreement.
1. Definitions
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
which may be licensed separately by their respective authors and/or licensors.
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
permits Plugin Integrations to integrate with the Axolotl service.
2. Grant of License
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
- Licensee must comply with all the terms and conditions of this Agreement.
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
portions of the Software.
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
3. Restrictions
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
third parties to fine-tune artificial intelligence models.
3.2 Licensee shall not:
- Use the Software for any illegal or unauthorized purpose.
- Reverse engineer, decompile, or disassemble the Software.
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
Software or interfere with any third-party use of the Software.
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
4. Intellectual Property Rights
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
Licensee.
5. Disclaimer of Warranty
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
6. Termination
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
copies in its possession.
7. Governing Law
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
without regards to conflicts of laws provisions thereof.
8. Entire Agreement
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
Licensees continued use of the Software after any such updates shall constitute acceptance of updated terms
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
bound by the terms and conditions of this Agreement.
This Agreement was last updated on August 23, 2024.

View File

@@ -0,0 +1,235 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
loss for top_k KL divergence
"""
import torch
def zscore_standardize(
logits: torch.Tensor,
mask: torch.Tensor = None,
base_temperature: float = 1.0,
eps: float = 1e-9,
):
"""
Z-score standardize along the last dimension of `logits`.
i.e., for each [B, seq_len] row, across K entries:
z = (logits - mean) / std,
then scale by 1 / base_temperature if desired.
mask can be broadcastable or None. If None, we standardize all elements.
"""
if mask is None:
# shape: [B, seq_len, K]
# Mean and std over dim=-1
mean = logits.mean(dim=-1, keepdim=True)
var = logits.var(dim=-1, unbiased=False, keepdim=True)
else:
# If you have to exclude some tokens, multiply by mask, etc.
float_mask = mask.to(logits.dtype)
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
std = torch.sqrt(var.clamp_min(eps))
z = (logits - mean) / std
# Scale by 1 / base_temperature
z = z / base_temperature
return z
@torch.jit.script
def loss(
student_logits: torch.Tensor,
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
Arguments:
student_logits (torch.Tensor): The logits of the student model.
Shape: [B, student_seq_len, vocab_size]
target_token_ids (torch.Tensor): The top-k teacher/target token IDs
Shape: [B, teacher_seq_len, top_k]
target_logprobs (torch.Tensor): The top-k teacher/target logprobs, these should already be re-normalized.
Shape: [B, teacher_seq_len, top_k]
target_mask (torch.Tensor): The mask for valid tokens.
Shape: [B, teacher_seq_len, top_k]
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
Default: 0
"""
target_logprobs = target_logprobs.float()
# Determine the teacher sequence length
# target_token_ids shape: [B, teacher_seq_len, K]
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1]
if top_k_before_softmax:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, teacher_seq_len, K]
else:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert teacher_mask to boolean for indexing
# In TorchScript, .bool() is sometimes unsupported, so we do:
valid_mask = target_mask.to(torch.bool)
# Prune tensors to only keep valid tokens
student_logprobs_topk = student_logprobs_topk[valid_mask]
target_logprobs = target_logprobs[valid_mask]
# Convert teacher logprobs to probabilities
teacher_probs = target_logprobs.exp()
# Compute forward KL
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
# Fall back to average over valid tokens
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
):
"""
A variant of top_k KL divergence with Z-score scaling
from "Logit Standardization in Knowledge Distillation".
"""
target_logprobs = target_logprobs.float()
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K]
student_topk_logits = student_topk_logits.float()
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# 4) Z-score teacher and student
# If target_mask is 2D, expand to 3D for the K dimension
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
)
# 5) Convert to log-probs for KL
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 6) Restrict to valid tokens if needed
valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
student_logprobs_z = student_logprobs_z[valid_mask]
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
kd_loss = kd_loss_per_token.sum()
# 8) If using classical KD scaling by T^2
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
# kd_loss = kd_loss * (zscore_base_temp**2)
# 9) Normalize
if num_items_in_batch is not None and num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -0,0 +1,113 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
KD trainer
"""
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)
"""
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
if self._signature_columns:
if "target_logprobs" not in self._signature_columns:
columns_to_add.append("target_logprobs")
if "target_token_ids" not in self._signature_columns:
columns_to_add.append("target_token_ids")
if "target_mask" not in self._signature_columns:
columns_to_add.append("target_mask")
if columns_to_add:
self._signature_columns += columns_to_add
def compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
kd_temperature=self.args.kd_temperature,
zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss

View File

@@ -0,0 +1,590 @@
{
"model.layers.0.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.1.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.2.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.3.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.4.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.5.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.6.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.7.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.8.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.9.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.10.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.11.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.12.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.13.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.14.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.15.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"lm_head": {
"snr": Infinity,
"type": "lm_head"
},
"model.layers.0.mlp.down_proj": {
"snr": 70.0594253540039,
"type": "mlp.down_proj"
},
"model.layers.1.mlp.down_proj": {
"snr": 11.135851860046387,
"type": "mlp.down_proj"
},
"model.layers.2.mlp.down_proj": {
"snr": 7.035482883453369,
"type": "mlp.down_proj"
},
"model.layers.3.mlp.down_proj": {
"snr": 6.422532081604004,
"type": "mlp.down_proj"
},
"model.layers.4.mlp.down_proj": {
"snr": 5.748020172119141,
"type": "mlp.down_proj"
},
"model.layers.5.mlp.down_proj": {
"snr": 3.885556697845459,
"type": "mlp.down_proj"
},
"model.layers.6.mlp.down_proj": {
"snr": 3.4336745738983154,
"type": "mlp.down_proj"
},
"model.layers.7.mlp.down_proj": {
"snr": 2.791595935821533,
"type": "mlp.down_proj"
},
"model.layers.8.mlp.down_proj": {
"snr": 5.36277961730957,
"type": "mlp.down_proj"
},
"model.layers.9.mlp.down_proj": {
"snr": 4.459208011627197,
"type": "mlp.down_proj"
},
"model.layers.10.mlp.down_proj": {
"snr": 6.272170066833496,
"type": "mlp.down_proj"
},
"model.layers.11.mlp.down_proj": {
"snr": 5.264761447906494,
"type": "mlp.down_proj"
},
"model.layers.12.mlp.down_proj": {
"snr": 4.324735641479492,
"type": "mlp.down_proj"
},
"model.layers.13.mlp.down_proj": {
"snr": 3.878648042678833,
"type": "mlp.down_proj"
},
"model.layers.14.mlp.down_proj": {
"snr": 2.9773054122924805,
"type": "mlp.down_proj"
},
"model.layers.15.mlp.down_proj": {
"snr": 4.471445560455322,
"type": "mlp.down_proj"
},
"model.layers.0.mlp.gate_proj": {
"snr": 25.227100372314453,
"type": "mlp.gate_proj"
},
"model.layers.1.mlp.gate_proj": {
"snr": 6.58299446105957,
"type": "mlp.gate_proj"
},
"model.layers.2.mlp.gate_proj": {
"snr": 3.4688243865966797,
"type": "mlp.gate_proj"
},
"model.layers.3.mlp.gate_proj": {
"snr": 1.555246114730835,
"type": "mlp.gate_proj"
},
"model.layers.4.mlp.gate_proj": {
"snr": 0.7770601511001587,
"type": "mlp.gate_proj"
},
"model.layers.5.mlp.gate_proj": {
"snr": 0.6239906549453735,
"type": "mlp.gate_proj"
},
"model.layers.6.mlp.gate_proj": {
"snr": 0.6440379023551941,
"type": "mlp.gate_proj"
},
"model.layers.7.mlp.gate_proj": {
"snr": 0.5120116472244263,
"type": "mlp.gate_proj"
},
"model.layers.8.mlp.gate_proj": {
"snr": 0.6544050574302673,
"type": "mlp.gate_proj"
},
"model.layers.9.mlp.gate_proj": {
"snr": 0.5381016731262207,
"type": "mlp.gate_proj"
},
"model.layers.10.mlp.gate_proj": {
"snr": 0.622873842716217,
"type": "mlp.gate_proj"
},
"model.layers.11.mlp.gate_proj": {
"snr": 0.9361700415611267,
"type": "mlp.gate_proj"
},
"model.layers.12.mlp.gate_proj": {
"snr": 1.475605845451355,
"type": "mlp.gate_proj"
},
"model.layers.13.mlp.gate_proj": {
"snr": 1.608325719833374,
"type": "mlp.gate_proj"
},
"model.layers.14.mlp.gate_proj": {
"snr": 1.0720024108886719,
"type": "mlp.gate_proj"
},
"model.layers.15.mlp.gate_proj": {
"snr": 0.7111338973045349,
"type": "mlp.gate_proj"
},
"model.layers.0.mlp.up_proj": {
"snr": 28.431896209716797,
"type": "mlp.up_proj"
},
"model.layers.1.mlp.up_proj": {
"snr": 15.546019554138184,
"type": "mlp.up_proj"
},
"model.layers.2.mlp.up_proj": {
"snr": 23.048023223876953,
"type": "mlp.up_proj"
},
"model.layers.3.mlp.up_proj": {
"snr": 25.790977478027344,
"type": "mlp.up_proj"
},
"model.layers.4.mlp.up_proj": {
"snr": 18.552549362182617,
"type": "mlp.up_proj"
},
"model.layers.5.mlp.up_proj": {
"snr": 8.85106372833252,
"type": "mlp.up_proj"
},
"model.layers.6.mlp.up_proj": {
"snr": 10.653799057006836,
"type": "mlp.up_proj"
},
"model.layers.7.mlp.up_proj": {
"snr": 7.365357875823975,
"type": "mlp.up_proj"
},
"model.layers.8.mlp.up_proj": {
"snr": 11.98373794555664,
"type": "mlp.up_proj"
},
"model.layers.9.mlp.up_proj": {
"snr": 8.04493236541748,
"type": "mlp.up_proj"
},
"model.layers.10.mlp.up_proj": {
"snr": 8.523039817810059,
"type": "mlp.up_proj"
},
"model.layers.11.mlp.up_proj": {
"snr": 5.381742477416992,
"type": "mlp.up_proj"
},
"model.layers.12.mlp.up_proj": {
"snr": 3.9845118522644043,
"type": "mlp.up_proj"
},
"model.layers.13.mlp.up_proj": {
"snr": 3.4893221855163574,
"type": "mlp.up_proj"
},
"model.layers.14.mlp.up_proj": {
"snr": 1.764201045036316,
"type": "mlp.up_proj"
},
"model.layers.15.mlp.up_proj": {
"snr": 0.9730708599090576,
"type": "mlp.up_proj"
},
"model.embed_tokens": {
"snr": Infinity,
"type": "model.embed_tokens"
},
"model.norm": {
"snr": Infinity,
"type": "model.norm"
},
"model.layers.0.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.1.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.2.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.3.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.4.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.5.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.6.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.7.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.8.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.9.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.10.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.11.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.12.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.13.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.14.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.15.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.0.self_attn.k_proj": {
"snr": 0.11727584153413773,
"type": "self_attn.k_proj"
},
"model.layers.1.self_attn.k_proj": {
"snr": 0.24786807596683502,
"type": "self_attn.k_proj"
},
"model.layers.2.self_attn.k_proj": {
"snr": 0.36378130316734314,
"type": "self_attn.k_proj"
},
"model.layers.3.self_attn.k_proj": {
"snr": 0.2983120381832123,
"type": "self_attn.k_proj"
},
"model.layers.4.self_attn.k_proj": {
"snr": 0.33789733052253723,
"type": "self_attn.k_proj"
},
"model.layers.5.self_attn.k_proj": {
"snr": 0.29155924916267395,
"type": "self_attn.k_proj"
},
"model.layers.6.self_attn.k_proj": {
"snr": 0.2537297010421753,
"type": "self_attn.k_proj"
},
"model.layers.7.self_attn.k_proj": {
"snr": 0.28204113245010376,
"type": "self_attn.k_proj"
},
"model.layers.8.self_attn.k_proj": {
"snr": 0.2776711583137512,
"type": "self_attn.k_proj"
},
"model.layers.9.self_attn.k_proj": {
"snr": 0.2927376627922058,
"type": "self_attn.k_proj"
},
"model.layers.10.self_attn.k_proj": {
"snr": 0.31486213207244873,
"type": "self_attn.k_proj"
},
"model.layers.11.self_attn.k_proj": {
"snr": 0.32363659143447876,
"type": "self_attn.k_proj"
},
"model.layers.12.self_attn.k_proj": {
"snr": 0.31382912397384644,
"type": "self_attn.k_proj"
},
"model.layers.13.self_attn.k_proj": {
"snr": 0.4635234773159027,
"type": "self_attn.k_proj"
},
"model.layers.14.self_attn.k_proj": {
"snr": 0.25379249453544617,
"type": "self_attn.k_proj"
},
"model.layers.15.self_attn.k_proj": {
"snr": 0.2628238797187805,
"type": "self_attn.k_proj"
},
"model.layers.0.self_attn.o_proj": {
"snr": 0.27602291107177734,
"type": "self_attn.o_proj"
},
"model.layers.1.self_attn.o_proj": {
"snr": 0.2149604707956314,
"type": "self_attn.o_proj"
},
"model.layers.2.self_attn.o_proj": {
"snr": 0.2540294826030731,
"type": "self_attn.o_proj"
},
"model.layers.3.self_attn.o_proj": {
"snr": 0.27978822588920593,
"type": "self_attn.o_proj"
},
"model.layers.4.self_attn.o_proj": {
"snr": 0.3121289908885956,
"type": "self_attn.o_proj"
},
"model.layers.5.self_attn.o_proj": {
"snr": 0.35037684440612793,
"type": "self_attn.o_proj"
},
"model.layers.6.self_attn.o_proj": {
"snr": 0.366205096244812,
"type": "self_attn.o_proj"
},
"model.layers.7.self_attn.o_proj": {
"snr": 0.3692712187767029,
"type": "self_attn.o_proj"
},
"model.layers.8.self_attn.o_proj": {
"snr": 0.3301038146018982,
"type": "self_attn.o_proj"
},
"model.layers.9.self_attn.o_proj": {
"snr": 0.3003396987915039,
"type": "self_attn.o_proj"
},
"model.layers.10.self_attn.o_proj": {
"snr": 0.30804169178009033,
"type": "self_attn.o_proj"
},
"model.layers.11.self_attn.o_proj": {
"snr": 0.28501132130622864,
"type": "self_attn.o_proj"
},
"model.layers.12.self_attn.o_proj": {
"snr": 0.2171541005373001,
"type": "self_attn.o_proj"
},
"model.layers.13.self_attn.o_proj": {
"snr": 0.19183959066867828,
"type": "self_attn.o_proj"
},
"model.layers.14.self_attn.o_proj": {
"snr": 0.19215913116931915,
"type": "self_attn.o_proj"
},
"model.layers.15.self_attn.o_proj": {
"snr": 0.25486502051353455,
"type": "self_attn.o_proj"
},
"model.layers.0.self_attn.q_proj": {
"snr": 0.03850084915757179,
"type": "self_attn.q_proj"
},
"model.layers.1.self_attn.q_proj": {
"snr": 0.0713055431842804,
"type": "self_attn.q_proj"
},
"model.layers.2.self_attn.q_proj": {
"snr": 0.07948919385671616,
"type": "self_attn.q_proj"
},
"model.layers.3.self_attn.q_proj": {
"snr": 0.08047746121883392,
"type": "self_attn.q_proj"
},
"model.layers.4.self_attn.q_proj": {
"snr": 0.0852593332529068,
"type": "self_attn.q_proj"
},
"model.layers.5.self_attn.q_proj": {
"snr": 0.09794823825359344,
"type": "self_attn.q_proj"
},
"model.layers.6.self_attn.q_proj": {
"snr": 0.09627152234315872,
"type": "self_attn.q_proj"
},
"model.layers.7.self_attn.q_proj": {
"snr": 0.11065381020307541,
"type": "self_attn.q_proj"
},
"model.layers.8.self_attn.q_proj": {
"snr": 0.12031875550746918,
"type": "self_attn.q_proj"
},
"model.layers.9.self_attn.q_proj": {
"snr": 0.09804573655128479,
"type": "self_attn.q_proj"
},
"model.layers.10.self_attn.q_proj": {
"snr": 0.10897502303123474,
"type": "self_attn.q_proj"
},
"model.layers.11.self_attn.q_proj": {
"snr": 0.09267337620258331,
"type": "self_attn.q_proj"
},
"model.layers.12.self_attn.q_proj": {
"snr": 0.08803492039442062,
"type": "self_attn.q_proj"
},
"model.layers.13.self_attn.q_proj": {
"snr": 0.0902542844414711,
"type": "self_attn.q_proj"
},
"model.layers.14.self_attn.q_proj": {
"snr": 0.10154066979885101,
"type": "self_attn.q_proj"
},
"model.layers.15.self_attn.q_proj": {
"snr": 0.09083802253007889,
"type": "self_attn.q_proj"
},
"model.layers.0.self_attn.v_proj": {
"snr": 2.842210054397583,
"type": "self_attn.v_proj"
},
"model.layers.1.self_attn.v_proj": {
"snr": 10.59461498260498,
"type": "self_attn.v_proj"
},
"model.layers.2.self_attn.v_proj": {
"snr": 8.993025779724121,
"type": "self_attn.v_proj"
},
"model.layers.3.self_attn.v_proj": {
"snr": 62.567787170410156,
"type": "self_attn.v_proj"
},
"model.layers.4.self_attn.v_proj": {
"snr": 23.80082893371582,
"type": "self_attn.v_proj"
},
"model.layers.5.self_attn.v_proj": {
"snr": 7.957369804382324,
"type": "self_attn.v_proj"
},
"model.layers.6.self_attn.v_proj": {
"snr": 12.01815414428711,
"type": "self_attn.v_proj"
},
"model.layers.7.self_attn.v_proj": {
"snr": 5.095500469207764,
"type": "self_attn.v_proj"
},
"model.layers.8.self_attn.v_proj": {
"snr": 11.719332695007324,
"type": "self_attn.v_proj"
},
"model.layers.9.self_attn.v_proj": {
"snr": 555.0869750976562,
"type": "self_attn.v_proj"
},
"model.layers.10.self_attn.v_proj": {
"snr": 22.95538330078125,
"type": "self_attn.v_proj"
},
"model.layers.11.self_attn.v_proj": {
"snr": 30.042158126831055,
"type": "self_attn.v_proj"
},
"model.layers.12.self_attn.v_proj": {
"snr": 9.577271461486816,
"type": "self_attn.v_proj"
},
"model.layers.13.self_attn.v_proj": {
"snr": 18.176361083984375,
"type": "self_attn.v_proj"
},
"model.layers.14.self_attn.v_proj": {
"snr": 1.5695856809616089,
"type": "self_attn.v_proj"
},
"model.layers.15.self_attn.v_proj": {
"snr": 2.7235565185546875,
"type": "self_attn.v_proj"
}
}

View File

@@ -0,0 +1,590 @@
{
"model.layers.0.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.1.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.2.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.3.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.4.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.5.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.6.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.7.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.8.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.9.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.10.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.11.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.12.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.13.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.14.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"model.layers.15.input_layernorm": {
"snr": Infinity,
"type": "input_layernorm"
},
"lm_head": {
"snr": Infinity,
"type": "lm_head"
},
"model.layers.0.mlp.down_proj": {
"snr": 57.09797286987305,
"type": "mlp.down_proj"
},
"model.layers.1.mlp.down_proj": {
"snr": 9.538983345031738,
"type": "mlp.down_proj"
},
"model.layers.2.mlp.down_proj": {
"snr": 6.227016925811768,
"type": "mlp.down_proj"
},
"model.layers.3.mlp.down_proj": {
"snr": 5.660686492919922,
"type": "mlp.down_proj"
},
"model.layers.4.mlp.down_proj": {
"snr": 5.178432464599609,
"type": "mlp.down_proj"
},
"model.layers.5.mlp.down_proj": {
"snr": 3.5638349056243896,
"type": "mlp.down_proj"
},
"model.layers.6.mlp.down_proj": {
"snr": 3.0918056964874268,
"type": "mlp.down_proj"
},
"model.layers.7.mlp.down_proj": {
"snr": 2.456392288208008,
"type": "mlp.down_proj"
},
"model.layers.8.mlp.down_proj": {
"snr": 4.525328636169434,
"type": "mlp.down_proj"
},
"model.layers.9.mlp.down_proj": {
"snr": 3.9409055709838867,
"type": "mlp.down_proj"
},
"model.layers.10.mlp.down_proj": {
"snr": 5.447249412536621,
"type": "mlp.down_proj"
},
"model.layers.11.mlp.down_proj": {
"snr": 4.807600975036621,
"type": "mlp.down_proj"
},
"model.layers.12.mlp.down_proj": {
"snr": 3.915374517440796,
"type": "mlp.down_proj"
},
"model.layers.13.mlp.down_proj": {
"snr": 3.4820363521575928,
"type": "mlp.down_proj"
},
"model.layers.14.mlp.down_proj": {
"snr": 2.6045074462890625,
"type": "mlp.down_proj"
},
"model.layers.15.mlp.down_proj": {
"snr": 3.7237701416015625,
"type": "mlp.down_proj"
},
"model.layers.0.mlp.gate_proj": {
"snr": 22.160131454467773,
"type": "mlp.gate_proj"
},
"model.layers.1.mlp.gate_proj": {
"snr": 6.072206020355225,
"type": "mlp.gate_proj"
},
"model.layers.2.mlp.gate_proj": {
"snr": 3.2467362880706787,
"type": "mlp.gate_proj"
},
"model.layers.3.mlp.gate_proj": {
"snr": 1.4111896753311157,
"type": "mlp.gate_proj"
},
"model.layers.4.mlp.gate_proj": {
"snr": 0.7405938506126404,
"type": "mlp.gate_proj"
},
"model.layers.5.mlp.gate_proj": {
"snr": 0.5916463136672974,
"type": "mlp.gate_proj"
},
"model.layers.6.mlp.gate_proj": {
"snr": 0.6149423718452454,
"type": "mlp.gate_proj"
},
"model.layers.7.mlp.gate_proj": {
"snr": 0.48369669914245605,
"type": "mlp.gate_proj"
},
"model.layers.8.mlp.gate_proj": {
"snr": 0.6047574877738953,
"type": "mlp.gate_proj"
},
"model.layers.9.mlp.gate_proj": {
"snr": 0.5092479586601257,
"type": "mlp.gate_proj"
},
"model.layers.10.mlp.gate_proj": {
"snr": 0.5999670624732971,
"type": "mlp.gate_proj"
},
"model.layers.11.mlp.gate_proj": {
"snr": 0.8980127573013306,
"type": "mlp.gate_proj"
},
"model.layers.12.mlp.gate_proj": {
"snr": 1.4252448081970215,
"type": "mlp.gate_proj"
},
"model.layers.13.mlp.gate_proj": {
"snr": 1.509937047958374,
"type": "mlp.gate_proj"
},
"model.layers.14.mlp.gate_proj": {
"snr": 1.0066585540771484,
"type": "mlp.gate_proj"
},
"model.layers.15.mlp.gate_proj": {
"snr": 0.6413647532463074,
"type": "mlp.gate_proj"
},
"model.layers.0.mlp.up_proj": {
"snr": 26.08852195739746,
"type": "mlp.up_proj"
},
"model.layers.1.mlp.up_proj": {
"snr": 13.382951736450195,
"type": "mlp.up_proj"
},
"model.layers.2.mlp.up_proj": {
"snr": 20.088768005371094,
"type": "mlp.up_proj"
},
"model.layers.3.mlp.up_proj": {
"snr": 23.0632381439209,
"type": "mlp.up_proj"
},
"model.layers.4.mlp.up_proj": {
"snr": 16.07433319091797,
"type": "mlp.up_proj"
},
"model.layers.5.mlp.up_proj": {
"snr": 8.00507640838623,
"type": "mlp.up_proj"
},
"model.layers.6.mlp.up_proj": {
"snr": 9.538354873657227,
"type": "mlp.up_proj"
},
"model.layers.7.mlp.up_proj": {
"snr": 6.286602973937988,
"type": "mlp.up_proj"
},
"model.layers.8.mlp.up_proj": {
"snr": 10.092820167541504,
"type": "mlp.up_proj"
},
"model.layers.9.mlp.up_proj": {
"snr": 7.193963527679443,
"type": "mlp.up_proj"
},
"model.layers.10.mlp.up_proj": {
"snr": 7.320116996765137,
"type": "mlp.up_proj"
},
"model.layers.11.mlp.up_proj": {
"snr": 4.8728532791137695,
"type": "mlp.up_proj"
},
"model.layers.12.mlp.up_proj": {
"snr": 3.596583366394043,
"type": "mlp.up_proj"
},
"model.layers.13.mlp.up_proj": {
"snr": 3.166161298751831,
"type": "mlp.up_proj"
},
"model.layers.14.mlp.up_proj": {
"snr": 1.5600818395614624,
"type": "mlp.up_proj"
},
"model.layers.15.mlp.up_proj": {
"snr": 0.8726214170455933,
"type": "mlp.up_proj"
},
"model.embed_tokens": {
"snr": Infinity,
"type": "model.embed_tokens"
},
"model.norm": {
"snr": Infinity,
"type": "model.norm"
},
"model.layers.0.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.1.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.2.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.3.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.4.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.5.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.6.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.7.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.8.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.9.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.10.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.11.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.12.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.13.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.14.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.15.post_attention_layernorm": {
"snr": Infinity,
"type": "post_attention_layernorm"
},
"model.layers.0.self_attn.k_proj": {
"snr": 0.1154392883181572,
"type": "self_attn.k_proj"
},
"model.layers.1.self_attn.k_proj": {
"snr": 0.24299409985542297,
"type": "self_attn.k_proj"
},
"model.layers.2.self_attn.k_proj": {
"snr": 0.3624322712421417,
"type": "self_attn.k_proj"
},
"model.layers.3.self_attn.k_proj": {
"snr": 0.29509487748146057,
"type": "self_attn.k_proj"
},
"model.layers.4.self_attn.k_proj": {
"snr": 0.32953736186027527,
"type": "self_attn.k_proj"
},
"model.layers.5.self_attn.k_proj": {
"snr": 0.2908833622932434,
"type": "self_attn.k_proj"
},
"model.layers.6.self_attn.k_proj": {
"snr": 0.2488437294960022,
"type": "self_attn.k_proj"
},
"model.layers.7.self_attn.k_proj": {
"snr": 0.27847856283187866,
"type": "self_attn.k_proj"
},
"model.layers.8.self_attn.k_proj": {
"snr": 0.27143892645835876,
"type": "self_attn.k_proj"
},
"model.layers.9.self_attn.k_proj": {
"snr": 0.28804272413253784,
"type": "self_attn.k_proj"
},
"model.layers.10.self_attn.k_proj": {
"snr": 0.31197959184646606,
"type": "self_attn.k_proj"
},
"model.layers.11.self_attn.k_proj": {
"snr": 0.3203586935997009,
"type": "self_attn.k_proj"
},
"model.layers.12.self_attn.k_proj": {
"snr": 0.30905747413635254,
"type": "self_attn.k_proj"
},
"model.layers.13.self_attn.k_proj": {
"snr": 0.46828722953796387,
"type": "self_attn.k_proj"
},
"model.layers.14.self_attn.k_proj": {
"snr": 0.24205778539180756,
"type": "self_attn.k_proj"
},
"model.layers.15.self_attn.k_proj": {
"snr": 0.2559327781200409,
"type": "self_attn.k_proj"
},
"model.layers.0.self_attn.o_proj": {
"snr": 0.2638678550720215,
"type": "self_attn.o_proj"
},
"model.layers.1.self_attn.o_proj": {
"snr": 0.21109595894813538,
"type": "self_attn.o_proj"
},
"model.layers.2.self_attn.o_proj": {
"snr": 0.24751724302768707,
"type": "self_attn.o_proj"
},
"model.layers.3.self_attn.o_proj": {
"snr": 0.2728094160556793,
"type": "self_attn.o_proj"
},
"model.layers.4.self_attn.o_proj": {
"snr": 0.3001374304294586,
"type": "self_attn.o_proj"
},
"model.layers.5.self_attn.o_proj": {
"snr": 0.33903488516807556,
"type": "self_attn.o_proj"
},
"model.layers.6.self_attn.o_proj": {
"snr": 0.3530929982662201,
"type": "self_attn.o_proj"
},
"model.layers.7.self_attn.o_proj": {
"snr": 0.36753255128860474,
"type": "self_attn.o_proj"
},
"model.layers.8.self_attn.o_proj": {
"snr": 0.3373180329799652,
"type": "self_attn.o_proj"
},
"model.layers.9.self_attn.o_proj": {
"snr": 0.2970578670501709,
"type": "self_attn.o_proj"
},
"model.layers.10.self_attn.o_proj": {
"snr": 0.3076324760913849,
"type": "self_attn.o_proj"
},
"model.layers.11.self_attn.o_proj": {
"snr": 0.2766900658607483,
"type": "self_attn.o_proj"
},
"model.layers.12.self_attn.o_proj": {
"snr": 0.20973259210586548,
"type": "self_attn.o_proj"
},
"model.layers.13.self_attn.o_proj": {
"snr": 0.18185566365718842,
"type": "self_attn.o_proj"
},
"model.layers.14.self_attn.o_proj": {
"snr": 0.18329747021198273,
"type": "self_attn.o_proj"
},
"model.layers.15.self_attn.o_proj": {
"snr": 0.2437991499900818,
"type": "self_attn.o_proj"
},
"model.layers.0.self_attn.q_proj": {
"snr": 0.038040731102228165,
"type": "self_attn.q_proj"
},
"model.layers.1.self_attn.q_proj": {
"snr": 0.0707998052239418,
"type": "self_attn.q_proj"
},
"model.layers.2.self_attn.q_proj": {
"snr": 0.0787411704659462,
"type": "self_attn.q_proj"
},
"model.layers.3.self_attn.q_proj": {
"snr": 0.08089710026979446,
"type": "self_attn.q_proj"
},
"model.layers.4.self_attn.q_proj": {
"snr": 0.08591937273740768,
"type": "self_attn.q_proj"
},
"model.layers.5.self_attn.q_proj": {
"snr": 0.09852176159620285,
"type": "self_attn.q_proj"
},
"model.layers.6.self_attn.q_proj": {
"snr": 0.09690654277801514,
"type": "self_attn.q_proj"
},
"model.layers.7.self_attn.q_proj": {
"snr": 0.11181341856718063,
"type": "self_attn.q_proj"
},
"model.layers.8.self_attn.q_proj": {
"snr": 0.12042108923196793,
"type": "self_attn.q_proj"
},
"model.layers.9.self_attn.q_proj": {
"snr": 0.09799323976039886,
"type": "self_attn.q_proj"
},
"model.layers.10.self_attn.q_proj": {
"snr": 0.10901063680648804,
"type": "self_attn.q_proj"
},
"model.layers.11.self_attn.q_proj": {
"snr": 0.09307146072387695,
"type": "self_attn.q_proj"
},
"model.layers.12.self_attn.q_proj": {
"snr": 0.0880950540304184,
"type": "self_attn.q_proj"
},
"model.layers.13.self_attn.q_proj": {
"snr": 0.08886399120092392,
"type": "self_attn.q_proj"
},
"model.layers.14.self_attn.q_proj": {
"snr": 0.09955056011676788,
"type": "self_attn.q_proj"
},
"model.layers.15.self_attn.q_proj": {
"snr": 0.08929339051246643,
"type": "self_attn.q_proj"
},
"model.layers.0.self_attn.v_proj": {
"snr": 2.5501928329467773,
"type": "self_attn.v_proj"
},
"model.layers.1.self_attn.v_proj": {
"snr": 9.449499130249023,
"type": "self_attn.v_proj"
},
"model.layers.2.self_attn.v_proj": {
"snr": 7.9920830726623535,
"type": "self_attn.v_proj"
},
"model.layers.3.self_attn.v_proj": {
"snr": 50.69462585449219,
"type": "self_attn.v_proj"
},
"model.layers.4.self_attn.v_proj": {
"snr": 19.083511352539062,
"type": "self_attn.v_proj"
},
"model.layers.5.self_attn.v_proj": {
"snr": 7.21597146987915,
"type": "self_attn.v_proj"
},
"model.layers.6.self_attn.v_proj": {
"snr": 11.27744197845459,
"type": "self_attn.v_proj"
},
"model.layers.7.self_attn.v_proj": {
"snr": 4.579711437225342,
"type": "self_attn.v_proj"
},
"model.layers.8.self_attn.v_proj": {
"snr": 10.940719604492188,
"type": "self_attn.v_proj"
},
"model.layers.9.self_attn.v_proj": {
"snr": 553.4417724609375,
"type": "self_attn.v_proj"
},
"model.layers.10.self_attn.v_proj": {
"snr": 20.59434700012207,
"type": "self_attn.v_proj"
},
"model.layers.11.self_attn.v_proj": {
"snr": 26.636865615844727,
"type": "self_attn.v_proj"
},
"model.layers.12.self_attn.v_proj": {
"snr": 8.614749908447266,
"type": "self_attn.v_proj"
},
"model.layers.13.self_attn.v_proj": {
"snr": 17.722007751464844,
"type": "self_attn.v_proj"
},
"model.layers.14.self_attn.v_proj": {
"snr": 1.48500657081604,
"type": "self_attn.v_proj"
},
"model.layers.15.self_attn.v_proj": {
"snr": 2.5776851177215576,
"type": "self_attn.v_proj"
}
}

View File

@@ -16,10 +16,21 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
package = "axolotl.prompt_strategies"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
elif len(strategy.split(".")) > 1:
try:
importlib.import_module(
"." + strategy.split(".")[-1],
".".join(strategy.split(".")[:-1]),
)
package = ".".join(strategy.split(".")[:-1])
strategy = strategy.split(".")[-1]
except ModuleNotFoundError:
pass
mod = importlib.import_module(f".{strategy}", package)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":

View File

@@ -10,6 +10,8 @@ LOG = logging.getLogger("axolotl")
def load(strategy, cfg, module_base=None, **kwargs):
try:
if len(strategy.split(".")) == 1:
strategy = strategy + ".default"
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", module_base)

View File

@@ -21,7 +21,11 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
Bradley-Terry reward model pairwise chat template prompt strategy.
"""
def tokenize_prompt(self, prompt):
@property
def supports_batched(self) -> bool:
return False
def _tokenize_single_prompt(self, prompt):
"""
:param prompt: the actual row of data from the underlying dataset
@@ -39,11 +43,11 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
chosen_tokenized = super()._tokenize_single_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
f"To-be-trimmed chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)
chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
@@ -62,11 +66,11 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
rejected_tokenized = super()._tokenize_single_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
f"To-be-trimmed rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)
rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][

View File

@@ -3,6 +3,7 @@ HF Chat Templates prompt strategy
"""
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional
from transformers import ProcessorMixin
@@ -193,7 +194,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def __init__(
self,
prompter,
prompter: ChatTemplatePrompter,
tokenizer,
train_on_inputs,
sequence_len,
@@ -220,22 +221,61 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def messages(self, messages):
self._messages = messages
def tokenize_prompt(self, prompt):
@property
def supports_batched(self) -> bool:
# Let calling code know we can handle lists of examples
return True
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.messages]
)
except KeyError:
return False
def tokenize_prompt(self, prompt: dict[str, Any]):
"""
Public method that can handle either a single prompt or a batch of prompts.
"""
if not self.is_prompt_batched(prompt) or not self.supports_batched:
return self._tokenize_single_prompt(prompt)
res = defaultdict(lambda: [])
feature_names = list(prompt.keys())
# Process each prompt individually
for row in zip(*prompt.values()):
tokenized_prompt = self._tokenize_single_prompt(
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary
if not res:
return {}
return dict(res)
def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
# Old simple legacy behavior that works reliably.
if (
not self.roles_to_train
and not self.train_on_eos
and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail
and not self.prompter.message_field_training # type: ignore
and not self.prompter.message_field_training_detail # type: ignore
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
prompt_ids = self.prompter.build_prompt( # type: ignore
turns[:-1],
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
@@ -256,7 +296,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt
turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns)
input_ids = self.prompter.build_prompt(turns) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
@@ -286,7 +326,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail(
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
content, train_detail
)
LOG.debug(f"Token offsets: {token_offsets}")
@@ -459,43 +499,62 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return prompt.get(self.images, None)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
class StrategyLoader:
"""
Load chat template strategy based on configuration.
"""
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
def _get_strategy_cls(self):
return ChatTemplateStrategy
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
return {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
def __call__(
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
return strategy
strategy_params = self._get_strategy_params(cfg, ds_cfg)
strategy_cls = self._get_strategy_cls()
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params),
tokenizer=tokenizer,
**strategy_params,
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
load = StrategyLoader()

View File

@@ -3,22 +3,41 @@ DPO strategies for chatml
"""
def argilla(
def default(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "prompt" in sample.keys():
prompt_key = "prompt"
elif "input" in sample.keys():
prompt_key = "input"
elif "question" in sample.keys():
prompt_key = "question"
else:
prompt_key = "instruction"
if "chosen" in sample.keys():
chosen_key = "chosen"
else:
chosen_key = "chosen_response"
if "rejected" in sample.keys():
rejected_key = "rejected"
else:
rejected_key = "rejected_response"
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample
return transform_fn

View File

@@ -3,22 +3,42 @@ DPO strategies for llama-3 chat template
"""
def argilla(
def default(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
# pylint: disable=duplicate-code
if "prompt" in sample.keys():
prompt_key = "prompt"
elif "input" in sample.keys():
prompt_key = "input"
elif "question" in sample.keys():
prompt_key = "question"
else:
prompt_key = "instruction"
if "chosen" in sample.keys():
chosen_key = "chosen"
else:
chosen_key = "chosen_response"
if "rejected" in sample.keys():
rejected_key = "rejected"
else:
rejected_key = "rejected_response"
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample
return transform_fn

View File

@@ -2,7 +2,7 @@
import abc
import logging
from typing import Dict, List, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizer
@@ -34,6 +34,8 @@ class PromptTokenizingStrategy(abc.ABC):
Abstract class for tokenizing strategies
"""
filter_rows: Optional[Callable] = None
def __init__(
self,
prompter: Prompter,

View File

@@ -846,6 +846,12 @@ class GCCallback(TrainerCallback):
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if state.global_step % self.gc_steps == 0:
if self.gc_steps > 0 and state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()
def on_epoch_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
torch.cuda.empty_cache()
gc.collect()

View File

@@ -1,4 +1,5 @@
"""Module for working with config dicts"""
import json
import logging
import os
@@ -129,10 +130,18 @@ def normalize_config(cfg):
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step
cfg.save_steps = save_steps
elif save_steps > 1:
LOG.warning(
f"Invalid value for save_steps ({save_steps}) from saves_per_epoch and/or num_epochs. Saving at training end only."
)
if (cfg.val_set_size or cfg.test_datasets) and cfg.evals_per_epoch:
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
if eval_steps < 1.0: # prevent evals on every step
cfg.eval_steps = eval_steps
elif eval_steps > 1:
LOG.warning(
f"Invalid value for eval_steps ({eval_steps}) from evals_per_epoch and/or num_epochs. Skipping evaluations."
)
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()

View File

@@ -115,6 +115,9 @@ class RemappedParameters(BaseModel):
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
default=None, alias="model_config"
)
overrides_of_model_kwargs: Optional[Dict[str, Any]] = Field(
default=None, alias="model_kwargs"
)
type_of_model: Optional[str] = Field(default=None, alias="model_type")
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
@@ -163,6 +166,7 @@ class SFTDataset(BaseModel):
type: Optional[Union[str, UserDefinedPrompterType]] = None
input_transform: Optional[str] = None
shards: Optional[int] = None
preprocess_shards: Optional[int] = None
conversation: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
@@ -185,6 +189,8 @@ class SFTDataset(BaseModel):
message_field_content: Optional[str] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None
temperature: Optional[float] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
@@ -423,8 +429,6 @@ class ModelInputConfig(BaseModel):
)
trust_remote_code: Optional[bool] = None
model_kwargs: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
@@ -861,6 +865,7 @@ class AxolotlInputConfig(
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
preprocess_iterable: Optional[bool] = None
total_num_tokens: Optional[int] = None
total_supervised_tokens: Optional[int] = None

View File

@@ -3,11 +3,12 @@
import functools
import logging
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
Sequence,
Value,
concatenate_datasets,
@@ -17,7 +18,7 @@ from datasets import (
from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset
from axolotl.datasets import TokenizedPromptDataset, wrap_dataset_for_tokenized_prompt
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
@@ -45,6 +46,7 @@ from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
drop_long_seq_in_dataset,
md5,
retry_on_request_exceptions,
)
@@ -55,11 +57,11 @@ from axolotl.utils.trainer import (
process_datasets_for_packing,
)
LOG = logging.getLogger("axolotl")
LOG = logging.getLogger(__name__)
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_local_main_process()):
@@ -70,6 +72,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
@@ -77,6 +80,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
@@ -84,6 +88,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg,
DEFAULT_DATASET_PREPARED_PATH,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
else:
# Load streaming dataset if pretraining_dataset is given
@@ -139,6 +144,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if cfg.dataset_exact_deduplication:
@@ -170,6 +176,7 @@ def load_tokenized_prepared_datasets(
default_dataset_prepared_path,
split="train",
processor=None,
preprocess_iterable: Optional[bool] = None,
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
@@ -184,10 +191,11 @@ def load_tokenized_prepared_datasets(
+ "@"
+ str(cfg.group_by_length)
+ "@"
+ str(cfg.kd_temperature or 1.0)
+ "|".join(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
f"{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}"
for d in cfg_datasets
]
)
@@ -262,13 +270,25 @@ def load_tokenized_prepared_datasets(
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
streaming_ds = False
if preprocess_iterable:
streaming_ds = True
# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token
config_dataset, use_auth_token, streaming=streaming_ds
)
d_base_type = d_prompt_style = None
@@ -320,12 +340,29 @@ def load_tokenized_prepared_datasets(
else:
LOG.debug("NOT shuffling merged datasets")
if cfg.sample_packing and not cfg.skip_prepare_dataset:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if not cfg.skip_prepare_dataset:
dataset = drop_long_seq_in_dataset(dataset, cfg)
if cfg.sample_packing:
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
if isinstance(dataset, IterableDataset):
def gen_from_iter_ds(_ds, _=None):
yield from _ds
ds_from_iter = Dataset.from_generator(
functools.partial(gen_from_iter_ds, dataset),
features=dataset.features,
num_proc=cfg.dataset_processes,
split=split,
gen_kwargs={"_": list(range(cfg.dataset_processes))},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
@@ -345,6 +382,7 @@ def load_prepare_datasets(
default_dataset_prepared_path,
split="train",
processor=None,
preprocess_iterable: Optional[bool] = False,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer,
@@ -352,6 +390,7 @@ def load_prepare_datasets(
default_dataset_prepared_path,
split=split,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
@@ -451,7 +490,7 @@ def get_dataset_wrapper(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -464,7 +503,7 @@ def get_dataset_wrapper(
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -487,7 +526,7 @@ def get_dataset_wrapper(
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
else:
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -500,7 +539,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -514,7 +553,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -528,7 +567,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -542,7 +581,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -556,7 +595,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -570,7 +609,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -584,7 +623,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -598,7 +637,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,

View File

@@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault):
return ds_type
def load_dataset_w_config(config_dataset, auth_token):
def load_dataset_w_config(
config_dataset, auth_token, streaming=False
) -> Union[Dataset, DatasetDict]:
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
@@ -124,7 +126,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
streaming=streaming,
**load_ds_kwargs,
)
else:
@@ -157,7 +159,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
streaming=streaming,
data_files=config_dataset.data_files,
token=auth_token,
revision=config_dataset.revision,
@@ -176,7 +178,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
streaming=streaming,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
@@ -187,7 +189,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
streaming=streaming,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
@@ -217,7 +219,7 @@ def load_dataset_w_config(config_dataset, auth_token):
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
streaming=streaming,
**load_ds_kwargs,
)
if not ds:

View File

@@ -1,4 +1,5 @@
"""data handling helpers"""
import functools
import hashlib
import logging
@@ -6,10 +7,15 @@ import time
from enum import Enum
import huggingface_hub
import numpy as np
import requests
from datasets import Dataset
from datasets import Dataset, IterableDataset
LOG = logging.getLogger("axolotl")
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers.utils import get_dataset_lengths
from axolotl.utils.trainer import drop_long_seq
LOG = logging.getLogger(__name__)
class RetryStrategy(Enum):
@@ -150,3 +156,53 @@ def deduplicate_and_log_datasets(
)
return train_dataset, eval_dataset, dataset
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault):
if "input_ids" not in dataset.column_names:
LOG.warning(
"Dataset does not contain 'input_ids' column. Skip drop long seq. This is expected for RewardModeling."
)
return dataset
drop_long = functools.partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len,
)
try:
min_input_len = np.min(get_dataset_lengths(dataset))
LOG.debug(f"min_input_len: {min_input_len}")
max_input_len = np.max(get_dataset_lengths(dataset))
LOG.debug(f"max_input_len: {max_input_len}")
except AttributeError:
pass
try:
prior_len = len(dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
dataset = dataset.filter(
drop_long,
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset")
return dataset

View File

@@ -10,7 +10,7 @@ from accelerate.utils.environment import get_gpu_info
def check_cuda_p2p_ib_support():
if not accelerate_check_cuda_p2p_ib_support():
return False
unsupported_devices = {"RTX 6000 Ada"}
unsupported_devices = {"RTX 6000 Ada", "L40S"}
try:
device_names, device_count = get_gpu_info()
if 1 < device_count < 8:

View File

@@ -357,8 +357,8 @@ class ModelLoader:
# init model kwargs
self.model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in cfg.model_kwargs.items():
if cfg.overrides_of_model_kwargs:
for key, val in cfg.overrides_of_model_kwargs.items():
self.model_kwargs[key] = val
# init model

View File

@@ -4,7 +4,6 @@ Multipack Batch Sampler
"""
import logging
import math
import os
from typing import Any, Iterable, List, Union
import numba
@@ -117,6 +116,7 @@ class MultipackBatchSampler(BatchSampler):
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
@@ -133,6 +133,9 @@ class MultipackBatchSampler(BatchSampler):
self.eff_total_used = 0
self.eff_total_slots = 0
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
self.num_count_samples = num_count_samples
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None
def set_epoch(self, epoch: int):
@@ -169,6 +172,9 @@ class MultipackBatchSampler(BatchSampler):
def __iter__(self):
batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# make sure the batches we iterate over is truncated to the same min length across all ranks
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
@@ -195,42 +201,15 @@ class MultipackBatchSampler(BatchSampler):
def gather_len_batches(self, num):
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(0.998 * min(estimates))
return math.floor(min(estimates))
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
if not self.len_across_ranks:
len_batches = self.num_batches()
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
)
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks
def _len_est(self):
efficiency = (
self.packing_efficiency_estimate
if self.packing_efficiency_estimate
else self.gather_efficiency()
)
world_size = int(os.getenv("WORLD_SIZE", "1"))
lengths_sum = np.sum(self.lengths)
lengths_sum_per_device = lengths_sum // world_size
LOG.info(
f"packing_efficiency_estimate: {efficiency} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return max(
0,
(
world_size
* math.floor(
0.99
* lengths_sum_per_device
/ efficiency
// (self.batch_max_len * self.batch_size)
)
- 1
),
)

View File

@@ -13,5 +13,4 @@ def get_dataset_lengths(dataset):
else:
input_ids = dataset.data.column("input_ids")
lengths = np.vectorize(len)(np.array(input_ids, dtype=object))
return lengths
return lengths

View File

@@ -26,6 +26,7 @@ def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
target_mask = example.pop("target_mask", None)
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
@@ -42,6 +43,13 @@ def check_example_labels(example, tokenizer, text_only=False):
delimiter = "" if text_only else " "
LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n")
target_labels_count = sum(label_id != -100 for label_id in labels)
total_len = len(input_ids)
LOG.info(f"Total input len: {total_len}")
LOG.info(f"Count of labels: {target_labels_count}")
if target_mask:
target_mask_positions = sum(m[0] for m in target_mask)
LOG.info(f"Number of positions in target_mask: {target_mask_positions}")
return " ".join(colored_tokens)

View File

@@ -1,4 +1,5 @@
"""Module containing the Trainer class and related functions"""
import json
import math
import os
@@ -11,7 +12,7 @@ import numpy as np
import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import disable_caching, enable_caching
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -95,9 +96,41 @@ def disable_datasets_caching():
def add_position_ids(sample):
sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
sample["length"] = sample_len
"""
Handle both single-example and batched data.
- single example: sample['input_ids'] is a list[int]
- batched data: sample['input_ids'] is a list[list[int]]
"""
# Return sample unchanged if "input_ids" is not present, or is empty
if "input_ids" not in sample or not sample["input_ids"]:
return sample
input_ids = sample["input_ids"]
# If first element is an int, its a single example
# If first element is a list, its a batch
if isinstance(input_ids[0], int):
# ---- SINGLE EXAMPLE ----
seq_len = len(input_ids)
# Position IDs for a single example
# As a list
sample["position_ids"] = list(range(seq_len))
sample["length"] = seq_len
else:
# ---- BATCHED EXAMPLES ----
# input_ids is a list of lists
position_ids_batch = []
lengths_batch = []
for seq in input_ids:
seq_len = len(seq)
position_ids_batch.append(list(range(seq_len)))
lengths_batch.append(seq_len)
# Now store them back
sample["position_ids"] = position_ids_batch
sample["length"] = lengths_batch
return sample
@@ -172,24 +205,36 @@ def add_length(sample):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
"""
min_sequence_len = min_sequence_len or 2
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
if not input_ids:
# Decide if you want to drop or keep empty. Let's drop.
return False
# Check if single example or batched by looking at the first element
if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int)
length = len(input_ids)
return min_sequence_len <= length <= sequence_len
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
length = len(seq)
results.append(min_sequence_len <= length <= sequence_len)
return results
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
drop_long = partial(
drop_long_seq,
sequence_len=cfg.sequence_len,
min_sequence_len=cfg.min_sample_len or 2,
)
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
train_dataset = train_dataset.remove_columns("attention_mask")
@@ -203,60 +248,71 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
prior_len = len(train_dataset)
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
# drop samples with where the number of elements with labels not equal to -100 is zero
def drop_no_trainable_tokens(sample):
return np.sum(np.array(sample["labels"]) != -100) > 0
"""
Drop samples if all labels are -100 (i.e., zero trainable tokens).
Works for both single-example or batched input.
"""
labels = sample["labels"]
if not labels:
return True
prior_len = len(train_dataset)
# Check if single example or batch
# If first element is an int, we assume a single example
# If it's a list, we assume we're dealing with a batch
if isinstance(labels[0], int):
# Single example: return a single bool
return np.any(labels != -100)
# Batched: 'labels' is a list of lists
# Return a list of booleans, one per sub-list
results = [np.any(row_labels != -100) for row_labels in labels]
return results
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)
if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(eval_dataset)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)
if eval_dataset:
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
@@ -291,19 +347,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
**filter_map_kwargs,
**drop_long_kwargs,
)
return train_dataset, eval_dataset
@@ -337,7 +395,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
and not cfg.reward_model
):
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
train_dataset.select_columns("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values

View File

@@ -0,0 +1,68 @@
"""
unit tests for generating sweep configurations
"""
from axolotl.cli.main import generate_sweep_configs
def test_generate_sweep_configs_no_pairs():
base_config = {
"learning_rate": 0.1,
"micro_batch_size": 1,
"sample_packing": True,
}
sweeps_config = {"micro_batch_size": [1, 2, 4], "weight_decay": [0.0, 0.1]}
generate_sweep_configs(base_config, sweeps_config)
assert len(generate_sweep_configs(base_config, sweeps_config)) == 6
cfg_1 = {
"learning_rate": 0.1,
"micro_batch_size": 2,
"weight_decay": 0.0,
"sample_packing": True,
}
assert any(
cfg_1 == cfg for cfg in generate_sweep_configs(base_config, sweeps_config)
)
def test_generate_sweep_configs_with_pairs():
base_config = {
"learning_rate": 0.1,
"micro_batch_size": 1,
"sample_packing": True,
}
sweeps_config = {
"_": [
{
"micro_batch_size": 1,
"gradient_accumulation_steps": 8,
},
{
"micro_batch_size": 2,
"gradient_accumulation_steps": 4,
},
{
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,
},
{
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
},
],
"weight_decay": [0.0, 0.1],
}
generate_sweep_configs(base_config, sweeps_config)
assert len(generate_sweep_configs(base_config, sweeps_config)) == 8
assert all(
cfg["gradient_accumulation_steps"] * cfg["micro_batch_size"] == 8
for cfg in generate_sweep_configs(base_config, sweeps_config)
)

View File

@@ -0,0 +1,121 @@
"""
e2e tests for kd trainer support in Axolotl
"""
from pathlib import Path
import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
"base_model": "osllmai-community/Llama-3.2-1B",
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,
"kd_temperature": 1.0,
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
"datasets": [
{
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
"type": "axolotl.integrations.kd.chat_template",
"field_messages": "messages_combined",
"split": "train",
"logprobs_field": "llm_text_generation_vllm_logprobs",
"temperature": 1.0,
"preprocess_shards": 2,
},
],
"val_set_size": 0.0,
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.00001,
"bf16": "auto",
"gradient_checkpointing": True,
"flash_attention": True,
"special_tokens": {
"pad_token": "<|end_of_text|>",
"eos_token": "<|eot_id|>",
},
"max_steps": 5,
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
}
class TestKnowledgeDistillation:
"""
Test case for Knowledge Distillation
"""
# While this will run on torch 2.4.x without torch_compile enabled
# the VRAM requirement is higher than what is available in CI
@require_torch_2_5_1
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
)
def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
cfg = DictDefault(
{
"load_in_8bit": load_in_8bit,
"torch_compile": False,
"adapter": "lora",
"peft_use_dora": True,
"lora_target_linear": True,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.0,
}
| kd_min_cfg
)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)

View File

@@ -55,6 +55,7 @@ class LigerIntegrationTestCase:
"max_steps": 5,
}
)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
@@ -100,6 +101,7 @@ class LigerIntegrationTestCase:
"max_steps": 5,
}
)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()

View File

@@ -63,7 +63,7 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.5, "Train Loss is too high"
temp_dir + "/runs", "train/train_loss", 2.7, "Train Loss (%s) is too high"
)
check_model_output_exists(temp_dir, cfg)

View File

@@ -33,7 +33,7 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
"num_labels": 1,
"chat_template": "alpaca",
"reward_model": True,
"sequence_len": 1024,
"sequence_len": 2048,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,

View File

@@ -82,7 +82,10 @@ def check_tensorboard(
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
assert df.value.values[-1] < lt_val, assertion_err
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
else:
assert df.value.values[-1] < lt_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None: