Compare commits
14 Commits
cli-cloud-
...
fix-merge-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
385736fae1 | ||
|
|
f89e962119 | ||
|
|
bc1c9c20e3 | ||
|
|
dd26cc3c0f | ||
|
|
d8b4027200 | ||
|
|
fb3352e21c | ||
|
|
ed77e7001e | ||
|
|
7669a03fb4 | ||
|
|
6553683170 | ||
|
|
5e0124e2ab | ||
|
|
2e8d7c1adb | ||
|
|
3c1921e400 | ||
|
|
7faf2b6e8e | ||
|
|
c1b920f291 |
1
.github/workflows/lint.yml
vendored
1
.github/workflows/lint.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: lint
|
||||
on:
|
||||
# check on PRs, and manual triggers
|
||||
merge_group:
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
|
||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -25,7 +25,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras: mamba-ssm
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -36,6 +35,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -92,7 +92,6 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.3.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 124
|
||||
cuda_version: 12.4.1
|
||||
python_version: "3.11"
|
||||
@@ -103,6 +102,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.5.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
runs-on: axolotl-gpu-runner
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
2
.github/workflows/multi-gpu-e2e.yml
vendored
2
.github/workflows/multi-gpu-e2e.yml
vendored
@@ -52,7 +52,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.63.64 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
2
.github/workflows/tests-nightly.yml
vendored
2
.github/workflows/tests-nightly.yml
vendored
@@ -129,7 +129,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.63.64 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
41
.github/workflows/tests.yml
vendored
41
.github/workflows/tests.yml
vendored
@@ -1,6 +1,7 @@
|
||||
name: Tests
|
||||
on:
|
||||
# check on push/merge to main, PRs, and manual triggers
|
||||
merge_group:
|
||||
push:
|
||||
branches:
|
||||
- "main"
|
||||
@@ -60,6 +61,15 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -100,6 +110,15 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest-sdist:
|
||||
name: PyTest from Source Dist
|
||||
runs-on: ubuntu-latest
|
||||
@@ -115,6 +134,15 @@ jobs:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -156,6 +184,15 @@ jobs:
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
docker-e2e-tests-1st:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
@@ -183,7 +220,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.63.64 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
@@ -229,7 +266,7 @@ jobs:
|
||||
- name: Install Modal
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install modal==0.63.64 jinja2
|
||||
pip install modal==0.71.8 jinja2
|
||||
- name: Update env vars
|
||||
run: |
|
||||
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
|
||||
|
||||
@@ -23,7 +23,7 @@ repos:
|
||||
hooks:
|
||||
- id: flake8
|
||||
- repo: https://github.com/PyCQA/pylint
|
||||
rev: v2.17.4
|
||||
rev: v3.3.0
|
||||
hooks:
|
||||
- id: pylint
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
[MASTER]
|
||||
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
|
||||
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
|
||||
|
||||
[TYPECHECK]
|
||||
|
||||
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
too-many-positional-arguments, possibly-used-before-assignment
|
||||
|
||||
@@ -217,7 +217,7 @@ If you love axolotl, consider sponsoring the project by reaching out directly to
|
||||
|
||||
---
|
||||
|
||||
- [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
||||
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
|
||||
|
||||
---
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
|
||||
ENV GITHUB_REF="{{ GITHUB_REF }}"
|
||||
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
|
||||
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
|
||||
ENV HF_HOME="{{ HF_HOME }}"
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||
|
||||
@@ -28,6 +28,7 @@ df_args = {
|
||||
"CUDA": os.environ.get("CUDA", "121"),
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
@@ -48,6 +49,12 @@ cicd_image = (
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
hf_cache_volume = modal.Volume.from_name(
|
||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||
)
|
||||
VOLUME_CONFIG = {
|
||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 2))
|
||||
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
|
||||
@@ -67,6 +74,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072 * N_GPUS,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
def cicd_pytest():
|
||||
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")
|
||||
|
||||
@@ -29,6 +29,7 @@ df_args = {
|
||||
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
|
||||
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
|
||||
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
|
||||
"HF_HOME": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
dockerfile_contents = df_template.render(**df_args)
|
||||
@@ -50,6 +51,12 @@ cicd_image = (
|
||||
|
||||
app = App("Axolotl CI/CD", secrets=[])
|
||||
|
||||
hf_cache_volume = modal.Volume.from_name(
|
||||
"axolotl-ci-hf-hub-cache", create_if_missing=True
|
||||
)
|
||||
VOLUME_CONFIG = {
|
||||
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
|
||||
}
|
||||
|
||||
N_GPUS = int(os.environ.get("N_GPUS", 1))
|
||||
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
|
||||
@@ -69,6 +76,7 @@ def run_cmd(cmd: str, run_folder: str):
|
||||
timeout=60 * 60,
|
||||
cpu=8.0,
|
||||
memory=131072,
|
||||
volumes=VOLUME_CONFIG,
|
||||
)
|
||||
def cicd_pytest():
|
||||
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")
|
||||
|
||||
@@ -19,7 +19,14 @@ For pretraining, there is no prompt template or roles. The only required field
|
||||
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
|
||||
|
||||
```{.yaml filename="config.yaml"}
|
||||
pretraining_dataset: # hf path only
|
||||
pretraining_dataset:
|
||||
- name:
|
||||
path:
|
||||
split:
|
||||
text_column: # column in dataset with the data, usually `text`
|
||||
type: pretrain
|
||||
trust_remote_code:
|
||||
skip: # number of rows of data to skip over from the beginning
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
volumes:
|
||||
- name: axolotl-data
|
||||
mount: /workspace/data
|
||||
- name: axolotl-artifacts
|
||||
mount: /workspace/artifacts
|
||||
secrets:
|
||||
- HF_TOKEN
|
||||
- WANDB_API_KEY
|
||||
branch: cli-cloud-modal
|
||||
gpu: h100
|
||||
gpu_count: 1
|
||||
memory: 128
|
||||
timeout: 86400
|
||||
timeout_preprocess: 14400
|
||||
memory_preprocess: 32
|
||||
@@ -1,11 +0,0 @@
|
||||
lm_eval_model: axolotl-ai-co/numina-8b-ep1-exp1
|
||||
lm_eval_tasks:
|
||||
- leaderboard_math_hard
|
||||
lm_eval_batch_size: 64
|
||||
|
||||
apply_chat_template: false
|
||||
wandb_project: numina-kd-experiment
|
||||
wandb_entity: axolotl-ai
|
||||
bf16: true
|
||||
flash_attention: true
|
||||
output_dir: ./outputs/model-evals-out
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
# START section of dependencies that don't install on Darwin/MacOS
|
||||
bitsandbytes==0.45.0
|
||||
triton>=2.3.0
|
||||
triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
flash-attn==2.7.0.post2
|
||||
xformers>=0.0.23.post1
|
||||
@@ -14,18 +14,17 @@ packaging==23.2
|
||||
|
||||
peft==0.14.0
|
||||
transformers==4.47.1
|
||||
tokenizers>=0.20.1
|
||||
tokenizers>=0.21.0
|
||||
accelerate==1.2.1
|
||||
datasets==3.1.0
|
||||
datasets==3.2.0
|
||||
deepspeed==0.16.1
|
||||
trl==0.12.1
|
||||
trl==0.13.0
|
||||
|
||||
optimum==1.16.2
|
||||
hf_transfer
|
||||
sentencepiece
|
||||
gradio==3.50.2
|
||||
|
||||
modal==0.70.5
|
||||
pydantic==2.6.3
|
||||
addict
|
||||
fire
|
||||
@@ -62,4 +61,4 @@ antlr4-python3-runtime==4.13.2
|
||||
torchao==0.7.0
|
||||
schedulefree==1.3.0
|
||||
|
||||
axolotl-contribs-lgpl==0.0.2
|
||||
axolotl-contribs-lgpl==0.0.3
|
||||
|
||||
17
scripts/motd
17
scripts/motd
@@ -1,15 +1,10 @@
|
||||
|
||||
#@@ #@@ @@# @@#
|
||||
@@ @@ @@ @@ =@@# @@ #@ =@@#.
|
||||
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
|
||||
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
|
||||
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
|
||||
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
|
||||
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
|
||||
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
|
||||
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
|
||||
@@@@ @@@@@@@@@@@@@@@@
|
||||
dP dP dP
|
||||
88 88 88
|
||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||
|
||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
||||
|
||||
|
||||
26
setup.py
26
setup.py
@@ -1,4 +1,5 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
import ast
|
||||
import os
|
||||
import platform
|
||||
@@ -29,15 +30,30 @@ def parse_requirements():
|
||||
elif not is_extras and line and line[0] != "#":
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
xformers_version = [req for req in _install_requires if "xformers" in req][0]
|
||||
triton_version = [req for req in _install_requires if "triton" in req][0]
|
||||
torchao_version = [req for req in _install_requires if "torchao" in req][0]
|
||||
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
|
||||
|
||||
if "Darwin" in platform.system():
|
||||
# don't install xformers on MacOS
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# skip packages not compatible with OSX
|
||||
skip_packages = [
|
||||
"bitsandbytes",
|
||||
"triton",
|
||||
"mamba-ssm",
|
||||
"flash-attn",
|
||||
"xformers",
|
||||
"autoawq",
|
||||
"liger-kernel",
|
||||
]
|
||||
_install_requires = [
|
||||
req
|
||||
for req in _install_requires
|
||||
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
|
||||
]
|
||||
print(
|
||||
_install_requires, [req in skip_packages for req in _install_requires]
|
||||
)
|
||||
else:
|
||||
# detect the version of torch already installed
|
||||
# and set it so dependencies don't clobber the torch version
|
||||
@@ -73,6 +89,8 @@ def parse_requirements():
|
||||
_install_requires.append("xformers==0.0.28.post1")
|
||||
elif (major, minor) >= (2, 3):
|
||||
_install_requires.pop(_install_requires.index(torchao_version))
|
||||
_install_requires.pop(_install_requires.index(triton_version))
|
||||
_install_requires.append("triton>=2.3.1")
|
||||
if patch == 0:
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append("xformers>=0.0.26.post1")
|
||||
|
||||
@@ -1,56 +0,0 @@
|
||||
"""
|
||||
launch axolotl in supported cloud platforms
|
||||
"""
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import yaml
|
||||
|
||||
from axolotl.cli import print_axolotl_text_art
|
||||
from axolotl.cli.cloud.modal_ import ModalCloud
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
|
||||
"""Load and validate cloud configuration."""
|
||||
# Load cloud configuration.
|
||||
with open(cloud_config, encoding="utf-8") as file:
|
||||
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
return cloud_cfg
|
||||
|
||||
|
||||
def do_cli_preprocess(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str] = Path("examples/"),
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.preprocess(config_yaml)
|
||||
|
||||
|
||||
def do_cli_train(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str] = Path("examples/"),
|
||||
accelerate: bool = True,
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.train(config_yaml, accelerate=accelerate)
|
||||
|
||||
|
||||
def do_cli_lm_eval(
|
||||
cloud_config: Union[Path, str],
|
||||
config: Union[Path, str] = Path("examples/"),
|
||||
) -> None:
|
||||
print_axolotl_text_art()
|
||||
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||
cloud = ModalCloud(cloud_cfg)
|
||||
with open(config, "r", encoding="utf-8") as file:
|
||||
config_yaml = file.read()
|
||||
cloud.lm_eval(config_yaml)
|
||||
@@ -1,18 +0,0 @@
|
||||
"""
|
||||
base class for cloud platforms from cli
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class Cloud(ABC):
|
||||
"""
|
||||
Abstract base class for cloud platforms.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def train(self, config_yaml: str, accelerate: bool = True) -> str:
|
||||
pass
|
||||
@@ -1,272 +0,0 @@
|
||||
"""
|
||||
Modal Cloud support from CLI
|
||||
"""
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import subprocess # nosec B404
|
||||
from pathlib import Path
|
||||
from random import randint
|
||||
|
||||
import modal
|
||||
|
||||
from axolotl.cli.cloud.base import Cloud
|
||||
|
||||
|
||||
def run_cmd(cmd: str, run_folder: str, volumes=None):
|
||||
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
|
||||
# Ensure volumes contain latest files.
|
||||
if volumes:
|
||||
for _, vol in volumes.items():
|
||||
vol.reload()
|
||||
|
||||
# modal workaround so it doesn't use the automounted axolotl
|
||||
new_env = copy.deepcopy(os.environ)
|
||||
if "PYTHONPATH" in new_env:
|
||||
del new_env["PYTHONPATH"]
|
||||
|
||||
# Propagate errors from subprocess.
|
||||
if exit_code := subprocess.call( # nosec B603
|
||||
cmd.split(), cwd=run_folder, env=new_env
|
||||
):
|
||||
exit(exit_code) # pylint: disable=consider-using-sys-exit
|
||||
|
||||
# Commit writes to volume.
|
||||
if volumes:
|
||||
for _, vol in volumes.items():
|
||||
vol.commit()
|
||||
|
||||
|
||||
class ModalCloud(Cloud):
|
||||
"""
|
||||
Modal Cloud implementation.
|
||||
"""
|
||||
|
||||
def __init__(self, config, app=None):
|
||||
self.config = config
|
||||
if not app:
|
||||
app = modal.App()
|
||||
self.app = app
|
||||
|
||||
self.volumes = {}
|
||||
if config.volumes:
|
||||
for volume_config in config.volumes:
|
||||
_, mount, vol = self.create_volume(volume_config)
|
||||
self.volumes[mount] = (vol, volume_config)
|
||||
|
||||
def get_env(self):
|
||||
res = {
|
||||
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
|
||||
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
|
||||
}
|
||||
|
||||
for key in self.config.get("env", []):
|
||||
if isinstance(key, str):
|
||||
if val := os.environ.get(key, ""):
|
||||
res[key] = val
|
||||
elif isinstance(key, dict):
|
||||
(key_, val) = list(key.items())[0]
|
||||
res[key_] = val
|
||||
return res
|
||||
|
||||
def get_image(self):
|
||||
docker_tag = "main-py3.11-cu124-2.5.1"
|
||||
if self.config.docker_tag:
|
||||
docker_tag = self.config.docker_tag
|
||||
docker_image = f"axolotlai/axolotl:{docker_tag}"
|
||||
|
||||
# grab the sha256 hash from docker hub for this image+tag
|
||||
# this ensures that we always get the latest image for this tag, even if it's already cached
|
||||
try:
|
||||
manifest = subprocess.check_output( # nosec B602
|
||||
f"docker manifest inspect {docker_image}",
|
||||
shell=True,
|
||||
).decode("utf-8")
|
||||
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
|
||||
except subprocess.CalledProcessError:
|
||||
sha256_hash = None
|
||||
|
||||
# create the image
|
||||
if sha256_hash:
|
||||
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
|
||||
else:
|
||||
image = modal.Image.from_registry(docker_image)
|
||||
|
||||
# branch
|
||||
if self.config.branch:
|
||||
image = image.dockerfile_commands(
|
||||
[
|
||||
# Random id for cache busting of branch commits
|
||||
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
|
||||
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
|
||||
"RUN cd /workspace/ && git clone https://github.com/winglian/lm-evaluation-harness.git && cd lm-evaluation-harness && pip install -e .[math]",
|
||||
]
|
||||
)
|
||||
|
||||
if env := self.get_env():
|
||||
image = image.env(env)
|
||||
|
||||
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
|
||||
|
||||
return image
|
||||
|
||||
def get_secrets(self):
|
||||
res = []
|
||||
if self.config.secrets:
|
||||
for key in self.config.get("secrets", []):
|
||||
# pylint: disable=duplicate-code
|
||||
if isinstance(key, str):
|
||||
if val := os.environ.get(key, ""):
|
||||
res.append(modal.Secret.from_dict({key: val}))
|
||||
elif isinstance(key, dict):
|
||||
(key_, val) = list(key.items())[0]
|
||||
res.append(modal.Secret.from_dict({key_: val}))
|
||||
return res
|
||||
|
||||
def create_volume(self, volume_config):
|
||||
name = volume_config.name
|
||||
mount = volume_config.mount
|
||||
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
|
||||
|
||||
def get_ephemeral_disk_size(self):
|
||||
return 1000 * 525 # 1 TiB
|
||||
|
||||
def get_preprocess_timeout(self):
|
||||
if self.config.timeout_preprocess:
|
||||
return int(self.config.timeout_preprocess)
|
||||
return 60 * 60 * 3 # 3 hours
|
||||
|
||||
def get_preprocess_memory(self):
|
||||
memory = 128 # default to 128GiB
|
||||
if self.config.memory:
|
||||
memory = int(self.config.memory)
|
||||
if self.config.memory_preprocess:
|
||||
memory = int(self.config.memory_preprocess)
|
||||
return 1024 * memory
|
||||
|
||||
def get_preprocess_env(self):
|
||||
return self.app.function(
|
||||
image=self.get_image(),
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=8.0,
|
||||
ephemeral_disk=self.get_ephemeral_disk_size(),
|
||||
memory=self.get_preprocess_memory(),
|
||||
timeout=self.get_preprocess_timeout(),
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def preprocess(self, config_yaml: str, *args, **kwargs):
|
||||
modal_fn = self.get_preprocess_env()(_preprocess)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_train_timeout(self):
|
||||
if self.config.timeout:
|
||||
return int(self.config.timeout)
|
||||
return 60 * 60 * 24 # 24 hours
|
||||
|
||||
def get_train_gpu(self): # pylint: disable=too-many-return-statements
|
||||
count = self.config.gpu_count or 1
|
||||
family = self.config.gpu.lower() or "l40s"
|
||||
|
||||
if family == "l40s":
|
||||
return modal.gpu.L40S(count=count)
|
||||
if family == "a100":
|
||||
return modal.gpu.A100(count=count, size="40GB")
|
||||
if family == "a100-80gb":
|
||||
return modal.gpu.A100(count=count, size="80GB")
|
||||
if family in ["a10", "a10g"]:
|
||||
return modal.gpu.A10G(count=count)
|
||||
if family == "h100":
|
||||
return modal.gpu.H100(count=count)
|
||||
if family == "t4":
|
||||
return modal.gpu.T4(count=count)
|
||||
if family == "l4":
|
||||
return modal.gpu.L4(count=count)
|
||||
raise ValueError(f"Unsupported GPU family: {family}")
|
||||
|
||||
def get_train_memory(self):
|
||||
memory = 128 # default to 128GiB
|
||||
if self.config.memory:
|
||||
memory = int(self.config.memory)
|
||||
return 1024 * memory
|
||||
|
||||
def get_train_env(self):
|
||||
return self.app.function(
|
||||
image=self.get_image(),
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
cpu=16.0,
|
||||
gpu=self.get_train_gpu(),
|
||||
memory=self.get_train_memory(),
|
||||
timeout=self.get_train_timeout(),
|
||||
secrets=self.get_secrets(),
|
||||
)
|
||||
|
||||
def train(self, config_yaml: str, accelerate: bool = True):
|
||||
modal_fn = self.get_train_env()(_train)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
accelerate=accelerate,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
)
|
||||
|
||||
def lm_eval(self, config_yaml: str):
|
||||
modal_fn = self.get_train_env()(_lm_eval)
|
||||
with modal.enable_output():
|
||||
with self.app.run(detach=True):
|
||||
modal_fn.remote(
|
||||
config_yaml,
|
||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||
)
|
||||
|
||||
|
||||
def _preprocess(config_yaml: str, volumes=None):
|
||||
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_cmd(
|
||||
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
if accelerate:
|
||||
accelerate_args = "--accelerate"
|
||||
else:
|
||||
accelerate_args = "--no-accelerate"
|
||||
run_cmd(
|
||||
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
|
||||
|
||||
def _lm_eval(config_yaml: str, volumes=None):
|
||||
with open(
|
||||
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||
) as f_out:
|
||||
f_out.write(config_yaml)
|
||||
run_folder = "/workspace/artifacts/axolotl"
|
||||
run_cmd(
|
||||
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
|
||||
run_folder,
|
||||
volumes,
|
||||
)
|
||||
@@ -13,7 +13,6 @@ from axolotl.cli.utils import (
|
||||
fetch_from_github,
|
||||
)
|
||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||
|
||||
@@ -26,21 +25,15 @@ def cli():
|
||||
|
||||
@cli.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(PreprocessCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||
def preprocess(config: str, **kwargs):
|
||||
"""Preprocess datasets before training."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_preprocess
|
||||
from axolotl.cli.preprocess import do_cli
|
||||
|
||||
do_cli_preprocess(cloud_config=cloud, config=config)
|
||||
else:
|
||||
from axolotl.cli.preprocess import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -50,33 +43,25 @@ def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
|
||||
default=True,
|
||||
help="Use accelerate launch for multi-GPU training",
|
||||
)
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
@add_options_from_dataclass(TrainerCliArgs)
|
||||
@add_options_from_config(AxolotlInputConfig)
|
||||
def train(config: str, accelerate: bool, cloud: Optional[str], **kwargs):
|
||||
def train(config: str, accelerate: bool, **kwargs):
|
||||
"""Train or fine-tune a model."""
|
||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||
|
||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||
set_pytorch_cuda_alloc_conf()
|
||||
from axolotl.cli.cloud import do_cli_train
|
||||
|
||||
if accelerate:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
|
||||
else:
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
|
||||
if config:
|
||||
base_cmd.append(config)
|
||||
cmd = build_command(base_cmd, kwargs)
|
||||
subprocess.run(cmd, check=True) # nosec B603
|
||||
else:
|
||||
if cloud:
|
||||
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
|
||||
else:
|
||||
from axolotl.cli.train import do_cli
|
||||
from axolotl.cli.train import do_cli
|
||||
|
||||
do_cli(config=config, **kwargs)
|
||||
do_cli(config=config, **kwargs)
|
||||
|
||||
|
||||
@cli.command()
|
||||
@@ -269,9 +254,6 @@ def fetch(directory: str, dest: Optional[str]):
|
||||
fetch_from_github(f"{directory}/", dest)
|
||||
|
||||
|
||||
cli.add_command(lm_eval)
|
||||
|
||||
|
||||
def main():
|
||||
cli()
|
||||
|
||||
|
||||
@@ -27,7 +27,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
|
||||
if field_type == bool:
|
||||
field_name = field.name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
|
||||
@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from packaging import version
|
||||
from peft.optimizers import create_loraplus_optimizer
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
@@ -608,8 +607,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
self.state.train_batch_size or self.args.per_device_train_batch_size
|
||||
)
|
||||
batch_max_len = train_batch_size * self.args.max_seq_length
|
||||
|
||||
if self.args.curriculum_sampling:
|
||||
sampler = SequentialSampler(self.train_dataset)
|
||||
else:
|
||||
sampler = RandomSampler(self.train_dataset)
|
||||
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
sampler,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
batch_max_len=batch_max_len,
|
||||
@@ -978,12 +983,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
try:
|
||||
return super().log(logs, start_time)
|
||||
except TypeError:
|
||||
return super().log(logs) # transformers<=4.46
|
||||
return super().log(logs) # transformers<=4.46
|
||||
return super().log(logs, start_time)
|
||||
|
||||
def store_metrics(
|
||||
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
|
||||
@@ -1167,22 +1167,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
||||
torch.cuda.empty_cache()
|
||||
return loss
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
"""
|
||||
@@ -1191,22 +1175,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
||||
|
||||
tag_names = ["axolotl", "orpo"]
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
"""
|
||||
@@ -1215,49 +1183,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
||||
|
||||
tag_names = ["axolotl", "kto"]
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# train metrics should have no prefix, eval should have 'eval_'
|
||||
prefix = "eval_" if train_eval == "eval" else ""
|
||||
# accumulate average metrics from sums and lengths
|
||||
for split in ["chosen", "rejected"]:
|
||||
if f"count/{split}" in self._stored_metrics[train_eval]:
|
||||
count_sum = (
|
||||
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
|
||||
.sum()
|
||||
.item()
|
||||
)
|
||||
for metric in ["rewards", "logps", "logits"]:
|
||||
logs[f"{prefix}{metric}/{split}"] = (
|
||||
torch.Tensor(
|
||||
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||
)
|
||||
.sum()
|
||||
.item()
|
||||
/ count_sum
|
||||
)
|
||||
# delete obsolete metric
|
||||
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
|
||||
del self._stored_metrics[train_eval][f"count/{split}"]
|
||||
# calculate reward margin
|
||||
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
|
||||
logs[f"{prefix}rewards/margins"] = (
|
||||
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
|
||||
)
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
"""
|
||||
@@ -1266,22 +1191,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
||||
|
||||
tag_names = ["axolotl", "cpo"]
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
# logs either has 'loss' or 'eval_loss'
|
||||
train_eval = "train" if "loss" in logs else "eval"
|
||||
# Add averaged stored metrics to logs
|
||||
for key, metrics in self._stored_metrics[train_eval].items():
|
||||
logs[key] = torch.tensor(metrics).mean().item()
|
||||
del self._stored_metrics[train_eval]
|
||||
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
"""
|
||||
@@ -1290,15 +1199,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
||||
|
||||
tag_names = ["axolotl", "reward"]
|
||||
|
||||
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
|
||||
# TODO remove once trl supports the updated to the Trainer.log method
|
||||
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
|
||||
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
|
||||
logs, start_time
|
||||
)
|
||||
# transformers<=4.46
|
||||
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
|
||||
@@ -22,13 +22,6 @@ import inspect
|
||||
import logging
|
||||
import sys
|
||||
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
|
||||
from ...utils.distributed import zero_only
|
||||
@@ -46,6 +39,13 @@ class LigerPlugin(BasePlugin):
|
||||
return "axolotl.integrations.liger.LigerArgs"
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
|
||||
from liger_kernel.transformers.functional import liger_cross_entropy
|
||||
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
|
||||
from liger_kernel.transformers.rms_norm import LigerRMSNorm
|
||||
from liger_kernel.transformers.rope import liger_rotary_pos_emb
|
||||
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
|
||||
|
||||
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
|
||||
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
|
||||
liger_fn_sig = inspect.signature(apply_liger_fn)
|
||||
|
||||
@@ -2,9 +2,9 @@
|
||||
Module for the Plugin for LM Eval Harness
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from datetime import datetime
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
|
||||
|
||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||
|
||||
@@ -18,19 +18,25 @@ class LMEvalPlugin(BasePlugin):
|
||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||
|
||||
def post_train_unload(self, cfg):
|
||||
if cfg.lm_eval_post_train:
|
||||
# pylint: disable=duplicate-code
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||
flash_attention=cfg.flash_attention,
|
||||
output_dir=cfg.output_dir,
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
check=True,
|
||||
)
|
||||
tasks = ",".join(cfg.lm_eval_tasks)
|
||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
||||
output_path = cfg.output_dir
|
||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
subprocess.run( # nosec
|
||||
[
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
||||
"--tasks",
|
||||
tasks,
|
||||
"--batch_size",
|
||||
str(cfg.lm_eval_batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
@@ -13,5 +13,3 @@ class LMEvalArgs(BaseModel):
|
||||
|
||||
lm_eval_tasks: List[str] = []
|
||||
lm_eval_batch_size: Optional[int] = 8
|
||||
lm_eval_post_train: Optional[bool] = True
|
||||
lm_eval_model: Optional[str] = None
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
"""
|
||||
axolotl CLI for running lm_eval tasks
|
||||
"""
|
||||
import subprocess # nosec
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def build_lm_eval_command(
|
||||
tasks: list[str],
|
||||
bfloat16=True,
|
||||
flash_attention=False,
|
||||
output_dir="./",
|
||||
batch_size=8,
|
||||
wandb_project=None,
|
||||
wandb_entity=None,
|
||||
model=None,
|
||||
revision=None,
|
||||
apply_chat_template=None,
|
||||
fewshot_as_multiturn=None,
|
||||
):
|
||||
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||
for task in tasks:
|
||||
num_fewshot = "-1"
|
||||
task_parts = task.split(":")
|
||||
task_name = task_parts[0]
|
||||
if len(task_parts) == 2:
|
||||
task_name, num_fewshot = task_parts
|
||||
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
|
||||
|
||||
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
|
||||
tasks_str = ",".join(tasks_list)
|
||||
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
||||
pretrained = "pretrained="
|
||||
pretrained += model if model else output_dir
|
||||
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||
revision = f",revision={revision}" if revision else ""
|
||||
output_path = output_dir
|
||||
output_path += "" if output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
lm_eval_args = [
|
||||
"lm_eval",
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"{pretrained}{fa2}{dtype}{revision}",
|
||||
"--tasks",
|
||||
tasks_str,
|
||||
"--batch_size",
|
||||
str(batch_size),
|
||||
"--output_path",
|
||||
output_path,
|
||||
]
|
||||
wandb_args = []
|
||||
if wandb_project:
|
||||
wandb_args.append(f"project={wandb_project}")
|
||||
if wandb_entity:
|
||||
wandb_args.append(f"entity={wandb_entity}")
|
||||
if wandb_args:
|
||||
lm_eval_args.append("--wandb_args")
|
||||
lm_eval_args.append(",".join(wandb_args))
|
||||
if apply_chat_template:
|
||||
lm_eval_args.append("--apply_chat_template")
|
||||
if num_fewshot_val:
|
||||
lm_eval_args.append("--num_fewshot")
|
||||
lm_eval_args.append(str(num_fewshot_val))
|
||||
if apply_chat_template and fewshot_as_multiturn:
|
||||
lm_eval_args.append("--fewshot_as_multiturn")
|
||||
|
||||
yield lm_eval_args
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||
def lm_eval(config: str, cloud: Optional[str] = None):
|
||||
"""
|
||||
use lm eval to evaluate a trained language model
|
||||
"""
|
||||
|
||||
if cloud:
|
||||
from axolotl.cli.cloud import do_cli_lm_eval
|
||||
|
||||
do_cli_lm_eval(cloud_config=cloud, config=config)
|
||||
else:
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||
flash_attention=cfg.flash_attention,
|
||||
output_dir=cfg.output_dir,
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
revision=cfg.revision,
|
||||
apply_chat_template=cfg.apply_chat_template,
|
||||
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
check=True,
|
||||
)
|
||||
@@ -6,7 +6,7 @@ import logging
|
||||
|
||||
from transformers import Trainer
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ import logging
|
||||
from transformers import LlamaForCausalLM, Trainer
|
||||
from transformers.modeling_flash_attention_utils import _flash_attention_forward
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import detab_code
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
|
||||
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
"""module for patching with unsloth optimizations"""
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import types
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from accelerate.logging import get_logger
|
||||
@@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM
|
||||
from torch import nn
|
||||
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
|
||||
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = get_logger("axolotl.monkeypatch.unsloth")
|
||||
|
||||
ORIGINAL_QKV_CODE = """
|
||||
@@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
|
||||
raise ValueError("Unsupported model type")
|
||||
|
||||
|
||||
def detab_code(code: str) -> Tuple[str, str]:
|
||||
try:
|
||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||
except AttributeError:
|
||||
return code, ""
|
||||
return code, spaces
|
||||
|
||||
|
||||
self_attn_lora_patched = False # pylint: disable=invalid-name
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
from typing import Optional
|
||||
import re
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def detab_code(code: str) -> Tuple[str, str]:
|
||||
try:
|
||||
spaces = re.match(r"([\s\t]{1,})", code).group(0)
|
||||
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
|
||||
except AttributeError:
|
||||
return code, ""
|
||||
return code, spaces
|
||||
|
||||
@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
|
||||
getattr, self.layers_attribute.split("."), self.trainer.model
|
||||
)
|
||||
LOG.info(
|
||||
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
|
||||
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
|
||||
)
|
||||
|
||||
def freeze_all_layers(self):
|
||||
|
||||
@@ -128,6 +128,8 @@ class PretrainingDataset(BaseModel):
|
||||
text_column: Optional[str] = "text"
|
||||
type: Optional[str] = "pretrain"
|
||||
trust_remote_code: Optional[bool] = False
|
||||
data_files: Optional[str] = None
|
||||
skip: Optional[int] = None
|
||||
|
||||
|
||||
class UserDefinedPrompterType(BaseModel):
|
||||
@@ -366,6 +368,13 @@ class LoraConfig(BaseModel):
|
||||
loraplus_lr_embedding = float(loraplus_lr_embedding)
|
||||
return loraplus_lr_embedding
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_lora_dropout(cls, data):
|
||||
if data.get("adapter") is not None and data.get("lora_dropout") is None:
|
||||
data["lora_dropout"] = 0.0
|
||||
return data
|
||||
|
||||
|
||||
class ReLoRAConfig(BaseModel):
|
||||
"""ReLoRA configuration subset"""
|
||||
|
||||
@@ -88,14 +88,19 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
path = cfg.pretraining_dataset
|
||||
split = "train"
|
||||
name = None
|
||||
data_files = None
|
||||
skip = 0
|
||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||
cfg.pretraining_dataset[0], dict
|
||||
):
|
||||
path = cfg.pretraining_dataset[0]["path"]
|
||||
name = cfg.pretraining_dataset[0]["name"]
|
||||
skip = cfg.pretraining_dataset[0]["skip"]
|
||||
if "split" in cfg.pretraining_dataset[0]:
|
||||
split = cfg.pretraining_dataset[0]["split"]
|
||||
|
||||
data_files = cfg.pretraining_dataset[0].get("data_files")
|
||||
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
@@ -104,8 +109,14 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
)
|
||||
|
||||
iter_ds = load_dataset(
|
||||
path, streaming=True, split=split, name=name, data_files=data_files
|
||||
)
|
||||
if skip:
|
||||
LOG.info(f"Skipping {skip} samples from the dataset")
|
||||
iter_ds = iter_ds.skip(skip)
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
load_dataset(path, streaming=True, split=split, name=name),
|
||||
iter_ds,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
|
||||
@@ -270,7 +270,7 @@ def load_sharded_model_quant(
|
||||
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
|
||||
|
||||
if cfg.local_rank == 0 and verbose:
|
||||
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
|
||||
print(f"Loaded model weights in {time.time() - start:.3f} seconds")
|
||||
# cleanup any extra memory usage from parallel loading
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@@ -196,7 +196,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||
|
||||
if cfg.model_config_type == "falcon":
|
||||
if cfg.model_config_type in ["falcon", "mistral"]:
|
||||
LOG.info("dropping token_type_ids column if it exists")
|
||||
if "token_type_ids" in train_dataset.column_names:
|
||||
train_dataset = train_dataset.remove_columns("token_type_ids")
|
||||
|
||||
@@ -120,13 +120,12 @@ def temp_dir():
|
||||
@pytest.fixture(scope="function", autouse=True)
|
||||
def cleanup_monkeypatches():
|
||||
from transformers import Trainer
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
|
||||
LlamaAttention,
|
||||
LlamaFlashAttention2,
|
||||
LlamaForCausalLM,
|
||||
)
|
||||
|
||||
original_fa2_forward = LlamaFlashAttention2.forward
|
||||
# original_fa2_forward = LlamaFlashAttention2.forward
|
||||
original_llama_attn_forward = LlamaAttention.forward
|
||||
original_llama_forward = LlamaForCausalLM.forward
|
||||
original_trainer_inner_training_loop = (
|
||||
@@ -136,7 +135,7 @@ def cleanup_monkeypatches():
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
# Reset LlamaFlashAttention2 forward
|
||||
LlamaFlashAttention2.forward = original_fa2_forward
|
||||
# LlamaFlashAttention2.forward = original_fa2_forward
|
||||
LlamaAttention.forward = original_llama_attn_forward
|
||||
LlamaForCausalLM.forward = original_llama_forward
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
@@ -149,7 +148,10 @@ def cleanup_monkeypatches():
|
||||
("transformers.models.llama",),
|
||||
(
|
||||
"transformers.models.llama.modeling_llama",
|
||||
["LlamaFlashAttention2", "LlamaAttention"],
|
||||
[
|
||||
# "LlamaFlashAttention2",
|
||||
"LlamaAttention",
|
||||
],
|
||||
),
|
||||
("transformers.trainer",),
|
||||
("transformers", ["Trainer"]),
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
Simple end-to-end test for Cut Cross Entropy integration
|
||||
"""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
@@ -13,6 +11,8 @@ from axolotl.utils import get_pytorch_version
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ class TestCutCrossEntropyIntegration:
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
else:
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"attention_type",
|
||||
@@ -95,4 +95,4 @@ class TestCutCrossEntropyIntegration:
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
else:
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"""
|
||||
Simple end-to-end test for Liger integration
|
||||
"""
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from e2e.utils import require_torch_2_4_1
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -10,34 +10,32 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, prepare_plugins
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import check_model_output_exists
|
||||
|
||||
|
||||
class LigerIntegrationTestCase(unittest.TestCase):
|
||||
class LigerIntegrationTestCase:
|
||||
"""
|
||||
e2e tests for liger integration with Axolotl
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_4_1
|
||||
def test_llama_wo_flce(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"plugins": [
|
||||
"axolotl.integrations.liger.LigerPlugin",
|
||||
],
|
||||
"liger_rope": True,
|
||||
"liger_rms_norm": True,
|
||||
"liger_swiglu": True,
|
||||
"liger_glu_activation": True,
|
||||
"liger_cross_entropy": True,
|
||||
"liger_fused_linear_cross_entropy": False,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -46,15 +44,15 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"max_steps": 10,
|
||||
"max_steps": 5,
|
||||
}
|
||||
)
|
||||
prepare_plugins(cfg)
|
||||
@@ -63,28 +61,26 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_4_1
|
||||
def test_llama_w_flce(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"plugins": [
|
||||
"axolotl.integrations.liger.LigerPlugin",
|
||||
],
|
||||
"liger_rope": True,
|
||||
"liger_rms_norm": True,
|
||||
"liger_swiglu": True,
|
||||
"liger_glu_activation": True,
|
||||
"liger_cross_entropy": False,
|
||||
"liger_fused_linear_cross_entropy": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"val_set_size": 0.05,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
@@ -93,15 +89,15 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 2,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_safetensors": True,
|
||||
"bf16": "auto",
|
||||
"max_steps": 10,
|
||||
"max_steps": 5,
|
||||
}
|
||||
)
|
||||
prepare_plugins(cfg)
|
||||
@@ -110,4 +106,4 @@ class LigerIntegrationTestCase(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
@@ -5,7 +5,6 @@ E2E tests for multipack fft llama using 4d attention masks
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import require_torch_2_3_1, with_temp_dir
|
||||
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -67,7 +66,7 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_torch_lora_packing(self, temp_dir):
|
||||
@@ -111,4 +110,4 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,6 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
@@ -15,7 +14,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_tensorboard
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -82,7 +81,7 @@ class TestFAXentropyLlama:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for falcon
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -69,7 +68,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
@@ -109,4 +108,4 @@ class TestFalconPatched(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
@@ -16,7 +15,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -73,4 +72,4 @@ class TestFusedLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for llama w/ S2 attn
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -15,7 +14,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -71,7 +70,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_s2_attn(self, temp_dir):
|
||||
@@ -111,4 +110,4 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
|
||||
@@ -16,7 +15,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -76,7 +75,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
||||
@with_temp_dir
|
||||
@@ -126,4 +125,4 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -69,7 +68,7 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packing(self, temp_dir):
|
||||
@@ -110,4 +109,4 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for mixtral
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -66,7 +65,7 @@ class TestMixtral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
@@ -108,4 +107,4 @@ class TestMixtral(unittest.TestCase):
|
||||
"MixtralFlashAttention2"
|
||||
in model.model.layers[0].self_attn.__class__.__name__
|
||||
)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
from ..utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -69,7 +68,7 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_packed(self, temp_dir):
|
||||
@@ -120,4 +119,4 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
@@ -16,7 +15,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import most_recent_subdir
|
||||
from ..utils import check_model_output_exists, most_recent_subdir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -83,7 +82,7 @@ class TestResumeLlama:
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
||||
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
)
|
||||
class TestUnslothIntegration(unittest.TestCase):
|
||||
"""Unsloth monkeypatch integration tests."""
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ e2e tests for unsloth qlora
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -13,13 +12,16 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_tensorboard
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
)
|
||||
class TestUnslothQLoRA:
|
||||
"""
|
||||
Test class for Unsloth QLoRA Llama models
|
||||
@@ -74,7 +76,7 @@ class TestUnslothQLoRA:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
@@ -124,7 +126,7 @@ class TestUnslothQLoRA:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
@@ -179,7 +181,7 @@ class TestUnslothQLoRA:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
|
||||
|
||||
@@ -15,7 +15,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -68,7 +68,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_nll_lora(self, temp_dir):
|
||||
@@ -113,7 +113,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_use_weighting(self, temp_dir):
|
||||
@@ -158,7 +158,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||
@with_temp_dir
|
||||
@@ -203,7 +203,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ipo_lora(self, temp_dir):
|
||||
@@ -247,7 +247,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_orpo_lora(self, temp_dir):
|
||||
@@ -294,7 +294,7 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@pytest.mark.skip(reason="Fix the implementation")
|
||||
@with_temp_dir
|
||||
@@ -358,4 +358,4 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for llama pretrain
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -62,7 +61,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
||||
@@ -106,7 +105,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for falcon
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -71,7 +70,7 @@ class TestFalcon(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_added_vocab(self, temp_dir):
|
||||
@@ -124,7 +123,7 @@ class TestFalcon(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
@@ -163,4 +162,4 @@ class TestFalcon(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -4,7 +4,8 @@ E2E tests for llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
from e2e.utils import check_model_output_exists
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -60,7 +61,7 @@ class TestLlama:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -103,7 +104,7 @@ class TestLlama:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_batch_flattening(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -142,4 +143,4 @@ class TestLlama:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for llama pretrain
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -64,4 +63,4 @@ class TestPretrainLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -68,7 +67,7 @@ class TestLlamaVision(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
|
||||
@@ -113,4 +112,4 @@ class TestLlamaVision(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -65,4 +64,4 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -15,7 +14,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -65,4 +64,4 @@ class TestMamba(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
@@ -15,7 +14,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -69,7 +68,7 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
@@ -112,4 +111,4 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for mixtral
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
@@ -16,7 +15,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -79,7 +78,7 @@ class TestMixtral(unittest.TestCase):
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_wo_fa2(self, temp_dir):
|
||||
@@ -133,7 +132,7 @@ class TestMixtral(unittest.TestCase):
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_16bit_lora_w_fa2(self, temp_dir):
|
||||
@@ -190,7 +189,7 @@ class TestMixtral(unittest.TestCase):
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_16bit_lora_wo_fa2(self, temp_dir):
|
||||
@@ -247,7 +246,7 @@ class TestMixtral(unittest.TestCase):
|
||||
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
|
||||
== torch.float32
|
||||
)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
@@ -287,4 +286,4 @@ class TestMixtral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for custom optimizers using Llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import require_torch_2_5_1, with_temp_dir
|
||||
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -65,7 +64,7 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
@@ -109,10 +108,11 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -144,4 +144,4 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -67,7 +66,7 @@ class TestPhi(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_phi_qlora(self, temp_dir):
|
||||
@@ -116,4 +115,4 @@ class TestPhi(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -13,7 +13,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_tensorboard, with_temp_dir
|
||||
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -78,10 +78,10 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
assert (
|
||||
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
|
||||
).exists()
|
||||
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
|
||||
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
|
||||
).exists(), "Relora model checkpoint not found"
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||
|
||||
@@ -5,7 +5,6 @@ E2E tests for reward model lora llama
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,7 +12,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
from .utils import check_model_output_exists, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -71,4 +70,4 @@ class TestRewardModelLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -14,6 +14,8 @@ import torch
|
||||
from packaging import version
|
||||
from tbparse import SummaryReader
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
@@ -49,7 +51,19 @@ def require_torch_2_3_1(test_case):
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.3.1")
|
||||
|
||||
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
|
||||
return unittest.skipUnless(is_min_2_3_1(), "test requires torch>=2.3.1")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_4_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.5.1
|
||||
"""
|
||||
|
||||
def is_min_2_4_1():
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.4.1")
|
||||
|
||||
return unittest.skipUnless(is_min_2_4_1(), "test requires torch>=2.4.1")(test_case)
|
||||
|
||||
|
||||
def require_torch_2_5_1(test_case):
|
||||
@@ -61,7 +75,7 @@ def require_torch_2_5_1(test_case):
|
||||
torch_version = version.parse(torch.__version__)
|
||||
return torch_version >= version.parse("2.5.1")
|
||||
|
||||
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
|
||||
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
|
||||
|
||||
|
||||
def is_hopper():
|
||||
@@ -81,3 +95,27 @@ def check_tensorboard(
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
assert df.value.values[-1] < lt_val, assertion_err
|
||||
|
||||
|
||||
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
|
||||
"""
|
||||
helper function to check if a model output file exists after training
|
||||
|
||||
checks based on adapter or not and if safetensors saves are enabled or not
|
||||
"""
|
||||
|
||||
if cfg.save_safetensors:
|
||||
if not cfg.adapter:
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
else:
|
||||
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
|
||||
else:
|
||||
# check for both, b/c in trl, it often defaults to saving safetensors
|
||||
if not cfg.adapter:
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists() or (
|
||||
Path(temp_dir) / "model.safetensors"
|
||||
).exists()
|
||||
else:
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists() or (
|
||||
Path(temp_dir) / "adapter_model.safetensors"
|
||||
).exists()
|
||||
|
||||
@@ -7,11 +7,11 @@ from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.config import prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_base_cfg")
|
||||
@pytest.fixture(name="minimal_liger_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
{
|
||||
@@ -25,56 +25,57 @@ def fixture_cfg():
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class BaseValidation:
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation:
|
||||
"""
|
||||
Base validation module to setup the log capture
|
||||
Test the validation module for liger
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
caplog.set_level(logging.WARNING)
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation(BaseValidation):
|
||||
"""
|
||||
Test the validation module for liger
|
||||
"""
|
||||
|
||||
def test_deprecated_swiglu(self, minimal_cfg):
|
||||
def test_deprecated_swiglu(self, minimal_liger_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
}
|
||||
| minimal_cfg
|
||||
| minimal_liger_cfg
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
with self._caplog.at_level(
|
||||
logging.WARNING, logger="axolotl.integrations.liger.args"
|
||||
):
|
||||
prepare_plugins(test_cfg)
|
||||
updated_cfg = validate_config(test_cfg)
|
||||
assert (
|
||||
"The 'liger_swiglu' argument is deprecated"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
# TODO this test is brittle in CI
|
||||
# assert (
|
||||
# "The 'liger_swiglu' argument is deprecated"
|
||||
# in self._caplog.records[0].message
|
||||
# )
|
||||
assert updated_cfg.liger_swiglu is None
|
||||
assert updated_cfg.liger_glu_activations is False
|
||||
assert updated_cfg.liger_glu_activation is False
|
||||
|
||||
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
|
||||
def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"liger_swiglu": False,
|
||||
"liger_glu_activations": True,
|
||||
"liger_glu_activation": True,
|
||||
}
|
||||
| minimal_cfg
|
||||
| minimal_liger_cfg
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
|
||||
):
|
||||
prepare_plugins(test_cfg)
|
||||
validate_config(test_cfg)
|
||||
69
tests/test_lora.py
Normal file
69
tests/test_lora.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""
|
||||
tests for loading loras
|
||||
"""
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
minimal_config = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"learning_rate": 0.000001,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
}
|
||||
],
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
class TestLoRALoad:
|
||||
"""
|
||||
Test class for loading LoRA weights
|
||||
"""
|
||||
|
||||
def test_load_lora_weights(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.0,
|
||||
"lora_target_linear": True,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"sequence_len": 1024,
|
||||
}
|
||||
| minimal_config
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer)
|
||||
|
||||
def test_load_lora_weights_empty_dropout(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": None,
|
||||
"lora_target_linear": True,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"sequence_len": 1024,
|
||||
}
|
||||
| minimal_config
|
||||
)
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
assert cfg.lora_dropout == 0.0
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer)
|
||||
@@ -4,9 +4,7 @@ import json
|
||||
import logging
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
from datasets import load_dataset
|
||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
||||
|
||||
@@ -65,12 +63,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
Test class for prompt tokenization strategies.
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
|
||||
Reference in New Issue
Block a user