Compare commits

..

78 Commits

Author SHA1 Message Date
Wing Lian
7232cbdeab chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
e8fceb7091 chore: lint 2025-01-14 22:47:48 -05:00
Wing Lian
a5e0671738 make sure to use tensorboard to capture loss for checks 2025-01-14 22:47:48 -05:00
Wing Lian
b9847553af fix adapter model check 2025-01-14 22:47:48 -05:00
Wing Lian
513ec9e03b make sure to use the correct tokenizer 2025-01-14 22:47:48 -05:00
Wing Lian
530347856d make sure to set tokenizer from l3 70b and save safetensors 2025-01-14 22:47:47 -05:00
Wing Lian
261e4fb619 lower lr 2025-01-14 22:47:47 -05:00
Wing Lian
158071e95f set lora_dropout explicitly 2025-01-14 22:47:47 -05:00
Wing Lian
432f65f5e6 make the kd e2e fit in vram for ci and add lora version 2025-01-14 22:47:47 -05:00
Wing Lian
1d039f5486 rename test files so it gets picked up 2025-01-14 22:47:47 -05:00
Wing Lian
b9a42b396f linting 2025-01-14 22:47:47 -05:00
Wing Lian
ff2fb0fc1b add kd trainer e2e test 2025-01-14 22:47:47 -05:00
Wing Lian
317f290186 reward model doesn't work well with batched 2025-01-14 22:47:46 -05:00
Wing Lian
ab690f3f01 improve check for batched 2025-01-14 22:47:46 -05:00
Wing Lian
47932f21c4 fix reward trainer calls for tokenization 2025-01-14 22:47:46 -05:00
Wing Lian
808328e041 reward can use same batch check 2025-01-14 22:47:46 -05:00
Wing Lian
6784822cfb tweak check for batched prompt data 2025-01-14 22:47:46 -05:00
Wing Lian
684b38291f ensure that batch vs single is done properly 2025-01-14 22:47:46 -05:00
Wing Lian
01896b1bde improve iterable support 2025-01-14 22:47:46 -05:00
Wing Lian
e659c01646 support streaming for processing sft datasts? 2025-01-14 22:47:45 -05:00
Wing Lian
204d6c43b4 make loss torch script compat 2025-01-14 22:47:45 -05:00
Wing Lian
d3c2b7ce9d kd sample packing 2025-01-14 22:47:45 -05:00
Wing Lian
93dfff92f1 be a bit pickier about loading dynamic prompt strategies 2025-01-14 22:47:45 -05:00
Wing Lian
6e409d2d88 more info on preprocess for kd and fix import 2025-01-14 22:47:45 -05:00
Wing Lian
d5bc214300 remove duplicate code 2025-01-14 22:47:45 -05:00
Wing Lian
92c6c1087e add copyrights 2025-01-14 22:47:45 -05:00
Wing Lian
feed96f95e increase logging around loading plugins 2025-01-14 22:47:44 -05:00
Wing Lian
cba6165ae1 make plugin setup concise 2025-01-14 22:47:44 -05:00
Wing Lian
cdfcd69afa remove moved class from import 2025-01-14 22:47:44 -05:00
Wing Lian
885653d52e move more things to kd plugin 2025-01-14 22:47:44 -05:00
Wing Lian
27faacbf5a refactor kd chat template loader 2025-01-14 22:47:44 -05:00
Wing Lian
c51b0337c1 support for custom trainer classes from plugins 2025-01-14 22:47:44 -05:00
Wing Lian
fa055f9f69 handle token/logprob shifting 2025-01-14 22:47:43 -05:00
Wing Lian
f60c623af0 remove references to triton kd for now 2025-01-14 22:47:43 -05:00
Wing Lian
746891eb5c add license block 2025-01-14 22:47:43 -05:00
Wing Lian
f09b5da60b refactor so we can easily add new loss functions 2025-01-14 22:47:43 -05:00
Wing Lian
689e1c10ba chore: lint 2025-01-14 22:47:43 -05:00
Wing Lian
a5c085e003 var naming and add todo 2025-01-14 22:47:43 -05:00
Wing Lian
63146300b7 fix kd loss so it's causal (fixes repeating tokens) 2025-01-14 22:47:43 -05:00
Wing Lian
ca5e397fc5 use kd_alpha in the correct loss method 2025-01-14 22:47:42 -05:00
Wing Lian
3416302b0d hash for temperature too 2025-01-14 22:47:42 -05:00
Wing Lian
7366efc4ca better rescaling for temperatures 2025-01-14 22:47:42 -05:00
Wing Lian
d8d817eaed don't use triton for now 2025-01-14 22:47:42 -05:00
Wing Lian
c0757e8a20 fix kwarg 2025-01-14 22:47:42 -05:00
Wing Lian
e565694914 v3 2025-01-14 22:47:42 -05:00
Wing Lian
081928e55b no torch.tensor 2025-01-14 22:47:42 -05:00
Wing Lian
dc90c93894 no log etc 2025-01-14 22:47:41 -05:00
Wing Lian
18a46c338a no torch.exp inside triton kernel 2025-01-14 22:47:41 -05:00
Wing Lian
119d586cf4 v2 trial 2025-01-14 22:47:41 -05:00
Wing Lian
c73acd7de0 no where support 2025-01-14 22:47:41 -05:00
Wing Lian
0b59a242d4 triton wip 2025-01-14 22:47:41 -05:00
Wing Lian
ed490517da chore: lint 2025-01-14 22:47:41 -05:00
Wing Lian
00ce77e7ef make sure to multiply against the correct loss 2025-01-14 22:47:41 -05:00
Wing Lian
ae545e0165 cross entropy loss coefficient during KD 2025-01-14 22:47:40 -05:00
Wing Lian
b592c05b93 flipped the slice 2025-01-14 22:47:40 -05:00
Wing Lian
7fe0ad088b make it work 2025-01-14 22:47:40 -05:00
Wing Lian
ddcf5c68b3 handle padding/collation for KD datasets 2025-01-14 22:47:40 -05:00
Wing Lian
e633a12dbe make batch smaller 2025-01-14 22:47:40 -05:00
Wing Lian
d584354ee4 filter bad rows 2025-01-14 22:47:40 -05:00
Wing Lian
303cfa71aa KD dataset loading and KD with logprobs 2025-01-14 22:47:40 -05:00
Wing Lian
88b3198894 refactor trainer to prevent circular dependencies later
fix loader default
2025-01-14 22:47:39 -05:00
jwongTensora
8606093921 fix for indexing error from token/embeddings mismatch (#2257)
Co-authored-by: jwong <jwongTensora@gmail.com>
2025-01-14 22:09:29 -05:00
NanoCode012
cba5a457d9 fix: use text_column even when not packing for pretraining (#2254)
* fix: use text_column even when not packing for pretraining

* feat: update test to check when not packing

* chore: lint

* Update src/axolotl/utils/data/pretraining.py

Co-authored-by: Wing Lian <wing.lian@gmail.com>

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-01-14 22:08:56 -05:00
Wing Lian
19cd83d408 rename references to dpo dataset prep to pref data (#2258) 2025-01-14 22:07:55 -05:00
Dan Saunders
1ed4de73b6 CLI cleanup and documentation (#2244)
* CLI init refactor

* fix

* cleanup and (partial) docs

* Adding documentation and continuing cleanup (in progress)

* remove finetune.py script

* continued cleanup and documentation

* pytest fixes

* review comments

* fix

* Fix

* typing fixes

* make sure the batch dataset patcher for multipack is always loaded when handling datasets

* review comments

* fix

---------

Co-authored-by: Dan Saunders <dan@axolotl.ai>
Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-01-13 17:55:29 +00:00
Wing Lian
f89e962119 skip over rows in pretraining dataset (#2223)
* skip over rows in pretraining dataset

* update docs
2025-01-13 10:44:45 -05:00
Wing Lian
bc1c9c20e3 assume empty lora dropout means 0.0 and add tests (#2243)
* assume empty lora dropout means 0.0 and add tests

* remove un-necessary arg

* refactor based on pr feedback:

* chore: lint
2025-01-13 10:44:11 -05:00
Wing Lian
dd26cc3c0f add helper to verify the correct model output file exists (#2245)
* add helper to verify the correct model output file exists

* more checks using helper

* chore: lint

* fix import and relora model check

* workaround for trl trainer saves

* remove stray print
2025-01-13 10:43:29 -05:00
Wing Lian
d8b4027200 use 2.5.1 docker images as latest tag as it seems stable (#2198) 2025-01-10 08:35:25 -05:00
Wing Lian
fb3352e21c rename liger test so it properly runs in ci (#2246) 2025-01-09 17:31:43 -05:00
NanoCode012
ed77e7001e feat: add support for data_files in pretraining (#2238) 2025-01-09 21:04:13 +00:00
Wing Lian
7669a03fb4 update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts:

* bump datasets, tokenizer, trl

* remove log workarounds in trl

* bump lm-eval

* remove unsloth_ import from critical path

* remove llama fa2 from conftest

* unsloth breaks with latest upstream
2025-01-09 21:01:59 +00:00
Vincenzo di Cicco
6553683170 Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235) 2025-01-09 21:01:22 +00:00
Wing Lian
5e0124e2ab update modal version for ci (#2242) 2025-01-09 21:01:02 +00:00
NanoCode012
2e8d7c1adb fix: mistral nemo does not recognize token_type_ids in forward (#2233) 2025-01-09 21:00:36 +00:00
Wing Lian
3c1921e400 add hf cache caching for GHA (#2247)
* add hf cache caching for GHA

* use modal volume to cache hf data

* make sure to update the cache as we add new fixtures in conftest
2025-01-09 20:59:54 +00:00
Wing Lian
7faf2b6e8e Merge group queue (#2248)
* add support for merge groups

* also lint merge groups
2025-01-09 15:49:00 -05:00
salman
c1b920f291 Fixing OSX installation (#2231)
* bumping version, removing non-osx compatible deps

* updating pylintrc

* fixing linters

* reverting changes
2025-01-07 13:42:01 +00:00
123 changed files with 4285 additions and 3155 deletions

View File

@@ -1,6 +1,7 @@
name: lint
on:
# check on PRs, and manual triggers
merge_group:
pull_request:
paths:
- '**.py'

View File

@@ -25,7 +25,6 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -36,6 +35,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -92,7 +92,6 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -103,6 +102,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -129,7 +129,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -1,6 +1,7 @@
name: Tests
on:
# check on push/merge to main, PRs, and manual triggers
merge_group:
push:
branches:
- "main"
@@ -60,6 +61,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -100,6 +110,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -115,6 +134,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -156,6 +184,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
@@ -170,7 +207,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
pytorch: 2.5.1
num_gpus: 1
axolotl_extras:
steps:
@@ -183,7 +220,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -216,7 +253,7 @@ jobs:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.5.1
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
@@ -229,7 +266,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v2.17.4
rev: v3.3.0
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy

View File

@@ -1,5 +1,5 @@
[MASTER]
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -8,6 +8,7 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -28,6 +28,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -48,6 +49,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
@@ -67,6 +74,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")

View File

@@ -29,6 +29,7 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -50,6 +51,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
@@ -69,6 +76,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")

View File

@@ -19,7 +19,14 @@ For pretraining, there is no prompt template or roles. The only required field
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```{.yaml filename="config.yaml"}
pretraining_dataset: # hf path only
pretraining_dataset:
- name:
path:
split:
text_column: # column in dataset with the data, usually `text`
type: pretrain
trust_remote_code:
skip: # number of rows of data to skip over from the beginning
...
```

View File

@@ -29,7 +29,7 @@ datasets:
type: chatml.intel
- path: argilla/ultrafeedback-binarized-preferences
split: train
type: chatml.argilla
type: chatml
```
#### IPO

View File

@@ -1,58 +0,0 @@
base_model: nvidia/Hymba-1.5B-Base
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
trust_remote_code: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 5
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,73 +0,0 @@
base_model: nvidia/Hymba-1.5B-Base
load_in_8bit: false
load_in_4bit: True
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.05
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
adapter: qlora
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
- gate_proj
- down_proj
- up_proj
- q_proj
- v_proj
- k_proj
- o_proj
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 2
micro_batch_size: 2
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
trust_remote_code: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
warmup_steps: 5
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -2,7 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
@@ -14,11 +14,11 @@ packaging==23.2
peft==0.14.0
transformers==4.47.1
tokenizers>=0.20.1
tokenizers>=0.21.0
accelerate==1.2.1
datasets==3.1.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.12.1
trl==0.13.0
optimum==1.16.2
hf_transfer
@@ -53,7 +53,7 @@ zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
lm_eval==0.4.7
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.2
axolotl-contribs-lgpl==0.0.3

View File

@@ -1,52 +0,0 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
from pathlib import Path
import fire
import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
do_inference,
do_merge_lora,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.cli.shard import shard
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
LOG = logging.getLogger("axolotl.scripts.finetune")
def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
LOG.warning(
str(
PendingDeprecationWarning(
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
)
)
)
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -1,4 +1,5 @@
"""setup.py for axolotl"""
import ast
import os
import platform
@@ -29,15 +30,30 @@ def parse_requirements():
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
@@ -73,6 +89,8 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")

View File

@@ -1,568 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
"""Axolotl CLI module initialization."""
import importlib
import json
import logging
import math
import os
import random
import sys
import tempfile
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import requests
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
prepare_plugins,
validate_config,
)
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_legacy_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(ascii_text, font=font)
if is_main_process():
print(ascii_art)
print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload(progressbar=True)
try:
model.to(dtype=cfg.torch_dtype)
except RuntimeError:
pass
model.generation_config.do_sample = True
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token():
# Skip check if HF_HUB_OFFLINE is set to True
if os.getenv("HF_HUB_OFFLINE") == "1":
LOG.info(
"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used."
)
return True
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

49
src/axolotl/cli/args.py Normal file
View File

@@ -0,0 +1,49 @@
"""Module for axolotl CLI command arguments."""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class PreprocessCliArgs:
"""Dataclass with CLI arguments for `axolotl preprocess` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@dataclass
class TrainerCliArgs:
"""Dataclass with CLI arguments for `axolotl train` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
@dataclass
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: Optional[str] = field(default=None)

23
src/axolotl/cli/art.py Normal file
View File

@@ -0,0 +1,23 @@
"""Axolotl ASCII logo utils."""
from axolotl.utils.distributed import is_main_process
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
if is_main_process():
print(AXOLOTL_LOGO)

50
src/axolotl/cli/checks.py Normal file
View File

@@ -0,0 +1,50 @@
"""Various checks for Axolotl CLI."""
import logging
import os
from pathlib import Path
from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__)
def check_accelerate_default_config() -> None:
"""Logs at warning level if no accelerate config file is found."""
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token() -> bool:
"""Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.
Returns:
Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).
Raises:
LocalTokenNotFoundError: If HF user info can't be retrieved.
"""
# Skip check if HF_HUB_OFFLINE is set to True
if os.getenv("HF_HUB_OFFLINE") == "1":
LOG.info(
"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used."
)
return True
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

217
src/axolotl/cli/config.py Normal file
View File

@@ -0,0 +1,217 @@
"""Configuration loading and processing."""
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Union
from urllib.parse import urlparse
import requests
import torch
import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
First, determines if the passed config is a valid HTTPS URL. Then, attempts to query
for it and parse its content, first as JSON, then as YAML (YAML is preferred).
Finally, the parsed content is written to a local file and its path is returned.
Args:
config: HTTPS URL to a YAML or JSON file.
Returns:
Either the original `config` if it's not a valid HTTPS URL, or the path to the
downloaded remote config.
Raises:
ValueError: If the remote configuration is neither valid JSON or YAML.
RuntimeError: If some request-related exception occurs from the file download.
Exception: Catch-all for any other exception.
"""
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly
# considered YAML.
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML.
# This can happen when you forget to point to a raw GitHub link.
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def choose_config(path: Path) -> str:
"""
Helper method for choosing a `axolotl` config YAML file (considering only files
ending with `.yml` or `.yaml`). If more than one config file exists in the passed
`path`, the user is prompted to choose one.
Args:
path: Directory in which config file(s) are stored.
Returns:
Path to either (1) the sole YAML file, or (2) if more than one YAML files exist,
the user-selected YAML file.
Raises:
ValueError: If no YAML files are found in the given `path`.
"""
yaml_files = list(path.glob("*.yml")) + list(path.glob("*.yaml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def prepare_plugins(cfg: DictDefault):
"""
Registers the plugins for the given configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
Args:
config: Path (local or remote) to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Returns:
`DictDefault` mapping configuration keys to values.
"""
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg

View File

@@ -1,6 +1,5 @@
"""
CLI to run training on a model
"""
"""CLI to run evaluation on a model."""
import logging
from pathlib import Path
from typing import Union
@@ -9,35 +8,48 @@ import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.evaluate import evaluate
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.cli.evaluate")
LOG = logging.getLogger(__name__)
def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
evaluation metrics on the given dataset(s) and writes them to disk.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
evaluate(cfg=cfg, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_evaluate`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)

View File

@@ -1,32 +1,267 @@
"""
CLI to run inference on a trained model
"""
"""CLI to run inference on a trained model."""
import importlib
import logging
import sys
from pathlib import Path
from threading import Thread
from typing import Union
import fire
import torch
import transformers
from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli import (
do_inference,
do_inference_gradio,
load_cfg,
print_axolotl_text_art,
from axolotl.cli.args import InferenceCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
def get_multi_line_input() -> str:
"""
Gets multi-line input from terminal.
Returns:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
return instruction
def do_inference(
*,
cfg: DictDefault,
cli_args: InferenceCliArgs,
):
"""
Runs inference on the command line in a loop. User input is accepted, a chat template
is (optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Inference-specific CLI arguments.
"""
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: InferenceCliArgs,
):
"""
Runs inference in a Gradio interface. User input is accepted, a chat template is
(optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Inference-specific CLI arguments.
"""
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def do_cli(
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
) -> None:
"""
Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser(InferenceCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.inference = True
if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -1,18 +1,20 @@
"""CLI definition for various axolotl commands."""
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
import click
import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
build_command,
fetch_from_github,
filter_none_kwargs,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -27,8 +29,16 @@ def cli():
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig)
def preprocess(config: str, **kwargs):
"""Preprocess datasets before training."""
@filter_none_kwargs
def preprocess(config: str, **kwargs) -> None:
"""
Preprocess datasets before training.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.preprocess import do_cli
@@ -45,10 +55,17 @@ def preprocess(config: str, **kwargs):
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@filter_none_kwargs
def train(config: str, accelerate: bool, **kwargs) -> None:
"""
Train or fine-tune a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
@@ -73,10 +90,17 @@ def train(config: str, accelerate: bool, **kwargs):
)
@add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig)
def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@filter_none_kwargs
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
"""
Evaluate a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
@@ -96,81 +120,33 @@ def evaluate(config: str, accelerate: bool, **kwargs):
default=False,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def inference(
config: str,
accelerate: bool,
lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["base_model"] = base_model
@filter_none_kwargs
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
"""
Run inference with a trained model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config:
base_cmd.append(config)
if gradio:
base_cmd.append("--gradio")
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.inference import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.shard import do_cli
do_cli(config=config, **kwargs)
do_cli(config=config, gradio=gradio, **kwargs)
@cli.command()
@@ -180,20 +156,19 @@ def shard(config: str, accelerate: bool, **kwargs):
default=True,
help="Use accelerate launch for weight merging",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
"""Merge sharded FSDP model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
@filter_none_kwargs
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
"""
Merge sharded FSDP model weights.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate:
base_cmd = [
"accelerate",
@@ -213,28 +188,19 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing the LoRA model to merge",
)
@click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save the merged model",
)
def merge_lora(
config: str,
lora_model_dir: Optional[str] = None,
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_lora(config: str, **kwargs) -> None:
"""
Merge trained LoRA adapters into a base model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
from axolotl.cli.merge_lora import do_cli
do_cli(config=config, **kwargs)
@@ -243,13 +209,17 @@ def merge_lora(
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]):
def fetch(directory: str, dest: Optional[str]) -> None:
"""
Fetch example configs or other resources.
Available directories:
- examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files
Args:
directory: One of `examples`, `deepspeed_configs`.
dest: Optional destination directory.
"""
fetch_from_github(f"{directory}/", dest)

View File

@@ -1,6 +1,6 @@
"""
CLI to run merge a trained LoRA into a base model
"""
"""CLI to merge a trained LoRA into a base model."""
import logging
from pathlib import Path
from typing import Union
@@ -8,14 +8,58 @@ import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
def do_merge_lora(*, cfg: DictDefault) -> None:
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
along with the LoRA adapters to combine them into a single base model.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
model.to(dtype=cfg.torch_dtype)
model.generation_config.do_sample = True
if cfg.local_rank == 0:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
config values will be overwritten to allow the LoRA merge logic to work as expected
(`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Raises:
ValueError: If target directory for LoRA merged model does not exist.
"""
# pylint: disable=duplicate-code
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
@@ -46,7 +90,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
do_merge_lora(cfg=parsed_cfg)
if __name__ == "__main__":

View File

@@ -1,6 +1,5 @@
"""
This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
"""
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint."""
import json
import logging
import os
@@ -25,16 +24,15 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.config import load_cfg
LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
LOG = logging.getLogger(__name__)
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""
A custom planner to cast tensors to bfloat16 on the fly during loading.
"""
"""A custom planner to cast tensors to bfloat16 on the fly during loading."""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16))
@@ -45,11 +43,19 @@ def _distributed_checkpoint_to_merged_weights(
save_path: str,
safe_serialization: bool = False,
max_shard_size: str = "5GB",
):
) -> Path:
"""
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Args:
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
Path where model is saved.
"""
state_dict: Dict = {}
@@ -79,6 +85,7 @@ def _distributed_checkpoint_to_merged_weights(
state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
)
# Save index if sharded
index = None
if state_dict_split.is_sharded:
@@ -135,6 +142,9 @@ def merge_fsdp_weights(
Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging.
Raises:
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
"""
checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState
@@ -178,18 +188,21 @@ def merge_fsdp_weights(
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(
config,
**kwargs,
)
parsed_cfg = load_cfg(config, **kwargs)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
merge_fsdp_weights(

View File

@@ -1,6 +1,5 @@
"""
CLI to run training on a model
"""
"""CLI to run preprocessing of a dataset."""
import logging
import warnings
from pathlib import Path
@@ -13,34 +12,31 @@ from colorama import Fore
from dotenv import load_dotenv
from transformers import AutoModelForCausalLM
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.cli.args import PreprocessCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger("axolotl.cli.preprocess")
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
"""
Preprocesses dataset specified in axolotl config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Preprocessing-specific CLI arguments.
"""
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
if not cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
@@ -48,16 +44,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if cfg.rl:
load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
load_datasets(cfg=cfg, cli_args=cli_args)
if parsed_cli_args.download:
model_name = parsed_cfg.base_model
if cli_args.download:
model_name = cfg.base_model
with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
@@ -74,11 +70,33 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`"
+ Fore.RESET
)
def do_cli(
config: Union[Path, str] = Path("examples/"),
**kwargs,
) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_preprocess`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_preprocess(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,45 +0,0 @@
"""
CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.scripts")
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,6 +1,5 @@
"""
CLI to run training on a model
"""
"""CLI to run training on a model."""
import logging
from pathlib import Path
from typing import Union
@@ -9,42 +8,38 @@ import fire
from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.cli.art import print_axolotl_text_art
from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.train import train
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.cli.train")
LOG = logging.getLogger(__name__)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
Trains a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin
manager's `post_train_unload` once training completes.
def do_train(cfg, cli_args) -> None:
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Training-specific CLI arguments.
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance()
del model
@@ -53,6 +48,24 @@ def do_train(cfg, cli_args) -> None:
plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_train`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_train(parsed_cfg, parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,32 +1,84 @@
"""Utility methods for axoltl CLI."""
"""Utility methods for axolotl CLI."""
import concurrent.futures
import dataclasses
import hashlib
import json
import logging
import typing
from functools import wraps
from pathlib import Path
from types import NoneType
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
import requests
from pydantic import BaseModel
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
LOG = logging.getLogger("axolotl.cli.utils")
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)
def add_options_from_dataclass(config_class: Type[Any]):
"""Create Click options from the fields of a dataclass."""
def strip_optional_type(field_type: type | typing._SpecialForm | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.
def decorator(function):
Args:
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
field_type = strip_optional_type(field.type)
if field_type == bool:
field_name = field.name.replace("_", "-")
@@ -44,18 +96,29 @@ def add_options_from_dataclass(config_class: Type[Any]):
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
def add_options_from_config(config_class: Type[BaseModel]):
"""Create Click options from the fields of a Pydantic model."""
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
def decorator(function):
Args:
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = strip_optional_type(field.annotation)
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -66,13 +129,23 @@ def add_options_from_config(config_class: Type[BaseModel]):
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
"""Build command list from base command and options."""
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
"""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy()
for key, value in options.items():
@@ -92,18 +165,18 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> Tuple[str, str]:
) -> tuple[str, str]:
"""
Download a single file and return its processing status.
Args:
file_info: Tuple of (file_path, remote_sha)
raw_base_url: Base URL for raw GitHub content
dest_path: Local destination directory
dir_prefix: Directory prefix to filter files
file_info: Tuple of (file_path, remote_sha).
raw_base_url: Base URL for raw GitHub content.
dest_path: Local destination directory.
dir_prefix: Directory prefix to filter files.
Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'.
"""
file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}"
@@ -145,16 +218,17 @@ def download_file(
def fetch_from_github(
dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5
) -> None:
"""
Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed.
Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
dest_dir: Local destination directory
max_workers: Maximum number of concurrent downloads
dir_prefix: Directory prefix to filter files (e.g., 'examples/',
'deepspeed_configs/').
dest_dir: Local destination directory.
max_workers: Maximum number of concurrent downloads.
"""
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
@@ -179,7 +253,7 @@ def fetch_from_github(
dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary
files_processed: Dict[str, List[str]] = {
files_processed: dict[str, list[str]] = {
"new": [],
"updated": [],
"unchanged": [],
@@ -216,3 +290,28 @@ def fetch_from_github(
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}")
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
`transformers` model and tokenizer.
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

View File

@@ -1,69 +0,0 @@
"""
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
inference = getattr(cli_args, "inference", False)
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

View File

@@ -0,0 +1,146 @@
"""Dataset loading utilities."""
import logging
import math
import random
from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
@dataclass
class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
def load_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = (
hasattr(cli_args, "iterable")
and cli_args.iterable is not None
and cli_args.iterable
)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def load_preference_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,933 @@
"""
module for customized trainers
"""
from __future__ import annotations
# pylint: disable=too-many-lines
import gc
import logging
import os
from collections import defaultdict
from functools import wraps
from typing import Any, Dict, Literal, Optional, Union
import torch
from datasets import Dataset
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import Trainer
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, seed_worker
from transformers.utils import is_sagemaker_mp_enabled
from trl import CPOTrainer, DPOTrainer, KTOTrainer, ORPOTrainer, RewardTrainer
from trl.trainer.utils import pad_to_length
from axolotl.monkeypatch.relora import ReLoRAScheduler
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
get_cosine_schedule_with_warmup_decay_constant,
)
if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
LOG = logging.getLogger("axolotl.core.trainer_builder")
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
def _sanitize_kwargs_for_ds_tagging(dataset_tags, kwargs=None):
if isinstance(dataset_tags, str):
dataset_tags = [dataset_tags]
if (dataset_tags is not None) and (kwargs is not None):
if "dataset_tags" not in kwargs:
kwargs["dataset_tags"] = dataset_tags
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], list):
kwargs["dataset_tags"].extend(dataset_tags)
elif "dataset_tags" in kwargs and isinstance(kwargs["dataset_tags"], str):
dataset_tags.append(kwargs["dataset_tags"])
kwargs["dataset_tags"] = dataset_tags
return kwargs
class SchedulerMixin(Trainer):
"""
Mixin class for scheduler setup in CausalTrainer.
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if self.args.alternate_lr_scheduler_type == "one_cycle":
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
extra_lr_kwargs = {}
if "pct_start" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["pct_start"] = pct_start
if "anneal_strategy" not in self.args.lr_scheduler_kwargs:
extra_lr_kwargs["anneal_strategy"] = "cos"
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
**extra_lr_kwargs,
**self.args.lr_scheduler_kwargs,
)
elif use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and self.args.cosine_constant_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
assert 0 <= self.args.cosine_constant_lr_ratio <= 1.0, "cosine_constant_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_warmup_decay_constant( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
constant_lr_ratio=self.args.cosine_constant_lr_ratio,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer=optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
class AxolotlTrainer(SchedulerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
tag_names = ["axolotl"]
def __init__(
self,
*_args,
bench_data_collator=None,
eval_data_collator=None,
dataset_tags=None,
**kwargs,
):
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
self.dataset_tags = dataset_tags
self._signature_columns = None # workaround for pylint
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
self._stored_metrics = defaultdict(lambda: defaultdict(list))
if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.torch_compile:
torch._dynamo.config.accumulated_cache_size_limit = ( # pylint: disable=protected-access
256
)
model = torch.compile(
model,
backend=self.args.torch_compile_backend,
mode=self.args.torch_compile_mode,
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"adopt_adamw",
]
):
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", 1e-6
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW(
optimizer_grouped_parameters, foreach=False, **optimizer_kwargs
)
)
elif self.args.alternate_optimizer == "ao_adamw_4bit":
from torchao.prototype.low_bit_optim import AdamW4bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW4bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_8bit":
from torchao.prototype.low_bit_optim import AdamW8bit
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamW8bit(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "ao_adamw_fp8":
from torchao.prototype.low_bit_optim import AdamWFp8
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
AdamWFp8(optimizer_grouped_parameters, **optimizer_kwargs)
)
elif self.args.alternate_optimizer == "adopt_adamw":
from axolotl.utils.optimizers.adopt import ADOPT
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
ADOPT(
optimizer_grouped_parameters,
decouple=True,
**optimizer_kwargs,
)
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
train_batch_size = (
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
if self.args.curriculum_sampling:
return SequentialSampler(self.train_dataset)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
lengths=get_dataset_lengths(self.eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
drop_last=True,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
if eval_dataset:
eval_dataset = eval_dataset.remove_columns(["length"])
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(
self, model, inputs, return_outputs=False, num_items_in_batch=None
):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
if self.args.orpo_alpha:
return self.orpo_compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
return super().compute_loss(
model,
inputs,
return_outputs=return_outputs,
num_items_in_batch=num_items_in_batch,
)
@staticmethod
def orpo_concatenate_inputs(inputs, label_pad_token=-100, pad_token=0, device=None):
concatenated_batch = {}
max_length = max(
inputs["input_ids"].shape[1], inputs["rejected_input_ids"].shape[1]
)
# Concatenate positive and negative inputs
concatenated_batch["input_ids"] = pad_to_length(
inputs["input_ids"], max_length, pad_token
)
concatenated_batch["rejected_input_ids"] = pad_to_length(
inputs["rejected_input_ids"], max_length, pad_token
)
concatenated_batch["labels"] = pad_to_length(
inputs["labels"], max_length, label_pad_token
)
concatenated_batch["rejected_labels"] = pad_to_length(
inputs["rejected_labels"], max_length, label_pad_token
)
concatenated_batch["attention_mask"] = pad_to_length(
inputs["attention_mask"], max_length, 0
)
concatenated_batch["rejected_attention_mask"] = pad_to_length(
inputs["rejected_attention_mask"], max_length, 0
)
concatenated_batch["prompt_attention_mask"] = pad_to_length(
inputs["prompt_attention_mask"], max_length, 0
).to(device=device)
input_ids = torch.cat(
[concatenated_batch["input_ids"], concatenated_batch["rejected_input_ids"]],
dim=0,
).to(device=device)
attention_mask = torch.cat(
[
concatenated_batch["attention_mask"],
concatenated_batch["rejected_attention_mask"],
],
dim=0,
).to(device=device)
labels = torch.cat(
[concatenated_batch["labels"], concatenated_batch["rejected_labels"]], dim=0
).to(device=device)
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"prompt_attention_mask": concatenated_batch["prompt_attention_mask"],
}
def orpo_compute_custom_loss(self, logits, labels):
logits = logits.contiguous()
loss = 0.0
if labels is not None:
# move labels to correct device to enable model parallelism
labels = labels.to(logits.device)
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss = self.loss_fct(shift_logits.transpose(2, 1), shift_labels).mean(
dim=-1
)
return loss
def orpo_compute_logps(
self, prompt_attention_mask, chosen_inputs, chosen_attention_mask, logits
):
# Get the shape of chosen_attention_mask[:, :-1]
chosen_shape = chosen_attention_mask[:, :-1].shape
# Calculate the padding size
pad_length = chosen_shape[1] - (prompt_attention_mask.shape[1] - 1)
# Pad prompt_attention_mask with zeros to match the desired shape
prompt_attention_mask_padded = torch.nn.functional.pad(
prompt_attention_mask[:, 1:], (0, pad_length), mode="constant", value=0
)
# Perform the subtraction operation
mask = chosen_attention_mask[:, :-1] > prompt_attention_mask_padded
per_token_logps = torch.gather(
logits[:, :-1, :].log_softmax(-1),
dim=2,
index=(mask * chosen_inputs[:, 1:]).unsqueeze(2),
).squeeze(2)
return torch.mul(per_token_logps, mask).sum(dim=1) / mask.sum(dim=1)
def orpo_compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None, # pylint: disable=unused-argument
):
concat_inputs = AxolotlTrainer.orpo_concatenate_inputs(
inputs,
label_pad_token=-100,
pad_token=self.tokenizer.pad_token_id,
device=self.accelerator.device,
)
# Perform a single forward pass
outputs = model(
**{
"input_ids": concat_inputs["input_ids"],
"attention_mask": concat_inputs["attention_mask"],
"labels": concat_inputs["labels"],
},
output_hidden_states=True,
)
# Split the outputs for positive and negative examples
outputs_pos, outputs_neg = outputs.logits.chunk(2)
# Calculate NLL loss
pos_loss = self.orpo_compute_custom_loss(
logits=outputs_pos, labels=concat_inputs["input_ids"].chunk(2)[0]
)
# Calculate Log Probability
pos_prob = self.orpo_compute_logps(
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
chosen_inputs=concat_inputs["input_ids"].chunk(2)[0],
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[0],
logits=outputs_pos,
)
neg_prob = self.orpo_compute_logps(
prompt_attention_mask=concat_inputs["prompt_attention_mask"],
chosen_inputs=concat_inputs["input_ids"].chunk(2)[1],
chosen_attention_mask=concat_inputs["attention_mask"].chunk(2)[1],
logits=outputs_neg,
)
# Calculate log odds
log_odds = (pos_prob - neg_prob) - (
torch.log(1 - torch.exp(pos_prob)) - torch.log(1 - torch.exp(neg_prob))
)
sig_ratio = torch.nn.functional.sigmoid(log_odds)
ratio = torch.log(sig_ratio)
# Calculate the Final Loss
loss = torch.mean(pos_loss - self.args.orpo_alpha * ratio).to(
dtype=torch.bfloat16
)
metrics = {}
metrics["chosen_geometric_mean"] = torch.mean(pos_prob).cpu().item()
metrics["rejected_geometric_mean"] = torch.mean(neg_prob).cpu().item()
metrics["log_odds_ratio"] = torch.mean(ratio).cpu().item()
metrics["log_odds"] = torch.mean(log_odds).cpu().item()
self.store_metrics(metrics, train_eval="train")
return (loss, outputs_pos) if return_outputs else loss
@wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@wraps(Trainer.create_accelerator_and_postprocess)
def create_accelerator_and_postprocess(self):
res = super().create_accelerator_and_postprocess()
if self.is_fsdp_enabled:
if (
"limit_all_gathers" in self.args.fsdp_config
and self.args.fsdp_config["limit_all_gathers"]
):
self.accelerator.state.fsdp_plugin.limit_all_gathers = True
return res
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
"""
Log `logs` on the various objects watching training, including stored metrics.
Args:
logs (`Dict[str, float]`):
The values to log.
start_time (`Optional[float]`):
The start of training.
"""
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
) -> None:
for key, value in metrics.items():
self._stored_metrics[train_eval][key].append(value)
def _save_checkpoint(self, model, trial, **kwargs):
# make sure the checkpoint dir exists, since trainer is flakey
checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.state.global_step}"
run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, **kwargs)
class AxolotlMambaTrainer(AxolotlTrainer):
"""
Mamba specific trainer to handle loss calculation
"""
tag_names = ["axolotl", "mamba"]
def compute_loss(
self,
model,
inputs,
return_outputs=False, # pylint: disable=unused-argument
num_items_in_batch=None, # pylint: disable=unused-argument
):
input_ids = inputs.pop("input_ids")
lm_logits = model(input_ids).logits
labels = input_ids.to(lm_logits.device)
shift_logits = lm_logits[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = torch.nn.CrossEntropyLoss()
lm_loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
)
return lm_loss
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
tag_names = ["axolotl", "relora"]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
def __init__(self, *args, dataset_tags=None, **kwargs):
super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags
self.optimizer = None
def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
if loraplus_lr_ratio:
print("Using lora+")
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
loraplus_lr_ratio=loraplus_lr_ratio,
loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs,
)
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
self.optimizer
)
return self.optimizer
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_ds_tagging(
dataset_tags=self.dataset_tags, kwargs=kwargs
)
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
@staticmethod
def tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
) -> Dict:
res = DPOTrainer.tokenize_row(
features,
processing_class,
max_prompt_length,
max_completion_length,
add_special_tokens,
)
# fix when the tokenizer doesn't have a bos_token_id, e.g. Qwen
if processing_class.bos_token is None and res["prompt_input_ids"][0] is None:
for key in res.keys():
res[key] = res[key][1:]
if processing_class.bos_token and processing_class.bos_token_id is not None:
# dpo trainer may incorrectly prepend the bos_token_id to the dpo outputs
if res["chosen_input_ids"][0] == processing_class.bos_token_id:
res["chosen_input_ids"] = res["chosen_input_ids"][1:]
res["chosen_labels"] = res["chosen_labels"][1:]
res["chosen_attention_mask"] = res["chosen_attention_mask"][1:]
if res["rejected_input_ids"][0] == processing_class.bos_token_id:
res["rejected_input_ids"] = res["rejected_input_ids"][1:]
res["rejected_labels"] = res["rejected_labels"][1:]
res["rejected_attention_mask"] = res["rejected_attention_mask"][1:]
return res
def training_step(
self,
model: nn.Module,
inputs: Dict[str, Union[torch.Tensor, Any]],
num_items_in_batch=None,
) -> torch.Tensor:
loss: torch.Tensor = super().training_step(model, inputs, num_items_in_batch)
gc.collect()
torch.cuda.empty_cache()
return loss
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "orpo"]
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "cpo"]
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
tag_names = ["axolotl", "reward"]

View File

@@ -0,0 +1,239 @@
"""
extra axolotl specific training args
"""
from dataclasses import dataclass, field
from typing import Optional
from transformers import TrainingArguments
from trl import CPOConfig, DPOConfig, KTOConfig, ORPOConfig, RewardConfig
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_optimizer: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate optimizer to the HF trainer"
},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
},
)
@dataclass
class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin.
"""
@dataclass
class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig):
"""
DPO config for DPO training
"""
@dataclass
class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig):
"""
ORPO config for ORPO training
"""
@dataclass
class AxolotlKTOConfig(AxolotlTrainingMixins, KTOConfig):
"""
KTO config for KTO training
"""
@dataclass
class AxolotlCPOConfig(AxolotlTrainingMixins, CPOConfig):
"""
CPO config for CPO training
"""
simpo_gamma: Optional[float] = field(
default=None,
metadata={"help": "simpo gamma parameter"},
)
@dataclass
class AxolotlRewardConfig(AxolotlTrainingMixins, RewardConfig):
"""
Reward config for Reward training
"""

View File

@@ -2,7 +2,7 @@
import logging
import os
from typing import List, Optional
from typing import List, Optional, Union
import torch
from datasets import Dataset, IterableDataset
@@ -51,7 +51,13 @@ class TokenizedPromptDataset(Dataset):
map_kwargs = {}
if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100
map_kwargs["batch_size"] = 1_000
if self.prompt_tokenizer.filter_rows:
dataset = dataset.filter(
self.prompt_tokenizer.filter_rows,
num_proc=num_proc,
desc="Strategy Filtering Rows",
)
return dataset.map(
self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc,
@@ -62,6 +68,24 @@ class TokenizedPromptDataset(Dataset):
)
def wrap_dataset_for_tokenized_prompt(
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Union[Dataset, IterableDataset],
**kwargs,
):
if isinstance(dataset, IterableDataset):
map_kwargs = {}
if prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True
features = dataset.features.keys()
return dataset.map(
prompt_tokenizer.tokenize_prompt,
remove_columns=features,
**map_kwargs,
)
return TokenizedPromptDataset(prompt_tokenizer, dataset, **kwargs)
# TODO this isn't the best since it can't interleave datasets
class ConstantLengthDataset(IterableDataset):
"""

View File

@@ -9,7 +9,6 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
@@ -62,16 +61,13 @@ def evaluate_dataset(
return metrics
def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
Args:
cfg: Configuration dictionary
cli_args: Command line arguments
dataset_meta: Dataset metadata containing training and evaluation datasets
cfg: Dictionary mapping `axolotl` config keys to values.
dataset_meta: Dataset metadata containing training and evaluation datasets.
Returns:
Tuple containing:
@@ -102,9 +98,7 @@ def evaluate(
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model, _ = load_model(cfg, tokenizer, processor=processor)
# Set up trainer
trainer = setup_trainer(

View File

@@ -111,6 +111,17 @@ class BasePlugin:
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Parameters:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
@@ -212,7 +223,17 @@ def load_plugin(plugin_name: str) -> BasePlugin:
module_name, class_name = plugin_name.rsplit(".", 1)
# import the module
module = importlib.import_module(module_name)
try:
module = importlib.import_module(module_name)
except ModuleNotFoundError as orig_exc:
try:
if not module_name.startswith("axolotl.integrations."):
module = importlib.import_module("axolotl.integrations." + module_name)
else:
raise orig_exc
except ModuleNotFoundError as exc:
raise orig_exc from exc
# instantiate the class
plugin_class = getattr(module, class_name)
# create an instance of the class
@@ -272,8 +293,10 @@ class PluginManager:
ImportError: If the plugin module cannot be imported.
"""
try:
logging.info(f"Attempting to load plugin: {plugin_name}")
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
logging.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -346,6 +369,22 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def get_trainer_cls(self, cfg):
"""
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The trainer class, or None if none was found.
"""
for plugin in self.plugins.values():
trainer_cls = plugin.get_trainer_cls(cfg)
if trainer_cls is not None:
return trainer_cls
return None
def create_optimizer(self, cfg, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

View File

@@ -0,0 +1,36 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Plugin init to add KD support to Axolotl.
"""
from axolotl.integrations.base import BasePlugin
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
class KDPlugin(BasePlugin):
"""
Plugin for KD support in Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None

View File

@@ -0,0 +1,33 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Plugin args for KD support.
"""
from typing import Optional
from pydantic import BaseModel
class KDArgs(BaseModel):
"""
Input args for knowledge distillation.
"""
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD

View File

@@ -0,0 +1,164 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Chat template prompt strategy loader with KD support
"""
from typing import Any, Dict
import torch
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
Handle fields for logprob KD
"""
def __init__(
self,
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=None,
train_on_eos=None,
logprobs_field="logprobs",
gen_temperature=1.0,
kd_temperature=1.0,
):
self.logprobs_field = logprobs_field
self.gen_temperature = gen_temperature
self.kd_temperature = kd_temperature
super().__init__(
prompter,
tokenizer,
train_on_inputs,
sequence_len,
roles_to_train=roles_to_train,
train_on_eos=train_on_eos,
)
def transform_logprobs(self, sample):
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
top_k = len(logprobs[0])
target_logprobs = []
target_token_ids = []
target_mask = []
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
for _ in range(1, input_padding_len): # start at 1 since this is causal
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for _ in range(target_seq_len):
# TODO also check against sample["labels"]
target_mask.append([1] * top_k)
for _, token_pos_logprobs in enumerate(logprobs):
# Initialize collections for logprobs and token_ids
position_logprobs = []
position_token_ids = []
# Process each token probability entry
for entry in token_pos_logprobs:
# Extract logprob value
logprob = entry["logprob"]
# Parse token_id from the "token_id:###" format
token_id = int(entry["token"].split(":")[1])
# Append to our collections
position_logprobs.append(logprob)
position_token_ids.append(token_id)
# Convert to a tensor for easier manipulation
# Convert to tensor
position_logprobs_tensor = torch.tensor(
position_logprobs, dtype=torch.float
)
if self.kd_temperature != self.gen_temperature:
#
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
class KDStrategyLoader(StrategyLoader):
"""
Load ChatTemplateStrategy with KD support using StrategyLoader.
"""
def _get_strategy_cls(self):
return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
strategy_params = super()._get_strategy_params(cfg, ds_cfg)
if logprobs_field := ds_cfg.get("logprobs_field"):
strategy_params["logprobs_field"] = logprobs_field
if gen_temperature := ds_cfg.get("temperature"):
strategy_params["gen_temperature"] = gen_temperature
if kd_temperature := cfg.get("kd_temperature"):
strategy_params["kd_temperature"] = kd_temperature
return strategy_params
load = KDStrategyLoader()

View File

@@ -0,0 +1,255 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
DataCollator for axolotl to handle KD fields without using -inf for padding,
and with a teacher_mask to identify padded positions.
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
import torch
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
from axolotl.utils.collators.batching import DataCollatorForSeq2Seq
@dataclass
class DataCollatorForKD(DataCollatorForSeq2Seq):
"""
Data collator for KD, including handling KD-specific fields.
This version avoids using -inf and instead uses a large negative value for padding
target_logprobs. It also creates a teacher_mask to indicate which entries are valid.
"""
# pylint: disable=duplicate-code
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
padding_side = self.tokenizer.padding_side
# Pad labels and position_ids first
for feature_name, pad_token_id in [
("labels", self.label_pad_token_id),
("position_ids", self.position_pad_token_id),
]:
if feature_name in features[0]:
feat = [f[feature_name] for f in features]
max_len = max(len(x) for x in feat)
if self.pad_to_multiple_of is not None:
max_len = (
(max_len + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
) * self.pad_to_multiple_of
for f in features: # pylint: disable=invalid-name
remainder = [pad_token_id] * (max_len - len(f[feature_name]))
if isinstance(f[feature_name], list):
f[feature_name] = (
f[feature_name] + remainder
if padding_side == "right"
else remainder + f[feature_name]
)
else:
# If they are numpy arrays
if padding_side == "right":
f[feature_name] = np.concatenate(
[f[feature_name], remainder]
).astype(np.int64)
else:
f[feature_name] = np.concatenate(
[remainder, f[feature_name]]
).astype(np.int64)
# Handle target_logprobs and target_token_ids manually
target_logprobs_list = []
target_token_ids_list = []
target_mask_list = []
has_teacher_data = ("target_logprobs" in features[0]) and (
"target_token_ids" in features[0]
)
if has_teacher_data:
# Extract and remove from features
for f in features: # pylint: disable=invalid-name
target_logprobs_list.append(f.pop("target_logprobs"))
target_token_ids_list.append(f.pop("target_token_ids"))
target_mask_list.append(f.pop("target_mask"))
# Determine max lengths
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
padded_target_logprobs = []
padded_target_token_ids = []
padded_teacher_mask_list = []
for t_logprobs, t_ids, t_mask in zip(
target_logprobs_list, target_token_ids_list, target_mask_list
):
t_logprobs_padded = []
t_ids_padded = []
t_mask_padded = []
for lp, ids, mask in zip( # pylint: disable=invalid-name
t_logprobs, t_ids, t_mask
):
lp_len = len(lp)
if lp_len < max_k:
# Use -1e9 for padding logprobs and 0 for token_ids
pad_len = max_k - lp_len
lp = lp + [-1e9] * pad_len # pylint: disable=invalid-name
ids = ids + [0] * pad_len
mask = mask + [0] * pad_len
else:
lp = lp[:max_k] # pylint: disable=invalid-name
ids = ids[:max_k]
mask = mask[:max_k]
t_logprobs_padded.append(lp)
t_ids_padded.append(ids)
t_mask_padded.append(mask)
seq_len_diff = max_teacher_seq_len - len(t_logprobs_padded)
if seq_len_diff > 0:
# Pad sequences fully if needed
t_logprobs_padded.extend(
[[-1e9] * max_k for _ in range(seq_len_diff)]
)
t_ids_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
t_mask_padded.extend([[0] * max_k for _ in range(seq_len_diff)])
padded_target_logprobs.append(t_logprobs_padded)
padded_target_token_ids.append(t_ids_padded)
padded_teacher_mask_list.append(t_mask_padded)
# Convert to tensors
padded_target_logprobs = torch.tensor(
padded_target_logprobs, dtype=torch.float
)
padded_target_token_ids = torch.tensor(
padded_target_token_ids, dtype=torch.long
)
padded_teacher_mask_list = torch.tensor(
padded_teacher_mask_list, dtype=torch.int
)
# Pad using tokenizer for regular fields
features = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# Add back teacher data if present
if has_teacher_data:
features["target_logprobs"] = padded_target_logprobs
features["target_token_ids"] = padded_target_token_ids
features["target_mask"] = padded_teacher_mask_list
# Prepare decoder_input_ids if the model supports it
if (
"labels" in features
and self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
labels=features["labels"]
)
features["decoder_input_ids"] = decoder_input_ids
return features
class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
"""
Collator for multipack (batch of sub-batches) specifically for KD.
Adapts DataCollatorForKD so it can pack multiple sequences in a single batch item.
"""
def __call__(self, features, return_tensors=None):
"""
Expects that `features` could be either:
- a single list of dicts, OR
- a list of lists of dicts (the "sub-batches" to be packed).
"""
# 1) If we are *not* dealing with multiple sequences per batch element,
# just pass straight to parent.
if not isinstance(features[0], list):
return super().__call__(features, return_tensors=return_tensors)
# 2) Otherwise, we *are* dealing with multiple sequences in each batch item.
# We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features]
for i, sub_features in enumerate(features):
# sub_features is a list of dicts, each dict = one sequences features
# We'll merge them into out_features[i].
#
# NOTE: You can customize how you combine fields as needed (e.g. summation
# or offset for attention_mask). Below is a straightforward concatenation/extension.
for field_name in sub_features[0].keys():
# Some fields you might want to skip or treat specially:
if field_name == "length":
continue
# If its a KD field thats a list-of-lists (e.g. target_logprobs),
# you typically just want to flatten them by extending.
if field_name in ["target_logprobs", "target_token_ids", "target_mask"]:
combined = []
for feat in sub_features:
combined.extend(feat[field_name])
out_features[i][field_name] = combined
elif field_name == "attention_mask":
# Here we apply the (j+1) factor to differentiate each sub-sample
# within this merged batch item.
arrays = []
for j, feat in enumerate(sub_features):
if field_name in feat:
arrays.append((j + 1) * np.array(feat[field_name]))
out_features[i][field_name] = np.concatenate(arrays)
else:
# By default, just concatenate them if they are arrays
# or extend them if they are lists.
# For example, input_ids or labels are often arrays.
arrays = []
for feat in sub_features:
if field_name in feat:
arr = np.array(feat[field_name])
arrays.append(arr)
out_features[i][field_name] = np.concatenate(arrays)
# 3) Now call the parent collator, which will do:
# - padding of labels/position_ids
# - KD-specific padding for target_logprobs, target_token_ids, etc.
# - final conversion to return_tensors
return super().__call__(out_features, return_tensors=return_tensors)

View File

@@ -0,0 +1,58 @@
### AXOLOTL COMMUNITY LICENSE AGREEMENT
This Axolotl Community License Agreement (“Agreement”) is entered into by and between Axolotl AI Corp. (“Axolotl”) and
any individual or entity (“Licensee”) who wishes to use the Software (as defined below) in accordance with the terms
and conditions set forth in this Agreement.
1. Definitions
1.1 “Licensee” refers to any individual or entity who has obtained a copy of the Software under this Agreement.
1.2 “Plugin Integration” means independent integration software modules which may or may not be offered by Axolotl,
which may be licensed separately by their respective authors and/or licensors.
1.3 “Software” refers to the specific sub-directory of the Axolotl, Inc. software located at
https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations and its subdirectories which
permits Plugin Integrations to integrate with the Axolotl service.
2. Grant of License
2.1 Axolotl hereby grants Licensee a worldwide, non-exclusive, royalty-free, license to use, copy, modify, merge,
publish, distribute, sublicense, and/or otherwise exploit the Software, subject to the following conditions:
- Licensee must comply with all the terms and conditions of this Agreement.
- Licensee must include the original copyright notice and disclaimer of warranty in all copies or substantial
portions of the Software.
2.2 Licensee may use the Software for any lawful purpose, except as restricted in Section 3.
3. Restrictions
3.1 Licensee shall not use the Software for any activity that constitutes a commercial activity of offering for
free or for sale any services, platform, or equivalent to third parties for the purposes of allowing such
third parties to fine-tune artificial intelligence models.
3.2 Licensee shall not:
- Use the Software for any illegal or unauthorized purpose.
- Reverse engineer, decompile, or disassemble the Software.
- Remove or modify any copyright, trademark, or other proprietary notices contained in the Software.
- Use the Software in a way that could damage, disable, overburden, or impair the functionality of the
Software or interfere with any third-party use of the Software.
3.3 Axolotl reserves the right to restrict certain Plugin Integrations for use with the Software. To the extent Licensee integrates a permitted, applicable Plugin Integration with the Software, Licensee shall comply with any additional terms and conditions imposed by the licensors of such Plugin Integration for use of such Plugin Integrations. Licensee shall contact Axolotl if it has questions about whether its use of the Software falls beyond the scope of this Agreement.
4. Intellectual Property Rights
4.1 Axolotl and its contributors retain all intellectual property rights in and to the Software. Licensee
acknowledges that this Agreement does not transfer any ownership rights or intellectual property rights to
Licensee.
5. Disclaimer of Warranty
5.1 THE SOFTWARE IS PROVIDED “AS IS,” WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND NON-INFRINGEMENT. IN NO EVENT SHALL
THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN ACTION OF
CONTRACT, TORT, OR OTHERWISE, ARISING FROM, OUT OF, OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
6. Termination
6.1 Axolotl may terminate this Agreement at any time if Licensee fails to comply with any of the terms and
conditions set forth herein. Upon termination, Licensee shall cease all use of the Software and destroy any
copies in its possession.
7. Governing Law
7.1 This Agreement shall be governed by and construed in accordance with the laws of the State of California,
without regards to conflicts of laws provisions thereof.
8. Entire Agreement
8.1 This Agreement constitutes the entire agreement between Axolotl and Licensee with respect to the subject matter
hereof and supersedes all prior or contemporaneous understandings or agreements between the parties concerning
the Software, whether written or oral. Axolotl may update the terms of this Agreement from time to time, and
Licensees continued use of the Software after any such updates shall constitute acceptance of updated terms
on a go-forward basis. Axolotl will use commercially reasonable efforts to provide Licensee notice of any
material updates. By using the Software, Licensee acknowledges that it has read, understood, and agrees to be
bound by the terms and conditions of this Agreement.
This Agreement was last updated on August 23, 2024.

View File

@@ -0,0 +1,82 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Axolotl Community License Agreement (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
loss for top_k KL divergence
"""
import torch
@torch.jit.script
def loss(
student_logits: torch.Tensor,
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
"""
# Determine the teacher sequence length
# target_token_ids shape: [B, teacher_seq_len, K]
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1]
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, teacher_seq_len, K]
# Convert teacher_mask to boolean for indexing
# In TorchScript, .bool() is sometimes unsupported, so we do:
valid_mask = target_mask.to(torch.bool)
# Prune tensors to only keep valid tokens
student_logprobs_topk = student_logprobs_topk[valid_mask]
target_logprobs = target_logprobs[valid_mask]
# Convert teacher logprobs to probabilities
teacher_probs = target_logprobs.exp()
# Compute forward KL
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
# Fall back to average over valid tokens
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -0,0 +1,107 @@
# Copyright 2024 Axolotl AI. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
KD trainer
"""
from axolotl.core.trainers.base import AxolotlTrainer
from .topk_logprob.forward_kl import loss as topk_kd_loss
class AxolotlKDTrainer(AxolotlTrainer):
"""
Custom trainer subclass for Knowledge Distillation (KD)
"""
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
if self._signature_columns:
if "target_logprobs" not in self._signature_columns:
columns_to_add.append("target_logprobs")
if "target_token_ids" not in self._signature_columns:
columns_to_add.append("target_token_ids")
if "target_mask" not in self._signature_columns:
columns_to_add.append("target_mask")
if columns_to_add:
self._signature_columns += columns_to_add
def compute_loss(
self,
model,
inputs,
return_outputs=False,
num_items_in_batch=None,
shift_targets=False,
):
"""
How the loss is computed by Trainer. By default, all models return the loss in the first element.
Subclass and override for custom behavior.
"""
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
if num_items_in_batch is not None:
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, :seq_len, :].contiguous()
if shift_targets:
shift_logits = student_logits[..., :-1, :].contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
else:
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs.contiguous()
target_token_ids_for_loss = target_token_ids.contiguous()
target_mask_for_loss = target_mask.contiguous()
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss

View File

@@ -22,13 +22,6 @@ import inspect
import logging
import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
@@ -46,6 +39,13 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -25,7 +25,6 @@ SUPPORTED_MULTIPACK_MODEL_TYPES = [
"gemmoe",
"starcoder2",
"deepseek_v2",
"hymba",
]

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -8,7 +8,7 @@ import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

View File

@@ -1,9 +1,7 @@
"""module for patching with unsloth optimizations"""
import inspect
import re
import types
from typing import Tuple
import torch
from accelerate.logging import get_logger
@@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """
@@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,7 +1,8 @@
"""
Shared utils for the monkeypatches
"""
from typing import Optional
import re
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -16,10 +16,21 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
package = "axolotl.prompt_strategies"
if strategy.split(".")[-1].startswith("load_"):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies")
elif len(strategy.split(".")) > 1:
try:
importlib.import_module(
"." + strategy.split(".")[-1],
".".join(strategy.split(".")[:-1]),
)
package = ".".join(strategy.split(".")[:-1])
strategy = strategy.split(".")[-1]
except ModuleNotFoundError:
pass
mod = importlib.import_module(f".{strategy}", package)
func = getattr(mod, load_fn)
load_kwargs = {}
if strategy == "user_defined":

View File

@@ -10,6 +10,8 @@ LOG = logging.getLogger("axolotl")
def load(strategy, cfg, module_base=None, **kwargs):
try:
if len(strategy.split(".")) == 1:
strategy = strategy + ".default"
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(f".{strategy}", module_base)

View File

@@ -21,7 +21,11 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
Bradley-Terry reward model pairwise chat template prompt strategy.
"""
def tokenize_prompt(self, prompt):
@property
def supports_batched(self) -> bool:
return False
def _tokenize_single_prompt(self, prompt):
"""
:param prompt: the actual row of data from the underlying dataset
@@ -39,7 +43,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
)
prompt[self.messages].append({"role": "user", "content": prompt["input"]})
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)
chosen_tokenized = super()._tokenize_single_prompt(prompt)
if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
@@ -62,7 +66,7 @@ class BTChatTemplateStrategy(ChatTemplateStrategy):
prompt[self.messages].append(
{"role": "assistant", "content": prompt["rejected"]}
)
rejected_tokenized = super().tokenize_prompt(prompt)
rejected_tokenized = super()._tokenize_single_prompt(prompt)
if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(

View File

@@ -3,6 +3,7 @@ HF Chat Templates prompt strategy
"""
import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional
from transformers import ProcessorMixin
@@ -193,7 +194,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def __init__(
self,
prompter,
prompter: ChatTemplatePrompter,
tokenizer,
train_on_inputs,
sequence_len,
@@ -220,22 +221,61 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def messages(self, messages):
self._messages = messages
def tokenize_prompt(self, prompt):
@property
def supports_batched(self) -> bool:
# Let calling code know we can handle lists of examples
return True
def is_prompt_batched(self, prompt: dict[str, Any]) -> bool:
try:
return all(isinstance(v, list) for v in prompt.values()) and all(
isinstance(v, list) for v in prompt[self.messages]
)
except KeyError:
return False
def tokenize_prompt(self, prompt: dict[str, Any]):
"""
Public method that can handle either a single prompt or a batch of prompts.
"""
if not self.is_prompt_batched(prompt) or not self.supports_batched:
return self._tokenize_single_prompt(prompt)
res = defaultdict(lambda: [])
feature_names = list(prompt.keys())
# Process each prompt individually
for row in zip(*prompt.values()):
tokenized_prompt = self._tokenize_single_prompt(
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary
if not res:
return {}
return dict(res)
def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
# Old simple legacy behavior that works reliably.
if (
not self.roles_to_train
and not self.train_on_eos
and not self.prompter.message_field_training
and not self.prompter.message_field_training_detail
and not self.prompter.message_field_training # type: ignore
and not self.prompter.message_field_training_detail # type: ignore
):
turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt(
prompt_ids = self.prompter.build_prompt( # type: ignore
turns[:-1],
add_generation_prompt=True,
images=images,
)
tokenized_res = self.prompter.build_prompt(turns, images=images)
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
tokenized_prompt = {}
if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
@@ -256,7 +296,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt
turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns)
input_ids = self.prompter.build_prompt(turns) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1
@@ -286,7 +326,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail(
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
content, train_detail
)
LOG.debug(f"Token offsets: {token_offsets}")
@@ -459,43 +499,62 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return prompt.get(self.images, None)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
class StrategyLoader:
"""
Load chat template strategy based on configuration.
"""
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
def _get_strategy_cls(self):
return ChatTemplateStrategy
strategy_params = {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
return {
"train_on_inputs": cfg.train_on_inputs,
"sequence_len": cfg.sequence_len,
"roles_to_train": ds_cfg.get("roles_to_train", ["assistant"]),
"train_on_eos": ds_cfg.get("train_on_eos", "turn"),
}
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(**prompter_params), tokenizer=tokenizer, **strategy_params
)
def __call__(
self, tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None, processor=None
):
# pylint: disable=duplicate-code
ds_cfg = ds_cfg or {}
chat_template_string = get_chat_template_from_config(
cfg=cfg, ds_cfg=ds_cfg, tokenizer=tokenizer
)
LOG.info(f"Using chat template:\n---\n{chat_template_string!s}\n---")
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
prompter_params = {
"tokenizer": tokenizer,
"chat_template": chat_template_string,
"message_field_role": ds_cfg.get("message_field_role", "role"),
"message_field_content": ds_cfg.get("message_field_content", "content"),
"message_field_training": ds_cfg.get("message_field_training", None),
"message_field_training_detail": ds_cfg.get(
"message_field_training_detail",
None,
),
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1,
"processor": processor,
}
return strategy
strategy_params = self._get_strategy_params(cfg, ds_cfg)
strategy_cls = self._get_strategy_cls()
strategy = strategy_cls(
ChatTemplatePrompter(**prompter_params),
tokenizer=tokenizer,
**strategy_params,
)
if "field_messages" in ds_cfg and hasattr(strategy, "messages"):
strategy.messages = ds_cfg["field_messages"]
return strategy
load = StrategyLoader()

View File

@@ -3,22 +3,41 @@ DPO strategies for chatml
"""
def argilla(
def default(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
if "prompt" in sample.keys():
prompt_key = "prompt"
elif "input" in sample.keys():
prompt_key = "input"
elif "question" in sample.keys():
prompt_key = "question"
else:
prompt_key = "instruction"
if "chosen" in sample.keys():
chosen_key = "chosen"
else:
chosen_key = "chosen_response"
if "rejected" in sample.keys():
rejected_key = "rejected"
else:
rejected_key = "rejected_response"
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen_response']}<|im_end|>"
sample["rejected"] = f"{sample['rejected_response']}<|im_end|>"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample
return transform_fn

View File

@@ -3,22 +3,42 @@ DPO strategies for llama-3 chat template
"""
def argilla(
def default(
cfg,
**kwargs,
): # pylint: disable=possibly-unused-variable,unused-argument
def transform_fn(sample):
# pylint: disable=duplicate-code
if "prompt" in sample.keys():
prompt_key = "prompt"
elif "input" in sample.keys():
prompt_key = "input"
elif "question" in sample.keys():
prompt_key = "question"
else:
prompt_key = "instruction"
if "chosen" in sample.keys():
chosen_key = "chosen"
else:
chosen_key = "chosen_response"
if "rejected" in sample.keys():
rejected_key = "rejected"
else:
rejected_key = "rejected_response"
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|start_header_id|>system<|end_header_id|>\n\n{sample['system']}<|eot_id|>"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample['chosen_response']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected_response']}<|eot_id|>"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample
return transform_fn

View File

@@ -2,7 +2,7 @@
import abc
import logging
from typing import Dict, List, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from transformers import BatchEncoding, PreTrainedTokenizer
@@ -34,6 +34,8 @@ class PromptTokenizingStrategy(abc.ABC):
Abstract class for tokenizing strategies
"""
filter_rows: Optional[Callable] = None
def __init__(
self,
prompter: Prompter,

View File

@@ -5,21 +5,19 @@ import os
import signal
import sys
import weakref
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Union
from typing import Tuple, Union
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
@@ -39,22 +37,11 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = get_logger("axolotl.train")
@dataclass
class TrainDatasetMeta:
"""
dataclass to capture the dataset specific options for training
"""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
LOG = get_logger(__name__)
def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# Load tokenizer
LOG.debug(
@@ -93,9 +80,7 @@ def train(
if cfg.adapter:
msg += " and peft_config..."
LOG.debug(msg)
model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
model, peft_config = load_model(cfg, tokenizer, processor=processor)
if model.generation_config is not None:
model.generation_config.do_sample = True
@@ -107,9 +92,7 @@ def train(
model_ref = None # explicit setting to None
else:
# load the model again for model_ref/baseline
model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
model_ref, _ = load_model(cfg, tokenizer, reference_model=True)
safe_serialization = cfg.save_safetensors is True

View File

@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
getattr, self.layers_attribute.split("."), self.trainer.model
)
LOG.info(
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
)
def freeze_all_layers(self):

View File

@@ -31,7 +31,6 @@ _CHAT_TEMPLATES = {
"qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within <tools></tools> XML tags:\\n<tools>\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n</tools>\\n\\nFor each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:\\n<tool_call>\\n{\\\"name\\\": <function-name>, \\\"arguments\\\": <args-json-object>}\\n</tool_call><|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n<tool_call>\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n</tool_call>' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n<tool_response>\\n' }}\n {{- message.content }}\n {{- '\\n</tool_response>' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
"exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}",
"metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}",
"hymba": "{{'<extra_id_0>System'}}{% for message in messages %}{% if message['role'] == 'system' %}{{'\n' + message['content'].strip()}}{% if tools or contexts %}{{'\n'}}{% endif %}{% endif %}{% endfor %}{% if tools %}{% for tool in tools %}{{ '\n<tool> ' + tool|tojson + ' </tool>' }}{% endfor %}{% endif %}{% if contexts %}{% if tools %}{{'\n'}}{% endif %}{% for context in contexts %}{{ '\n<context> ' + context.strip() + ' </context>' }}{% endfor %}{% endif %}{{'\n\n'}}{% for message in messages %}{% if message['role'] == 'user' %}{{ '<extra_id_1>User\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'assistant' %}{{ '<extra_id_1>Assistant\n' + message['content'].strip() + '\n' }}{% elif message['role'] == 'tool' %}{{ '<extra_id_1>Tool\n' + message['content'].strip() + '\n' }}{% endif %}{% endfor %}{%- if add_generation_prompt %}{{'<extra_id_1>Assistant\n'}}{%- endif %}",
}

View File

@@ -128,6 +128,8 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
skip: Optional[int] = None
class UserDefinedPrompterType(BaseModel):
@@ -153,6 +155,7 @@ class SFTDataset(BaseModel):
type: Optional[Union[str, UserDefinedPrompterType]] = None
input_transform: Optional[str] = None
shards: Optional[int] = None
preprocess_shards: Optional[int] = None
conversation: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[
@@ -175,6 +178,8 @@ class SFTDataset(BaseModel):
message_field_content: Optional[str] = None
message_field_training: Optional[str] = None
message_field_training_detail: Optional[str] = None
logprobs_field: Optional[str] = None
temperature: Optional[float] = None
roles_to_train: Optional[List[str]] = None
train_on_eos: Optional[str] = None
roles: Optional[Dict[str, List[str]]] = None
@@ -366,6 +371,13 @@ class LoraConfig(BaseModel):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
@model_validator(mode="before")
@classmethod
def validate_lora_dropout(cls, data):
if data.get("adapter") is not None and data.get("lora_dropout") is None:
data["lora_dropout"] = 0.0
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -798,6 +810,7 @@ class AxolotlInputConfig(
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None
preprocess_iterable: Optional[bool] = None
total_num_tokens: Optional[int] = None
total_supervised_tokens: Optional[int] = None
@@ -1629,19 +1642,3 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
else:
data["torch_compile"] = False
return data
@model_validator(mode="before")
@classmethod
def check_hymba_torch_version(cls, data):
if "hymba" in data.get("base_model", {}).lower():
env_capabilities = data.get("env_capabilities", {})
torch_version = env_capabilities.get("torch_version")
if torch_version is None:
import torch
torch_version = str(torch.__version__).split("+", maxsplit=1)[0]
if version.parse(torch_version) < version.parse("2.5.0"):
raise ValueError("Hymba requires torch version >= 2.5")
return data

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining,
wrap_pretraining_dataset,
)
from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper,
load_prepare_datasets,

View File

@@ -18,10 +18,13 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
examples[text_column],
truncation=True,
max_length=max_tokens - 2,
add_special_tokens=True,
@@ -196,7 +199,12 @@ def wrap_pretraining_dataset(
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
encode = functools.partial(
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
)
if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg):
def load_prepare_preference_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -3,11 +3,12 @@
import functools
import logging
from pathlib import Path
from typing import List, Tuple, Union
from typing import List, Optional, Tuple, Union
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
concatenate_datasets,
load_dataset,
load_from_disk,
@@ -15,7 +16,7 @@ from datasets import (
from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import TokenizedPromptDataset
from axolotl.datasets import wrap_dataset_for_tokenized_prompt
from axolotl.prompt_strategies import load
from axolotl.prompt_strategies.bradley_terry import load as bradley_terry_load
from axolotl.prompt_tokenizers import (
@@ -57,7 +58,7 @@ LOG = logging.getLogger("axolotl")
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
prompters = []
if not cfg.pretraining_dataset:
with zero_first(is_local_main_process()):
@@ -68,6 +69,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH,
split="train",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
@@ -75,6 +77,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
else:
train_dataset, eval_dataset, prompters = load_prepare_datasets(
@@ -82,20 +85,26 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg,
DEFAULT_DATASET_PREPARED_PATH,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
else:
# Load streaming dataset if pretraining_dataset is given
path = cfg.pretraining_dataset
split = "train"
name = None
data_files = None
skip = 0
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
skip = cfg.pretraining_dataset[0]["skip"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
@@ -104,8 +113,14 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name),
iter_ds,
tokenizer,
cfg,
ds_wrapper_partial,
@@ -126,6 +141,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if cfg.dataset_exact_deduplication:
@@ -157,6 +173,7 @@ def load_tokenized_prepared_datasets(
default_dataset_prepared_path,
split="train",
processor=None,
preprocess_iterable: Optional[bool] = None,
) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config
@@ -171,10 +188,11 @@ def load_tokenized_prepared_datasets(
+ "@"
+ str(cfg.group_by_length)
+ "@"
+ str(cfg.kd_temperature or 1.0)
+ "|".join(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
f"{d.path}:{d.type}:{d.shards}:{d.conversation}:{d.split}:{d.temperature or 1.0}"
for d in cfg_datasets
]
)
@@ -249,13 +267,25 @@ def load_tokenized_prepared_datasets(
# at the same time for a given dataset
for name in dataset.name:
yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else:
yield dataset
streaming_ds = False
if preprocess_iterable:
streaming_ds = True
# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token
config_dataset, use_auth_token, streaming=streaming_ds
)
d_base_type = d_prompt_style = None
@@ -312,7 +342,21 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path))
if isinstance(dataset, IterableDataset):
def gen_from_iter_ds(_ds, _=None):
yield from _ds
ds_from_iter = Dataset.from_generator(
functools.partial(gen_from_iter_ds, dataset),
features=dataset.features,
num_proc=cfg.dataset_processes,
split=split,
gen_kwargs={"_": list(range(cfg.dataset_processes))},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub:
LOG.info(
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
@@ -332,6 +376,7 @@ def load_prepare_datasets(
default_dataset_prepared_path,
split="train",
processor=None,
preprocess_iterable: Optional[bool] = False,
) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets(
tokenizer,
@@ -339,6 +384,7 @@ def load_prepare_datasets(
default_dataset_prepared_path,
split=split,
processor=processor,
preprocess_iterable=preprocess_iterable,
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
@@ -438,7 +484,7 @@ def get_dataset_wrapper(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -451,7 +497,7 @@ def get_dataset_wrapper(
config_dataset.type.split(".", 1)[1], tokenizer, cfg, config_dataset
):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -463,7 +509,7 @@ def get_dataset_wrapper(
dataset_wrapper = ds_strategy.wrap_dataset(dataset, **ds_kwargs)
else:
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(
dataset_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -476,7 +522,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -490,7 +536,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -504,7 +550,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -518,7 +564,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -532,7 +578,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -546,7 +592,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -560,7 +606,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,
@@ -574,7 +620,7 @@ def get_dataset_wrapper(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(
ds_wrapper = wrap_dataset_for_tokenized_prompt(
ds_strategy,
dataset,
**ds_kwargs,

View File

@@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault):
return ds_type
def load_dataset_w_config(config_dataset, auth_token):
def load_dataset_w_config(
config_dataset, auth_token, streaming=False
) -> Union[Dataset, DatasetDict]:
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
@@ -117,7 +119,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
streaming=streaming,
split=None,
)
else:
@@ -153,7 +155,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
streaming=streaming,
data_files=config_dataset.data_files,
token=auth_token,
revision=config_dataset.revision,
@@ -172,7 +174,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
streaming=streaming,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
@@ -183,7 +185,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
streaming=streaming,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
@@ -213,7 +215,7 @@ def load_dataset_w_config(config_dataset, auth_token):
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
streaming=streaming,
split=None,
)
if not ds:

View File

@@ -270,7 +270,7 @@ def load_sharded_model_quant(
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
print(f"Loaded model weights in {time.time() - start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()

View File

@@ -409,7 +409,6 @@ class ModelLoader:
and self.cfg.sample_packing
):
if "auto_map" in self.model_config:
# some model config objects are not subscriptable
try:
auto_map_config = self.model_config["auto_map"]
except TypeError:
@@ -1058,7 +1057,7 @@ class ModelLoader:
)
if (
hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings < embeddings_len
and self.model.get_input_embeddings().num_embeddings != embeddings_len
):
resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -26,6 +26,7 @@ def check_example_labels(example, tokenizer, text_only=False):
# Get the input_ids, labels, and attention_mask from the dataset
input_ids = example["input_ids"]
labels = example["labels"]
target_mask = example.pop("target_mask", None)
# You can compare the input_ids and labels element-wise
# Remember to ignore positions with IGNORE_TOKEN_ID (if you use it) or attention_mask equal to 0
@@ -42,6 +43,13 @@ def check_example_labels(example, tokenizer, text_only=False):
delimiter = "" if text_only else " "
LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n")
target_labels_count = sum(label_id != -100 for label_id in labels)
total_len = len(input_ids)
LOG.info(f"Total input len: {total_len}")
LOG.info(f"Count of labels: {target_labels_count}")
if target_mask:
target_mask_positions = sum(m[0] for m in target_mask)
LOG.info(f"Number of positions in target_mask: {target_mask_positions}")
return " ".join(colored_tokens)

View File

@@ -11,7 +11,7 @@ import numpy as np
import torch
import torch.cuda
from accelerate.logging import get_logger
from datasets import disable_caching, enable_caching
from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
@@ -95,9 +95,46 @@ def disable_datasets_caching():
def add_position_ids(sample):
sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
sample["length"] = sample_len
"""
Handle both single-example and batched data.
- single example: sample['input_ids'] is a list[int]
- batched data: sample['input_ids'] is a list[list[int]]
"""
if "input_ids" not in sample:
# If there's no "input_ids", just return sample unchanged
return sample
input_ids = sample["input_ids"]
# Detect if it's a single example or a batch
if not input_ids:
# Edge case: empty
return sample
# If first element is an int, its a single example
# If first element is a list, its a batch
if isinstance(input_ids[0], int):
# ---- SINGLE EXAMPLE ----
seq_len = len(input_ids)
# Position IDs for a single example
# As a list
sample["position_ids"] = list(range(seq_len))
sample["length"] = seq_len
else:
# ---- BATCHED EXAMPLES ----
# input_ids is a list of lists
position_ids_batch = []
lengths_batch = []
for seq in input_ids:
seq_len = len(seq)
position_ids_batch.append(list(range(seq_len)))
lengths_batch.append(seq_len)
# Now store them back
sample["position_ids"] = position_ids_batch
sample["length"] = lengths_batch
return sample
@@ -172,10 +209,31 @@ def add_length(sample):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return (
len(sample["input_ids"]) <= sequence_len
and len(sample["input_ids"]) >= min_sequence_len
)
"""
Drop samples whose sequence length is either too long (> sequence_len)
or too short (< min_sequence_len).
Works for both single-example (list[int]) or batched (list[list[int]]).
"""
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
if not input_ids:
# Decide if you want to drop or keep empty. Let's drop.
return False
# Check if single example or batched by looking at the first element
if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int)
length = len(input_ids)
return min_sequence_len <= length <= sequence_len
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
length = len(seq)
results.append(min_sequence_len <= length <= sequence_len)
return results
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
@@ -185,10 +243,13 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
min_sequence_len=cfg.min_sample_len or 2,
)
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
try:
min_input_len = np.min(get_dataset_lengths(train_dataset))
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
except AttributeError:
pass
if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column")
@@ -196,67 +257,116 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "falcon":
if cfg.model_config_type in ["falcon", "mistral"]:
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
prior_len = len(train_dataset)
filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
**filter_map_kwargs,
**drop_long_kwargs,
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if eval_dataset:
prior_len = len(eval_dataset)
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
**filter_map_kwargs,
**drop_long_kwargs,
)
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
# drop samples with where the number of elements with labels not equal to -100 is zero
def drop_no_trainable_tokens(sample):
return np.sum(np.array(sample["labels"]) != -100) > 0
"""
Drop samples if all labels are -100 (i.e., zero trainable tokens).
Works for both single-example or batched input.
"""
labels = sample["labels"]
if not labels:
# Edge case: if labels is empty, decide if you want to keep or drop
return True # or False
prior_len = len(train_dataset)
# Check if single example or batch
# If first element is an int, we assume a single example
# If it's a list, we assume we're dealing with a batch
if isinstance(labels[0], int):
# Single example: return a single bool
return np.sum(np.array(labels) != -100) > 0
# Batched: 'labels' is a list of lists
# Return a list of booleans, one per sub-list
results = []
for row_labels in labels:
# Each row_labels is a list[int]
results.append(np.sum(np.array(row_labels) != -100) > 0)
return results
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)
if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(eval_dataset)
if prior_len:
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)
if eval_dataset:
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
add_length,
@@ -291,19 +401,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Add position_id column (PoSE)",
)
elif cfg.sample_packing:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
train_dataset = train_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
batched=True,
**filter_map_kwargs,
**drop_long_kwargs,
)
if cfg.eval_sample_packing is not False:
if eval_dataset:
eval_dataset = eval_dataset.map(
add_position_ids,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Add position_id column (Sample Packing)",
**filter_map_kwargs,
**drop_long_kwargs,
)
return train_dataset, eval_dataset
@@ -334,7 +446,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
and not cfg.reward_model
):
total_num_tokens = np.sum(
train_dataset.data.column("input_ids")
train_dataset.select_columns("input_ids")
.to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values

View File

@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -15,46 +16,3 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
"""Test merge_sharded_fsdp_weights command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command with save_path option"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--save-path",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch

View File

@@ -1,76 +0,0 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_shard_with_accelerate(cli_runner, config_path):
"""Test shard command with accelerate"""
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_shard_no_accelerate(cli_runner, config_path):
"""Test shard command without accelerate"""
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
assert mock.called
assert result.exit_code == 0
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
"""Test shard command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
catch_exceptions=False,
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_shard_with_save_dir(cli_runner, config_path):
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--save-dir",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch

View File

@@ -120,13 +120,12 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import (
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
original_fa2_forward = LlamaFlashAttention2.forward
# original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
@@ -136,7 +135,7 @@ def cleanup_monkeypatches():
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
# LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
@@ -149,7 +148,10 @@ def cleanup_monkeypatches():
("transformers.models.llama",),
(
"transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"],
[
# "LlamaFlashAttention2",
"LlamaAttention",
],
),
("transformers.trainer",),
("transformers", ["Trainer"]),

View File

@@ -2,17 +2,17 @@
Simple end-to-end test for Cut Cross Entropy integration
"""
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
# pylint: disable=duplicate-code
@@ -64,10 +64,10 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.parametrize(
"attention_type",
@@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4):
with pytest.raises(ImportError):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -0,0 +1,121 @@
"""
e2e tests for kd trainer support in Axolotl
"""
from pathlib import Path
import pytest
from e2e.utils import check_tensorboard, require_torch_2_5_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
"base_model": "osllmai-community/Llama-3.2-1B",
"tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rms_norm": True,
"liger_glu_activation": True,
"torch_compile": True,
"chat_template": "llama3",
"kd_trainer": True,
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,
"kd_temperature": 2.0,
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
"datasets": [
{
"path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
"type": "axolotl.integrations.kd.chat_template",
"field_messages": "messages_combined",
"split": "train",
"logprobs_field": "llm_text_generation_vllm_logprobs",
"temperature": 1.0,
"preprocess_shards": 2,
},
],
"val_set_size": 0.0,
"sequence_len": 2048,
"sample_packing": True,
"pad_to_sequence_len": True,
"gradient_accumulation_steps": 2,
"micro_batch_size": 1,
"num_epochs": 1,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"learning_rate": 0.00001,
"bf16": "auto",
"gradient_checkpointing": True,
"flash_attention": True,
"special_tokens": {
"pad_token": "<|end_of_text|>",
"eos_token": "<|eot_id|>",
},
"max_steps": 5,
"output_dir": temp_dir,
"save_safetensors": True,
"use_tensorboard": True,
}
class TestKnowledgeDistillation:
"""
Test case for Knowledge Distillation
"""
# While this will run on torch 2.4.x without torch_compile enabled
# the VRAM requirement is higher than what is available in CI
@require_torch_2_5_1
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
)
def test_llama_lora_kd(self, temp_dir, kd_min_cfg, load_in_8bit):
cfg = DictDefault(
{
"load_in_8bit": load_in_8bit,
"torch_compile": False,
"adapter": "lora",
"peft_use_dora": True,
"lora_target_linear": True,
"lora_r": 16,
"lora_alpha": 32,
"lora_dropout": 0.0,
}
| kd_min_cfg
)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.0, "Train Loss is too high"
)

View File

@@ -1,43 +1,41 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from e2e.utils import require_torch_2_4_1
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists
class LigerIntegrationTestCase(unittest.TestCase):
class LigerIntegrationTestCase:
"""
e2e tests for liger integration with Axolotl
"""
@with_temp_dir
@require_torch_2_4_1
def test_llama_wo_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_glu_activation": True,
"liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -46,45 +44,44 @@ class LigerIntegrationTestCase(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
"max_steps": 5,
}
)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_glu_activation": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -93,21 +90,22 @@ class LigerIntegrationTestCase(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
"max_steps": 5,
}
)
# pylint: disable=duplicate-code
prepare_plugins(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,15 +5,14 @@ E2E tests for multipack fft llama using 4d attention masks
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_3_1, with_temp_dir
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -66,8 +65,8 @@ class Test4dMultipackLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
@@ -110,5 +109,5 @@ class Test4dMultipackLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import yaml
from axolotl.cli import load_cfg
from axolotl.cli.config import load_cfg
from axolotl.utils.dict import DictDefault

View File

@@ -4,18 +4,17 @@ E2E tests for lora llama
import logging
import os
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_tensorboard
from ..utils import check_model_output_exists, check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -81,8 +80,8 @@ class TestFAXentropyLlama:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"

View File

@@ -5,15 +5,14 @@ E2E tests for falcon
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,8 +67,8 @@ class TestFalconPatched(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft(self, temp_dir):
@@ -108,5 +107,5 @@ class TestFalconPatched(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,18 +5,17 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -72,5 +71,5 @@ class TestFusedLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,17 +5,16 @@ E2E tests for llama w/ S2 attn
import logging
import os
import unittest
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -70,8 +69,8 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_fft_s2_attn(self, temp_dir):
@@ -110,5 +109,5 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,18 +5,17 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -75,8 +74,8 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
@@ -125,5 +124,5 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,15 +5,14 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,8 +67,8 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft_packing(self, temp_dir):
@@ -109,5 +108,5 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -5,15 +5,14 @@ E2E tests for mixtral
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -65,8 +64,8 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_ft(self, temp_dir):
@@ -103,9 +102,9 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
model, _ = train(cfg=cfg, dataset_meta=dataset_meta)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,7 +6,6 @@ import unittest
import transformers
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
@@ -49,9 +48,8 @@ class TestModelPatches(unittest.TestCase):
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
model, _ = load_model(cfg, tokenizer, inference=False)
assert (
"MixtralFlashAttention2"
@@ -87,9 +85,8 @@ class TestModelPatches(unittest.TestCase):
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=cli_args.inference)
load_model(cfg, tokenizer, inference=False)
assert (
"torch.jit"

View File

@@ -5,15 +5,14 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import check_model_output_exists, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,8 +67,8 @@ class TestPhiMultipack(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_qlora_packed(self, temp_dir):
@@ -119,5 +118,5 @@ class TestPhiMultipack(unittest.TestCase):
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -6,17 +6,16 @@ import logging
import os
import re
import subprocess
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import most_recent_subdir
from ..utils import check_model_output_exists, most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -72,7 +71,7 @@ class TestResumeLlama:
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
train(cfg=cfg, dataset_meta=dataset_meta)
resume_cfg = cfg | DictDefault(
{
@@ -82,8 +81,8 @@ class TestResumeLlama:
normalize_config(resume_cfg)
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
train(cfg=resume_cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"

Some files were not shown because too many files have changed in this diff Show More