Compare commits

..

55 Commits

Author SHA1 Message Date
Dan Saunders
66262c3092 moving out all diff attn code to plugin repo 2025-01-24 17:46:11 +00:00
Dan Saunders
016ba124e4 README update 2025-01-23 22:11:35 +00:00
Dan Saunders
7145d52d99 moving diff attn code to separate repo 2025-01-23 21:33:53 +00:00
Dan Saunders
28694219a5 inline comment change 2025-01-14 16:59:43 +00:00
Dan Saunders
fd8ad6fcbf fixing negative component mixing 2025-01-13 19:21:55 +00:00
Dan Saunders
661d71a14b adding diff attn negative component warmup (in progress) 2025-01-10 21:57:31 +00:00
Dan Saunders
6dd47edcb8 fire CLI fixes 2025-01-10 18:24:16 +00:00
Dan Saunders
7aca08ff60 adding guard statements 2025-01-10 16:39:21 +00:00
Dan Saunders
4f804f6d88 adding diff attn callback, adding documentation 2025-01-10 16:28:51 +00:00
Dan Saunders
443327c585 CLI build_command bugfix 2025-01-10 16:28:51 +00:00
Dan Saunders
70c4e6fbe6 updates and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
2a7f139ad2 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
332ce0ae85 fixes and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
e5fa842ff8 update 2025-01-10 16:28:51 +00:00
Dan Saunders
78e0ec0aa5 changes 2025-01-10 16:28:51 +00:00
Dan Saunders
3bc568eb27 adding registration function 2025-01-10 16:28:51 +00:00
Dan Saunders
eb6611d55f progress on modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
4ff3328e66 updated custom modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
a3fd5074a9 fix duplicate-code warnings 2025-01-10 16:28:51 +00:00
Dan Saunders
5b90da0be3 added modeling code; cleanup + refactor 2025-01-10 16:28:51 +00:00
Dan Saunders
fcbfa86373 refactor and fixing test isolation issues 2025-01-10 16:28:51 +00:00
Dan Saunders
0d56582090 adding yaml dumper preserving input config format 2025-01-10 16:28:51 +00:00
Dan Saunders
390cb5742e removing extra pytest xdist args 2025-01-10 16:28:51 +00:00
Dan Saunders
1d935f65c3 moving tests around for flash_attn install 2025-01-10 16:28:51 +00:00
Dan Saunders
66176b3e07 adding split_heads argument for retaining original (Q, K) dimensionanlity 2025-01-10 16:28:51 +00:00
Dan Saunders
505321ac95 isolating problematic test 2025-01-10 16:28:51 +00:00
Dan Saunders
0b382c88da fixes post-rebase 2025-01-10 16:28:51 +00:00
Dan Saunders
ea07a7086e plugin implementation 2025-01-10 16:28:51 +00:00
Dan Saunders
d22e1136bc convert-differential-transformer test coverage 2025-01-10 16:28:51 +00:00
Dan Saunders
63b8e42c6b duplicate code ignore 2025-01-10 16:28:51 +00:00
Dan Saunders
bda1eed59e differential flash attention 2; cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
41ebd93158 moving monkeypatch 2025-01-10 16:28:51 +00:00
Dan Saunders
4c050ce807 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
6665acf63d fix model save / load logic 2025-01-10 16:28:51 +00:00
Dan Saunders
2f9fa4c465 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
849bc94112 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
e484ec778d training fixes, patching, minor cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
df1504ae14 adding CLI command for convert-diff-transformer 2025-01-10 16:28:51 +00:00
Dan Saunders
7be0d7496c Adding script for doing conversion; fixes and updates 2025-01-10 16:28:51 +00:00
Dan Saunders
13cdffa91f initial diff attn layer / model conversion implementation (support for llama arch) 2025-01-10 16:28:51 +00:00
Dan Saunders
7a4b296f60 Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-01-10 16:28:51 +00:00
Wing Lian
d8b4027200 use 2.5.1 docker images as latest tag as it seems stable (#2198) 2025-01-10 08:35:25 -05:00
Wing Lian
fb3352e21c rename liger test so it properly runs in ci (#2246) 2025-01-09 17:31:43 -05:00
NanoCode012
ed77e7001e feat: add support for data_files in pretraining (#2238) 2025-01-09 21:04:13 +00:00
Wing Lian
7669a03fb4 update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts:

* bump datasets, tokenizer, trl

* remove log workarounds in trl

* bump lm-eval

* remove unsloth_ import from critical path

* remove llama fa2 from conftest

* unsloth breaks with latest upstream
2025-01-09 21:01:59 +00:00
Vincenzo di Cicco
6553683170 Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235) 2025-01-09 21:01:22 +00:00
Wing Lian
5e0124e2ab update modal version for ci (#2242) 2025-01-09 21:01:02 +00:00
NanoCode012
2e8d7c1adb fix: mistral nemo does not recognize token_type_ids in forward (#2233) 2025-01-09 21:00:36 +00:00
Wing Lian
3c1921e400 add hf cache caching for GHA (#2247)
* add hf cache caching for GHA

* use modal volume to cache hf data

* make sure to update the cache as we add new fixtures in conftest
2025-01-09 20:59:54 +00:00
Wing Lian
7faf2b6e8e Merge group queue (#2248)
* add support for merge groups

* also lint merge groups
2025-01-09 15:49:00 -05:00
salman
c1b920f291 Fixing OSX installation (#2231)
* bumping version, removing non-osx compatible deps

* updating pylintrc

* fixing linters

* reverting changes
2025-01-07 13:42:01 +00:00
Wing Lian
3915abee4c make sure padding is labeled as -100 for pretraining (#2227) 2024-12-31 15:22:18 -05:00
NJordan72
7a38dbe674 fix: allow trainer builder to use custom jinja chat template (#2219)
* fix: allow trainer builder to use custom jinja chat template

* chore: use get_chat_template_from_config

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>

* fix: swap imports

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
2024-12-24 16:18:50 -05:00
Wing Lian
e0a2eb2ebd fix untrained tokens if specified explicitly from a list (#2210) 2024-12-23 09:08:28 -05:00
Wing Lian
d852d7af7a inference - don't default w accelerate, fix base model (#2216) [skip ci] 2024-12-23 07:48:41 -05:00
57 changed files with 614 additions and 412 deletions

View File

@@ -1,6 +1,7 @@
name: lint
on:
# check on PRs, and manual triggers
merge_group:
pull_request:
paths:
- '**.py'

View File

@@ -25,7 +25,6 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -36,6 +35,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -92,7 +92,6 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -103,6 +102,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -129,7 +129,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -1,6 +1,7 @@
name: Tests
on:
# check on push/merge to main, PRs, and manual triggers
merge_group:
push:
branches:
- "main"
@@ -60,6 +61,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -100,6 +110,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -115,6 +134,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -156,6 +184,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
@@ -183,7 +220,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -229,7 +266,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

3
.gitignore vendored
View File

@@ -186,3 +186,6 @@ out/
# vim
*.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v2.17.4
rev: v3.3.0
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy

View File

@@ -1,5 +1,5 @@
[MASTER]
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -8,6 +8,7 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
@@ -28,6 +28,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -48,6 +49,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
@@ -67,6 +74,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")

View File

@@ -29,6 +29,7 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -50,6 +51,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
@@ -69,6 +76,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")

View File

@@ -1,29 +0,0 @@
---
title: Learning Rate Groups
description: "Setting different learning rates by module name"
---
## Background
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
modules in a model.
## Example
```yaml
lr_groups:
- name: o_proj
modules:
- self_attn.o_proj.weight
lr: 1e-6
- name: q_proj
modules:
- model.layers.2.self_attn.q_proj.weight
lr: 1e-5
learning_rate: 2e-5
```
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.

View File

@@ -2,7 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
@@ -14,11 +14,11 @@ packaging==23.2
peft==0.14.0
transformers==4.47.1
tokenizers>=0.20.1
tokenizers>=0.21.0
accelerate==1.2.1
datasets==3.1.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.12.1
trl==0.13.0
optimum==1.16.2
hf_transfer
@@ -53,7 +53,7 @@ zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
lm_eval==0.4.7
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.1b2
axolotl-contribs-lgpl==0.0.3

View File

@@ -1,4 +1,5 @@
"""setup.py for axolotl"""
import ast
import os
import platform
@@ -29,15 +30,30 @@ def parse_requirements():
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
@@ -73,6 +89,8 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")

View File

@@ -202,7 +202,7 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
from typing import Dict, Union
import fire
from dotenv import load_dotenv
@@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate
LOG = logging.getLogger("axolotl.cli.evaluate")
def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg, cli_args) -> Dict[str, float]:
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
@@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -1,11 +1,13 @@
"""CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name
import subprocess # nosec B404
from typing import Optional
import click
import axolotl
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import (
add_options_from_config,
add_options_from_dataclass,
@@ -77,6 +79,9 @@ def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
@@ -93,7 +98,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
default=False,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
@@ -124,7 +129,7 @@ def inference(
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["output_dir"] = base_model
kwargs["base_model"] = base_model
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
@@ -254,6 +259,9 @@ def fetch(directory: str, dest: Optional[str]):
fetch_from_github(f"{directory}/", dest)
setup_plugin_commands(cli)
def main():
cli()

View File

@@ -0,0 +1,36 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -22,7 +22,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
@@ -44,6 +43,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
@@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = field.annotation
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -66,6 +73,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
@@ -84,6 +92,8 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else:
cmd.extend([f"--{key}", str(value)])

View File

@@ -4,22 +4,26 @@ shared module for cli specific things
import logging
from dataclasses import dataclass, field
from typing import Optional
from typing import TYPE_CHECKING, Optional, Union
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
if TYPE_CHECKING:
try:
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
except: # noqa: E722 # pylint: disable=bare-except # nosec B110
pass
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
LOG = logging.getLogger(__name__)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
"""dataclass with arguments for preprocessing only"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -30,9 +34,7 @@ class PreprocessCliArgs:
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
"""dataclass with various non-training arguments"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -45,9 +47,7 @@ class TrainerCliArgs:
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
"""
"""dataclass with various evaluation arguments"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -57,7 +57,7 @@ class EvaluateCliArgs:
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, "ConvertDiffTransformerCliArgs"],
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -68,7 +67,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -244,10 +243,6 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
@@ -298,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin.
"""
@@ -466,96 +461,11 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
if lr_groups_lookup and any(
group_modules in name for group_modules in lr_groups_lookup
):
lr_group_module = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
][0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
@@ -569,13 +479,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -592,7 +548,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
@@ -652,8 +607,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
@@ -1022,12 +983,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1211,22 +1167,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1235,22 +1175,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
@@ -1259,49 +1183,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
@@ -1310,22 +1191,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
@@ -1334,15 +1199,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC):
"""
@@ -1808,7 +1664,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -1879,8 +1734,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)

View File

@@ -9,12 +9,11 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_processor
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -62,8 +61,9 @@ def evaluate_dataset(
return metrics
# pylint: disable=duplicate-code
def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
*, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
@@ -79,16 +79,11 @@ def evaluate(
- The tokenizer
- Dictionary of evaluation metrics
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
# Load model
LOG.debug("loading model for evaluation...")
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Load processor for multimodal models if needed
processor = None
@@ -100,12 +95,6 @@ def evaluate(
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer
trainer = setup_trainer(
cfg,

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""
for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -22,13 +22,6 @@ import inspect
import logging
import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
@@ -46,6 +39,13 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -8,7 +8,7 @@ import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

View File

@@ -1,9 +1,7 @@
"""module for patching with unsloth optimizations"""
import inspect
import re
import types
from typing import Tuple
import torch
from accelerate.logging import get_logger
@@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """
@@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,7 +1,8 @@
"""
Shared utils for the monkeypatches
"""
from typing import Optional
import re
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import inspect
import os
import signal
import sys
@@ -126,7 +127,20 @@ def train(
)
if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset)
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization

View File

@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
getattr, self.layers_attribute.split("."), self.trainer.model
)
LOG.info(
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
)
def freeze_all_layers(self):

View File

@@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
class UserDefinedPrompterType(BaseModel):
@@ -145,14 +146,6 @@ class UserDefinedPrompterType(BaseModel):
field: Optional[str] = None
class LrGroup(BaseModel):
"""Custom learning rate group configuration"""
name: str
modules: List[str]
lr: float
class SFTDataset(BaseModel):
"""SFT configuration subset"""
@@ -474,7 +467,6 @@ class HyperparametersConfig(BaseModel):
cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None
lr_groups: Optional[List[LrGroup]] = None
adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None
@@ -803,7 +795,7 @@ class AxolotlInputConfig(
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None

View File

@@ -28,8 +28,10 @@ def encode_pretraining(
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = []
new_labels = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
@@ -40,22 +42,34 @@ def encode_pretraining(
),
dim=0,
)
targets[i] = torch.cat(
(
targets[i],
torch.tensor([tokenizer.eos_token_id, -100]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask):
for ids, labels, mask in zip(input_ids, targets, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
@@ -69,6 +83,17 @@ def encode_pretraining(
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
@@ -81,11 +106,14 @@ def encode_pretraining(
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
@@ -101,6 +129,17 @@ def encode_pretraining(
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
@@ -113,11 +152,12 @@ def encode_pretraining(
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_labels],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}

View File

@@ -88,6 +88,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
path = cfg.pretraining_dataset
split = "train"
name = None
data_files = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
@@ -96,6 +97,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
@@ -105,7 +108,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name),
load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
),
tokenizer,
cfg,
ds_wrapper_partial,

View File

@@ -270,7 +270,7 @@ def load_sharded_model_quant(
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
print(f"Loaded model weights in {time.time() - start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()

View File

@@ -713,19 +713,45 @@ class ModelLoader:
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
if self.cfg.diff_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_flash_attention_2"
)
else:
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
)
else:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
)
else:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
"differential_eager"
)
if self.cfg.low_cpu_mem_usage:
@@ -816,6 +842,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,

View File

@@ -196,7 +196,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "falcon":
if cfg.model_config_type in ["falcon", "mistral"]:
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")

157
src/axolotl/utils/yaml.py Normal file
View File

@@ -0,0 +1,157 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner

View File

@@ -43,14 +43,12 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli
@@ -22,6 +23,7 @@ def test_build_command():
"--batch-size",
"8",
"--debug",
"--nouse-fp16",
]

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -11,14 +12,12 @@ def test_shard_with_accelerate(cli_runner, config_path):
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch

View File

@@ -120,13 +120,12 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import (
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
original_fa2_forward = LlamaFlashAttention2.forward
# original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
@@ -136,7 +135,7 @@ def cleanup_monkeypatches():
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
# LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
@@ -149,7 +148,10 @@ def cleanup_monkeypatches():
("transformers.models.llama",),
(
"transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"],
[
# "LlamaFlashAttention2",
"LlamaAttention",
],
),
("transformers.trainer",),
("transformers", ["Trainer"]),

View File

@@ -1,43 +1,40 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path
from e2e.utils import require_torch_2_4_1
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
class LigerIntegrationTestCase(unittest.TestCase):
class LigerIntegrationTestCase:
"""
e2e tests for liger integration with Axolotl
"""
@with_temp_dir
@require_torch_2_4_1
def test_llama_wo_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_glu_activation": True,
"liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -46,15 +43,15 @@ class LigerIntegrationTestCase(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
"max_steps": 5,
}
)
prepare_plugins(cfg)
@@ -65,26 +62,24 @@ class LigerIntegrationTestCase(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@with_temp_dir
@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_glu_activation": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -93,15 +88,15 @@ class LigerIntegrationTestCase(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
"max_steps": 5,
}
)
prepare_plugins(cfg)

View File

@@ -1,9 +1,14 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""

View File

@@ -20,6 +20,9 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models

View File

@@ -113,6 +113,7 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -49,7 +49,19 @@ def require_torch_2_3_1(test_case):
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.3.1")
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
return unittest.skipUnless(is_min_2_3_1(), "test requires torch>=2.3.1")(test_case)
def require_torch_2_4_1(test_case):
"""
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_4_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.4.1")
return unittest.skipUnless(is_min_2_4_1(), "test requires torch>=2.4.1")(test_case)
def require_torch_2_5_1(test_case):
@@ -61,7 +73,7 @@ def require_torch_2_5_1(test_case):
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.5.1")
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
def is_hopper():

View File

@@ -7,11 +7,11 @@ from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_base_cfg")
@pytest.fixture(name="minimal_liger_cfg")
def fixture_cfg():
return DictDefault(
{
@@ -25,56 +25,57 @@ def fixture_cfg():
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
}
)
class BaseValidation:
# pylint: disable=too-many-public-methods
class TestValidation:
"""
Base validation module to setup the log capture
Test the validation module for liger
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
caplog.set_level(logging.WARNING)
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
def test_deprecated_swiglu(self, minimal_liger_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_cfg
| minimal_liger_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level(
logging.WARNING, logger="axolotl.integrations.liger.args"
):
prepare_plugins(test_cfg)
updated_cfg = validate_config(test_cfg)
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
# TODO this test is brittle in CI
# assert (
# "The 'liger_swiglu' argument is deprecated"
# in self._caplog.records[0].message
# )
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activations is False
assert updated_cfg.liger_glu_activation is False
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activations": True,
"liger_glu_activation": True,
}
| minimal_cfg
| minimal_liger_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
prepare_plugins(test_cfg)
validate_config(test_cfg)

View File

@@ -4,9 +4,7 @@ import json
import logging
import unittest
from pathlib import Path
from typing import Optional
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
@@ -65,12 +63,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies.
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")