Compare commits
31 Commits
sp-restore
...
print_venv
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
454eea049f | ||
|
|
5a961ecadf | ||
|
|
b37ddf9778 | ||
|
|
bf38e507fb | ||
|
|
a5946ff1f0 | ||
|
|
d00bd99279 | ||
|
|
2b41bfe9eb | ||
|
|
5bbbd599b4 | ||
|
|
26c782183d | ||
|
|
70ca1b2291 | ||
|
|
8065fed126 | ||
|
|
8ae5a2311b | ||
|
|
6383630155 | ||
|
|
f2b352f2e5 | ||
|
|
bf5928d0ee | ||
|
|
d1224db8f4 | ||
|
|
327b4e48e9 | ||
|
|
35fdbce102 | ||
|
|
cb811f8bf1 | ||
|
|
7563e1bd30 | ||
|
|
81893c775c | ||
|
|
a1a740608d | ||
|
|
ec15a7a691 | ||
|
|
0a7a216b60 | ||
|
|
d8280d45c1 | ||
|
|
24f2887e87 | ||
|
|
29289a4de9 | ||
|
|
a24957fa04 | ||
|
|
927bf530bc | ||
|
|
18954ba100 | ||
|
|
d8cf66edbd |
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -5,11 +5,13 @@ on:
|
|||||||
branches:
|
branches:
|
||||||
- "main"
|
- "main"
|
||||||
paths:
|
paths:
|
||||||
- 'Dockerfile-base'
|
- 'docker/Dockerfile-base'
|
||||||
|
- 'docker/Dockerfile-uv-base'
|
||||||
- '.github/workflows/base.yml'
|
- '.github/workflows/base.yml'
|
||||||
pull_request:
|
pull_request:
|
||||||
paths:
|
paths:
|
||||||
- 'Dockerfile-base'
|
- 'docker/Dockerfile-base'
|
||||||
|
- 'docker/Dockerfile-uv-base'
|
||||||
- '.github/workflows/base.yml'
|
- '.github/workflows/base.yml'
|
||||||
workflow_dispatch:
|
workflow_dispatch:
|
||||||
|
|
||||||
|
|||||||
13
.github/workflows/main.yml
vendored
13
.github/workflows/main.yml
vendored
@@ -20,12 +20,11 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -88,8 +87,8 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
@@ -146,8 +145,8 @@ jobs:
|
|||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
|||||||
6
.github/workflows/multi-gpu-e2e.yml
vendored
6
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -26,11 +26,11 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
num_gpus: 2
|
num_gpus: 2
|
||||||
nightly_build: "true"
|
nightly_build: "true"
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
|
|||||||
115
.github/workflows/tests-nightly.yml
vendored
115
.github/workflows/tests-nightly.yml
vendored
@@ -18,96 +18,9 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
SKIP: no-commit-to-branch
|
SKIP: no-commit-to-branch
|
||||||
|
|
||||||
preload-cache:
|
|
||||||
name: Preload HF cache
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
python_version: ["3.11"]
|
|
||||||
pytorch_version: ["2.6.0"]
|
|
||||||
timeout-minutes: 20
|
|
||||||
|
|
||||||
env:
|
|
||||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: Check out repository code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Restore HF cache
|
|
||||||
id: hf-cache-restore
|
|
||||||
uses: actions/cache/restore@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
|
||||||
|
|
||||||
- name: Setup Python
|
|
||||||
uses: actions/setup-python@v5
|
|
||||||
with:
|
|
||||||
python-version: ${{ matrix.python_version }}
|
|
||||||
cache: 'pip' # caching pip dependencies
|
|
||||||
|
|
||||||
- name: upgrade pip
|
|
||||||
run: |
|
|
||||||
pip3 install --upgrade pip
|
|
||||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
|
||||||
|
|
||||||
- name: Install PyTorch
|
|
||||||
run: |
|
|
||||||
pip3 install torch==${{ matrix.pytorch_version }}
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip3 show torch
|
|
||||||
pip3 install --no-build-isolation -U -e .
|
|
||||||
python scripts/unsloth_install.py | sh
|
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
|
||||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
|
||||||
|
|
||||||
- name: Make sure PyTorch version wasn't clobbered
|
|
||||||
run: |
|
|
||||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
|
||||||
|
|
||||||
- name: Ensure axolotl CLI was installed
|
|
||||||
run: |
|
|
||||||
axolotl --help
|
|
||||||
|
|
||||||
- name: Pre-Download dataset fixture
|
|
||||||
run: |
|
|
||||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
|
||||||
|
|
||||||
- name: Run tests
|
|
||||||
run: |
|
|
||||||
pytest -v tests/conftest.py
|
|
||||||
|
|
||||||
- name: Upload coverage to Codecov
|
|
||||||
uses: codecov/codecov-action@v5
|
|
||||||
with:
|
|
||||||
token: ${{ secrets.CODECOV_TOKEN }}
|
|
||||||
files: ./coverage.xml
|
|
||||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
|
||||||
fail_ci_if_error: false
|
|
||||||
|
|
||||||
- name: cleanup pip cache
|
|
||||||
run: |
|
|
||||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
|
||||||
|
|
||||||
- name: Save HF cache
|
|
||||||
id: hf-cache
|
|
||||||
uses: actions/cache/save@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
|
||||||
|
|
||||||
pytest:
|
pytest:
|
||||||
name: PyTest
|
name: PyTest
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [preload-cache]
|
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
max-parallel: 2
|
max-parallel: 2
|
||||||
@@ -120,14 +33,11 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Restore HF cache
|
- name: Restore Cache from S3
|
||||||
id: hf-cache-restore
|
id: hf-cache-restore-s3
|
||||||
uses: actions/cache/restore@v4
|
run: |
|
||||||
with:
|
mkdir -p /home/runner/.cache/huggingface/hub
|
||||||
path: |
|
curl -L https://d1dttdx32dkk5p.cloudfront.net/hf-cache.tar.zst | tar -xf - -C /home/runner/.cache/huggingface/hub/ --use-compress-program unzstd
|
||||||
/home/runner/.cache/huggingface/hub/datasets--*
|
|
||||||
/home/runner/.cache/huggingface/hub/models--*
|
|
||||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
@@ -168,10 +78,6 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
- name: Pre-Download dataset fixture
|
|
||||||
run: |
|
|
||||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
@@ -193,15 +99,8 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
|
||||||
pytorch: 2.5.1
|
|
||||||
num_gpus: 1
|
|
||||||
axolotl_extras:
|
|
||||||
nightly_build: "true"
|
|
||||||
- cuda: 124
|
|
||||||
cuda_version: 12.4.1
|
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
|
|||||||
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@@ -195,12 +195,12 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
- cuda: 126
|
- cuda: 126
|
||||||
cuda_version: 12.6.3
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
@@ -247,8 +247,8 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
include:
|
include:
|
||||||
- cuda: 124
|
- cuda: 126
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.6.3
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
@@ -311,7 +311,7 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
num_gpus: 1
|
num_gpus: 1
|
||||||
axolotl_extras: vllm
|
axolotl_extras:
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: isort
|
- id: isort
|
||||||
- repo: https://github.com/PyCQA/flake8
|
- repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.2.0
|
rev: 7.3.0
|
||||||
hooks:
|
hooks:
|
||||||
- id: flake8
|
- id: flake8
|
||||||
- repo: https://github.com/pylint-dev/pylint
|
- repo: https://github.com/pylint-dev/pylint
|
||||||
@@ -27,7 +27,7 @@ repos:
|
|||||||
hooks:
|
hooks:
|
||||||
- id: pylint
|
- id: pylint
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
rev: v1.16.0
|
rev: v1.16.1
|
||||||
hooks:
|
hooks:
|
||||||
- id: mypy
|
- id: mypy
|
||||||
additional_dependencies:
|
additional_dependencies:
|
||||||
@@ -36,7 +36,7 @@ repos:
|
|||||||
'pydantic>=2.5.3',
|
'pydantic>=2.5.3',
|
||||||
]
|
]
|
||||||
- repo: https://github.com/PyCQA/bandit
|
- repo: https://github.com/PyCQA/bandit
|
||||||
rev: 1.8.3
|
rev: 1.8.5
|
||||||
hooks:
|
hooks:
|
||||||
- id: bandit
|
- id: bandit
|
||||||
args: [
|
args: [
|
||||||
|
|||||||
11
README.md
11
README.md
@@ -43,7 +43,7 @@ Features:
|
|||||||
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
- **Multiple Model Support**: Train various models like LLaMA, Mistral, Mixtral, Pythia, and more. We are compatible with HuggingFace transformers causal language models.
|
||||||
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
- **Training Methods**: Full fine-tuning, LoRA, QLoRA, GPTQ, QAT, Preference Tuning (DPO, IPO, KTO, ORPO), RL (GRPO), Multimodal, and Reward Modelling (RM) / Process Reward Modelling (PRM).
|
||||||
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
- **Easy Configuration**: Re-use a single YAML file between dataset preprocess, training, evaluation, quantization, and inference.
|
||||||
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), Sequence Parallelism (SP), LoRA optimizations, Multi-GPU training (FSDP1, FSDP2, DeepSpeed), Multi-node training (Torchrun, Ray), and many more!
|
- **Performance Optimizations**: [Multipacking](https://docs.axolotl.ai/docs/multipack.html), [Flash Attention](https://github.com/Dao-AILab/flash-attention), [Xformers](https://github.com/facebookresearch/xformers), [Flex Attention](https://pytorch.org/blog/flexattention/), [Liger Kernel](https://github.com/linkedin/Liger-Kernel), [Cut Cross Entropy](https://github.com/apple/ml-cross-entropy/tree/main), [Sequence Parallelism (SP)](https://docs.axolotl.ai/docs/sequence_parallelism.html), [LoRA optimizations](https://docs.axolotl.ai/docs/lora_optims.html), [Multi-GPU training (FSDP1, FSDP2, DeepSpeed)](https://docs.axolotl.ai/docs/multi-gpu.html), [Multi-node training (Torchrun, Ray)](https://docs.axolotl.ai/docs/multi-node.html), and many more!
|
||||||
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
- **Flexible Dataset Handling**: Load from local, HuggingFace, and cloud (S3, Azure, GCP, OCI) datasets.
|
||||||
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
- **Cloud Ready**: We ship [Docker images](https://hub.docker.com/u/axolotlai) and also [PyPI packages](https://pypi.org/project/axolotl/) for use on cloud platforms and local hardware.
|
||||||
|
|
||||||
@@ -59,6 +59,8 @@ Features:
|
|||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
|
#### Using pip
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
pip3 install -U packaging==23.2 setuptools==75.8.0 wheel ninja
|
||||||
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
pip3 install --no-build-isolation axolotl[flash-attn,deepspeed]
|
||||||
@@ -68,6 +70,13 @@ axolotl fetch examples
|
|||||||
axolotl fetch deepspeed_configs # OPTIONAL
|
axolotl fetch deepspeed_configs # OPTIONAL
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Using Docker
|
||||||
|
|
||||||
|
Installing with Docker can be less error prone than installing in your own environment.
|
||||||
|
```bash
|
||||||
|
docker run --gpus '"all"' --rm -it axolotlai/axolotl:main-latest
|
||||||
|
```
|
||||||
|
|
||||||
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
Other installation approaches are described [here](https://docs.axolotl.ai/docs/installation.html).
|
||||||
|
|
||||||
### Your First Fine-tune
|
### Your First Fine-tune
|
||||||
|
|||||||
@@ -32,6 +32,8 @@ df_args = {
|
|||||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||||
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
"CODECOV_TOKEN": os.environ.get("CODECOV_TOKEN", ""),
|
||||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||||
|
"PYTHONUNBUFFERED": os.environ.get("PYTHONUNBUFFERED", "1"),
|
||||||
|
"DEEPSPEED_LOG_LEVEL": os.environ.get("DEEPSPEED_LOG_LEVEL", "WARNING"),
|
||||||
}
|
}
|
||||||
|
|
||||||
dockerfile_contents = df_template.render(**df_args)
|
dockerfile_contents = df_template.render(**df_args)
|
||||||
|
|||||||
@@ -22,9 +22,11 @@ RUN apt-get update \
|
|||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
&& conda create -n "axolotl-py${PYTHON_VERSION}" python="${PYTHON_VERSION}" \
|
||||||
|
&& conda init bash \
|
||||||
|
&& echo "conda activate axolotl-py${PYTHON_VERSION}" >> ~/.bashrc
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
ENV PATH="/root/miniconda3/envs/axolotl-py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
@@ -38,6 +40,6 @@ RUN git lfs install --skip-repo && \
|
|||||||
# The base image ships with `pydantic==1.8.2` which is not working
|
# The base image ships with `pydantic==1.8.2` which is not working
|
||||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
RUN if [ "$PYTORCH_VERSION" = "2.6.0" ] && [ "$CUDA" = "124" ] ; then \
|
||||||
pip3 install flash-attn==2.7.4.post1; \
|
FLASH_ATTENTION_FORCE_BUILD="TRUE" pip3 install --no-build-isolation flash-attn==2.8.0.post2; \
|
||||||
fi
|
fi
|
||||||
|
|||||||
@@ -22,9 +22,11 @@ RUN apt-get update \
|
|||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
&& conda create -n "axolotl-py${PYTHON_VERSION}" python="${PYTHON_VERSION}" \
|
||||||
|
&& conda init bash \
|
||||||
|
&& echo "conda activate axolotl-py${PYTHON_VERSION}" >> ~/.bashrc
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
ENV PATH="/root/miniconda3/envs/axolotl-py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
@@ -22,9 +22,11 @@ RUN apt-get update \
|
|||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
&& bash Miniconda3-latest-Linux-x86_64.sh -b \
|
||||||
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
&& rm -f Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& conda create -n "py${PYTHON_VERSION}" python="${PYTHON_VERSION}"
|
&& conda create -n "axolotl-py${PYTHON_VERSION}" python="${PYTHON_VERSION}" \
|
||||||
|
&& conda init bash \
|
||||||
|
&& echo "conda activate axolotl-py${PYTHON_VERSION}" >> ~/.bashrc
|
||||||
|
|
||||||
ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
ENV PATH="/root/miniconda3/envs/axolotl-py${PYTHON_VERSION}/bin:${PATH}"
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
|||||||
@@ -34,7 +34,3 @@ RUN uv pip install packaging setuptools wheel psutil \
|
|||||||
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
&& uv pip install --no-build-isolation "causal_conv1d @ git+https://github.com/Dao-AILab/causal-conv1d.git@main" \
|
||||||
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
&& uv pip install "mamba_ssm @ git+https://github.com/state-spaces/mamba.git@main" \
|
||||||
&& uv pip install awscli pydantic
|
&& uv pip install awscli pydantic
|
||||||
|
|
||||||
RUN if [ "$PYTORCH_VERSION" = "2.7.1" ] ; then \
|
|
||||||
uv pip install --no-build-isolation flash-attn==2.7.4.post1; \
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ order: 3
|
|||||||
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
Chat Template strategy uses a jinja2 template that converts a list of messages into a prompt. Support using tokenizer's template, a supported template, or custom jinja2.
|
||||||
|
|
||||||
```{.json filename="data.jsonl"}
|
```{.json filename="data.jsonl"}
|
||||||
{"conversations": [{"role": "...", "content": "..."}]}
|
{"messages": [{"role": "...", "content": "..."}, {"role": "...", "content": "..."}, ...]}
|
||||||
```
|
```
|
||||||
|
|
||||||
See [configs](../config-reference.qmd) for full configs and supported templates.
|
See [configs](../config-reference.qmd) for full configs and supported templates.
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ format:
|
|||||||
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
This section describes the different Docker images that are released by AxolotlAI at [Docker Hub](https://hub.docker.com/u/axolotlai).
|
||||||
|
|
||||||
::: {.callout-important}
|
::: {.callout-important}
|
||||||
For Blackwell GPUs, please use the tags with Pytorch 2.7.1 and CUDA 12.8.
|
For Blackwell GPUs, please use the tags with PyTorch 2.7.1 and CUDA 12.8.
|
||||||
:::
|
:::
|
||||||
|
|
||||||
## Base
|
## Base
|
||||||
@@ -34,6 +34,7 @@ Tags examples:
|
|||||||
|
|
||||||
- `main-base-py3.11-cu128-2.7.1`
|
- `main-base-py3.11-cu128-2.7.1`
|
||||||
- `main-base-py3.11-cu126-2.7.1`
|
- `main-base-py3.11-cu126-2.7.1`
|
||||||
|
- `main-base-py3.11-cu126-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.6.0`
|
- `main-base-py3.11-cu124-2.6.0`
|
||||||
- `main-base-py3.11-cu124-2.5.1`
|
- `main-base-py3.11-cu124-2.5.1`
|
||||||
|
|
||||||
@@ -73,13 +74,15 @@ There may be some extra tags appended to the image, like `-vllm` which installs
|
|||||||
|
|
||||||
Tags examples:
|
Tags examples:
|
||||||
|
|
||||||
- `main-py3.11-cu126-2.7.0`
|
- `main-py3.11-cu128-2.7.1`
|
||||||
|
- `main-py3.11-cu126-2.7.1`
|
||||||
|
- `main-py3.11-cu126-2.6.0`
|
||||||
- `main-py3.11-cu124-2.6.0`
|
- `main-py3.11-cu124-2.6.0`
|
||||||
- `main-py3.11-cu124-2.5.1`
|
- `main-py3.11-cu124-2.5.1`
|
||||||
- `main-latest`
|
- `main-latest`
|
||||||
- `main-20250303-py3.11-cu124-2.6.0`
|
- `main-20250303-py3.11-cu124-2.6.0`
|
||||||
- `main-20250303-py3.11-cu124-2.5.1`
|
- `main-20250303-py3.11-cu124-2.5.1`
|
||||||
- `0.9.2`
|
- `0.10.1`
|
||||||
|
|
||||||
## Cloud
|
## Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,10 @@ description: Frequently asked questions
|
|||||||
> pad_token: "..."
|
> pad_token: "..."
|
||||||
> ```
|
> ```
|
||||||
|
|
||||||
|
**Q: `IterableDataset error` or `KeyError: 'input_ids'` when using `preprocess` CLI**
|
||||||
|
|
||||||
|
> A: This is because you may be using `preprocess` CLI with `pretraining_dataset:` or `skip_prepare_dataset: true` respectively. Please use `axolotl train` CLI directly instead as these datasets are prepared on demand.
|
||||||
|
|
||||||
### Chat templates
|
### Chat templates
|
||||||
|
|
||||||
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
**Q: `jinja2.exceptions.UndefinedError: 'dict object' has no attribute 'content' / 'role' / ____`**
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,7 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
|
|
||||||
# START section of dependencies that don't install on Darwin/MacOS
|
# START section of dependencies that don't install on Darwin/MacOS
|
||||||
bitsandbytes==0.45.4
|
bitsandbytes==0.46.0
|
||||||
triton>=3.0.0
|
triton>=3.0.0
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
xformers>=0.0.23.post1
|
xformers>=0.0.23.post1
|
||||||
@@ -15,7 +15,7 @@ huggingface_hub==0.32.2
|
|||||||
peft==0.15.2
|
peft==0.15.2
|
||||||
transformers==4.52.4
|
transformers==4.52.4
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.7.0
|
accelerate==1.8.1
|
||||||
datasets==3.6.0
|
datasets==3.6.0
|
||||||
deepspeed>=0.17.0
|
deepspeed>=0.17.0
|
||||||
trl==0.18.2
|
trl==0.18.2
|
||||||
@@ -68,4 +68,4 @@ schedulefree==1.4.1
|
|||||||
axolotl-contribs-lgpl==0.0.6
|
axolotl-contribs-lgpl==0.0.6
|
||||||
axolotl-contribs-mit==0.0.3
|
axolotl-contribs-mit==0.0.3
|
||||||
|
|
||||||
mistral-common==1.6.0
|
mistral-common==1.6.3
|
||||||
|
|||||||
4
setup.py
4
setup.py
@@ -111,9 +111,9 @@ def get_package_version():
|
|||||||
|
|
||||||
|
|
||||||
extras_require = {
|
extras_require = {
|
||||||
"flash-attn": ["flash-attn==2.7.4.post1"],
|
"flash-attn": ["flash-attn==2.8.0.post2"],
|
||||||
"ring-flash-attn": [
|
"ring-flash-attn": [
|
||||||
"flash-attn==2.7.4.post1",
|
"flash-attn==2.8.0.post2",
|
||||||
"ring-flash-attn>=0.1.4",
|
"ring-flash-attn>=0.1.4",
|
||||||
"yunchang==0.6.0",
|
"yunchang==0.6.0",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -35,6 +35,12 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
|
|
||||||
|
for key in ["skip_prepare_dataset", "pretraining_dataset"]:
|
||||||
|
if cfg.get("key"):
|
||||||
|
raise ValueError(
|
||||||
|
f"You have set `{key}:`. `preprocess` is not needed. Run the `axolotl train` CLI directly instead."
|
||||||
|
)
|
||||||
|
|
||||||
if not cfg.dataset_prepared_path:
|
if not cfg.dataset_prepared_path:
|
||||||
msg = (
|
msg = (
|
||||||
Fore.RED
|
Fore.RED
|
||||||
|
|||||||
@@ -75,13 +75,17 @@ def load_datasets(
|
|||||||
|
|
||||||
num_examples = cli_args.debug_num_examples if cli_args else 1
|
num_examples = cli_args.debug_num_examples if cli_args else 1
|
||||||
text_only = cli_args.debug_text_only if cli_args else False
|
text_only = cli_args.debug_text_only if cli_args else False
|
||||||
train_samples = sample_dataset(train_dataset, num_examples)
|
try:
|
||||||
check_dataset_labels(
|
train_samples = sample_dataset(train_dataset, num_examples)
|
||||||
train_samples,
|
check_dataset_labels(
|
||||||
tokenizer,
|
train_samples,
|
||||||
num_examples=num_examples,
|
tokenizer,
|
||||||
text_only=text_only,
|
num_examples=num_examples,
|
||||||
)
|
text_only=text_only,
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
# can't sample iterable datasets
|
||||||
|
pass
|
||||||
|
|
||||||
LOG.info("printing prompters...")
|
LOG.info("printing prompters...")
|
||||||
for prompter in prompters:
|
for prompter in prompters:
|
||||||
|
|||||||
@@ -219,7 +219,9 @@ class TrainerBuilderBase(abc.ABC):
|
|||||||
if self.cfg.bf16 == "full":
|
if self.cfg.bf16 == "full":
|
||||||
training_args_kwargs["bf16_full_eval"] = True
|
training_args_kwargs["bf16_full_eval"] = True
|
||||||
else:
|
else:
|
||||||
training_args_kwargs["bf16"] = self.cfg.bf16 or self.cfg.bfloat16
|
bf16 = self.cfg.bf16 or self.cfg.bfloat16
|
||||||
|
bf16 = bf16 if bf16 is not None else False
|
||||||
|
training_args_kwargs["bf16"] = bf16
|
||||||
|
|
||||||
def _configure_scheduler(self, training_args_kwargs: dict):
|
def _configure_scheduler(self, training_args_kwargs: dict):
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "rex"]:
|
||||||
|
|||||||
@@ -253,6 +253,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["eval_sample_packing"] = bool(
|
training_arguments_kwargs["eval_sample_packing"] = bool(
|
||||||
self.cfg.eval_sample_packing
|
self.cfg.eval_sample_packing
|
||||||
)
|
)
|
||||||
|
if self.cfg.sample_packing_sequentially is not None:
|
||||||
|
training_arguments_kwargs["sample_packing_sequentially"] = (
|
||||||
|
self.cfg.sample_packing_sequentially
|
||||||
|
)
|
||||||
if self.cfg.sample_packing_bin_size is not None:
|
if self.cfg.sample_packing_bin_size is not None:
|
||||||
training_arguments_kwargs["sample_packing_bin_size"] = (
|
training_arguments_kwargs["sample_packing_bin_size"] = (
|
||||||
self.cfg.sample_packing_bin_size
|
self.cfg.sample_packing_bin_size
|
||||||
@@ -413,7 +417,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
or self.cfg.micro_batch_size > 1
|
or self.cfg.micro_batch_size > 1
|
||||||
):
|
):
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
return None
|
if not (self.cfg.sample_packing and self.cfg.pretrain_multipack_attn):
|
||||||
|
return None
|
||||||
|
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return MambaDataCollator(tokenizer=self.tokenizer)
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|||||||
@@ -116,6 +116,7 @@ class AxolotlTrainer(
|
|||||||
sequential=self.args.sample_packing_sequentially,
|
sequential=self.args.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=self.args.dataset_num_proc,
|
num_processes=self.args.dataset_num_proc,
|
||||||
|
mp_start_method=self.args.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
len(sampler)
|
len(sampler)
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class DPOStrategy:
|
|||||||
training_args_kwargs["max_completion_length"] = None
|
training_args_kwargs["max_completion_length"] = None
|
||||||
training_args_kwargs["max_length"] = cfg.sequence_len
|
training_args_kwargs["max_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
training_args_kwargs["max_prompt_length"] = cfg.sequence_len
|
||||||
training_args_kwargs["generate_during_eval"] = cfg.use_wandb
|
training_args_kwargs["generate_during_eval"] = cfg.dpo_generate_during_eval
|
||||||
if cfg.dpo_use_weighting is not None:
|
if cfg.dpo_use_weighting is not None:
|
||||||
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
training_args_kwargs["use_weighting"] = cfg.dpo_use_weighting
|
||||||
if cfg.dpo_padding_free is not None:
|
if cfg.dpo_padding_free is not None:
|
||||||
|
|||||||
@@ -38,6 +38,10 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The multiprocessing start method to use."},
|
||||||
|
)
|
||||||
multipack_real_batches: bool = field(
|
multipack_real_batches: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
|||||||
@@ -776,6 +776,9 @@ class ModelLoader:
|
|||||||
dist_dtype: torch.dtype,
|
dist_dtype: torch.dtype,
|
||||||
before_kbit_train_or_finetune: bool,
|
before_kbit_train_or_finetune: bool,
|
||||||
):
|
):
|
||||||
|
dest = {"dtype": dist_dtype}
|
||||||
|
if self.cfg.lora_on_cpu:
|
||||||
|
dest["device"] = "cpu"
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
@@ -786,4 +789,4 @@ class ModelLoader:
|
|||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
module.to(dist_dtype)
|
module.to(**dest)
|
||||||
|
|||||||
@@ -65,6 +65,7 @@ class PatchManager:
|
|||||||
self._apply_mistral_cross_entropy_patch()
|
self._apply_mistral_cross_entropy_patch()
|
||||||
self._apply_self_attention_lora_patch()
|
self._apply_self_attention_lora_patch()
|
||||||
self._apply_gemma3_conditional_generation_forward_patch()
|
self._apply_gemma3_conditional_generation_forward_patch()
|
||||||
|
self._apply_sequence_parallel_patches()
|
||||||
|
|
||||||
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
def apply_post_model_load_patches(self, model: PreTrainedModel):
|
||||||
"""Apply patches that require the model instance."""
|
"""Apply patches that require the model instance."""
|
||||||
@@ -231,6 +232,17 @@ class PatchManager:
|
|||||||
|
|
||||||
patch_gemma3_conditional_generation_forward()
|
patch_gemma3_conditional_generation_forward()
|
||||||
|
|
||||||
|
def _apply_sequence_parallel_patches(self):
|
||||||
|
"""Apply sequence parallelism patches."""
|
||||||
|
if self.cfg.sequence_parallel_degree and self.cfg.sequence_parallel_degree > 1:
|
||||||
|
from axolotl.monkeypatch.ring_attn.patch import (
|
||||||
|
patch_prepare_data_loader,
|
||||||
|
patch_prepare_device_mesh,
|
||||||
|
)
|
||||||
|
|
||||||
|
patch_prepare_data_loader()
|
||||||
|
patch_prepare_device_mesh(self.cfg.sequence_parallel_degree, self.cfg.fsdp)
|
||||||
|
|
||||||
def _patch_attention(self):
|
def _patch_attention(self):
|
||||||
"""Apply attention-specific patches based on model type."""
|
"""Apply attention-specific patches based on model type."""
|
||||||
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
if not (self.cfg.flash_attention and hasattr(self.model_config, "model_type")):
|
||||||
|
|||||||
@@ -42,6 +42,10 @@ def patch_for_multipack(model_type, model_name=None, has_remote_code=False):
|
|||||||
if has_remote_code:
|
if has_remote_code:
|
||||||
patch_remote(model_name)
|
patch_remote(model_name)
|
||||||
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
elif hasattr(transformers, "modeling_flash_attention_utils"):
|
||||||
|
# sanity check in case upstream api changes on this
|
||||||
|
assert hasattr(
|
||||||
|
transformers.modeling_flash_attention_utils, "_get_unpad_data"
|
||||||
|
), "transformers api changed for _get_unpad_data for flash attention"
|
||||||
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
transformers.modeling_flash_attention_utils._get_unpad_data = ( # pylint: disable=protected-access
|
||||||
get_unpad_data
|
get_unpad_data
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -152,7 +152,7 @@ def update_ring_attn_params(position_ids: torch.Tensor | None):
|
|||||||
def patch_prepare_data_loader():
|
def patch_prepare_data_loader():
|
||||||
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
"""Patch `accelerate.data_loader.prepare_data_loader` to respect the SP degree.
|
||||||
|
|
||||||
Raies:
|
Raises:
|
||||||
RuntimeError: If source code to patch does not exist.
|
RuntimeError: If source code to patch does not exist.
|
||||||
"""
|
"""
|
||||||
original_fn = accelerate.data_loader.prepare_data_loader
|
original_fn = accelerate.data_loader.prepare_data_loader
|
||||||
@@ -168,23 +168,34 @@ def patch_prepare_data_loader():
|
|||||||
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
ORIGINAL_PREPARE_DATALOADER_CODE, NEW_PREPARE_DATALOADER_CODE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
items_to_import = []
|
||||||
|
for item in dir(accelerate.data_loader):
|
||||||
|
if item in patched_source:
|
||||||
|
items_to_import.append(item)
|
||||||
|
|
||||||
# Create a new function from the patched source
|
# Create a new function from the patched source
|
||||||
namespace = {}
|
namespace = {}
|
||||||
exec( # pylint: disable=exec-used # nosec B102
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
patched_source, accelerate.data_loader.__dict__, namespace
|
f"from accelerate.data_loader import ({', '.join(items_to_import)})",
|
||||||
|
globals(),
|
||||||
|
)
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
patched_source, globals(), namespace
|
||||||
)
|
)
|
||||||
patched_function = namespace["prepare_data_loader"]
|
|
||||||
|
|
||||||
accelerate.data_loader.prepare_data_loader = patched_function
|
patched_function = namespace["prepare_data_loader"]
|
||||||
|
original_fn.__code__ = patched_function.__code__
|
||||||
|
|
||||||
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
LOG.info("Patched accelerate.data_loader.prepare_data_loader for SP support")
|
||||||
|
|
||||||
|
|
||||||
def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
def patch_prepare_device_mesh(sequence_parallel_degree: int, fsdp: bool = False):
|
||||||
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
"""Patches the `Accelerator._prepare_device_mesh` method to create a device mesh
|
||||||
that includes sequence parallelism with the specified degree.
|
that includes sequence parallelism with the specified degree.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
sequence_parallel_degree (int): The degree of sequence parallelism to use.
|
sequence_parallel_degree: The degree of sequence parallelism to use.
|
||||||
|
fsdp: Whether to use FSDP.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _prepare_device_mesh(self):
|
def _prepare_device_mesh(self):
|
||||||
@@ -207,12 +218,14 @@ def patch_prepare_device_mesh(sequence_parallel_degree: int):
|
|||||||
)
|
)
|
||||||
device_ids = list(range(world_size))
|
device_ids = list(range(world_size))
|
||||||
|
|
||||||
# Note that we use "cp" instead of "sp" to match the PyTorch native "context
|
# NOTE: We use "cp" instead of "sp" to match the PyTorch native "context
|
||||||
# parallelism" implementation naming
|
# parallelism" implementation naming.
|
||||||
|
# NOTE: We have a simplified FSDP handling here; i.e., if FSDP is enabled, we
|
||||||
|
# only use "fsdp" and "cp" for the device mesh.
|
||||||
return dist.DeviceMesh(
|
return dist.DeviceMesh(
|
||||||
"cuda",
|
"cuda",
|
||||||
torch.tensor(device_ids).reshape(mesh_shape),
|
torch.tensor(device_ids).reshape(mesh_shape),
|
||||||
mesh_dim_names=("dp", "cp"),
|
mesh_dim_names=("dp", "cp") if not fsdp else ("fsdp", "cp"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Replace the original method with our new method
|
# Replace the original method with our new method
|
||||||
|
|||||||
@@ -103,6 +103,7 @@ class ChatTemplatePrompter(Prompter):
|
|||||||
chat_template_kwargs = {
|
chat_template_kwargs = {
|
||||||
"chat_template": self.chat_template,
|
"chat_template": self.chat_template,
|
||||||
"add_generation_prompt": add_generation_prompt,
|
"add_generation_prompt": add_generation_prompt,
|
||||||
|
**self.chat_template_kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
if tools:
|
if tools:
|
||||||
|
|||||||
@@ -223,6 +223,9 @@ def execute_training(
|
|||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
|
# TODO: disabling for now as not compatible with FSDP2 + torchao low bit optimizers
|
||||||
|
# if cfg.bf16:
|
||||||
|
# torch.set_default_dtype(torch.bfloat16)
|
||||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -12,8 +12,6 @@ from transformers.utils import ModelOutput
|
|||||||
|
|
||||||
from axolotl.monkeypatch.ring_attn import (
|
from axolotl.monkeypatch.ring_attn import (
|
||||||
get_ring_attn_group,
|
get_ring_attn_group,
|
||||||
patch_prepare_data_loader,
|
|
||||||
patch_prepare_device_mesh,
|
|
||||||
register_ring_attn,
|
register_ring_attn,
|
||||||
update_ring_attn_params,
|
update_ring_attn_params,
|
||||||
)
|
)
|
||||||
@@ -207,9 +205,6 @@ class SequenceParallelContextManager:
|
|||||||
# Store original sequence length and padding information
|
# Store original sequence length and padding information
|
||||||
self.original_seq_len = 0
|
self.original_seq_len = 0
|
||||||
self.pad_len = 0
|
self.pad_len = 0
|
||||||
|
|
||||||
# Store kwargs passed to model forward pass
|
|
||||||
self.original_kwargs: None | dict[str, torch.Tensor] = None
|
|
||||||
|
|
||||||
# Create a partially applied version of the apply_sequence_parallelism function
|
# Create a partially applied version of the apply_sequence_parallelism function
|
||||||
self.apply_sequence_parallelism = functools.partial(
|
self.apply_sequence_parallelism = functools.partial(
|
||||||
@@ -241,12 +236,6 @@ class SequenceParallelContextManager:
|
|||||||
ring_attn_func=self.ring_attn_func,
|
ring_attn_func=self.ring_attn_func,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Patches for accelerate functionality
|
|
||||||
patch_prepare_data_loader()
|
|
||||||
patch_prepare_device_mesh(
|
|
||||||
sequence_parallel_degree=self.sequence_parallel_degree
|
|
||||||
)
|
|
||||||
|
|
||||||
def _register_model_hooks(self):
|
def _register_model_hooks(self):
|
||||||
# Forward pre-hook to apply sequence parallelism
|
# Forward pre-hook to apply sequence parallelism
|
||||||
def sequence_parallel_pre_hook(_, args, kwargs):
|
def sequence_parallel_pre_hook(_, args, kwargs):
|
||||||
@@ -262,9 +251,6 @@ class SequenceParallelContextManager:
|
|||||||
|
|
||||||
# Any excess positional arguments are kept as-is
|
# Any excess positional arguments are kept as-is
|
||||||
remaining_args = args[len(forward_params) :]
|
remaining_args = args[len(forward_params) :]
|
||||||
|
|
||||||
# Store original kwargs
|
|
||||||
self.original_kwargs = {key: value.clone() for key, value in updated_kwargs.items()}
|
|
||||||
|
|
||||||
# Apply sequence parallelism to updated kwargs
|
# Apply sequence parallelism to updated kwargs
|
||||||
updated_kwargs, self.original_seq_len, self.pad_len = (
|
updated_kwargs, self.original_seq_len, self.pad_len = (
|
||||||
|
|||||||
@@ -224,10 +224,10 @@ def wrap_pretraining_dataset(
|
|||||||
remove_columns = []
|
remove_columns = []
|
||||||
if dataset.features is None:
|
if dataset.features is None:
|
||||||
for first_row in dataset:
|
for first_row in dataset:
|
||||||
remove_columns = first_row.keys()
|
remove_columns = list(first_row.keys())
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
remove_columns = dataset.features.keys()
|
remove_columns = list(dataset.features.keys())
|
||||||
|
|
||||||
dataset = dataset.map(
|
dataset = dataset.map(
|
||||||
encode,
|
encode,
|
||||||
@@ -267,6 +267,7 @@ def encode_packed_pretraining(
|
|||||||
batch_size=1,
|
batch_size=1,
|
||||||
batch_max_len=batch_size * max_seq_length,
|
batch_max_len=batch_size * max_seq_length,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
|
num_processes=1,
|
||||||
)
|
)
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
chunked_data = defaultdict(list)
|
||||||
|
|||||||
@@ -334,7 +334,10 @@ def _load_raw_datasets(
|
|||||||
dataset = merge_datasets(datasets, cfg)
|
dataset = merge_datasets(datasets, cfg)
|
||||||
|
|
||||||
if not cfg.skip_prepare_dataset:
|
if not cfg.skip_prepare_dataset:
|
||||||
dataset = drop_long_seq_in_dataset(dataset, cfg)
|
if split == "test" and cfg.eval_sequence_len:
|
||||||
|
dataset = drop_long_seq_in_dataset(dataset, cfg.eval_sequence_len, cfg)
|
||||||
|
else:
|
||||||
|
dataset = drop_long_seq_in_dataset(dataset, cfg.sequence_len, cfg)
|
||||||
if cfg.sample_packing:
|
if cfg.sample_packing:
|
||||||
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
dataset, _ = process_datasets_for_packing(cfg, dataset, None)
|
||||||
|
|
||||||
|
|||||||
@@ -524,13 +524,25 @@ def merge_datasets(datasets: list[Dataset], cfg: DictDefault) -> Dataset:
|
|||||||
Merged dataset.
|
Merged dataset.
|
||||||
"""
|
"""
|
||||||
if len(datasets) == 1:
|
if len(datasets) == 1:
|
||||||
return datasets[0]
|
ds = datasets[0]
|
||||||
|
|
||||||
|
# Do not shuffle if curriculum sampling is enabled or
|
||||||
|
# shuffle_merged_datasets is disabled
|
||||||
|
if cfg.curriculum_sampling or not cfg.shuffle_merged_datasets:
|
||||||
|
return ds
|
||||||
|
|
||||||
|
return ds.shuffle(seed=cfg.seed)
|
||||||
|
|
||||||
LOG.info("Merging datasets...")
|
LOG.info("Merging datasets...")
|
||||||
merged_dataset = concatenate_datasets(datasets)
|
merged_dataset = concatenate_datasets(datasets)
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets:
|
if cfg.shuffle_merged_datasets:
|
||||||
LOG.debug("Shuffling merged datasets...")
|
LOG.debug("Shuffling merged datasets...")
|
||||||
|
if cfg.curriculum_sampling:
|
||||||
|
LOG.warning(
|
||||||
|
"Shuffling merged datasets with curriculum sampling is not recommended. "
|
||||||
|
"This will randomize the order of samples."
|
||||||
|
)
|
||||||
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
merged_dataset = merged_dataset.shuffle(seed=cfg.seed)
|
||||||
else:
|
else:
|
||||||
LOG.debug("Not shuffling merged datasets.")
|
LOG.debug("Not shuffling merged datasets.")
|
||||||
|
|||||||
@@ -148,11 +148,14 @@ def deduplicate_and_log_datasets(
|
|||||||
return dataset, other_dataset
|
return dataset, other_dataset
|
||||||
|
|
||||||
|
|
||||||
def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
def drop_long_seq_in_dataset(
|
||||||
|
dataset: Dataset, sequence_len: int, cfg: DictDefault
|
||||||
|
) -> Dataset:
|
||||||
"""Remove sequences longer than configured maximum from dataset.
|
"""Remove sequences longer than configured maximum from dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: Dataset to filter.
|
dataset: Dataset to filter.
|
||||||
|
sequence_len: Maximum length for sequences to keep
|
||||||
cfg: Dictionary mapping `axolotl` config keys to values.
|
cfg: Dictionary mapping `axolotl` config keys to values.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -167,7 +170,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
|||||||
|
|
||||||
drop_long = functools.partial(
|
drop_long = functools.partial(
|
||||||
drop_long_seq,
|
drop_long_seq,
|
||||||
sequence_len=cfg.sequence_len,
|
sequence_len=sequence_len,
|
||||||
min_sequence_len=cfg.min_sample_len,
|
min_sequence_len=cfg.min_sample_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,7 +190,7 @@ def drop_long_seq_in_dataset(dataset: Dataset, cfg: DictDefault) -> Dataset:
|
|||||||
|
|
||||||
drop_long_kwargs = {}
|
drop_long_kwargs = {}
|
||||||
if filter_map_kwargs:
|
if filter_map_kwargs:
|
||||||
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
drop_long_kwargs["desc"] = f"Dropping Long Sequences (>{sequence_len})"
|
||||||
|
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
from mistral_common.tokens.tokenizers.mistral import MistralTokenizer
|
||||||
from mistral_common.tokens.tokenizers.tekken import Tekkenizer
|
from mistral_common.tokens.tokenizers.tekken import SpecialTokenPolicy, Tekkenizer
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
@@ -251,10 +251,13 @@ class HFMistralTokenizer:
|
|||||||
token_ids = [token_ids]
|
token_ids = [token_ids]
|
||||||
|
|
||||||
if skip_special_tokens:
|
if skip_special_tokens:
|
||||||
return self._mistral.instruct_tokenizer.tokenizer.decode(token_ids)
|
return self._mistral.instruct_tokenizer.tokenizer.decode(
|
||||||
|
token_ids, special_token_policy=SpecialTokenPolicy.IGNORE
|
||||||
|
)
|
||||||
|
|
||||||
# to_string returns a string with special tokens
|
return self._mistral.instruct_tokenizer.tokenizer.decode(
|
||||||
return self._mistral.instruct_tokenizer.tokenizer.to_string(token_ids)
|
token_ids, special_token_policy=SpecialTokenPolicy.KEEP
|
||||||
|
)
|
||||||
|
|
||||||
def _create_mistral_chat_completion_request(
|
def _create_mistral_chat_completion_request(
|
||||||
self, conversation: list[dict], tools: list[dict] | None = None
|
self, conversation: list[dict], tools: list[dict] | None = None
|
||||||
|
|||||||
@@ -127,7 +127,7 @@ def pack_parallel(
|
|||||||
bin_size: int,
|
bin_size: int,
|
||||||
num_processes: int | None = None,
|
num_processes: int | None = None,
|
||||||
safe_mode: bool = True,
|
safe_mode: bool = True,
|
||||||
mp_start_method: str | None = "spawn",
|
mp_start_method: str | None = "fork",
|
||||||
) -> list[list[int]]:
|
) -> list[list[int]]:
|
||||||
"""Pack sequences into bins using parallel processing.
|
"""Pack sequences into bins using parallel processing.
|
||||||
|
|
||||||
@@ -260,12 +260,13 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
lengths: np.ndarray, # Sequence lengths
|
lengths: np.ndarray, # Sequence lengths
|
||||||
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
|
||||||
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
|
||||||
num_count_samples: int = 8, # Number of times to estimate batch count
|
num_count_samples: int = 4, # Number of times to estimate batch count
|
||||||
sequential: bool = False, # Whether to use sequential packing
|
sequential: bool = False, # Whether to use sequential packing
|
||||||
group_size: int = 100_000, # Size of groups for parallel packing
|
group_size: int = 100_000, # Size of groups for parallel packing
|
||||||
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
bin_size: int = 200, # The max number of samples that can be packed in a single bin
|
||||||
num_processes: int | None = None, # Number of processes for parallel packing
|
num_processes: int | None = None, # Number of processes for parallel packing
|
||||||
safe_mode: bool = True, # Conservative packing to prevent training instability
|
safe_mode: bool = True, # Conservative packing to prevent training instability
|
||||||
|
mp_start_method: str = "fork",
|
||||||
**kwargs, # pylint: disable=unused-argument
|
**kwargs, # pylint: disable=unused-argument
|
||||||
):
|
):
|
||||||
super().__init__(sampler, batch_size, drop_last)
|
super().__init__(sampler, batch_size, drop_last)
|
||||||
@@ -278,6 +279,7 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
self.bin_size = bin_size
|
self.bin_size = bin_size
|
||||||
self.num_processes = num_processes
|
self.num_processes = num_processes
|
||||||
self.safe_mode = safe_mode
|
self.safe_mode = safe_mode
|
||||||
|
self.mp_start_method = mp_start_method
|
||||||
|
|
||||||
assert isinstance(self.lengths, np.ndarray)
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
|
||||||
@@ -333,13 +335,15 @@ class MultipackBatchSampler(BatchSampler):
|
|||||||
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
|
||||||
else:
|
else:
|
||||||
# Use parallel packing
|
# Use parallel packing
|
||||||
|
num_processes = self.num_processes or 1
|
||||||
all_bins = pack_parallel(
|
all_bins = pack_parallel(
|
||||||
lengths,
|
lengths,
|
||||||
bin_capacity=self.batch_max_len,
|
bin_capacity=self.batch_max_len,
|
||||||
group_size=self.group_size,
|
group_size=self.group_size,
|
||||||
bin_size=self.bin_size,
|
bin_size=self.bin_size,
|
||||||
num_processes=self.num_processes,
|
num_processes=min(4, num_processes) if num_processes else 4,
|
||||||
safe_mode=self.safe_mode,
|
safe_mode=self.safe_mode,
|
||||||
|
mp_start_method=self.mp_start_method,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Map bin indices back to original indices
|
# Map bin indices back to original indices
|
||||||
|
|||||||
@@ -146,6 +146,7 @@ class AxolotlInputConfig(
|
|||||||
dpo_label_smoothing: float | None = None
|
dpo_label_smoothing: float | None = None
|
||||||
dpo_norm_loss: bool | None = None
|
dpo_norm_loss: bool | None = None
|
||||||
dpo_padding_free: bool | None = None
|
dpo_padding_free: bool | None = None
|
||||||
|
dpo_generate_during_eval: bool | None = None
|
||||||
|
|
||||||
datasets: (
|
datasets: (
|
||||||
Annotated[
|
Annotated[
|
||||||
@@ -366,6 +367,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
"description": "The maximum length of an input to train with, this should typically be less than 2048 as most models have a token/context limit of 2048"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
eval_sequence_len: int | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The maximum length of an input for evaluation. If not specified, defaults to sequence_len"
|
||||||
|
},
|
||||||
|
)
|
||||||
min_sample_len: int | None = None
|
min_sample_len: int | None = None
|
||||||
max_prompt_len: int = Field(
|
max_prompt_len: int = Field(
|
||||||
default=512,
|
default=512,
|
||||||
@@ -393,6 +400,12 @@ class AxolotlInputConfig(
|
|||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
json_schema_extra={"description": "Whether to pack samples sequentially"},
|
||||||
)
|
)
|
||||||
|
sample_packing_mp_start_method: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "The multiprocessing start method to use for packing. Should be 'fork', 'spawn' or 'forkserver'"
|
||||||
|
},
|
||||||
|
)
|
||||||
eval_sample_packing: bool | None = Field(
|
eval_sample_packing: bool | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
@@ -772,6 +785,12 @@ class AxolotlInputConfig(
|
|||||||
"description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
|
"description": "Custom jinja template for chat template. This will be only used if chat_template is set to `jinja` or `null` (in which case chat_template is automatically set to `jinja`). Default is null."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
chat_template_kwargs: dict[str, Any] | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "Additional kwargs to pass to the chat template. This is useful for customizing the chat template. For example, you can pass `thinking=False` to add a generation prompt to the chat template."
|
||||||
|
},
|
||||||
|
)
|
||||||
eot_tokens: list[str] | None = Field(
|
eot_tokens: list[str] | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
@@ -462,6 +462,20 @@ class TrainingValidationMixin:
|
|||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
|
@classmethod
|
||||||
|
def pretrain_with_tps(cls, data):
|
||||||
|
if data.get("pretraining_dataset") and data.get(
|
||||||
|
"include_tokens_per_second", False
|
||||||
|
):
|
||||||
|
# combining these would raise `TypeError: cannot pickle 'dict_keys' object`
|
||||||
|
# due to trying to count the number of tokens total in the dataset
|
||||||
|
raise ValueError(
|
||||||
|
"pretraining_dataset and include_tokens_per_second cannot be used together."
|
||||||
|
)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LoRAValidationMixin:
|
class LoRAValidationMixin:
|
||||||
"""Validation methods related to LoRA/QLoRA configuration."""
|
"""Validation methods related to LoRA/QLoRA configuration."""
|
||||||
|
|||||||
@@ -381,6 +381,7 @@ def process_pretraining_datasets_for_packing(
|
|||||||
if not skip_position_ids:
|
if not skip_position_ids:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
|
batched=True,
|
||||||
desc="Add position_id column (Pretraining Sample Packing)",
|
desc="Add position_id column (Pretraining Sample Packing)",
|
||||||
)
|
)
|
||||||
if drop_attention_mask:
|
if drop_attention_mask:
|
||||||
@@ -467,6 +468,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
sequential=cfg.sample_packing_sequentially,
|
sequential=cfg.sample_packing_sequentially,
|
||||||
drop_last=True,
|
drop_last=True,
|
||||||
num_processes=cfg.dataset_processes,
|
num_processes=cfg.dataset_processes,
|
||||||
|
mp_start_method=cfg.sample_packing_mp_start_method or "fork",
|
||||||
)
|
)
|
||||||
|
|
||||||
data_loader = DataLoader(
|
data_loader = DataLoader(
|
||||||
@@ -607,6 +609,9 @@ def prepare_opinionated_env(cfg):
|
|||||||
if cfg.qlora_sharded_model_loading:
|
if cfg.qlora_sharded_model_loading:
|
||||||
# model loading is forked after the tokenizer
|
# model loading is forked after the tokenizer
|
||||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
if cfg.sample_packing:
|
||||||
|
# multipack parallel packing sampler defaults to using fork
|
||||||
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(
|
def setup_trainer(
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ shared pytest fixtures
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path, PosixPath
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
import pytest
|
import pytest
|
||||||
@@ -24,6 +26,8 @@ from tests.hf_offline_utils import (
|
|||||||
hf_offline_context,
|
hf_offline_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||||
|
|
||||||
|
|
||||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -411,7 +415,7 @@ def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_dir():
|
def temp_dir() -> Generator[str, None, None]:
|
||||||
# Create a temporary directory
|
# Create a temporary directory
|
||||||
_temp_dir = tempfile.mkdtemp()
|
_temp_dir = tempfile.mkdtemp()
|
||||||
yield _temp_dir
|
yield _temp_dir
|
||||||
@@ -419,6 +423,11 @@ def temp_dir():
|
|||||||
shutil.rmtree(_temp_dir)
|
shutil.rmtree(_temp_dir)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
|
def unique_triton_cache_dir(temp_dir: str | PosixPath) -> None:
|
||||||
|
os.environ["TRITON_CACHE_DIR"] = str(temp_dir) + "/.triton/cache"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="function", autouse=True)
|
@pytest.fixture(scope="function", autouse=True)
|
||||||
def cleanup_monkeypatches():
|
def cleanup_monkeypatches():
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class TestSequenceParallelism:
|
|||||||
"micro_batch_size": micro_batch_size,
|
"micro_batch_size": micro_batch_size,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ class TestPackedFlex:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -309,6 +309,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"warmup_steps": 10,
|
"warmup_steps": 10,
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -400,6 +401,7 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
|||||||
"warmup_steps": 10,
|
"warmup_steps": 10,
|
||||||
"val_set_size": 0.0,
|
"val_set_size": 0.0,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -38,12 +38,13 @@ class TestMultiGPUEval:
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
"val_set_size": 0.004,
|
"val_set_size": 0.05,
|
||||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "teknium/GPT4-LLM-Cleaned",
|
"path": "teknium/GPT4-LLM-Cleaned",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
|
"split": "train[:5%]",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
@@ -51,6 +52,7 @@ class TestMultiGPUEval:
|
|||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -107,12 +109,13 @@ class TestMultiGPUEval:
|
|||||||
"lora_dropout": 0.05,
|
"lora_dropout": 0.05,
|
||||||
"lora_target_linear": True,
|
"lora_target_linear": True,
|
||||||
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
"val_set_size": 0.0004,
|
"val_set_size": 0.01,
|
||||||
"special_tokens": {"pad_token": "<|endoftext|>"},
|
"special_tokens": {"pad_token": "<|endoftext|>"},
|
||||||
"datasets": [
|
"datasets": [
|
||||||
{
|
{
|
||||||
"path": "teknium/GPT4-LLM-Cleaned",
|
"path": "teknium/GPT4-LLM-Cleaned",
|
||||||
"type": "alpaca",
|
"type": "alpaca",
|
||||||
|
"split": "train[:5%]",
|
||||||
},
|
},
|
||||||
],
|
],
|
||||||
"num_epochs": 1,
|
"num_epochs": 1,
|
||||||
@@ -120,6 +123,7 @@ class TestMultiGPUEval:
|
|||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ class TestMultiGPUGemma3:
|
|||||||
},
|
},
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -127,6 +128,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -200,6 +202,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -278,6 +281,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"warmup_steps": 0,
|
"warmup_steps": 0,
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
@@ -340,6 +344,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -412,6 +417,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -491,6 +497,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"gradient_checkpointing": True,
|
"gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_8bit",
|
"optimizer": "adamw_torch_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -573,6 +580,7 @@ class TestMultiGPULlama:
|
|||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
# "gradient_checkpointing": True,
|
# "gradient_checkpointing": True,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -669,6 +677,7 @@ class TestMultiGPULlama:
|
|||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -743,6 +752,7 @@ class TestMultiGPULlama:
|
|||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -817,6 +827,7 @@ class TestMultiGPULlama:
|
|||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ class TestMultiGPUQwen2:
|
|||||||
"micro_batch_size": 2,
|
"micro_batch_size": 2,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch_fused",
|
"optimizer": "adamw_torch_fused",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ class TestMultiGPURay:
|
|||||||
"micro_batch_size": 4,
|
"micro_batch_size": 4,
|
||||||
"gradient_accumulation_steps": 2,
|
"gradient_accumulation_steps": 2,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_8bit",
|
"optimizer": "adamw_8bit",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
@@ -107,6 +108,7 @@ class TestMultiGPURay:
|
|||||||
"micro_batch_size": 1,
|
"micro_batch_size": 1,
|
||||||
"gradient_accumulation_steps": gradient_accumulation_steps,
|
"gradient_accumulation_steps": gradient_accumulation_steps,
|
||||||
"output_dir": temp_dir,
|
"output_dir": temp_dir,
|
||||||
|
"dataset_prepared_path": temp_dir + "/last_run_prepared",
|
||||||
"learning_rate": 0.00001,
|
"learning_rate": 0.00001,
|
||||||
"optimizer": "adamw_torch",
|
"optimizer": "adamw_torch",
|
||||||
"lr_scheduler": "cosine",
|
"lr_scheduler": "cosine",
|
||||||
|
|||||||
@@ -396,7 +396,7 @@ def test_model_architecture(model_config):
|
|||||||
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
def test_kernel_training_integration():
|
def test_kernel_training_integration(temp_dir):
|
||||||
"""Test model loading with kernel patches enabled."""
|
"""Test model loading with kernel patches enabled."""
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
|
|
||||||
@@ -426,6 +426,14 @@ def test_kernel_training_integration():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Write cfg to yaml file
|
||||||
|
path = Path(temp_dir) / "config.yaml"
|
||||||
|
with open(path, "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
cfg = load_cfg(str(path))
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
model, _, _ = load_model_and_tokenizer(cfg=cfg)
|
||||||
|
|
||||||
@@ -505,7 +513,7 @@ def test_kernel_training_integration_auto_enable(temp_dir):
|
|||||||
assert found_patched_attn
|
assert found_patched_attn
|
||||||
|
|
||||||
|
|
||||||
def test_kernel_training_integration_dropout_non_zero():
|
def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||||
"""Test model loading with dropout non-zero should not patch."""
|
"""Test model loading with dropout non-zero should not patch."""
|
||||||
|
|
||||||
from axolotl.cli.utils import load_model_and_tokenizer
|
from axolotl.cli.utils import load_model_and_tokenizer
|
||||||
@@ -533,6 +541,14 @@ def test_kernel_training_integration_dropout_non_zero():
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Write cfg to yaml file
|
||||||
|
path = Path(temp_dir) / "config.yaml"
|
||||||
|
with open(path, "w", encoding="utf-8") as fout:
|
||||||
|
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
|
||||||
|
|
||||||
|
# Load config
|
||||||
|
cfg = load_cfg(str(path))
|
||||||
|
|
||||||
# Get original attention class
|
# Get original attention class
|
||||||
attention_cls = get_attention_cls_from_config(cfg)
|
attention_cls = get_attention_cls_from_config(cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class TestBatchedSamplerPacking:
|
|||||||
)
|
)
|
||||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||||
|
|
||||||
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg)
|
train_dataset = drop_long_seq_in_dataset(train_dataset, cfg.sequence_len, cfg)
|
||||||
|
|
||||||
lengths = get_dataset_lengths(train_dataset)
|
lengths = get_dataset_lengths(train_dataset)
|
||||||
batch_sampler = MultipackBatchSampler(
|
batch_sampler = MultipackBatchSampler(
|
||||||
|
|||||||
Reference in New Issue
Block a user