Compare commits
44 Commits
optimizers
...
quartodoc
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0bffef25d0 | ||
|
|
94c00c1d04 | ||
|
|
ddd84d7c65 | ||
|
|
42bdf0bd74 | ||
|
|
b03d96a228 | ||
|
|
2653f170fc | ||
|
|
3bfcce9f0a | ||
|
|
8feb746953 | ||
|
|
a563815fe7 | ||
|
|
81f2203151 | ||
|
|
5b7e688fc5 | ||
|
|
5134aa66cd | ||
|
|
ba9a867adb | ||
|
|
c618f42c39 | ||
|
|
fc1f985296 | ||
|
|
a5e37f183c | ||
|
|
e6a7bbe9ff | ||
|
|
e4fd7aad0b | ||
|
|
c907ac173e | ||
|
|
187227d837 | ||
|
|
f8de8bb4f2 | ||
|
|
8e604848a4 | ||
|
|
aae4337f40 | ||
|
|
38df5a36ea | ||
|
|
4d92a68a96 | ||
|
|
85147ec430 | ||
|
|
51cd409488 | ||
|
|
7235123d44 | ||
|
|
4f5eb42a73 | ||
|
|
fbe54be6b8 | ||
|
|
04f6324833 | ||
|
|
f0072f3b9d | ||
|
|
59899b9817 | ||
|
|
4a736986fa | ||
|
|
5d0f110a3b | ||
|
|
83f8698b8a | ||
|
|
60a11a6410 | ||
|
|
46a045e528 | ||
|
|
3b477e08a0 | ||
|
|
16dc6ee68d | ||
|
|
fa7c79b3b9 | ||
|
|
ae66374156 | ||
|
|
5e21b1a9da | ||
|
|
575e5f28ec |
8
.github/workflows/base.yml
vendored
8
.github/workflows/base.yml
vendored
@@ -40,6 +40,12 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
- cuda: "128"
|
||||
cuda_version: 12.8.1
|
||||
cudnn_version: ""
|
||||
python_version: "3.11"
|
||||
pytorch: nightly
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -61,7 +67,7 @@ jobs:
|
||||
uses: docker/build-push-action@v4
|
||||
with:
|
||||
context: .
|
||||
file: ./docker/Dockerfile-base
|
||||
file: ${{ matrix.pytorch == 'nightly' && './docker/Dockerfile-base-nightly' || './docker/Dockerfile-base' }}
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
|
||||
7
.github/workflows/docs.yml
vendored
7
.github/workflows/docs.yml
vendored
@@ -20,9 +20,12 @@ jobs:
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
- name: install dependencies
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python3 -m pip install jupyter
|
||||
python3 -m pip install jupyter quartodoc
|
||||
python3 -m pip install -e .
|
||||
- name: Build autodoc
|
||||
run: quartodoc build
|
||||
- name: Publish to GitHub Pages (and render)
|
||||
uses: quarto-dev/quarto-actions/publish@v2
|
||||
with:
|
||||
|
||||
5
.github/workflows/main.yml
vendored
5
.github/workflows/main.yml
vendored
@@ -88,6 +88,11 @@ jobs:
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
5
.github/workflows/nightlies.yml
vendored
5
.github/workflows/nightlies.yml
vendored
@@ -80,6 +80,11 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
pytorch: 2.6.0
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
49
.github/workflows/precommit-autoupdate.yml
vendored
Normal file
49
.github/workflows/precommit-autoupdate.yml
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
name: Pre-commit auto-update
|
||||
|
||||
on:
|
||||
schedule:
|
||||
- cron: '0 0 * * 0' # Run weekly
|
||||
workflow_dispatch: # Manual kickoff
|
||||
|
||||
jobs:
|
||||
auto-update:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Update pre-commit hooks
|
||||
id: update
|
||||
run: |
|
||||
pip install pre-commit
|
||||
pre-commit autoupdate
|
||||
if [[ -n $(git status --porcelain) ]]; then
|
||||
echo "changes=true" >> $GITHUB_OUTPUT
|
||||
git diff .pre-commit-config.yaml > pre-commit-update.diff
|
||||
fi
|
||||
|
||||
- name: Create Pull Request
|
||||
if: steps.update.outputs.changes == 'true'
|
||||
uses: peter-evans/create-pull-request@v6
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }}
|
||||
branch: update/pre-commit-hooks
|
||||
delete-branch: true
|
||||
title: "chore: update pre-commit hooks"
|
||||
commit-message: "chore: update pre-commit hooks"
|
||||
body: |
|
||||
Automated PR to update pre-commit hooks to their latest versions.
|
||||
|
||||
<details>
|
||||
<summary>Changes:</summary>
|
||||
|
||||
```diff
|
||||
${{ steps.update.outputs.diff }}
|
||||
```
|
||||
</details>
|
||||
2
.github/workflows/pypi.yml
vendored
2
.github/workflows/pypi.yml
vendored
@@ -40,7 +40,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install wheel packaging
|
||||
pip3 install wheel packaging==23.2
|
||||
pip3 install --no-build-isolation -e .
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
|
||||
4
.github/workflows/tests-nightly.yml
vendored
4
.github/workflows/tests-nightly.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -59,7 +59,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging
|
||||
pip3 install --upgrade packaging==23.2
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -74,7 +74,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
@@ -147,7 +147,7 @@ jobs:
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging setuptools setuptools_scm build wheel
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 setuptools_scm build wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
|
||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -181,6 +181,10 @@ prepared-datasets/
|
||||
submit.sh
|
||||
*.out*
|
||||
|
||||
# Quartodoc generated files
|
||||
objects.json
|
||||
site_libs/
|
||||
|
||||
typings/
|
||||
out/
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ default_language_version:
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.4.0
|
||||
rev: v5.0.0
|
||||
hooks:
|
||||
- id: check-yaml
|
||||
- id: end-of-file-fixer
|
||||
@@ -11,23 +11,23 @@ repos:
|
||||
- id: no-commit-to-branch
|
||||
args: ['--branch', 'main']
|
||||
- repo: https://github.com/psf/black
|
||||
rev: 23.3.0
|
||||
rev: 25.1.0
|
||||
hooks:
|
||||
- id: black
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: 5.12.0
|
||||
rev: 6.0.1
|
||||
hooks:
|
||||
- id: isort
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 6.1.0
|
||||
rev: 7.1.2
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/pylint
|
||||
rev: v3.3.0
|
||||
- repo: https://github.com/pylint-dev/pylint
|
||||
rev: v3.3.6
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.3.0
|
||||
rev: v1.15.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
@@ -36,7 +36,7 @@ repos:
|
||||
'pydantic>=2.5.3',
|
||||
]
|
||||
- repo: https://github.com/PyCQA/bandit
|
||||
rev: 1.7.5
|
||||
rev: 1.8.3
|
||||
hooks:
|
||||
- id: bandit
|
||||
args: [
|
||||
|
||||
@@ -55,6 +55,7 @@ Features:
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
|
||||
# Download example axolotl configs, deepspeed configs
|
||||
@@ -96,6 +97,7 @@ That's it! Check out our [Getting Started Guide](https://axolotl-ai-cloud.github
|
||||
- [Multi-GPU Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-gpu.html)
|
||||
- [Multi-Node Training](https://axolotl-ai-cloud.github.io/axolotl/docs/multi-node.html)
|
||||
- [Multipacking](https://axolotl-ai-cloud.github.io/axolotl/docs/multipack.html)
|
||||
- [API Reference](https://axolotl-ai-cloud.github.io/axolotl/docs/api/) - Auto-generated code documentation
|
||||
- [FAQ](https://axolotl-ai-cloud.github.io/axolotl/docs/faq.html) - Frequently asked questions
|
||||
|
||||
## 🤝 Getting Help
|
||||
|
||||
200
_quarto.yml
200
_quarto.yml
@@ -1,6 +1,178 @@
|
||||
project:
|
||||
type: website
|
||||
|
||||
quartodoc:
|
||||
dir: docs/api
|
||||
package: axolotl
|
||||
title: API Reference
|
||||
parser: google
|
||||
|
||||
sections:
|
||||
- title: Core
|
||||
desc: Core functionality for training
|
||||
contents:
|
||||
- train
|
||||
- evaluate
|
||||
- datasets
|
||||
- convert
|
||||
- prompt_tokenizers
|
||||
- logging_config
|
||||
- core.trainer_builder
|
||||
- core.training_args
|
||||
- core.chat.messages
|
||||
- core.chat.format.chatml
|
||||
- core.chat.format.llama3x
|
||||
- core.chat.format.shared
|
||||
- core.datasets.chat
|
||||
- core.datasets.transforms.chat_builder
|
||||
- title: CLI
|
||||
desc: Command-line interface
|
||||
contents:
|
||||
- cli.main
|
||||
- cli.train
|
||||
- cli.evaluate
|
||||
- cli.args
|
||||
- cli.checks
|
||||
- cli.config
|
||||
- cli.inference
|
||||
- cli.merge_lora
|
||||
- cli.merge_sharded_fsdp_weights
|
||||
- cli.preprocess
|
||||
- cli.sweeps
|
||||
- cli.utils
|
||||
- cli.cloud.base
|
||||
- cli.cloud.modal_
|
||||
- title: Trainers
|
||||
desc: Training implementations
|
||||
contents:
|
||||
- core.trainers.base
|
||||
- core.trainers.trl
|
||||
- core.trainers.dpo.trainer
|
||||
- core.trainers.grpo.trainer
|
||||
- title: Prompt Strategies
|
||||
desc: Prompt formatting strategies
|
||||
contents:
|
||||
- prompt_strategies.base
|
||||
- prompt_strategies.chat_template
|
||||
- prompt_strategies.alpaca_chat
|
||||
- prompt_strategies.alpaca_instruct
|
||||
- prompt_strategies.alpaca_w_system
|
||||
- prompt_strategies.user_defined
|
||||
- prompt_strategies.llama2_chat
|
||||
- prompt_strategies.completion
|
||||
- prompt_strategies.input_output
|
||||
- prompt_strategies.stepwise_supervised
|
||||
- prompt_strategies.metharme
|
||||
- prompt_strategies.orcamini
|
||||
- prompt_strategies.pygmalion
|
||||
- prompt_strategies.messages.chat
|
||||
- prompt_strategies.dpo.chat_template
|
||||
- prompt_strategies.dpo.llama3
|
||||
- prompt_strategies.dpo.chatml
|
||||
- prompt_strategies.dpo.zephyr
|
||||
- prompt_strategies.dpo.user_defined
|
||||
- prompt_strategies.dpo.passthrough
|
||||
- prompt_strategies.kto.llama3
|
||||
- prompt_strategies.kto.chatml
|
||||
- prompt_strategies.kto.user_defined
|
||||
- prompt_strategies.orpo.chat_template
|
||||
- prompt_strategies.bradley_terry.llama3
|
||||
- title: Kernels
|
||||
desc: Low-level performance optimizations
|
||||
contents:
|
||||
- kernels.lora
|
||||
- kernels.geglu
|
||||
- kernels.swiglu
|
||||
- kernels.quantize
|
||||
- kernels.utils
|
||||
- title: MonkeyPatches
|
||||
desc: Runtime patches for model optimizations
|
||||
contents:
|
||||
- monkeypatch.llama_attn_hijack_flash
|
||||
- monkeypatch.llama_attn_hijack_xformers
|
||||
- monkeypatch.mistral_attn_hijack_flash
|
||||
- monkeypatch.multipack
|
||||
- monkeypatch.relora
|
||||
- monkeypatch.llama_expand_mask
|
||||
- monkeypatch.lora_kernels
|
||||
- monkeypatch.utils
|
||||
- monkeypatch.btlm_attn_hijack_flash
|
||||
- monkeypatch.llama_patch_multipack
|
||||
- monkeypatch.stablelm_attn_hijack_flash
|
||||
- monkeypatch.trainer_fsdp_optim
|
||||
- monkeypatch.transformers_fa_utils
|
||||
- monkeypatch.unsloth_
|
||||
- monkeypatch.attention.mllama
|
||||
- monkeypatch.data.batch_dataset_fetcher
|
||||
- monkeypatch.mixtral
|
||||
- title: Utils
|
||||
desc: Utility functions
|
||||
contents:
|
||||
- utils.models
|
||||
- utils.tokenization
|
||||
- utils.chat_templates
|
||||
- utils.lora
|
||||
- utils.lora_embeddings
|
||||
- utils.model_shard_quant
|
||||
- utils.bench
|
||||
- utils.freeze
|
||||
- utils.trainer
|
||||
- utils.schedulers
|
||||
- utils.distributed
|
||||
- utils.dict
|
||||
- utils.optimizers.adopt
|
||||
- utils.data.pretraining
|
||||
- utils.data.sft
|
||||
- utils.gradient_checkpointing.unsloth
|
||||
- title: Schemas
|
||||
desc: Pydantic data models for Axolotl config
|
||||
contents:
|
||||
- utils.schemas.config
|
||||
- utils.schemas.model
|
||||
- utils.schemas.training
|
||||
- utils.schemas.datasets
|
||||
- utils.schemas.peft
|
||||
- utils.schemas.trl
|
||||
- utils.schemas.integrations
|
||||
- utils.schemas.enums
|
||||
- utils.schemas.utils
|
||||
- title: Integrations
|
||||
desc: Third-party integrations and extensions
|
||||
contents:
|
||||
- integrations.base
|
||||
- integrations.cut_cross_entropy.args
|
||||
- integrations.grokfast.optimizer
|
||||
- integrations.kd.trainer
|
||||
- integrations.liger.args
|
||||
- integrations.lm_eval.args
|
||||
- integrations.spectrum.args
|
||||
- title: Common
|
||||
desc: Common utilities and shared functionality
|
||||
contents:
|
||||
- common.architectures
|
||||
- common.const
|
||||
- common.datasets
|
||||
- title: Models
|
||||
desc: Custom model implementations
|
||||
contents:
|
||||
- models.mamba.modeling_mamba
|
||||
- title: Data Processing
|
||||
desc: Data processing utilities
|
||||
contents:
|
||||
- utils.collators.core
|
||||
- utils.collators.batching
|
||||
- utils.collators.mamba
|
||||
- utils.collators.mm_chat
|
||||
- utils.samplers.multipack
|
||||
- title: Callbacks
|
||||
desc: Training callbacks
|
||||
contents:
|
||||
- utils.callbacks.perplexity
|
||||
- utils.callbacks.profiler
|
||||
- utils.callbacks.lisa
|
||||
- utils.callbacks.mlflow_
|
||||
- utils.callbacks.comet_
|
||||
|
||||
website:
|
||||
title: "Axolotl"
|
||||
description: "We make fine-tuning accessible, scalable, and fun"
|
||||
@@ -32,8 +204,11 @@ website:
|
||||
contents:
|
||||
- docs/getting-started.qmd
|
||||
- docs/installation.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/inference.qmd
|
||||
- docs/cli.qmd
|
||||
- docs/config.qmd
|
||||
- text: "API Reference"
|
||||
href: docs/api
|
||||
|
||||
- section: "Dataset Formats"
|
||||
contents: docs/dataset-formats/*
|
||||
@@ -74,12 +249,27 @@ website:
|
||||
- docs/debugging.qmd
|
||||
- docs/nccl.qmd
|
||||
|
||||
- section: "Reference"
|
||||
contents:
|
||||
- docs/config.qmd
|
||||
|
||||
format:
|
||||
html:
|
||||
theme: darkly
|
||||
css: styles.css
|
||||
toc: true
|
||||
# Enable better handling of line breaks in markdown
|
||||
preserve-tabs: true
|
||||
html-math-method: mathjax
|
||||
# Improved markdown processing options
|
||||
md-extensions:
|
||||
- markdown_it
|
||||
- def_list
|
||||
- attr_list
|
||||
- fenced_divs
|
||||
- tables
|
||||
- html_admonition
|
||||
- lineblocks
|
||||
- fancy_lists
|
||||
# Control whitespace handling
|
||||
whitespace: preserve
|
||||
# Process newlines in paragraphs
|
||||
wrap: preserve
|
||||
# Better line break handling
|
||||
preserve-linebreaks: true
|
||||
|
||||
@@ -31,6 +31,7 @@ RUN if [ "$NIGHTLY_BUILD" = "true" ] ; then \
|
||||
sed -i 's#^datasets.*#datasets @ git+https://github.com/huggingface/datasets.git@main#' requirements.txt; \
|
||||
fi
|
||||
|
||||
RUN pip install packaging==23.2 setuptools==75.8.0
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install --no-build-isolation -e .[deepspeed,flash-attn,optimizers,ray,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \
|
||||
else \
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
modal application to run axolotl gpu tests in Modal
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Modal app to run axolotl GPU tests"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
|
||||
@@ -28,7 +28,7 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install -U packaging==23.2 setuptools==75.8.0 wheel && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
39
docker/Dockerfile-base-nightly
Normal file
39
docker/Dockerfile-base-nightly
Normal file
@@ -0,0 +1,39 @@
|
||||
ARG CUDA_VERSION="12.8.1"
|
||||
ARG CUDNN_VERSION="8"
|
||||
ARG UBUNTU_VERSION="22.04"
|
||||
ARG MAX_JOBS=4
|
||||
|
||||
FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION AS base-builder
|
||||
|
||||
ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
|
||||
ARG PYTHON_VERSION="3.11"
|
||||
ARG PYTORCH_VERSION="nightly"
|
||||
ARG CUDA="128"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev pkg-config && rm -rf /var/lib/apt/lists/* \
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
||||
|
||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch --extra-index-url https://download.pytorch.org/whl/nightly/cu$CUDA && \
|
||||
python3 -m pip install --no-cache-dir "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" && \
|
||||
python3 -m pip install --no-cache-dir "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main"
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
@@ -14,7 +14,7 @@ COPY scripts/motd /etc/motd
|
||||
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||
RUN apt install --yes --no-install-recommends openssh-server tmux iproute2 nvtop && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
|
||||
2
docs/.gitignore
vendored
2
docs/.gitignore
vendored
@@ -1,2 +1,4 @@
|
||||
/.quarto/
|
||||
_site/
|
||||
/api/*.qmd
|
||||
/api/*.html
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "CLI Reference"
|
||||
title: "Command Line Interface (CLI)"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: Config options
|
||||
title: Config Reference
|
||||
description: A complete list of all configuration options.
|
||||
---
|
||||
|
||||
@@ -30,6 +30,8 @@ tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
|
||||
# (Internal use only)
|
||||
# Used to identify which the model is based on
|
||||
@@ -83,6 +85,12 @@ gpu_memory_limit: 20GiB
|
||||
# Do the LoRA/PEFT loading on CPU -- this is required if the base model is so large it takes up most or all of the available GPU VRAM, e.g. during a model and LoRA merge
|
||||
lora_on_cpu: true
|
||||
|
||||
# List[str]. Add plugins to extend the pipeline.
|
||||
# See `src/axolotl/integrations` for the available plugins or doc below for more details.
|
||||
# https://axolotl-ai-cloud.github.io/axolotl/docs/custom_integrations.html
|
||||
plugins:
|
||||
# - axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
@@ -154,8 +162,6 @@ datasets:
|
||||
content: value
|
||||
# ...
|
||||
|
||||
message_property_mappings:
|
||||
|
||||
# Optional[Dict[str, List]]. Roles mapping in the messages. The default is:
|
||||
roles:
|
||||
user: ["human", "user"]
|
||||
@@ -207,10 +213,46 @@ test_datasets:
|
||||
data_files:
|
||||
- /workspace/data/eval.jsonl
|
||||
|
||||
# use RL training: 'dpo', 'ipo', 'kto'
|
||||
# use RL training: 'dpo', 'ipo', 'kto', 'simpo', 'orpo', 'grpo'
|
||||
rl:
|
||||
# whether to perform weighting if doing DPO training. Boolean.
|
||||
dpo_use_weighting:
|
||||
rl_beta: # Optional[float]. The beta parameter for the RL training.
|
||||
|
||||
# dpo
|
||||
dpo_use_weighting: # Optional[bool]. Whether to perform weighting.
|
||||
rpo_alpha: # Optional[float]. Weighting of NLL term in loss from RPO paper.
|
||||
|
||||
# orpo
|
||||
orpo_alpha: 0.1 # Parameter controlling the relative ratio loss weight in the ORPO loss. Passed to `beta` in `ORPOConfig` due to trl mapping.
|
||||
|
||||
# kto
|
||||
kto_desirable_weight: # Optional[float]. Factor for desirable loss term in KTO loss.
|
||||
kto_undesirable_weight: # Optional[float]. Factor for undesirable loss term in KTO loss.
|
||||
|
||||
# simpo
|
||||
cpo_alpha: 1.0 # Weight of the BC regularizer
|
||||
simpo_gamma: 0.5 # Target reward margin for the SimPO loss
|
||||
|
||||
# grpo
|
||||
trl:
|
||||
use_vllm: # Optional[bool]. Whether to use VLLM for RL training.
|
||||
vllm_device: # Optional[str]. Device to use for VLLM.
|
||||
vllm_gpu_memory_utilization: # Optional[float]. GPU memory utilization for VLLM.
|
||||
vllm_max_model_len: # Optional[int]. Maximum length of the model for VLLM.
|
||||
vllm_dtype: # Optional[str]. Data type for VLLM.
|
||||
|
||||
beta: # Optional[float]. Beta parameter for the RL training. Same as `rl_beta`. Use
|
||||
max_completion_length: # Optional[int]. Maximum length of the completion for RL training.
|
||||
|
||||
reward_funcs: # Optional[list[str]]. List of reward functions to load. Paths must be importable from current dir.
|
||||
reward_weights: # Optional[list[float]]. List of reward weights for the reward functions.
|
||||
|
||||
num_generations: # Optional[int]. Number of generations to sample.
|
||||
log_completions: # Optional[bool]. Whether to log completions.
|
||||
|
||||
sync_ref_model: # Optional[bool]. Whether to sync the reference model.
|
||||
ref_model_mixup_alpha: # Optional[float]. Mixup alpha for the reference model.
|
||||
ref_model_sync_steps: # Optional[int]. Sync steps for the reference model.
|
||||
|
||||
|
||||
# reward modelling: `True` or `False`
|
||||
reward_model:
|
||||
@@ -234,7 +276,7 @@ default_system_message: You are a helpful assistant. Please give a long and deta
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
push_dataset_to_hub: # Optional[str] repo_org/repo_name
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
dataset_processes: # defaults to os.cpu_count() if not set
|
||||
@@ -556,6 +598,13 @@ special_tokens:
|
||||
# Add extra tokens.
|
||||
tokens:
|
||||
|
||||
# Mapping token_id to new_token_string to override reserved added_tokens in the tokenizer.
|
||||
# Only works for tokens that are not part of the base vocab (aka are added_tokens).
|
||||
# Can be checked if they exist in tokenizer.json added_tokens.
|
||||
added_tokens_overrides: # Dict[int, str]
|
||||
# 128041: "<|im_start|>"
|
||||
# 128042: "<|im_end|>"
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
@@ -55,3 +55,47 @@ sections = [
|
||||
for section_name, folder_name in sections:
|
||||
print(print_section(section_name, folder_name))
|
||||
```
|
||||
|
||||
## Adding a new integration
|
||||
|
||||
Plugins can be used to customize the behavior of the training pipeline through [hooks](https://en.wikipedia.org/wiki/Hooking). See [`axolotl.integrations.BasePlugin`](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/integrations/base.py) for the possible hooks.
|
||||
|
||||
To add a new integration, please follow these steps:
|
||||
|
||||
1. Create a new folder in the `src/axolotl/integrations` directory.
|
||||
2. Add any relevant files (`LICENSE`, `README.md`, `ACKNOWLEDGEMENTS.md`, etc.) to the new folder.
|
||||
3. Add `__init__.py` and `args.py` files to the new folder.
|
||||
- `__init__.py` should import the integration and hook into the appropriate functions.
|
||||
- `args.py` should define the arguments for the integration.
|
||||
4. (If applicable) Add CPU tests under `tests/integrations` or GPU tests under `tests/e2e/integrations`.
|
||||
|
||||
::: {.callout-tip}
|
||||
|
||||
See [src/axolotl/integrations/cut_cross_entropy](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/integrations/cut_cross_entropy) for a minimal integration example.
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-warning}
|
||||
|
||||
If you could not load your integration, please ensure you are pip installing in editable mode.
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
and correctly spelled the integration name in the config file.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.your_integration_name.YourIntegrationPlugin
|
||||
```
|
||||
|
||||
:::
|
||||
|
||||
::: {.callout-note}
|
||||
|
||||
It is not necessary to place your integration in the `integrations` folder. It can be in any location, so long as it's installed in a package in your python env.
|
||||
|
||||
See this repo for an example: [https://github.com/axolotl-ai-cloud/diff-transformer](https://github.com/axolotl-ai-cloud/diff-transformer)
|
||||
|
||||
:::
|
||||
|
||||
@@ -74,6 +74,10 @@ datasets:
|
||||
train_on_eos:
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
If you receive an error like "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null.", it means the tokenizer does not have a default `chat_template`. Follow the examples below instead to set a custom `chat_template`.
|
||||
:::
|
||||
|
||||
2. Using the `gemma` chat template to override the tokenizer_config.json's chat template on OpenAI messages format, training on all assistant messages.
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -6,7 +6,7 @@ description: How datasets are processed
|
||||
## Overview
|
||||
|
||||
Dataset pre-processing is the step where Axolotl takes each dataset you've configured alongside
|
||||
the [dataset format](docs/dataset-formats) and prompt strategies to:
|
||||
the [dataset format](dataset-formats) and prompt strategies to:
|
||||
|
||||
- parse the dataset based on the *dataset format*
|
||||
- transform the dataset to how you would interact with the model based on the *prompt strategy*
|
||||
|
||||
14
docs/faq.qmd
14
docs/faq.qmd
@@ -27,6 +27,16 @@ description: Frequently asked questions
|
||||
|
||||
> A: This is usually an issue with the GPU. This can be resolved through setting the os environment variable `CUDA_VISIBLE_DEVICES=0`. If you are on runpod, this is usually a pod issue. Starting a new pod should take care of it.
|
||||
|
||||
**Q: Received mismatch error on merge adapters / loading adapters between torch.Size of checkpoint and model.**
|
||||
|
||||
> A: This is likely due to vocab size mismatch. By default, Axolotl expands the model's embeddings if the tokenizer has more tokens than the model. Please use the `axolotl merge-lora` command to merge the adapters instead of using your own scripts.
|
||||
|
||||
> On the other hand, if the model has more tokens than the tokenizer, Axolotl does not shrink the model's embeddings unless `shrink_embeddings: true` is set in the config.
|
||||
|
||||
**Q: How to call Axolotl via custom python scripts?**
|
||||
|
||||
> A: Yes, since Axolotl is just Python, please see `src/axolotl/cli/main.py` on how each command is called.
|
||||
|
||||
### Chat templates
|
||||
|
||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||
@@ -52,3 +62,7 @@ description: Frequently asked questions
|
||||
**Q: The EOS/EOT token is incorrectly being masked or not being masked.**
|
||||
|
||||
> A: This is because of the mismatch between `tokenizer.eos_token` and EOS/EOT token in template. Please make sure to set `eos_token` under `special_tokens` to the same EOS/EOT token as in template.
|
||||
|
||||
**Q: "`chat_template` choice is `tokenizer_default` but tokenizer's `chat_template` is null. Please add a `chat_template` in tokenizer config"**
|
||||
|
||||
> A: This is because the tokenizer does not have a chat template. Please add a chat template in the tokenizer config. See [chat_template](dataset-formats/conversation.qmd#chat-template) for more details.
|
||||
|
||||
@@ -36,7 +36,9 @@ The YAML configuration file controls everything about your training. Here's what
|
||||
|
||||
```yaml
|
||||
base_model: NousResearch/Llama-3.2-1B
|
||||
# hub_model_id: username/custom_model_name
|
||||
|
||||
load_in_8bit: true
|
||||
adapter: lora
|
||||
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
@@ -44,11 +46,15 @@ datasets:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.1
|
||||
output_dir: ./outputs/lora-out
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
```
|
||||
|
||||
::: {.callout-tip}
|
||||
`load_in_8bit: true` and `adapter: lora` enables LoRA adapter finetuning.
|
||||
|
||||
- To perform Full finetuning, remove these two lines.
|
||||
- To perform QLoRA finetuning, replace with `load_in_4bit: true` and `adapter: qlora`.
|
||||
:::
|
||||
|
||||
See our [Config options](config.qmd) for more details.
|
||||
|
||||
### Training {#sec-training}
|
||||
@@ -56,7 +62,7 @@ See our [Config options](config.qmd) for more details.
|
||||
When you run `axolotl train`, Axolotl:
|
||||
|
||||
1. Downloads the base model
|
||||
2. (If specified) applies LoRA adapter layers
|
||||
2. (If specified) applies QLoRA/LoRA adapter layers
|
||||
3. Loads and processes the dataset
|
||||
4. Runs the training loop
|
||||
5. Saves the trained model and / or LoRA weights
|
||||
@@ -69,6 +75,8 @@ Let's modify the example for your own data:
|
||||
|
||||
```yaml
|
||||
base_model: NousResearch/Nous-Hermes-llama-1b-v1
|
||||
|
||||
load_in_8bit: true
|
||||
adapter: lora
|
||||
|
||||
# Training settings
|
||||
@@ -104,8 +112,6 @@ format):
|
||||
{"instruction": "Classify this text", "input": "Not good at all", "output": "negative"}
|
||||
```
|
||||
|
||||
Please consult the supported [Dataset Formats](dataset-formats/) for more details.
|
||||
|
||||
3. Run the training:
|
||||
|
||||
```bash
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
---
|
||||
title: "Inference"
|
||||
title: "Inference and Merging"
|
||||
format:
|
||||
html:
|
||||
toc: true
|
||||
@@ -9,10 +9,14 @@ execute:
|
||||
enabled: false
|
||||
---
|
||||
|
||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, and common troubleshooting steps.
|
||||
This guide covers how to use your trained models for inference, including model loading, interactive testing, merging adapters, and common troubleshooting steps.
|
||||
|
||||
## Quick Start {#sec-quickstart}
|
||||
|
||||
::: {.callout-tip}
|
||||
Use the same config used for training on inference/merging.
|
||||
:::
|
||||
|
||||
### Basic Inference {#sec-basic}
|
||||
|
||||
::: {.panel-tabset}
|
||||
|
||||
@@ -22,6 +22,7 @@ This guide covers all the ways you can install and set up Axolotl for your envir
|
||||
### PyPI Installation (Recommended) {#sec-pypi}
|
||||
|
||||
```{.bash}
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||
```
|
||||
|
||||
@@ -37,7 +38,7 @@ For the latest features between releases:
|
||||
```{.bash}
|
||||
git clone https://github.com/axolotl-ai-cloud/axolotl.git
|
||||
cd axolotl
|
||||
pip3 install packaging ninja
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
|
||||
@@ -78,6 +79,7 @@ For providers supporting Docker:
|
||||
- [Latitude.sh](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- [JarvisLabs.ai](https://jarvislabs.ai/templates/axolotl)
|
||||
- [RunPod](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
- [Novita](https://novita.ai/gpus-console?templateId=311)
|
||||
|
||||
### Google Colab {#sec-colab}
|
||||
|
||||
@@ -107,7 +109,7 @@ We recommend using WSL2 (Windows Subsystem for Linux) or Docker.
|
||||
2. Install PyTorch: https://pytorch.org/get-started/locally/
|
||||
3. Install Axolotl:
|
||||
```{.bash}
|
||||
pip3 install packaging
|
||||
pip3 install -U packaging setuptools wheel ninja
|
||||
pip3 install --no-build-isolation -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Hugging Face:
|
||||
|
||||
@@ -66,6 +66,10 @@ logic to be compatible with more of them.
|
||||
|
||||
</details>
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [LoRA optimizations blog](https://axolotlai.substack.com/p/accelerating-lora-fine-tuning-with).
|
||||
:::
|
||||
|
||||
## Usage
|
||||
|
||||
These optimizations can be enabled in your Axolotl config YAML file. The
|
||||
|
||||
@@ -28,8 +28,23 @@ val_set_size: 0.1
|
||||
eval_steps: 100
|
||||
```
|
||||
|
||||
Bradley-Terry chat templates expect single-turn conversations in the following format:
|
||||
|
||||
```json
|
||||
{
|
||||
"system": "...", // optional
|
||||
"input": "...",
|
||||
"chosen": "...",
|
||||
"rejected": "..."
|
||||
}
|
||||
```
|
||||
|
||||
### Process Reward Models (PRM)
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [PRM blog](https://axolotlai.substack.com/p/process-reward-models).
|
||||
:::
|
||||
|
||||
Process reward models are trained using data which contains preference annotations for each step in a series of interactions. Typically, PRMs are trained to provide reward signals over each step of a reasoning trace and are used for downstream reinforcement learning.
|
||||
```yaml
|
||||
base_model: Qwen/Qwen2.5-3B
|
||||
@@ -45,3 +60,5 @@ datasets:
|
||||
val_set_size: 0.1
|
||||
eval_steps: 100
|
||||
```
|
||||
|
||||
Please see [stepwise_supervised](dataset-formats/stepwise_supervised.qmd) for more details on the dataset format.
|
||||
|
||||
@@ -3,6 +3,7 @@ title: "RLHF (Beta)"
|
||||
description: "Reinforcement Learning from Human Feedback is a method whereby a language model is optimized from data using human feedback."
|
||||
back-to-top-navigation: true
|
||||
toc: true
|
||||
toc-expand: 2
|
||||
toc-depth: 4
|
||||
---
|
||||
|
||||
@@ -297,7 +298,7 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
|
||||
### IPO
|
||||
|
||||
As IPO is just DPO with a different loss function, all supported options for DPO works here.
|
||||
As IPO is just DPO with a different loss function, all supported dataset formats for [DPO](#dpo) are also supported for IPO.
|
||||
|
||||
```yaml
|
||||
rl: ipo
|
||||
@@ -343,8 +344,9 @@ ORPO supports the following types with the following dataset format:
|
||||
|
||||
```yaml
|
||||
rl: kto
|
||||
rl_beta: 0.5
|
||||
kto_desirable_weight: 0.2
|
||||
rl_beta: 0.1 # default
|
||||
kto_desirable_weight: 1.0 # default
|
||||
kto_undesirable_weight: 1.0 # default
|
||||
|
||||
remove_unused_columns: false
|
||||
|
||||
@@ -496,6 +498,10 @@ The input format is a simple JSON input with customizable fields based on the ab
|
||||
|
||||
### GRPO
|
||||
|
||||
::: {.callout-tip}
|
||||
Check out our [GRPO cookbook](https://github.com/axolotl-ai-cloud/axolotl-cookbook/tree/main/grpo#training-an-r1-style-large-language-model-using-grpo).
|
||||
:::
|
||||
|
||||
GRPO uses custom reward functions and transformations. Please have them ready locally.
|
||||
|
||||
For ex, to load OpenAI's GSM8K and use a random reward for completions:
|
||||
@@ -528,6 +534,7 @@ trl:
|
||||
vllm_gpu_memory_utilization: 0.15
|
||||
num_generations: 4
|
||||
reward_funcs: ["rewards.rand_reward_func"] # format: '{file_name}.{fn_name}'
|
||||
reward_weights: [1.0]
|
||||
datasets:
|
||||
- path: openai/gsm8k
|
||||
name: main
|
||||
@@ -536,6 +543,21 @@ datasets:
|
||||
|
||||
To see other examples of custom reward functions, please see [TRL GRPO Docs](https://github.com/huggingface/trl/blob/main/docs/source/grpo_trainer.md#using-a-custom-reward-function).
|
||||
|
||||
To see description of the configs, please see [TRLConfig](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/models/input/v0_4_1/trl.py).
|
||||
|
||||
### SimPO
|
||||
|
||||
SimPO uses [CPOTrainer](https://huggingface.co/docs/trl/main/en/cpo_trainer) but with alternative loss function.
|
||||
|
||||
```yaml
|
||||
rl: simpo
|
||||
rl_beta: 0.1 # default in CPOTrainer
|
||||
cpo_alpha: 1.0 # default in CPOTrainer
|
||||
simpo_gamma: 0.5 # default in CPOTrainer
|
||||
```
|
||||
|
||||
This method uses the same dataset format as [DPO](#dpo).
|
||||
|
||||
### Using local dataset files
|
||||
|
||||
```yaml
|
||||
|
||||
@@ -55,7 +55,7 @@ tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
gradient_checkpointing_kwargs:
|
||||
use_reentrant: true
|
||||
use_reentrant: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8"]
|
||||
requires = ["setuptools>=64", "wheel", "setuptools_scm>=8", "packaging==23.2"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
@@ -8,6 +8,7 @@ dynamic = ["version", "dependencies", "optional-dependencies"]
|
||||
description = "LLM Trainer"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10"
|
||||
# license = "Apache-2.0"
|
||||
|
||||
[project.scripts]
|
||||
axolotl = "axolotl.cli.main:main"
|
||||
|
||||
@@ -2,3 +2,5 @@ pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
quartodoc
|
||||
jupyter
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.2
|
||||
bitsandbytes==0.45.3
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.4.post1
|
||||
@@ -12,12 +12,12 @@ liger-kernel==0.5.3
|
||||
|
||||
packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
peft==0.15.0
|
||||
transformers==4.49.0
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.3.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
tokenizers>=0.21.1
|
||||
accelerate==1.5.2
|
||||
datasets==3.4.1
|
||||
deepspeed==0.16.4
|
||||
trl==0.15.1
|
||||
|
||||
optimum==1.16.2
|
||||
@@ -62,5 +62,5 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.3
|
||||
axolotl-contribs-lgpl==0.0.6
|
||||
axolotl-contribs-mit==0.0.3
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
helper script to parse chat datasets into a usable yaml
|
||||
"""
|
||||
|
||||
import click
|
||||
import yaml
|
||||
from datasets import load_dataset
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Script to output the correct installation command for cut-cross-entropy."""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
|
||||
@@ -24,5 +25,5 @@ if cce_spec:
|
||||
|
||||
print(
|
||||
UNINSTALL_PREFIX
|
||||
+ 'pip install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||
+ 'pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"'
|
||||
)
|
||||
|
||||
2
setup.py
2
setup.py
@@ -128,7 +128,7 @@ setup(
|
||||
"flash-attn==2.7.4.post1",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed==0.16.1",
|
||||
"deepspeed==0.16.4",
|
||||
"deepspeed-kernels",
|
||||
],
|
||||
"mamba-ssm": [
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
launch axolotl in supported cloud platforms
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
base class for cloud platforms from cli
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Modal Cloud support from CLI
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
@@ -113,7 +114,7 @@ class ModalCloud(Cloud):
|
||||
[
|
||||
# Random id for cache busting of branch commits
|
||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch} && git pull",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -270,6 +271,7 @@ def _preprocess(config_yaml: str, volumes=None):
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
@@ -288,6 +290,7 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None, **kwargs):
|
||||
|
||||
|
||||
def _lm_eval(config_yaml: str, volumes=None):
|
||||
Path("/workspace/mounts").mkdir(parents=True, exist_ok=True)
|
||||
with open("/workspace/mounts/config.yaml", "w", encoding="utf-8") as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/mounts"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Click CLI definitions for various axolotl commands."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import logging
|
||||
@@ -24,7 +25,7 @@ from axolotl.cli.utils import (
|
||||
)
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig
|
||||
|
||||
|
||||
@click.group()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""CLI to run training on a model."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
@@ -34,7 +35,8 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
|
||||
"""
|
||||
print_axolotl_text_art()
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -5,7 +5,6 @@ import dataclasses
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import typing
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
from types import NoneType
|
||||
@@ -24,7 +23,7 @@ configure_logging()
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def strip_optional_type(field_type: type | typing._SpecialForm | None):
|
||||
def strip_optional_type(field_type: type | str | None):
|
||||
"""
|
||||
Extracts the non-`None` type from an `Optional` / `Union` type.
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
"""Module containing File Reader, File Writer, Json Parser, and Jsonl Serializer classes"""
|
||||
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
ChatML transformation functions for MessageContents
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..messages import MessageContents, Messages
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Llama 3.x chat formatting functions for MessageContents
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from ..messages import MessageContents, Messages
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
shared functions for format transforms
|
||||
"""
|
||||
|
||||
from axolotl.core.chat.messages import MessageContents, Messages
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
internal message representations of chat messages
|
||||
"""
|
||||
|
||||
import json
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
chat dataset module
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
@@ -43,7 +44,7 @@ class TokenizedChatDataset(Dataset):
|
||||
process_or_cpu_count: int = (
|
||||
process_count or os.cpu_count() # type: ignore[assignment]
|
||||
)
|
||||
num_proc = min(64, process_or_cpu_count)
|
||||
num_proc = min(32, process_or_cpu_count)
|
||||
features = data.features.keys()
|
||||
tokenized_data = data.map(
|
||||
map_fn,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
This module contains a function that builds a transform that takes a row from the dataset and converts it to a Chat.
|
||||
"""
|
||||
|
||||
from typing import Any, Mapping, Union
|
||||
|
||||
|
||||
|
||||
@@ -13,9 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
"""
|
||||
Builder for the training args and trainer
|
||||
"""
|
||||
"""Builder for the training args and trainer"""
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
@@ -85,8 +83,8 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.config.models.input.v0_4_1 import CustomSupportedOptimizers
|
||||
from axolotl.utils.models import ensure_dtype
|
||||
from axolotl.utils.schemas.enums import CustomSupportedOptimizers
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
@@ -332,9 +330,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs = {}
|
||||
|
||||
if self.cfg.include_tokens_per_second is not None:
|
||||
training_arguments_kwargs[
|
||||
"include_tokens_per_second"
|
||||
] = self.cfg.include_tokens_per_second
|
||||
training_arguments_kwargs["include_tokens_per_second"] = (
|
||||
self.cfg.include_tokens_per_second
|
||||
)
|
||||
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
@@ -351,13 +349,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
training_arguments_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing_kwargs"
|
||||
] = self.cfg.gradient_checkpointing_kwargs
|
||||
training_arguments_kwargs["gradient_checkpointing_kwargs"] = (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
@@ -373,9 +371,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
|
||||
if self.cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs[
|
||||
"lr_quadratic_warmup"
|
||||
] = self.cfg.lr_quadratic_warmup
|
||||
training_arguments_kwargs["lr_quadratic_warmup"] = (
|
||||
self.cfg.lr_quadratic_warmup
|
||||
)
|
||||
|
||||
if self.cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
@@ -399,28 +397,28 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
training_arguments_kwargs["dataloader_pin_memory"] = (
|
||||
self.cfg.dataloader_pin_memory
|
||||
)
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
training_arguments_kwargs["dataloader_num_workers"] = (
|
||||
self.cfg.dataloader_num_workers
|
||||
)
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
training_arguments_kwargs["dataloader_prefetch_factor"] = (
|
||||
self.cfg.dataloader_prefetch_factor
|
||||
)
|
||||
if self.cfg.dataloader_drop_last is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_drop_last"
|
||||
] = self.cfg.dataloader_drop_last
|
||||
training_arguments_kwargs["dataloader_drop_last"] = (
|
||||
self.cfg.dataloader_drop_last
|
||||
)
|
||||
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
|
||||
training_arguments_kwargs["dataloader_drop_last"] = True
|
||||
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_arguments_kwargs[
|
||||
"remove_unused_columns"
|
||||
] = self.cfg.remove_unused_columns
|
||||
training_arguments_kwargs["remove_unused_columns"] = (
|
||||
self.cfg.remove_unused_columns
|
||||
)
|
||||
|
||||
if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
@@ -452,9 +450,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.do_causal_lm_eval:
|
||||
training_arguments_kwargs["do_causal_lm_eval"] = self.cfg.do_causal_lm_eval
|
||||
if self.cfg.metric_for_best_model:
|
||||
training_arguments_kwargs[
|
||||
"metric_for_best_model"
|
||||
] = self.cfg.metric_for_best_model
|
||||
training_arguments_kwargs["metric_for_best_model"] = (
|
||||
self.cfg.metric_for_best_model
|
||||
)
|
||||
if self.cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||
|
||||
@@ -467,13 +465,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = self.cfg.torch_compile_backend
|
||||
training_arguments_kwargs["torch_compile_backend"] = (
|
||||
self.cfg.torch_compile_backend
|
||||
)
|
||||
if self.cfg.torch_compile_mode:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_mode"
|
||||
] = self.cfg.torch_compile_mode
|
||||
training_arguments_kwargs["torch_compile_mode"] = (
|
||||
self.cfg.torch_compile_mode
|
||||
)
|
||||
|
||||
# DDP Config
|
||||
if self.cfg.ddp_timeout:
|
||||
@@ -482,32 +480,32 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||
if self.cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs[
|
||||
"ddp_broadcast_buffers"
|
||||
] = self.cfg.ddp_broadcast_buffers
|
||||
training_arguments_kwargs["ddp_broadcast_buffers"] = (
|
||||
self.cfg.ddp_broadcast_buffers
|
||||
)
|
||||
|
||||
# these are all the "standard" kwargs that are def used
|
||||
training_arguments_kwargs["max_steps"] = (
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs[
|
||||
"per_device_train_batch_size"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["per_device_train_batch_size"] = (
|
||||
self.cfg.micro_batch_size
|
||||
)
|
||||
if self.cfg.eval_batch_size:
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
training_arguments_kwargs["per_device_eval_batch_size"] = (
|
||||
self.cfg.eval_batch_size
|
||||
)
|
||||
if self.cfg.auto_find_batch_size is not None:
|
||||
training_arguments_kwargs[
|
||||
"auto_find_batch_size"
|
||||
] = self.cfg.auto_find_batch_size
|
||||
training_arguments_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs[
|
||||
"eval_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs["auto_find_batch_size"] = (
|
||||
self.cfg.auto_find_batch_size
|
||||
)
|
||||
training_arguments_kwargs["gradient_accumulation_steps"] = (
|
||||
self.cfg.gradient_accumulation_steps
|
||||
)
|
||||
training_arguments_kwargs["eval_accumulation_steps"] = (
|
||||
self.cfg.gradient_accumulation_steps
|
||||
)
|
||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||
@@ -554,9 +552,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "rex", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
training_arguments_kwargs[
|
||||
"alternate_lr_scheduler_type"
|
||||
] = self.cfg.lr_scheduler
|
||||
training_arguments_kwargs["alternate_lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
)
|
||||
else:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler if self.cfg.lr_scheduler else "cosine"
|
||||
@@ -565,9 +563,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
training_arguments_kwargs["cosine_min_lr_ratio"] = self.cfg.cosine_min_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"cosine_constant_lr_ratio"
|
||||
] = self.cfg.cosine_constant_lr_ratio
|
||||
training_arguments_kwargs["cosine_constant_lr_ratio"] = (
|
||||
self.cfg.cosine_constant_lr_ratio
|
||||
)
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
@@ -580,40 +578,40 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.eval_sample_packing
|
||||
)
|
||||
if self.cfg.sample_packing_bin_size is not None:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_bin_size"
|
||||
] = self.cfg.sample_packing_bin_size
|
||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||
self.cfg.sample_packing_bin_size
|
||||
)
|
||||
if self.cfg.sample_packing_group_size is not None:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_group_size"
|
||||
] = self.cfg.sample_packing_group_size
|
||||
training_arguments_kwargs["sample_packing_group_size"] = (
|
||||
self.cfg.sample_packing_group_size
|
||||
)
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = self.cfg.sample_packing_eff_est
|
||||
training_arguments_kwargs["sample_packing_efficiency"] = (
|
||||
self.cfg.sample_packing_eff_est
|
||||
)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs[
|
||||
"relora_warmup_steps"
|
||||
] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = (
|
||||
self.cfg.relora_warmup_steps
|
||||
)
|
||||
if self.cfg.relora_anneal_steps:
|
||||
training_arguments_kwargs[
|
||||
"relora_anneal_steps"
|
||||
] = self.cfg.relora_anneal_steps
|
||||
training_arguments_kwargs["relora_anneal_steps"] = (
|
||||
self.cfg.relora_anneal_steps
|
||||
)
|
||||
if self.cfg.relora_prune_ratio:
|
||||
training_arguments_kwargs[
|
||||
"relora_prune_ratio"
|
||||
] = self.cfg.relora_prune_ratio
|
||||
training_arguments_kwargs["relora_prune_ratio"] = (
|
||||
self.cfg.relora_prune_ratio
|
||||
)
|
||||
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
training_arguments_kwargs["lisa_n_layers"] = self.cfg.lisa_n_layers
|
||||
training_arguments_kwargs[
|
||||
"lisa_step_interval"
|
||||
] = self.cfg.lisa_step_interval
|
||||
training_arguments_kwargs[
|
||||
"lisa_layers_attribute"
|
||||
] = self.cfg.lisa_layers_attribute
|
||||
training_arguments_kwargs["lisa_step_interval"] = (
|
||||
self.cfg.lisa_step_interval
|
||||
)
|
||||
training_arguments_kwargs["lisa_layers_attribute"] = (
|
||||
self.cfg.lisa_layers_attribute
|
||||
)
|
||||
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
@@ -627,9 +625,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
)
|
||||
|
||||
if self.cfg.neftune_noise_alpha is not None:
|
||||
training_arguments_kwargs[
|
||||
"neftune_noise_alpha"
|
||||
] = self.cfg.neftune_noise_alpha
|
||||
training_arguments_kwargs["neftune_noise_alpha"] = (
|
||||
self.cfg.neftune_noise_alpha
|
||||
)
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
@@ -731,23 +729,23 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
if self.cfg.optim_target_modules:
|
||||
training_arguments_kwargs[
|
||||
"optim_target_modules"
|
||||
] = self.cfg.optim_target_modules
|
||||
training_arguments_kwargs["optim_target_modules"] = (
|
||||
self.cfg.optim_target_modules
|
||||
)
|
||||
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
|
||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||
training_arguments_kwargs[
|
||||
"loraplus_lr_embedding"
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["loraplus_lr_embedding"] = (
|
||||
self.cfg.loraplus_lr_embedding
|
||||
)
|
||||
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
|
||||
|
||||
if self.cfg.accelerator_config:
|
||||
training_arguments_kwargs[
|
||||
"accelerator_config"
|
||||
] = self.cfg.accelerator_config
|
||||
training_arguments_kwargs["accelerator_config"] = (
|
||||
self.cfg.accelerator_config
|
||||
)
|
||||
|
||||
if self.cfg.kd_ce_alpha is not None:
|
||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
||||
@@ -756,13 +754,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.kd_temperature is not None:
|
||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
||||
if self.cfg.kd_zscore_base_temp is not None:
|
||||
training_arguments_kwargs[
|
||||
"kd_zscore_base_temp"
|
||||
] = self.cfg.kd_zscore_base_temp
|
||||
training_arguments_kwargs["kd_zscore_base_temp"] = (
|
||||
self.cfg.kd_zscore_base_temp
|
||||
)
|
||||
if self.cfg.kd_top_k_before_softmax is not None:
|
||||
training_arguments_kwargs[
|
||||
"kd_top_k_before_softmax"
|
||||
] = self.cfg.kd_top_k_before_softmax
|
||||
training_arguments_kwargs["kd_top_k_before_softmax"] = (
|
||||
self.cfg.kd_top_k_before_softmax
|
||||
)
|
||||
|
||||
if self.cfg.reward_model:
|
||||
training_args_cls = AxolotlRewardConfig
|
||||
@@ -972,32 +970,32 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||
)
|
||||
if self.cfg.remove_unused_columns is not None:
|
||||
training_args_kwargs[
|
||||
"remove_unused_columns"
|
||||
] = self.cfg.remove_unused_columns
|
||||
training_args_kwargs["remove_unused_columns"] = (
|
||||
self.cfg.remove_unused_columns
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["remove_unused_columns"] = False
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_args_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
training_args_kwargs["dataloader_pin_memory"] = (
|
||||
self.cfg.dataloader_pin_memory
|
||||
)
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_args_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
training_args_kwargs["dataloader_num_workers"] = (
|
||||
self.cfg.dataloader_num_workers
|
||||
)
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_args_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
training_args_kwargs["dataloader_prefetch_factor"] = (
|
||||
self.cfg.dataloader_prefetch_factor
|
||||
)
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_args_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
training_args_kwargs["gradient_checkpointing"] = (
|
||||
self.cfg.gradient_checkpointing
|
||||
)
|
||||
if self.cfg.gradient_checkpointing_kwargs is not None:
|
||||
training_args_kwargs[
|
||||
"gradient_checkpointing_kwargs"
|
||||
] = self.cfg.gradient_checkpointing_kwargs
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = (
|
||||
self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
else:
|
||||
training_args_kwargs["gradient_checkpointing_kwargs"] = {
|
||||
"use_reentrant": False
|
||||
@@ -1071,9 +1069,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.dpo_use_weighting is not None:
|
||||
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
|
||||
if self.cfg.dpo_use_logits_to_keep is not None:
|
||||
training_args_kwargs[
|
||||
"use_logits_to_keep"
|
||||
] = self.cfg.dpo_use_logits_to_keep
|
||||
training_args_kwargs["use_logits_to_keep"] = (
|
||||
self.cfg.dpo_use_logits_to_keep
|
||||
)
|
||||
|
||||
for blocklist_key in blocklist_args_kwargs:
|
||||
if blocklist_key in training_args_kwargs:
|
||||
@@ -1108,9 +1106,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.adapter and self.peft_config:
|
||||
dpo_trainer_kwargs["peft_config"] = self.peft_config
|
||||
if self.cfg.precompute_ref_log_probs is not None:
|
||||
dpo_trainer_kwargs[
|
||||
"precompute_ref_log_probs"
|
||||
] = self.cfg.precompute_ref_log_probs
|
||||
dpo_trainer_kwargs["precompute_ref_log_probs"] = (
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
if self.cfg.rl == "grpo":
|
||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
@@ -462,9 +462,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
)
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
@@ -509,9 +509,9 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, Trainer):
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
dataloader_params["prefetch_factor"] = (
|
||||
self.args.dataloader_prefetch_factor
|
||||
)
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DPO Specific Strategy for training
|
||||
"""
|
||||
|
||||
from axolotl.core.trainers.dpo.trainer import AxolotlDPOTrainer
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Axolotl specific DPO args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import DPOConfig
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DPO trainer for axolotl
|
||||
"""
|
||||
|
||||
import gc
|
||||
from functools import wraps
|
||||
from typing import Any, Dict, Union
|
||||
|
||||
@@ -9,7 +9,7 @@ import logging
|
||||
from trl.trainer.grpo_trainer import RewardFunc
|
||||
|
||||
from axolotl.core.trainers.grpo.trainer import AxolotlGRPOTrainer
|
||||
from axolotl.utils.config.models.input.v0_4_1.trl import TRLConfig
|
||||
from axolotl.utils.schemas.trl import TRLConfig
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -45,9 +45,9 @@ class GRPOStrategy:
|
||||
)
|
||||
|
||||
if trl.vllm_gpu_memory_utilization:
|
||||
grpo_args_kwargs[
|
||||
"vllm_gpu_memory_utilization"
|
||||
] = trl.vllm_gpu_memory_utilization
|
||||
grpo_args_kwargs["vllm_gpu_memory_utilization"] = (
|
||||
trl.vllm_gpu_memory_utilization
|
||||
)
|
||||
|
||||
if trl.vllm_max_model_len:
|
||||
grpo_args_kwargs["vllm_max_model_len"] = trl.vllm_max_model_len
|
||||
@@ -86,9 +86,9 @@ class GRPOStrategy:
|
||||
def set_trainer_kwargs(cls, cfg):
|
||||
trainer_kwargs = {}
|
||||
if cfg.trl and cfg.trl.reward_processing_classes:
|
||||
trainer_kwargs[
|
||||
"reward_processing_classes"
|
||||
] = cfg.trl.reward_processing_classes
|
||||
trainer_kwargs["reward_processing_classes"] = (
|
||||
cfg.trl.reward_processing_classes
|
||||
)
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Axolotl Specific Training Args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from trl import GRPOConfig
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Axolotl GRPO trainer
|
||||
"""
|
||||
|
||||
from accelerate.utils import is_peft_model
|
||||
from accelerate.utils.other import is_compiled_module
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
module for TRL PPO training
|
||||
"""
|
||||
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
from trl import PPOTrainer
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
extra axolotl specific training args
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from transformers.trainer import Trainer
|
||||
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
@@ -25,18 +27,18 @@ LOG = get_logger("axolotl.evaluate")
|
||||
|
||||
|
||||
def evaluate_dataset(
|
||||
trainer, dataset, dataset_type: str, flash_optimum: bool = False
|
||||
trainer: Trainer, dataset: Dataset, dataset_type: str, flash_optimum: bool = False
|
||||
) -> Optional[Dict[str, float]]:
|
||||
"""Helper function to evaluate a single dataset safely.
|
||||
"""Helper function to evaluate a single dataset.
|
||||
|
||||
Args:
|
||||
trainer: The trainer instance
|
||||
dataset: Dataset to evaluate
|
||||
dataset_type: Type of dataset ('train' or 'eval')
|
||||
flash_optimum: Whether to use flash optimum
|
||||
trainer: The trainer instance.
|
||||
dataset: Dataset to evaluate.
|
||||
dataset_type: Type of dataset ('train' or 'eval').
|
||||
flash_optimum: Whether to use flash optimum.
|
||||
|
||||
Returns:
|
||||
Dictionary of metrics or None if dataset is None
|
||||
Dictionary of metrics or None if dataset is None.
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
@@ -63,17 +65,14 @@ def evaluate_dataset(
|
||||
|
||||
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]:
|
||||
"""
|
||||
Evaluate a model on training and validation datasets
|
||||
Evaluate a model on training and validation datasets.
|
||||
|
||||
Args:
|
||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||
dataset_meta: Dataset metadata containing training and evaluation datasets.
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- The model (either PeftModel or PreTrainedModel)
|
||||
- The tokenizer
|
||||
- Dictionary of evaluation metrics
|
||||
Dictionary mapping metric names to their values.
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
|
||||
@@ -11,19 +11,17 @@
|
||||
# the License.
|
||||
|
||||
"""
|
||||
module to handle merging the plugins' input arguments with the base configurations.
|
||||
Module to handle merging the plugins' input arguments with the base configurations.
|
||||
|
||||
this was moved here to prevent circular imports
|
||||
This was moved here to prevent circular imports.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
from axolotl.utils.schemas.config import (
|
||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||
)
|
||||
from axolotl.utils.config.models.input.v0_4_1 import (
|
||||
AxolotlInputConfig as AxolotlInputConfigBase,
|
||||
)
|
||||
from axolotl.utils.schemas.config import AxolotlInputConfig as AxolotlInputConfigBase
|
||||
|
||||
|
||||
def merge_input_args():
|
||||
|
||||
@@ -17,7 +17,7 @@ Run the following command to install `cut_cross_entropy[transformers]` if you do
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
|
||||
# if you are not in dev environment
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy @ git+https://github.com/apple/ml-cross-entropy.git@9c297c905f55b73594b5d650722d1e78183b77bd"'
|
||||
pip3 uninstall -y cut-cross-entropy && pip3 install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"
|
||||
```
|
||||
|
||||
## Usage
|
||||
|
||||
@@ -33,7 +33,7 @@ LOG = logging.getLogger("axolotl.integrations.cut_cross_entropy")
|
||||
|
||||
_CCE_INSTALL_MESSAGE = (
|
||||
"Please install cut_cross_entropy with transformers support using "
|
||||
'`pip install "cut-cross-entropy[transformers]==24.11.4"`'
|
||||
'`pip install "cut-cross-entropy[transformers] @ git+https://github.com/apple/ml-cross-entropy.git@24fbe4b5dab9a6c250a014573613c1890190536c"`'
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Grokfast plugin for Axolotl
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from transformers.trainer_callback import TrainerCallback
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
config args for grokfast plugin
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -26,12 +26,12 @@ class KDArgs(BaseModel):
|
||||
"""
|
||||
|
||||
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||
kd_ce_alpha: Optional[
|
||||
float
|
||||
] = None # loss coefficient for cross-entropy loss during KD
|
||||
kd_ce_alpha: Optional[float] = (
|
||||
None # loss coefficient for cross-entropy loss during KD
|
||||
)
|
||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
|
||||
kd_top_k_before_softmax: Optional[
|
||||
bool
|
||||
] = None # whether to sample top k before softmax during KD
|
||||
kd_top_k_before_softmax: Optional[bool] = (
|
||||
None # whether to sample top k before softmax during KD
|
||||
)
|
||||
|
||||
@@ -55,9 +55,9 @@ class LigerPlugin(BasePlugin):
|
||||
if "cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs["cross_entropy"] = cfg.liger_cross_entropy
|
||||
if "fused_linear_cross_entropy" in liger_fn_sig.parameters:
|
||||
kwargs[
|
||||
"fused_linear_cross_entropy"
|
||||
] = cfg.liger_fused_linear_cross_entropy
|
||||
kwargs["fused_linear_cross_entropy"] = (
|
||||
cfg.liger_fused_linear_cross_entropy
|
||||
)
|
||||
if "rms_norm" in liger_fn_sig.parameters:
|
||||
kwargs["rms_norm"] = cfg.liger_rms_norm
|
||||
if "layer_norm" in liger_fn_sig.parameters:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
DeepseekV2 model with LigerFusedLinearCrossEntropyLoss
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Jamba model with LigerFusedLinearCrossEntropyLoss
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Module for the Plugin for LM Eval Harness
|
||||
"""
|
||||
|
||||
import subprocess # nosec
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Module for handling lm eval harness input arguments.
|
||||
"""
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
axolotl CLI for running lm_eval tasks
|
||||
"""
|
||||
|
||||
import subprocess # nosec
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
|
||||
@@ -17,7 +17,7 @@ Module for handling Spectrum input arguments.
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, model_validator
|
||||
|
||||
|
||||
class SpectrumArgs(BaseModel):
|
||||
@@ -27,3 +27,20 @@ class SpectrumArgs(BaseModel):
|
||||
|
||||
spectrum_top_fraction: Optional[float] = 0.5
|
||||
spectrum_model_name: Optional[str] = None
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_fsdp_use_orig_params(cls, data):
|
||||
if (
|
||||
data.get("fsdp")
|
||||
and data.get("fsdp_config")
|
||||
and not data["fsdp_config"].get("use_orig_params")
|
||||
and data.get("plugins")
|
||||
and any("SpectrumPlugin" in plugin for plugin in data["plugins"])
|
||||
):
|
||||
# would otherwise raise
|
||||
# ValueError: Must flatten tensors with uniform `requires_grad` when `use_orig_params=False`
|
||||
raise ValueError(
|
||||
"FSDP + SpectrumPlugin cannot be used together when `use_orig_params=False` is set"
|
||||
)
|
||||
return data
|
||||
|
||||
@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name,unnecessary-lambda-assignment,duplicate-code
|
||||
|
||||
import torch
|
||||
|
||||
@@ -6,6 +6,7 @@ See "LoRA: Low-Rank Adaptation of Large Language Models"
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
from typing import Callable
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Dequantization utilities for `bitsandbytes` integration."""
|
||||
|
||||
# pylint: disable=invalid-name,global-statement
|
||||
|
||||
import ctypes
|
||||
|
||||
@@ -5,6 +5,7 @@ See "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
|
||||
|
||||
Credit to `unsloth` (https://unsloth.ai/) for inspiration for this implementation.
|
||||
"""
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
HF Transformers MambaConfig
|
||||
"""
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Monkeypatch for Vision Llama for FA2 support
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from typing import Optional, Tuple
|
||||
@@ -220,10 +221,10 @@ def patch_mllama():
|
||||
True
|
||||
)
|
||||
MLLAMA_TEXT_ATTENTION_CLASSES["flash_attention_2"] = MllamaTextSelfFlashAttention2
|
||||
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES[
|
||||
"flash_attention_2"
|
||||
] = MllamaTextCrossFlashAttention2
|
||||
MLLAMA_TEXT_CROSS_ATTENTION_CLASSES["flash_attention_2"] = (
|
||||
MllamaTextCrossFlashAttention2
|
||||
)
|
||||
# fallback to SDPA
|
||||
MLLAMA_VISION_ATTENTION_CLASSES[
|
||||
"flash_attention_2"
|
||||
] = MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
|
||||
MLLAMA_VISION_ATTENTION_CLASSES["flash_attention_2"] = (
|
||||
MLLAMA_VISION_ATTENTION_CLASSES["sdpa"]
|
||||
)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import torch
|
||||
|
||||
@@ -12,7 +12,9 @@ import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
)
|
||||
@@ -490,9 +492,11 @@ def flashattn_forward(
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
@@ -531,9 +535,11 @@ def flashattn_forward(
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Flash attention monkey patch for mistral model"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import logging
|
||||
@@ -21,7 +22,10 @@ from transformers.models.mistral.modeling_mistral import (
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
@@ -243,9 +247,11 @@ def flashattn_forward(
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
@@ -286,9 +292,11 @@ def flashattn_forward(
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
query_padding_mask=(
|
||||
attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Patches to support multipack for mixtral
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Implements the ReLoRA training procedure from https://arxiv.org/abs/2307.05695, minus the initial full fine-tune."""
|
||||
|
||||
import glob
|
||||
import json
|
||||
import logging
|
||||
@@ -411,7 +412,10 @@ def merge_and_save(
|
||||
if shard_path.endswith(".safetensors"):
|
||||
in_tensors = st.load_file(str(Path(model_src) / shard_path))
|
||||
else:
|
||||
in_tensors = torch.load(Path(model_src) / shard_path)
|
||||
in_tensors = torch.load(
|
||||
Path(model_src) / shard_path,
|
||||
weights_only=True, # to prevent arbitrary code execution
|
||||
)
|
||||
if "state_dict" in in_tensors:
|
||||
in_tensors = in_tensors["state_dict"]
|
||||
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
# pylint: disable=duplicate-code
|
||||
""" PyTorch StableLM Epoch model. """
|
||||
"""PyTorch StableLM Epoch model."""
|
||||
import importlib
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
fix for FSDP optimizer save in trainer w 4.47.0
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
import torch
|
||||
from transformers.models.llama.modeling_llama import LlamaMLP
|
||||
from xformers.ops import SwiGLU
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Prompt strategies loader for alpaca instruction datasets with system prompts
|
||||
"""
|
||||
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.prompt_strategies.jinja_template_analyzer import JinjaTemplateAnaly
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID, Prompter
|
||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||
from axolotl.utils.config.models.input.v0_4_1 import DatasetConfig
|
||||
from axolotl.utils.schemas.datasets import DatasetConfig
|
||||
|
||||
# Configure the logger
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
Basic completion text
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, Generator, Optional, Tuple
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module containing the classes for Context QA Prompt Tokenization Strategies"""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
"""
|
||||
module for DPO style dataset transform strategies
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
|
||||
from ..base import load as load_base
|
||||
|
||||
@@ -3,7 +3,7 @@ DPO prompt strategies for using tokenizer chat templates.
|
||||
"""
|
||||
|
||||
from axolotl.utils.chat_templates import extract_chat_template_args, get_chat_template
|
||||
from axolotl.utils.config.models.input.v0_4_1 import handle_legacy_message_fields_logic
|
||||
from axolotl.utils.schemas.utils import handle_legacy_message_fields_logic
|
||||
|
||||
|
||||
def default(
|
||||
|
||||
@@ -33,9 +33,9 @@ def default(
|
||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>user\n{sample[prompt_key]}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|im_end|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|im_end|>"
|
||||
return sample
|
||||
@@ -52,9 +52,9 @@ def argilla_chat(
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>user\n{sample['chosen'][0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||
return sample
|
||||
@@ -78,9 +78,9 @@ def icr(
|
||||
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||
return sample
|
||||
@@ -100,9 +100,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>user\n{sample['question']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||
return sample
|
||||
@@ -120,9 +120,9 @@ def prompt_pairs(
|
||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
|
||||
return sample
|
||||
@@ -142,9 +142,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
sample["prompt"] = (
|
||||
f"<|im_start|>user\n{sample['prompt']}<|im_end|>\n<|im_start|>assistant\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|im_end|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|im_end|>"
|
||||
return sample
|
||||
|
||||
@@ -34,9 +34,9 @@ def default(
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample[prompt_key]}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample[chosen_key]}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample[rejected_key]}<|eot_id|>"
|
||||
return sample
|
||||
@@ -53,9 +53,9 @@ def argilla_chat(
|
||||
"""
|
||||
|
||||
def transform_fn(sample):
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['chosen'][0]['content']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||
return sample
|
||||
@@ -79,9 +79,9 @@ def icr(
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['input']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
@@ -101,9 +101,9 @@ def intel(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['question']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
@@ -121,9 +121,9 @@ def prompt_pairs(
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected']}<|eot_id|>"
|
||||
return sample
|
||||
@@ -143,9 +143,9 @@ def ultra(cfg, **kwargs): # pylint: disable=possibly-unused-variable,unused-arg
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
else:
|
||||
sample[
|
||||
"prompt"
|
||||
] = f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
sample["prompt"] = (
|
||||
f"<|start_header_id|>user<|end_header_id|>\n\n{sample['prompt']}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
)
|
||||
sample["chosen"] = f"{sample['chosen'][1]['content']}<|eot_id|>"
|
||||
sample["rejected"] = f"{sample['rejected'][1]['content']}<|eot_id|>"
|
||||
return sample
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for plain input/output prompt pairs"""
|
||||
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user