Compare commits

..

18 Commits

Author SHA1 Message Date
Dan Saunders
156fede4f7 Update .pre-commit-config.yaml
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2025-03-21 10:36:18 -04:00
Dan Saunders
dcbbd7af79 sorry to revert, but pylint complained 2025-03-21 10:36:18 -04:00
Dan Saunders
21bac7ce1a running updated pre-commit plugins 2025-03-21 10:36:18 -04:00
Dan Saunders
aaa4571826 adding pre-commit auto-update GH action and bumping plugin versions 2025-03-21 10:36:17 -04:00
salman
187227d837 Fixing KTO+QLoRA+multi-GPU (#2420)
* WIP

* removing artifacts

* adding error

* adding adapter check

* linting

* simplifying check

* linting v2

* config fix -___-
2025-03-21 10:18:28 -04:00
NanoCode012
f8de8bb4f2 chore(doc): add instructions on adding custom integrations (#2422) [skip ci]
* chore(doc): add instructions on adding custom integrations

* chore: add warning help

* feat: add note about integration path

* fix: adjust text per suggestion
2025-03-21 10:18:01 -04:00
hugo
8e604848a4 add run on novita ai (#2421) [skip ci]
* add run on novita ai

* Revert "add run on novita ai"

This reverts commit 4d5df1ac6b.

* add run axolotl on novita ai
2025-03-21 10:17:47 -04:00
Wing Lian
aae4337f40 add 12.8.1 cuda to the base matrix (#2426)
* add 12.8.1 cuda to the base matrix

* use nightly

* bump deepspeed and set no binary

* deepspeed binary fixes hopefully

* install deepspeed by itself

* multiline fix

* make sure ninja is installed

* try with reversion of packaging/setuptools/wheel install

* use license instead of license-file

* try rolling back packaging and setuptools versions

* comment out license for validation for now

* make sure packaging version is consistent

* more parity across tests and docker images for packaging/setuptools
2025-03-21 10:17:25 -04:00
Wing Lian
38df5a36ea bump HF versions except for trl (#2427) 2025-03-20 10:22:05 -04:00
Wing Lian
4d92a68a96 use default torch fused adamw optimizer as default as adamw_hf is deprecated (#2425)
* use default torch fused adamw optimizer as default as adamw_hf is deprecated

* make sure to have latest packaging installed

* bump packagingin requirements.txt too
2025-03-19 23:58:33 -04:00
SicariusSicariiStuff
85147ec430 Update README.md (#2360)
* Update README.md

wheel is needed

* feat: add ninja, setuptools, packing to installation steps

* fix: add missing instruction

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-03-17 08:39:17 -04:00
NanoCode012
51cd409488 Feat: minor docs improvements for RLHF and faq on embeddings (#2401) [skip ci]
* feat: add doc on shrink_embeddings and custom calling

* chore: rename inference doc

* fix: clarify same config is used for all cli

* chore: rearrange order inference qmd

* feat: add simpo to doc

* fix: update defaults

* feat: add rl configs to doc

* fix: ensure beta consistent with trl.beta

* fix: clarify about lora/fft

* chore: rename title

* chore: fix language

* feat: move config reference higher

* Update docs/getting-started.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

* Update docs/rlhf.qmd

Co-authored-by: salman <salman.mohammadi@outlook.com>

---------

Co-authored-by: salman <salman.mohammadi@outlook.com>
2025-03-17 08:39:04 -04:00
NanoCode012
7235123d44 chore(docs): add cookbook/blog link to docs (#2410) [skip ci] 2025-03-17 08:38:19 -04:00
Wing Lian
4f5eb42a73 remove reference to deprecated import (#2407) 2025-03-15 08:49:41 -04:00
Wing Lian
fbe54be6b8 only validate hf user token on rank 0 (#2408) 2025-03-13 23:29:06 -04:00
Wing Lian
04f6324833 build cloud images with torch 2.6.0 (#2413)
* build cloud images with torch 2.6.0

* nightlies too
2025-03-13 23:28:51 -04:00
Wing Lian
f0072f3b9d use max of 32 dataset processes if not explicit (#2403)
* use max of 32 dataset processes if not explicit

* change alternate min val for consistency
2025-03-11 12:02:58 -04:00
Wing Lian
59899b9817 pass additional info for fix untrained tokens when using distributed + offloading (#2388)
* pass additional info for fix untrained tokens when using distributed + offloading

* use latest version of vendored lib

* use v0.0.5 of contribs lgpl

* fix for no bad tokens and add tests

* use release

* add multigpu test too

* make sure the multigpu zero3 test actually uses zero3
2025-03-11 12:02:43 -04:00
162 changed files with 888 additions and 373 deletions

View File

@@ -40,6 +40,12 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
python_version: "3.11"
pytorch: nightly
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -61,7 +67,7 @@ jobs:
uses: docker/build-push-action@v4
with:
context: .
file: ./docker/Dockerfile-base
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }}

View File

@@ -88,6 +88,11 @@ jobs:
pytorch: 2.5.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -80,6 +80,11 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -0,0 +1,49 @@
name: Pre-commit auto-update
on:
schedule:
- cron: '0 0 * * 0' # Run weekly
workflow_dispatch: # Manual kickoff
jobs:
auto-update:
runs-on: ubuntu-latest
permissions:
contents: write
pull-requests: write
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.11'
- name: Update pre-commit hooks
id: update
run: |
pip install pre-commit
pre-commit autoupdate
if [[ -n $(git status --porcelain) ]]; then
echo "changes=true" >> $GITHUB_OUTPUT
git diff .pre-commit-config.yaml > pre-commit-update.diff
fi
- name: Create Pull Request
if: steps.update.outputs.changes == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
branch: update/pre-commit-hooks
delete-branch: true
title: "chore: update pre-commit hooks"
commit-message: "chore: update pre-commit hooks"
body: |
Automated PR to update pre-commit hooks to their latest versions.
<details>
<summary>Changes:</summary>
```diff
${{ steps.update.outputs.diff }}
```
</details>

View File

@@ -40,7 +40,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install wheel packaging
pip3 install wheel packaging==23.2
pip3 install --no-build-isolation -e .
pip3 install -r requirements-dev.txt -r requirements-tests.txt

View File

@@ -42,7 +42,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -59,7 +59,7 @@ jobs:
- name: Install dependencies
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging
pip3 install --upgrade packaging==23.2
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh

View File

@@ -74,7 +74,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
@@ -147,7 +147,7 @@ jobs:
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
- name: Install PyTorch
run: |

View File

@@ -3,7 +3,7 @@ default_language_version:
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.4.0
rev: v5.0.0
hooks:
- id: check-yaml
- id: end-of-file-fixer
@@ -11,23 +11,23 @@ repos:
- id: no-commit-to-branch
args: ['--branch', 'main']
- repo: https://github.com/psf/black
rev: 23.3.0
rev: 25.1.0
hooks:
- id: black
- repo: https://github.com/pycqa/isort
rev: 5.12.0
rev: 6.0.1
hooks:
- id: isort
- repo: https://github.com/PyCQA/flake8
rev: 6.1.0
rev: 7.1.2
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v3.3.0
- repo: https://github.com/pylint-dev/pylint
rev: v3.3.6
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.3.0
rev: v1.15.0
hooks:
- id: mypy
additional_dependencies:
@@ -36,7 +36,7 @@ repos:
'pydantic>=2.5.3',
]
- repo: https://github.com/PyCQA/bandit
rev: 1.7.5
rev: 1.8.3
hooks:
- id: bandit
args: [

View File

@@ -55,6 +55,7 @@ Features:
### Installation
```bash
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
# Download example axolotl configs, deepspeed configs

View File

@@ -32,8 +32,9 @@ website:
contents:
- docs/getting-started.qmd
- docs/installation.qmd
- docs/cli.qmd
- docs/inference.qmd
- docs/cli.qmd
- docs/config.qmd
- section: "Dataset Formats"
contents: docs/dataset-formats/*
@@ -74,10 +75,6 @@ website:
- docs/debugging.qmd
- docs/nccl.qmd
- section: "Reference"
contents:
- docs/config.qmd
format:
html:
theme: darkly

View File

@@ -31,6 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
fi
RUN pip install packaging==23.2 setuptools==75.8.0
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
else \

View File

@@ -1,6 +1,7 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os

View File

@@ -1,4 +1,5 @@
"""Modal app to run axolotl GPU tests"""
# pylint: disable=duplicate-code
import os

View File

@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"

View File

@@ -0,0 +1,39 @@
ARG CUDA_VERSION="12.8.1"
ARG CUDNN_VERSION="8"
ARG UBUNTU_VERSION="22.04"
ARG MAX_JOBS=4
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
ENV PATH="/root/miniconda3/bin:${PATH}"
ARG PYTHON_VERSION="3.11"
ARG PYTORCH_VERSION="nightly"
ARG CUDA="128"
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV PYTHON_VERSION=$PYTHON_VERSION
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
&& wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
WORKDIR /workspace
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
RUN git lfs install --skip-repo && \
pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10

View File

@@ -1,5 +1,5 @@
---
title: Config options
title: Config Reference
description: A complete list of all configuration options.
---
@@ -30,6 +30,8 @@ tokenizer_legacy:
# Resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# (Internal use only)
# Used to identify which the model is based on
@@ -83,6 +85,12 @@ gpu_memory_limit: 20GiB
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
lora_on_cpu: true
# List[str]. Add plugins to extend the pipeline.
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
plugins:
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
# A list of one or more datasets to finetune the model with
datasets:
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
@@ -205,10 +213,46 @@ test_datasets:
data_files:
- /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto'
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
rl_beta: # Optional[float]. The beta parameter for the RL training.
# dpo
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
# orpo
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
# kto
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
# simpo
cpo_alpha: 1.0 # Weight of the BC regularizer
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
# grpo
trl:
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
vllm_device: # Optional[str]. Device to use for VLLM.
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
vllm_dtype: # Optional[str]. Data type for VLLM.
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
num_generations: # Optional[int]. Number of generations to sample.
log_completions: # Optional[bool]. Whether to log completions.
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
# reward modelling: `True` or `False`
reward_model:
@@ -232,7 +276,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
# subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub
push_dataset_to_hub: # repo path
push_dataset_to_hub: # Optional[str] repo_org/repo_name
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set.
dataset_processes: # defaults to os.cpu_count() if not set

View File

@@ -55,3 +55,47 @@ sections = [
for section_name, folder_name in sections:
print(print_section(section_name, folder_name))
```
## Adding a new integration
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
To add a new integration, please follow these steps:
1. Create a new folder in the `src/axolotl/integrations` directory.
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
3. Add `__init__.py` and `args.py` files to the new folder.
- `__init__.py` should import the integration and hook into the appropriate functions.
- `args.py` should define the arguments for the integration.
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
::: {.callout-tip}
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
:::
::: {.callout-warning}
If you could not load your integration, please ensure you are pip installing in editable mode.
```bash
pip install -e .
```
and correctly spelled the integration name in the config file.
```yaml
plugins:
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
```
:::
::: {.callout-note}
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
:::

View File

@@ -27,6 +27,16 @@ description: Frequently asked questions
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
**Q: How to call Axolotl via custom python scripts?**
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
### Chat templates
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**

View File

@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
```yaml
base_model: NousResearch/Llama-3.2-1B
# hub_model_id: username/custom_model_name
load_in_8bit: true
adapter: lora
datasets:
- path: teknium/GPT4-LLM-Cleaned
@@ -44,11 +46,15 @@ datasets:
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/lora-out
adapter: lora
lora_model_dir:
```
::: {.callout-tip}
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
- To perform Full finetuning, remove these two lines.
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
:::
See our [Config options](config.qmd) for more details.
### Training {#sec-training}
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
When you run `axolotl train`, Axolotl:
1. Downloads the base model
2. (If specified) applies LoRA adapter layers
2. (If specified) applies QLoRA/LoRA adapter layers
3. Loads and processes the dataset
4. Runs the training loop
5. Saves the trained model and / or LoRA weights
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
```yaml
base_model: NousResearch/Nous-Hermes-llama-1b-v1
load_in_8bit: true
adapter: lora
# Training settings
@@ -104,8 +112,6 @@ format):
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
```
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
3. Run the training:
```bash

View File

@@ -1,5 +1,5 @@
---
title: "Inference"
title: "Inference and Merging"
format:
html:
toc: true
@@ -9,10 +9,14 @@ execute:
enabled: false
---
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
## Quick Start {#sec-quickstart}
::: {.callout-tip}
Use the same config used for training on inference/merging.
:::
### Basic Inference {#sec-basic}
::: {.panel-tabset}

View File

@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
### PyPI Installation (Recommended) {#sec-pypi}
```{.bash}
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
```
@@ -37,7 +38,7 @@ For the latest features between releases:
```{.bash}
git clone https://github.com/axolotl-ai-cloud/axolotl.git
cd axolotl
pip3 install packaging ninja
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
@@ -78,6 +79,7 @@ For providers supporting Docker:
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
- [Novita](https://novita.ai/gpus-console?templateId=311)
### Google Colab {#sec-colab}
@@ -107,7 +109,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
2. Install PyTorch: https://pytorch.org/get-started/locally/
3. Install Axolotl:
```{.bash}
pip3 install packaging
pip3 install -U packaging setuptools wheel ninja
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
```
4. (Optional) Login to Hugging Face:

View File

@@ -66,6 +66,10 @@ logic to be compatible with more of them.
</details>
::: {.callout-tip}
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
:::
## Usage
These optimizations can be enabled in your Axolotl config YAML file. The

View File

@@ -41,6 +41,10 @@ Bradley-Terry chat templates expect single-turn conversations in the following f
### Process Reward Models (PRM)
::: {.callout-tip}
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
:::
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
```yaml
base_model: Qwen/Qwen2.5-3B

View File

@@ -298,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
### IPO
As IPO is just DPO with a different loss function, all supported options for DPO works here.
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
```yaml
rl: ipo
@@ -344,8 +344,9 @@ ORPO supports the following types with the following dataset format:
```yaml
rl: kto
rl_beta: 0.5
kto_desirable_weight: 0.2
rl_beta: 0.1 # default
kto_desirable_weight: 1.0 # default
kto_undesirable_weight: 1.0 # default
remove_unused_columns: false
@@ -497,6 +498,10 @@ The input format is a simple JSON input with customizable fields based on the ab
### GRPO
::: {.callout-tip}
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
:::
GRPO uses custom reward functions and transformations. Please have them ready locally.
For ex, to load OpenAI's GSM8K and use a random reward for completions:
@@ -540,6 +545,19 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
### SimPO
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
```yaml
rl: simpo
rl_beta: 0.1 # default in CPOTrainer
cpo_alpha: 1.0 # default in CPOTrainer
simpo_gamma: 0.5 # default in CPOTrainer
```
This method uses the same dataset format as [DPO](#dpo).
### Using local dataset files
```yaml

View File

@@ -55,7 +55,7 @@ tf32: true
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: true
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
local_rank:

View File

@@ -1,5 +1,5 @@
[build-system]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
build-backend = "setuptools.build_meta"
[project]
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
description = "LLM Trainer"
readme = "README.md"
requires-python = ">=3.10"
# license = "Apache-2.0"
[project.scripts]
axolotl = "axolotl.cli.main:main"

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.2
bitsandbytes==0.45.3
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.4.post1
@@ -12,12 +12,12 @@ liger-kernel==0.5.3
packaging==23.2
peft==0.14.0
peft==0.15.0
transformers==4.49.0
tokenizers>=0.21.0
accelerate==1.3.0
datasets==3.2.0
deepspeed==0.16.1
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.4.1
deepspeed==0.16.4
trl==0.15.1
optimum==1.16.2
@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl @ git+https://github.com/axolotl-ai-cloud/axolotl-contribs-lgpl.git@import-issues-v2
axolotl-contribs-lgpl==0.0.6
axolotl-contribs-mit==0.0.3

View File

@@ -1,6 +1,7 @@
"""
helper script to parse chat datasets into a usable yaml
"""
import click
import yaml
from datasets import load_dataset

View File

@@ -1,4 +1,5 @@
"""Script to output the correct installation command for cut-cross-entropy."""
import importlib.util
import sys

View File

@@ -128,7 +128,7 @@ setup(
"flash-attn==2.7.4.post1",
],
"deepspeed": [
"deepspeed==0.16.1",
"deepspeed==0.16.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -1,6 +1,7 @@
"""
launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union

View File

@@ -1,6 +1,7 @@
"""
base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod

View File

@@ -1,6 +1,7 @@
"""
Modal Cloud support from CLI
"""
import copy
import json
import os

View File

@@ -1,4 +1,5 @@
"""Click CLI definitions for various axolotl commands."""
# pylint: disable=redefined-outer-name
import logging

View File

@@ -1,6 +1,7 @@
"""CLI to run training on a model."""
import logging
import os
from pathlib import Path
from typing import Union
@@ -34,7 +35,8 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
"""
print_axolotl_text_art()
check_accelerate_default_config()
check_user_token()
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -5,7 +5,6 @@ import dataclasses
import hashlib
import json
import logging
import typing
from functools import wraps
from pathlib import Path
from types import NoneType
@@ -24,7 +23,7 @@ configure_logging()
LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None):
def strip_optional_type(field_type: type | str | None):
"""
Extracts the non-`None` type from an `Optional` / `Union` type.

View File

@@ -1,6 +1,5 @@
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
import json
import sys

View File

@@ -1,6 +1,7 @@
"""
ChatML transformation functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
Llama 3.x chat formatting functions for MessageContents
"""
from typing import Optional
from ..messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
shared functions for format transforms
"""
from axolotl.core.chat.messages import MessageContents, Messages

View File

@@ -1,6 +1,7 @@
"""
internal message representations of chat messages
"""
import json
from enum import Enum
from typing import Any, Callable, List, Optional, Union

View File

@@ -1,6 +1,7 @@
"""
chat dataset module
"""
import os
from typing import Callable, Optional, Union
@@ -43,7 +44,7 @@ class TokenizedChatDataset(Dataset):
process_or_cpu_count: int = (
process_count or os.cpu_count() # type: ignore[assignment]
)
num_proc = min(64, process_or_cpu_count)
num_proc = min(32, process_or_cpu_count)
features = data.features.keys()
tokenized_data = data.map(
map_fn,

View File

@@ -1,6 +1,7 @@
"""
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
"""
from typing import Any, Mapping, Union

View File

@@ -332,9 +332,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs = {}
if self.cfg.include_tokens_per_second is not None:
training_arguments_kwargs[
"include_tokens_per_second"
] = self.cfg.include_tokens_per_second
training_arguments_kwargs["include_tokens_per_second"] = (
self.cfg.include_tokens_per_second
)
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
@@ -351,13 +351,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
training_arguments_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
@@ -373,9 +373,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[
"lr_quadratic_warmup"
] = self.cfg.lr_quadratic_warmup
training_arguments_kwargs["lr_quadratic_warmup"] = (
self.cfg.lr_quadratic_warmup
)
if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
@@ -399,28 +399,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.dataloader_pin_memory is not None:
training_arguments_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
training_arguments_kwargs["dataloader_pin_memory"] = (
self.cfg.dataloader_pin_memory
)
if self.cfg.dataloader_num_workers is not None:
training_arguments_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
training_arguments_kwargs["dataloader_num_workers"] = (
self.cfg.dataloader_num_workers
)
if self.cfg.dataloader_prefetch_factor is not None:
training_arguments_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
training_arguments_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.dataloader_drop_last is not None:
training_arguments_kwargs[
"dataloader_drop_last"
] = self.cfg.dataloader_drop_last
training_arguments_kwargs["dataloader_drop_last"] = (
self.cfg.dataloader_drop_last
)
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.remove_unused_columns is not None:
training_arguments_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
training_arguments_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval
@@ -452,9 +452,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.do_causal_lm_eval:
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
] = self.cfg.metric_for_best_model
training_arguments_kwargs["metric_for_best_model"] = (
self.cfg.metric_for_best_model
)
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
@@ -467,13 +467,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
training_arguments_kwargs["torch_compile_backend"] = (
self.cfg.torch_compile_backend
)
if self.cfg.torch_compile_mode:
training_arguments_kwargs[
"torch_compile_mode"
] = self.cfg.torch_compile_mode
training_arguments_kwargs["torch_compile_mode"] = (
self.cfg.torch_compile_mode
)
# DDP Config
if self.cfg.ddp_timeout:
@@ -482,32 +482,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[
"ddp_broadcast_buffers"
] = self.cfg.ddp_broadcast_buffers
training_arguments_kwargs["ddp_broadcast_buffers"] = (
self.cfg.ddp_broadcast_buffers
)
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs["per_device_train_batch_size"] = (
self.cfg.micro_batch_size
)
if self.cfg.eval_batch_size:
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs["per_device_eval_batch_size"] = (
self.cfg.eval_batch_size
)
if self.cfg.auto_find_batch_size is not None:
training_arguments_kwargs[
"auto_find_batch_size"
] = self.cfg.auto_find_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs[
"eval_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["auto_find_batch_size"] = (
self.cfg.auto_find_batch_size
)
training_arguments_kwargs["gradient_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["eval_accumulation_steps"] = (
self.cfg.gradient_accumulation_steps
)
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
@@ -554,9 +554,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
training_arguments_kwargs[
"alternate_lr_scheduler_type"
] = self.cfg.lr_scheduler
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
self.cfg.lr_scheduler
)
else:
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
@@ -565,9 +565,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
training_arguments_kwargs[
"cosine_constant_lr_ratio"
] = self.cfg.cosine_constant_lr_ratio
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
self.cfg.cosine_constant_lr_ratio
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
@@ -580,40 +580,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.eval_sample_packing
)
if self.cfg.sample_packing_bin_size is not None:
training_arguments_kwargs[
"sample_packing_bin_size"
] = self.cfg.sample_packing_bin_size
training_arguments_kwargs["sample_packing_bin_size"] = (
self.cfg.sample_packing_bin_size
)
if self.cfg.sample_packing_group_size is not None:
training_arguments_kwargs[
"sample_packing_group_size"
] = self.cfg.sample_packing_group_size
training_arguments_kwargs["sample_packing_group_size"] = (
self.cfg.sample_packing_group_size
)
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
training_arguments_kwargs["sample_packing_efficiency"] = (
self.cfg.sample_packing_eff_est
)
if self.cfg.relora_steps:
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs[
"relora_warmup_steps"
] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_warmup_steps"] = (
self.cfg.relora_warmup_steps
)
if self.cfg.relora_anneal_steps:
training_arguments_kwargs[
"relora_anneal_steps"
] = self.cfg.relora_anneal_steps
training_arguments_kwargs["relora_anneal_steps"] = (
self.cfg.relora_anneal_steps
)
if self.cfg.relora_prune_ratio:
training_arguments_kwargs[
"relora_prune_ratio"
] = self.cfg.relora_prune_ratio
training_arguments_kwargs["relora_prune_ratio"] = (
self.cfg.relora_prune_ratio
)
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
training_arguments_kwargs[
"lisa_step_interval"
] = self.cfg.lisa_step_interval
training_arguments_kwargs[
"lisa_layers_attribute"
] = self.cfg.lisa_layers_attribute
training_arguments_kwargs["lisa_step_interval"] = (
self.cfg.lisa_step_interval
)
training_arguments_kwargs["lisa_layers_attribute"] = (
self.cfg.lisa_layers_attribute
)
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
@@ -627,9 +627,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
)
if self.cfg.neftune_noise_alpha is not None:
training_arguments_kwargs[
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
training_arguments_kwargs["neftune_noise_alpha"] = (
self.cfg.neftune_noise_alpha
)
trainer_kwargs = {}
@@ -731,23 +731,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
importlib.import_module("torchdistx")
if self.cfg.optim_target_modules:
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
training_arguments_kwargs["optim_target_modules"] = (
self.cfg.optim_target_modules
)
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["loraplus_lr_embedding"] = (
self.cfg.loraplus_lr_embedding
)
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.accelerator_config:
training_arguments_kwargs[
"accelerator_config"
] = self.cfg.accelerator_config
training_arguments_kwargs["accelerator_config"] = (
self.cfg.accelerator_config
)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
@@ -756,13 +756,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs[
"kd_zscore_base_temp"
] = self.cfg.kd_zscore_base_temp
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs[
"kd_top_k_before_softmax"
] = self.cfg.kd_top_k_before_softmax
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -972,32 +972,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
)
if self.cfg.remove_unused_columns is not None:
training_args_kwargs[
"remove_unused_columns"
] = self.cfg.remove_unused_columns
training_args_kwargs["remove_unused_columns"] = (
self.cfg.remove_unused_columns
)
else:
training_args_kwargs["remove_unused_columns"] = False
if self.cfg.dataloader_pin_memory is not None:
training_args_kwargs[
"dataloader_pin_memory"
] = self.cfg.dataloader_pin_memory
training_args_kwargs["dataloader_pin_memory"] = (
self.cfg.dataloader_pin_memory
)
if self.cfg.dataloader_num_workers is not None:
training_args_kwargs[
"dataloader_num_workers"
] = self.cfg.dataloader_num_workers
training_args_kwargs["dataloader_num_workers"] = (
self.cfg.dataloader_num_workers
)
if self.cfg.dataloader_prefetch_factor is not None:
training_args_kwargs[
"dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor
training_args_kwargs["dataloader_prefetch_factor"] = (
self.cfg.dataloader_prefetch_factor
)
if self.cfg.gradient_checkpointing:
training_args_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
training_args_kwargs["gradient_checkpointing"] = (
self.cfg.gradient_checkpointing
)
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
training_args_kwargs["gradient_checkpointing_kwargs"] = (
self.cfg.gradient_checkpointing_kwargs
)
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
@@ -1071,9 +1071,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.dpo_use_logits_to_keep is not None:
training_args_kwargs[
"use_logits_to_keep"
] = self.cfg.dpo_use_logits_to_keep
training_args_kwargs["use_logits_to_keep"] = (
self.cfg.dpo_use_logits_to_keep
)
for blocklist_key in blocklist_args_kwargs:
if blocklist_key in training_args_kwargs:
@@ -1108,9 +1108,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
self.cfg.precompute_ref_log_probs
)
if self.cfg.rl == "grpo":
trainer_cls = GRPOStrategy.get_trainer_class()
trainer_cls_args = [self.model]

View File

@@ -462,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
@@ -509,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
dataloader_params["prefetch_factor"] = (
self.args.dataloader_prefetch_factor
)
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler

View File

@@ -1,6 +1,7 @@
"""
DPO Specific Strategy for training
"""
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer

View File

@@ -1,6 +1,7 @@
"""
Axolotl specific DPO args
"""
from dataclasses import dataclass
from trl import DPOConfig

View File

@@ -1,6 +1,7 @@
"""
DPO trainer for axolotl
"""
import gc
from functools import wraps
from typing import Any, Dict, Union

View File

@@ -45,9 +45,9 @@ class GRPOStrategy:
)
if trl.vllm_gpu_memory_utilization:
grpo_args_kwargs[
"vllm_gpu_memory_utilization"
] = trl.vllm_gpu_memory_utilization
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
trl.vllm_gpu_memory_utilization
)
if trl.vllm_max_model_len:
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
@@ -86,9 +86,9 @@ class GRPOStrategy:
def set_trainer_kwargs(cls, cfg):
trainer_kwargs = {}
if cfg.trl and cfg.trl.reward_processing_classes:
trainer_kwargs[
"reward_processing_classes"
] = cfg.trl.reward_processing_classes
trainer_kwargs["reward_processing_classes"] = (
cfg.trl.reward_processing_classes
)
return trainer_kwargs
@classmethod

View File

@@ -1,6 +1,7 @@
"""
Axolotl Specific Training Args
"""
from dataclasses import dataclass
from trl import GRPOConfig

View File

@@ -1,6 +1,7 @@
"""
Axolotl GRPO trainer
"""
from accelerate.utils import is_peft_model
from accelerate.utils.other import is_compiled_module
from transformers import PreTrainedModel

View File

@@ -1,6 +1,7 @@
"""
module for TRL PPO training
"""
import torch
from tqdm import tqdm
from trl import PPOTrainer

View File

@@ -1,6 +1,7 @@
"""
extra axolotl specific training args
"""
from dataclasses import dataclass, field
from typing import Optional

View File

@@ -1,6 +1,7 @@
"""
Grokfast plugin for Axolotl
"""
import logging
from transformers.trainer_callback import TrainerCallback

View File

@@ -1,6 +1,7 @@
"""
config args for grokfast plugin
"""
from typing import Optional
from pydantic import BaseModel

View File

@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
"""
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_ce_alpha: Optional[float] = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[
bool
] = None # whether to sample top k before softmax during KD
kd_top_k_before_softmax: Optional[bool] = (
None # whether to sample top k before softmax during KD
)

View File

@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
if "cross_entropy" in liger_fn_sig.parameters:
kwargs["cross_entropy"] = cfg.liger_cross_entropy
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
kwargs[
"fused_linear_cross_entropy"
] = cfg.liger_fused_linear_cross_entropy
kwargs["fused_linear_cross_entropy"] = (
cfg.liger_fused_linear_cross_entropy
)
if "rms_norm" in liger_fn_sig.parameters:
kwargs["rms_norm"] = cfg.liger_rms_norm
if "layer_norm" in liger_fn_sig.parameters:

View File

@@ -1,6 +1,7 @@
"""
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import List, Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
"""
Jamba model with LigerFusedLinearCrossEntropyLoss
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
"""
Module for the Plugin for LM Eval Harness
"""
import subprocess # nosec
from axolotl.integrations.base import BasePlugin

View File

@@ -1,6 +1,7 @@
"""
Module for handling lm eval harness input arguments.
"""
from typing import List, Optional
from pydantic import BaseModel

View File

@@ -1,6 +1,7 @@
"""
axolotl CLI for running lm_eval tasks
"""
import subprocess # nosec
from collections import defaultdict
from datetime import datetime

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
import torch

View File

@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
# pylint: disable=invalid-name
from typing import Callable

View File

@@ -1,4 +1,5 @@
"""Dequantization utilities for `bitsandbytes` integration."""
# pylint: disable=invalid-name,global-statement
import ctypes

View File

@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
"""
import torch
import triton
import triton.language as tl

View File

@@ -1,6 +1,7 @@
"""
HF Transformers MambaConfig
"""
from transformers import PretrainedConfig

View File

@@ -1,6 +1,7 @@
"""
Monkeypatch for Vision Llama for FA2 support
"""
# pylint: disable=duplicate-code
from typing import Optional, Tuple
@@ -220,10 +221,10 @@ def patch_mllama():
True
)
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
"flash_attention_2"
] = MllamaTextCrossFlashAttention2
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
MllamaTextCrossFlashAttention2
)
# fallback to SDPA
MLLAMA_VISION_ATTENTION_CLASSES[
"flash_attention_2"
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
)

View File

@@ -1,4 +1,5 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access
import torch

View File

@@ -12,7 +12,9 @@ import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import (
LlamaAttention,
)
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
)
@@ -490,9 +492,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
@@ -531,9 +535,11 @@ def flashattn_forward(
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch

View File

@@ -1,4 +1,5 @@
"""Flash attention monkey patch for mistral model"""
# pylint: disable=duplicate-code
import logging
@@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import (
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer,
)
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
from transformers.models.mistral.modeling_mistral import (
apply_rotary_pos_emb,
repeat_kv,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
@@ -243,9 +247,11 @@ def flashattn_forward(
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
output_unpad = flash_attn_varlen_qkvpacked_func(
qkv_unpad,
@@ -286,9 +292,11 @@ def flashattn_forward(
value_states,
kvpacked=True,
key_padding_mask=attention_mask,
query_padding_mask=attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None,
query_padding_mask=(
attention_mask[:, -query_states.size(1) :]
if attention_mask is not None
else None
),
)
if q_unpad.dtype != kv_unpad.dtype:
kv_unpad = kv_unpad.to(q_unpad.dtype)

View File

@@ -1,6 +1,7 @@
"""
Patches to support multipack for mixtral
"""
import torch

View File

@@ -1,4 +1,5 @@
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
import glob
import json
import logging
@@ -411,7 +412,10 @@ def merge_and_save(
if shard_path.endswith(".safetensors"):
in_tensors = st.load_file(str(Path(model_src) / shard_path))
else:
in_tensors = torch.load(Path(model_src) / shard_path)
in_tensors = torch.load(
Path(model_src) / shard_path,
weights_only=True, # to prevent arbitrary code execution
)
if "state_dict" in in_tensors:
in_tensors = in_tensors["state_dict"]

View File

@@ -17,7 +17,7 @@
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
# pylint: disable=duplicate-code
""" PyTorch StableLM Epoch model. """
"""PyTorch StableLM Epoch model."""
import importlib
import math
from typing import Optional, Tuple, Union

View File

@@ -1,6 +1,7 @@
"""
fix for FSDP optimizer save in trainer w 4.47.0
"""
import inspect
import logging

View File

@@ -1,6 +1,7 @@
"""
Shared utils for the monkeypatches
"""
import re
from typing import Optional, Tuple

View File

@@ -1,6 +1,7 @@
"""
Fused MLP layer for incrementally improved training efficiency
"""
import torch
from transformers.models.llama.modeling_llama import LlamaMLP
from xformers.ops import SwiGLU

View File

@@ -1,6 +1,7 @@
"""
Prompt strategies loader for alpaca instruction datasets with system prompts
"""
from typing import Generator, Tuple, Union
from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
"""
Basic completion text
"""
from collections import defaultdict
from typing import Any, Dict, Generator, Optional, Tuple

View File

@@ -1,4 +1,5 @@
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
from typing import Tuple
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy

View File

@@ -1,6 +1,7 @@
"""
module for DPO style dataset transform strategies
"""
from functools import partial
from ..base import load as load_base

View File

@@ -33,9 +33,9 @@ def default(
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
return sample
@@ -52,9 +52,9 @@ def argilla_chat(
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample
@@ -78,9 +78,9 @@ def icr(
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
@@ -120,9 +120,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
return sample

View File

@@ -34,9 +34,9 @@ def default(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
return sample
@@ -53,9 +53,9 @@ def argilla_chat(
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample
@@ -79,9 +79,9 @@ def icr(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
@@ -121,9 +121,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
return sample
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
return sample

View File

@@ -1,4 +1,5 @@
"""Module for plain input/output prompt pairs"""
from typing import Generator, Tuple
from axolotl.prompt_tokenizers import PromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""Module for inspect jinja templates for the variables they use"""
from typing import Dict, Optional, Set, TypedDict, Union
from jinja2 import Environment, meta, nodes

View File

@@ -1,6 +1,7 @@
"""
KTO strategies for chatml
"""
# pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['instruction']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
@@ -33,9 +34,9 @@ def argilla_chat(
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion'][1]['content']}<|im_end|>"
return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
sample["prompt"] = (
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
)
sample["completion"] = f"{sample['completion']}<|im_end|>"
return sample

View File

@@ -1,6 +1,7 @@
"""
KTO strategies for llama-3 chat template
"""
# pylint: disable=duplicate-code
@@ -15,9 +16,9 @@ def argilla(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['instruction']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
@@ -33,9 +34,9 @@ def argilla_chat(
"""
def transform_fn(sample):
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['completion'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion'][1]['content']}<|eot_id|>"
return sample
@@ -55,9 +56,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
@@ -74,9 +75,9 @@ def prompt_pairs(
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample
@@ -96,9 +97,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
else:
sample[
"prompt"
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
sample["prompt"] = (
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
)
sample["completion"] = f"{sample['completion']}<|eot_id|>"
return sample

View File

@@ -1,6 +1,7 @@
"""
User-defined KTO strategies
"""
# pylint: disable=duplicate-code

View File

@@ -1,6 +1,7 @@
"""
Chat dataset wrapping strategy for new internal messages representations
"""
from typing import Any, Callable, Dict, Optional
from axolotl.core.datasets.chat import TokenizedChatDataset

View File

@@ -9,6 +9,7 @@ this one specifies the system prompt with "### System:".
Not suited/tested for multiple-turn conversations without further adjustments.
"""
from typing import Generator, Union
from axolotl.prompt_strategies.alpaca_w_system import OpenOrcaPromptTokenizingStrategy

View File

@@ -1,4 +1,5 @@
"""chatml prompt tokenization strategy for ORPO"""
from typing import Any, Dict, Generator, List, Optional, Tuple
from pydantic import BaseModel

View File

@@ -1,4 +1,5 @@
"""pretraining prompt strategies"""
from typing import Generator
from transformers import BatchEncoding

View File

@@ -7,7 +7,7 @@ import signal
import sys
import weakref
from pathlib import Path
from typing import Any
from typing import Any, Dict
import torch
import transformers.modelcard
@@ -20,7 +20,7 @@ from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from transformers.trainer import Trainer
from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
@@ -382,21 +382,23 @@ def handle_untrained_tokens_fix(
if not cfg.fix_untrained_tokens:
return
is_ds_zero3: bool = False
if os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE") == "3":
is_ds_zero3 = True
# Check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
fix_kwargs: Dict[str, Any] = {}
# If the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
fix_kwargs["token_ids_to_fix"] = cfg.fix_untrained_tokens
if "is_ds_zero3" in sig.parameters:
fix_kwargs["is_ds_zero3"] = is_ds_zero3
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
if cfg.local_rank == 0:
model.save_pretrained(
@@ -404,9 +406,7 @@ def handle_untrained_tokens_fix(
)
def setup_model_and_trainer(
cfg: DictDefault, dataset_meta: TrainDatasetMeta
) -> tuple[
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,

View File

@@ -40,6 +40,6 @@ def set_pytorch_cuda_alloc_conf():
torch_major, torch_minor = int(torch_version[0]), int(torch_version[1])
if torch_major == 2 and torch_minor >= 2:
if os.getenv("PYTORCH_CUDA_ALLOC_CONF") is None:
os.environ[
"PYTORCH_CUDA_ALLOC_CONF"
] = "expandable_segments:True,roundup_power2_divisions:16"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = (
"expandable_segments:True,roundup_power2_divisions:16"
)

View File

@@ -1,4 +1,5 @@
"""Benchmarking and measurement utilities"""
import functools
import torch

View File

@@ -343,9 +343,9 @@ def bench_eval_callback_factory(trainer, tokenizer):
bench_refs.extend(combined_bench_names[bench_name]["refs"])
bench_preds.extend(combined_bench_names[bench_name]["preds"])
if not pd.isna(bench_score):
results[
f"{bench_split}_bench_accuracy_{bench_name}"
] = bench_score
results[f"{bench_split}_bench_accuracy_{bench_name}"] = (
bench_score
)
bench_scores.append(bench_score)
else:
results[f"{bench_split}_bench_accuracy_{bench_name}"] = 0.0

View File

@@ -1,4 +1,5 @@
"""MLFlow module for trainer callbacks"""
import logging
from shutil import copyfile
from tempfile import NamedTemporaryFile

View File

@@ -1,4 +1,5 @@
"""callback to calculate perplexity as an evaluation metric."""
from typing import Dict, List, Optional
import torch

View File

@@ -1,6 +1,7 @@
"""
HF Trainer callback for creating pytorch profiling snapshots
"""
from pathlib import Path
from pickle import dump # nosec B403

Some files were not shown because too many files have changed in this diff Show More