Compare commits
18 Commits
kd-logprob
...
cuda-12.8.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
31799bdcc0 | ||
|
|
25455ac25f | ||
|
|
edea25bd58 | ||
|
|
42e32223c9 | ||
|
|
6e0fed0ce7 | ||
|
|
5ece44b4a8 | ||
|
|
e7532c9b0c | ||
|
|
2518a9b2a2 | ||
|
|
faeae323cb | ||
|
|
bb683644c3 | ||
|
|
7009a48398 | ||
|
|
ee529e2354 | ||
|
|
b2976e64ec | ||
|
|
38df5a36ea | ||
|
|
4d92a68a96 | ||
|
|
85147ec430 | ||
|
|
51cd409488 | ||
|
|
7235123d44 |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -40,6 +40,12 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "128"
|
||||||
|
cuda_version: 12.8.1
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: nightly
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
@@ -61,7 +67,7 @@ jobs:
|
|||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v4
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
file: ./docker/Dockerfile-base
|
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
|||||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install wheel packaging
|
pip3 install wheel packaging==23.2
|
||||||
pip3 install --no-build-isolation -e .
|
pip3 install --no-build-isolation -e .
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||||
|
|
||||||
|
|||||||
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging setuptools wheel
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -59,7 +59,7 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging
|
pip3 install --upgrade packaging==23.2
|
||||||
pip3 install --no-build-isolation -U -e .
|
pip3 install --no-build-isolation -U -e .
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging setuptools wheel
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
@@ -147,7 +147,7 @@ jobs:
|
|||||||
- name: upgrade pip
|
- name: upgrade pip
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 install --upgrade pip
|
||||||
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ repos:
|
|||||||
rev: 6.1.0
|
rev: 6.1.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/PyCQA/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
rev: v3.3.0
|
rev: c8c96d20cde3552a79858c7456bb1483bf83d633
|
||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ Features:
|
|||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
|
|
||||||
# Download example axolotl configs, deepspeed configs
|
# Download example axolotl configs, deepspeed configs
|
||||||
|
|||||||
@@ -32,8 +32,9 @@ website:
|
|||||||
contents:
|
contents:
|
||||||
- docs/getting-started.qmd
|
- docs/getting-started.qmd
|
||||||
- docs/installation.qmd
|
- docs/installation.qmd
|
||||||
- docs/cli.qmd
|
|
||||||
- docs/inference.qmd
|
- docs/inference.qmd
|
||||||
|
- docs/cli.qmd
|
||||||
|
- docs/config.qmd
|
||||||
|
|
||||||
- section: "Dataset Formats"
|
- section: "Dataset Formats"
|
||||||
contents: docs/dataset-formats/*
|
contents: docs/dataset-formats/*
|
||||||
@@ -74,10 +75,6 @@ website:
|
|||||||
- docs/debugging.qmd
|
- docs/debugging.qmd
|
||||||
- docs/nccl.qmd
|
- docs/nccl.qmd
|
||||||
|
|
||||||
- section: "Reference"
|
|
||||||
contents:
|
|
||||||
- docs/config.qmd
|
|
||||||
|
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
theme: darkly
|
theme: darkly
|
||||||
|
|||||||
@@ -31,6 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
|||||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
|||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
|
|||||||
39
docker/Dockerfile-base-nightly
Normal file
39
docker/Dockerfile-base-nightly
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
ARG CUDA_VERSION="12.8.1"
|
||||||
|
ARG CUDNN_VERSION="8"
|
||||||
|
ARG UBUNTU_VERSION="22.04"
|
||||||
|
ARG MAX_JOBS=4
|
||||||
|
|
||||||
|
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||||
|
|
||||||
|
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||||
|
|
||||||
|
ARG PYTHON_VERSION="3.11"
|
||||||
|
ARG PYTORCH_VERSION="nightly"
|
||||||
|
ARG CUDA="128"
|
||||||
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
|
||||||
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
|
RUN apt-get update \
|
||||||
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||||
|
&& wget \
|
||||||
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
|
&& mkdir /root/.conda \
|
||||||
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
|
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||||
|
|
||||||
|
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||||
|
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||||
|
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||||
|
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||||
|
|
||||||
|
RUN git lfs install --skip-repo && \
|
||||||
|
pip3 install awscli && \
|
||||||
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: Config options
|
title: Config Reference
|
||||||
description: A complete list of all configuration options.
|
description: A complete list of all configuration options.
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -30,6 +30,8 @@ tokenizer_legacy:
|
|||||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||||
# This is reported to improve training speed on some models
|
# This is reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
|
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||||
|
shrink_embeddings:
|
||||||
|
|
||||||
# (Internal use only)
|
# (Internal use only)
|
||||||
# Used to identify which the model is based on
|
# Used to identify which the model is based on
|
||||||
@@ -205,10 +207,46 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto'
|
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
||||||
rl:
|
rl:
|
||||||
# whether to perform weighting if doing DPO training. Boolean.
|
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
||||||
dpo_use_weighting:
|
|
||||||
|
# dpo
|
||||||
|
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
||||||
|
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
||||||
|
|
||||||
|
# orpo
|
||||||
|
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
||||||
|
|
||||||
|
# kto
|
||||||
|
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
||||||
|
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
||||||
|
|
||||||
|
# simpo
|
||||||
|
cpo_alpha: 1.0 # Weight of the BC regularizer
|
||||||
|
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||||
|
|
||||||
|
# grpo
|
||||||
|
trl:
|
||||||
|
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||||
|
vllm_device: # Optional[str]. Device to use for VLLM.
|
||||||
|
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
||||||
|
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
||||||
|
vllm_dtype: # Optional[str]. Data type for VLLM.
|
||||||
|
|
||||||
|
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||||
|
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||||
|
|
||||||
|
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
||||||
|
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
||||||
|
|
||||||
|
num_generations: # Optional[int]. Number of generations to sample.
|
||||||
|
log_completions: # Optional[bool]. Whether to log completions.
|
||||||
|
|
||||||
|
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||||
|
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||||
|
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||||
|
|
||||||
|
|
||||||
# reward modelling: `True` or `False`
|
# reward modelling: `True` or `False`
|
||||||
reward_model:
|
reward_model:
|
||||||
@@ -232,7 +270,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
|
|||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
# Push prepared dataset to hub
|
# Push prepared dataset to hub
|
||||||
push_dataset_to_hub: # repo path
|
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
||||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||||
# if not set.
|
# if not set.
|
||||||
dataset_processes: # defaults to os.cpu_count() if not set
|
dataset_processes: # defaults to os.cpu_count() if not set
|
||||||
|
|||||||
10
docs/faq.qmd
10
docs/faq.qmd
@@ -27,6 +27,16 @@ description: Frequently asked questions
|
|||||||
|
|
||||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||||
|
|
||||||
|
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
|
||||||
|
|
||||||
|
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
|
||||||
|
|
||||||
|
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
|
||||||
|
|
||||||
|
**Q: How to call Axolotl via custom python scripts?**
|
||||||
|
|
||||||
|
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Llama-3.2-1B
|
base_model: NousResearch/Llama-3.2-1B
|
||||||
# hub_model_id: username/custom_model_name
|
|
||||||
|
load_in_8bit: true
|
||||||
|
adapter: lora
|
||||||
|
|
||||||
datasets:
|
datasets:
|
||||||
- path: teknium/GPT4-LLM-Cleaned
|
- path: teknium/GPT4-LLM-Cleaned
|
||||||
@@ -44,11 +46,15 @@ datasets:
|
|||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.1
|
val_set_size: 0.1
|
||||||
output_dir: ./outputs/lora-out
|
output_dir: ./outputs/lora-out
|
||||||
|
|
||||||
adapter: lora
|
|
||||||
lora_model_dir:
|
|
||||||
```
|
```
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
|
||||||
|
|
||||||
|
- To perform Full finetuning, remove these two lines.
|
||||||
|
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||||
|
:::
|
||||||
|
|
||||||
See our [Config options](config.qmd) for more details.
|
See our [Config options](config.qmd) for more details.
|
||||||
|
|
||||||
### Training {#sec-training}
|
### Training {#sec-training}
|
||||||
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
|
|||||||
When you run `axolotl train`, Axolotl:
|
When you run `axolotl train`, Axolotl:
|
||||||
|
|
||||||
1. Downloads the base model
|
1. Downloads the base model
|
||||||
2. (If specified) applies LoRA adapter layers
|
2. (If specified) applies QLoRA/LoRA adapter layers
|
||||||
3. Loads and processes the dataset
|
3. Loads and processes the dataset
|
||||||
4. Runs the training loop
|
4. Runs the training loop
|
||||||
5. Saves the trained model and / or LoRA weights
|
5. Saves the trained model and / or LoRA weights
|
||||||
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||||
|
|
||||||
|
load_in_8bit: true
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|
||||||
# Training settings
|
# Training settings
|
||||||
@@ -104,8 +112,6 @@ format):
|
|||||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||||
```
|
```
|
||||||
|
|
||||||
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
|
||||||
|
|
||||||
3. Run the training:
|
3. Run the training:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
---
|
---
|
||||||
title: "Inference"
|
title: "Inference and Merging"
|
||||||
format:
|
format:
|
||||||
html:
|
html:
|
||||||
toc: true
|
toc: true
|
||||||
@@ -9,10 +9,14 @@ execute:
|
|||||||
enabled: false
|
enabled: false
|
||||||
---
|
---
|
||||||
|
|
||||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
|
||||||
|
|
||||||
## Quick Start {#sec-quickstart}
|
## Quick Start {#sec-quickstart}
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Use the same config used for training on inference/merging.
|
||||||
|
:::
|
||||||
|
|
||||||
### Basic Inference {#sec-basic}
|
### Basic Inference {#sec-basic}
|
||||||
|
|
||||||
::: {.panel-tabset}
|
::: {.panel-tabset}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
|||||||
### PyPI Installation (Recommended) {#sec-pypi}
|
### PyPI Installation (Recommended) {#sec-pypi}
|
||||||
|
|
||||||
```{.bash}
|
```{.bash}
|
||||||
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -37,7 +38,7 @@ For the latest features between releases:
|
|||||||
```{.bash}
|
```{.bash}
|
||||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||||
cd axolotl
|
cd axolotl
|
||||||
pip3 install packaging ninja
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -107,7 +108,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
|||||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||||
3. Install Axolotl:
|
3. Install Axolotl:
|
||||||
```{.bash}
|
```{.bash}
|
||||||
pip3 install packaging
|
pip3 install -U packaging setuptools wheel ninja
|
||||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||||
```
|
```
|
||||||
4. (Optional) Login to Hugging Face:
|
4. (Optional) Login to Hugging Face:
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ logic to be compatible with more of them.
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
||||||
|
:::
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||||
|
|||||||
@@ -41,6 +41,10 @@ Bradley-Terry chat templates expect single-turn conversations in the following f
|
|||||||
|
|
||||||
### Process Reward Models (PRM)
|
### Process Reward Models (PRM)
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
|
||||||
|
:::
|
||||||
|
|
||||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||||
```yaml
|
```yaml
|
||||||
base_model: Qwen/Qwen2.5-3B
|
base_model: Qwen/Qwen2.5-3B
|
||||||
|
|||||||
@@ -298,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### IPO
|
### IPO
|
||||||
|
|
||||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: ipo
|
rl: ipo
|
||||||
@@ -344,8 +344,9 @@ ORPO supports the following types with the following dataset format:
|
|||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
rl: kto
|
rl: kto
|
||||||
rl_beta: 0.5
|
rl_beta: 0.1 # default
|
||||||
kto_desirable_weight: 0.2
|
kto_desirable_weight: 1.0 # default
|
||||||
|
kto_undesirable_weight: 1.0 # default
|
||||||
|
|
||||||
remove_unused_columns: false
|
remove_unused_columns: false
|
||||||
|
|
||||||
@@ -497,6 +498,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
|||||||
|
|
||||||
### GRPO
|
### GRPO
|
||||||
|
|
||||||
|
::: {.callout-tip}
|
||||||
|
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||||
|
:::
|
||||||
|
|
||||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||||
|
|
||||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||||
@@ -540,6 +545,19 @@ To see other examples of custom reward functions, please see [TRL GRPO Docs](htt
|
|||||||
|
|
||||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||||
|
|
||||||
|
### SimPO
|
||||||
|
|
||||||
|
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
rl: simpo
|
||||||
|
rl_beta: 0.1 # default in CPOTrainer
|
||||||
|
cpo_alpha: 1.0 # default in CPOTrainer
|
||||||
|
simpo_gamma: 0.5 # default in CPOTrainer
|
||||||
|
```
|
||||||
|
|
||||||
|
This method uses the same dataset format as [DPO](#dpo).
|
||||||
|
|
||||||
### Using local dataset files
|
### Using local dataset files
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
[build-system]
|
[build-system]
|
||||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
|
||||||
build-backend = "setuptools.build_meta"
|
build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
|
|||||||
description = "LLM Trainer"
|
description = "LLM Trainer"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
# license = "Apache-2.0"
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
axolotl = "axolotl.cli.main:main"
|
axolotl = "axolotl.cli.main:main"
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.45.2
|
bitsandbytes==0.45.3
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
flash-attn==2.7.4.post1
|
flash-attn==2.7.4.post1
|
||||||
@@ -12,12 +12,12 @@ liger-kernel==0.5.3
|
|||||||
|
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.14.0
|
peft==0.15.0
|
||||||
transformers==4.49.0
|
transformers==4.49.0
|
||||||
tokenizers>=0.21.0
|
tokenizers>=0.21.1
|
||||||
accelerate==1.3.0
|
accelerate==1.5.2
|
||||||
datasets==3.2.0
|
datasets==3.4.1
|
||||||
deepspeed==0.16.1
|
deepspeed==0.16.4
|
||||||
trl==0.15.1
|
trl==0.15.1
|
||||||
|
|
||||||
optimum==1.16.2
|
optimum==1.16.2
|
||||||
|
|||||||
@@ -17,12 +17,12 @@ if v < V("2.4.0"):
|
|||||||
|
|
||||||
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
cce_spec = importlib.util.find_spec("cut_cross_entropy")
|
||||||
|
|
||||||
UNINSTALL_PREFIX = ""
|
uninstall_prefix = ""
|
||||||
if cce_spec:
|
if cce_spec:
|
||||||
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
if not importlib.util.find_spec("cut_cross_entropy.transformers"):
|
||||||
UNINSTALL_PREFIX = "pip uninstall -y cut-cross-entropy && "
|
uninstall_prefix = "pip uninstall -y cut-cross-entropy && "
|
||||||
|
|
||||||
print(
|
print(
|
||||||
UNINSTALL_PREFIX
|
uninstall_prefix
|
||||||
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
||||||
)
|
)
|
||||||
|
|||||||
2
setup.py
2
setup.py
@@ -128,7 +128,7 @@ setup(
|
|||||||
"flash-attn==2.7.4.post1",
|
"flash-attn==2.7.4.post1",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed==0.16.1",
|
"deepspeed==0.16.4",
|
||||||
"deepspeed-kernels",
|
"deepspeed-kernels",
|
||||||
],
|
],
|
||||||
"mamba-ssm": [
|
"mamba-ssm": [
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Module with Pydantic models for configuration."""
|
"""Module with Pydantic models for configuration."""
|
||||||
|
|
||||||
# pylint: disable=too-many-lines
|
# pylint: disable=too-many-lines
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
@@ -506,7 +507,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
weight_decay: Optional[float] = 0.0
|
weight_decay: Optional[float] = 0.0
|
||||||
optimizer: Optional[
|
optimizer: Optional[
|
||||||
Union[OptimizerNames, CustomSupportedOptimizers]
|
Union[OptimizerNames, CustomSupportedOptimizers]
|
||||||
] = OptimizerNames.ADAMW_HF
|
] = OptimizerNames.ADAMW_TORCH_FUSED
|
||||||
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
|
||||||
@@ -1827,6 +1828,14 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
data["torch_compile"] = False
|
data["torch_compile"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def check_beta_and_trl_beta_match(cls, data):
|
||||||
|
if data.get("beta") and data.get("trl", {}).get("beta"):
|
||||||
|
if data["beta"] != data["trl"]["beta"]:
|
||||||
|
raise ValueError("beta and trl.beta must match or one must be removed")
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
def handle_legacy_message_fields_logic(data: dict) -> dict:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -344,6 +345,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||||
else:
|
else:
|
||||||
|
os.makedirs(prepared_ds_path, exist_ok=True)
|
||||||
dataset.save_to_disk(str(prepared_ds_path))
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
|
|||||||
@@ -108,6 +108,12 @@ def download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_tiny_shakespeare_dataset():
|
||||||
|
# download the dataset
|
||||||
|
snapshot_download_w_retry("Trelis/tiny-shakespeare", repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_dir():
|
def temp_dir():
|
||||||
# Create a temporary directory
|
# Create a temporary directory
|
||||||
|
|||||||
@@ -40,8 +40,8 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
"lora_alpha": 16,
|
"lora_alpha": 16,
|
||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_modules": ["q_proj", "v_proj"],
|
"lora_target_modules": ["q_proj", "v_proj"],
|
||||||
"relora_steps": 100,
|
"relora_steps": 50,
|
||||||
"relora_warmup_steps": 20,
|
"relora_warmup_steps": 10,
|
||||||
"relora_anneal_steps": 10,
|
"relora_anneal_steps": 10,
|
||||||
"relora_prune_ratio": 0.9,
|
"relora_prune_ratio": 0.9,
|
||||||
"relora_cpu_offload": True,
|
"relora_cpu_offload": True,
|
||||||
@@ -60,9 +60,9 @@ class TestReLoraLlama(unittest.TestCase):
|
|||||||
"message_field_content": "value",
|
"message_field_content": "value",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"warmup_steps": 20,
|
"warmup_steps": 10,
|
||||||
"num_epochs": 2,
|
"num_epochs": 2,
|
||||||
"max_steps": 205, # at least 2x relora_steps
|
"max_steps": 105, # at least 2x relora_steps
|
||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 1,
|
"gradient_accumulation_steps": 1,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
|||||||
@@ -7,13 +7,13 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from conftest import snapshot_download_w_retry
|
||||||
from constants import (
|
from constants import (
|
||||||
ALPACA_MESSAGES_CONFIG_OG,
|
ALPACA_MESSAGES_CONFIG_OG,
|
||||||
ALPACA_MESSAGES_CONFIG_REVISION,
|
ALPACA_MESSAGES_CONFIG_REVISION,
|
||||||
SPECIAL_TOKENS,
|
SPECIAL_TOKENS,
|
||||||
)
|
)
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from huggingface_hub import snapshot_download
|
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from axolotl.utils.data import load_tokenized_prepared_datasets
|
from axolotl.utils.data import load_tokenized_prepared_datasets
|
||||||
@@ -69,7 +69,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_download(
|
snapshot_download_w_retry(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
@@ -81,7 +81,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
# how to load it.
|
# how to load it.
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"tokenizer_config": "huggyllama/llama-7b",
|
"tokenizer_config": "HuggingFaceTB/SmolLM2-135M",
|
||||||
"sequence_len": 1024,
|
"sequence_len": 1024,
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
@@ -339,7 +339,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_download(
|
snapshot_download_w_retry(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
@@ -381,7 +381,7 @@ class TestDatasetPreparation(unittest.TestCase):
|
|||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
tmp_ds_path.mkdir(parents=True, exist_ok=True)
|
||||||
snapshot_download(
|
snapshot_download_w_retry(
|
||||||
repo_id="mhenrichsen/alpaca_2k_test",
|
repo_id="mhenrichsen/alpaca_2k_test",
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
local_dir=tmp_ds_path,
|
local_dir=tmp_ds_path,
|
||||||
|
|||||||
Reference in New Issue
Block a user