Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
59047ee6c4 dump snapshot location for caching 2025-01-09 11:26:33 -05:00
50 changed files with 267 additions and 481 deletions

View File

@@ -1,7 +1,6 @@
name: lint name: lint
on: on:
# check on PRs, and manual triggers # check on PRs, and manual triggers
merge_group:
pull_request: pull_request:
paths: paths:
- '**.py' - '**.py'

View File

@@ -25,6 +25,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: mamba-ssm axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -35,7 +36,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -92,6 +92,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -102,7 +103,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -129,7 +129,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -1,7 +1,6 @@
name: Tests name: Tests
on: on:
# check on push/merge to main, PRs, and manual triggers # check on push/merge to main, PRs, and manual triggers
merge_group:
push: push:
branches: branches:
- "main" - "main"
@@ -61,15 +60,6 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@@ -110,15 +100,6 @@ jobs:
run: | run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist: pytest-sdist:
name: PyTest from Source Dist name: PyTest from Source Dist
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -134,15 +115,6 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@@ -184,15 +156,6 @@ jobs:
run: | run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st: docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
@@ -220,7 +183,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -266,7 +229,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

3
.gitignore vendored
View File

@@ -186,6 +186,3 @@ out/
# vim # vim
*.swp *.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -8,7 +8,6 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}" ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}" ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}" ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -4,6 +4,7 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__" python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,6 @@
""" """
modal application to run axolotl gpu tests in Modal modal application to run axolotl gpu tests in Modal
""" """
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
import os import os
@@ -28,7 +28,6 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"), "CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
} }
dockerfile_contents = df_template.render(**df_args) dockerfile_contents = df_template.render(**df_args)
@@ -49,12 +48,6 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[]) app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 2)) N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS) GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
@@ -74,7 +67,6 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60, timeout=60 * 60,
cpu=8.0, cpu=8.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
) )
def cicd_pytest(): def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl") run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")

View File

@@ -29,7 +29,6 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
} }
dockerfile_contents = df_template.render(**df_args) dockerfile_contents = df_template.render(**df_args)
@@ -51,12 +50,6 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[]) app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1)) N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS) GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
@@ -76,7 +69,6 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60, timeout=60 * 60,
cpu=8.0, cpu=8.0,
memory=131072, memory=131072,
volumes=VOLUME_CONFIG,
) )
def cicd_pytest(): def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl") run_cmd("./cicd/cicd.sh", "/workspace/axolotl")

View File

@@ -2,7 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0 bitsandbytes==0.45.0
triton>=3.0.0 triton>=2.3.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2 flash-attn==2.7.0.post2
xformers>=0.0.23.post1 xformers>=0.0.23.post1
@@ -14,11 +14,11 @@ packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.47.1 transformers==4.47.1
tokenizers>=0.21.0 tokenizers>=0.20.1
accelerate==1.2.1 accelerate==1.2.1
datasets==3.2.0 datasets==3.1.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.13.0 trl==0.12.1
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
@@ -53,7 +53,7 @@ zstandard==0.22.0
fastcore fastcore
# lm eval harness # lm eval harness
lm_eval==0.4.7 lm_eval==0.4.4
langdetect==1.0.9 langdetect==1.0.9
immutabledict==4.2.0 immutabledict==4.2.0
antlr4-python3-runtime==4.13.2 antlr4-python3-runtime==4.13.2
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3 axolotl-contribs-lgpl==0.0.2

View File

@@ -32,7 +32,6 @@ def parse_requirements():
_install_requires.append(line) _install_requires.append(line)
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0] torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0] autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
@@ -89,8 +88,6 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post1") _install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3): elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1") _install_requires.append("xformers>=0.0.26.post1")

View File

@@ -202,7 +202,7 @@ def do_inference(
) )
elif cfg.chat_template: elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template) chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets and cfg.datasets[0].type == "chat_template": elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config( chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
) )

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Union
import fire import fire
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate
LOG = logging.getLogger("axolotl.cli.evaluate") LOG = logging.getLogger("axolotl.cli.evaluate")
def do_evaluate(cfg, cli_args) -> Dict[str, float]: def do_evaluate(cfg, cli_args) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
@@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> Dict[str, float]:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -1,13 +1,11 @@
"""CLI definition for various axolotl commands.""" """CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import subprocess # nosec B404 import subprocess # nosec B404
from typing import Optional from typing import Optional
import click import click
import axolotl import axolotl
from axolotl.cli.plugins import setup_plugin_commands
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
@@ -79,9 +77,6 @@ def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model.""" """Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None} kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
if accelerate: if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config: if config:
@@ -259,9 +254,6 @@ def fetch(directory: str, dest: Optional[str]):
fetch_from_github(f"{directory}/", dest) fetch_from_github(f"{directory}/", dest)
setup_plugin_commands(cli)
def main(): def main():
cli() cli()

View File

@@ -1,36 +0,0 @@
"""Module for adding click CLI commands from axolotl plugins."""
import logging
import click
from axolotl.cli.utils import add_options_from_config, add_options_from_dataclass
from axolotl.logging_config import configure_logging
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
configure_logging()
LOG = logging.getLogger(__name__)
def setup_plugin_commands(cli: click.core.Group) -> None:
"""
Setup CLI commands for available plugins.
Args:
cli: Click CLI object to add plugin CLI options to.
"""
try:
from axolotl_diff_transformer.convert_diff_transformer import do_cli
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
do_cli(config=config, **kwargs)
except ImportError as exc:
LOG.debug("axolotl-diff-transformer not found: %s", exc)

View File

@@ -22,6 +22,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
# Process dataclass fields in reverse order for correct option ordering # Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)): for field in reversed(dataclasses.fields(config_class)):
field_type = field.type field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type): if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next( field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType) t for t in get_args(field_type) if not isinstance(t, NoneType)
@@ -43,7 +44,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
default=field.default, default=field.default,
help=field.metadata.get("description"), help=field.metadata.get("description"),
)(function) )(function)
return function return function
return decorator return decorator
@@ -55,14 +55,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function): def decorator(function):
# Process model fields in reverse order for correct option ordering # Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
field_type = field.annotation if field.annotation == bool:
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-") field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
function = click.option( function = click.option(
@@ -73,7 +66,6 @@ def add_options_from_config(config_class: Type[BaseModel]):
function = click.option( function = click.option(
option_name, default=None, help=field.description option_name, default=None, help=field.description
)(function) )(function)
return function return function
return decorator return decorator
@@ -92,8 +84,6 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
if isinstance(value, bool): if isinstance(value, bool):
if value: if value:
cmd.append(f"--{key}") cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else: else:
cmd.extend([f"--{key}", str(value)]) cmd.extend([f"--{key}", str(value)])

View File

@@ -4,26 +4,22 @@ shared module for cli specific things
import logging import logging
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Optional, Union from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
if TYPE_CHECKING:
try:
from axolotl_diff_transformer.plugin.cli import ConvertDiffTransformerCliArgs
except: # noqa: E722 # pylint: disable=bare-except # nosec B110
pass
configure_logging() configure_logging()
LOG = logging.getLogger(__name__) LOG = logging.getLogger("axolotl.common.cli")
@dataclass @dataclass
class PreprocessCliArgs: class PreprocessCliArgs:
"""dataclass with arguments for preprocessing only""" """
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
@@ -34,7 +30,9 @@ class PreprocessCliArgs:
@dataclass @dataclass
class TrainerCliArgs: class TrainerCliArgs:
"""dataclass with various non-training arguments""" """
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
@@ -47,7 +45,9 @@ class TrainerCliArgs:
@dataclass @dataclass
class EvaluateCliArgs: class EvaluateCliArgs:
"""dataclass with various evaluation arguments""" """
dataclass representing the various evaluation arguments
"""
debug: bool = field(default=False) debug: bool = field(default=False)
debug_text_only: bool = field(default=False) debug_text_only: bool = field(default=False)
@@ -57,7 +57,7 @@ class EvaluateCliArgs:
def load_model_and_tokenizer( def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, "ConvertDiffTransformerCliArgs"], cli_args: TrainerCliArgs,
): ):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -293,7 +294,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
""" """
Training arguments for Causal trainer Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
so it can't be used as a mixin. so it can't be used as a mixin.
""" """
@@ -607,14 +608,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.state.train_batch_size or self.args.per_device_train_batch_size self.state.train_batch_size or self.args.per_device_train_batch_size
) )
batch_max_len = train_batch_size * self.args.max_seq_length batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler( return MultipackBatchSampler(
sampler, RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len, batch_max_len=batch_max_len,
@@ -983,7 +978,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super().log(logs, start_time) if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1167,6 +1167,22 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return loss return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
""" """
@@ -1175,6 +1191,22 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
""" """
@@ -1183,6 +1215,49 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
""" """
@@ -1191,6 +1266,22 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
""" """
@@ -1199,6 +1290,15 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"] tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """

View File

@@ -9,11 +9,12 @@ from typing import Dict, Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -61,9 +62,8 @@ def evaluate_dataset(
return metrics return metrics
# pylint: disable=duplicate-code
def evaluate( def evaluate(
*, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]: ) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets
@@ -79,11 +79,16 @@ def evaluate(
- The tokenizer - The tokenizer
- Dictionary of evaluation metrics - Dictionary of evaluation metrics
""" """
# Load model # pylint: disable=duplicate-code
LOG.debug("loading model for evaluation...") # Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) # Load tokenizer
model = model.to(cfg.device, dtype=cfg.torch_dtype) LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
# Load processor for multimodal models if needed # Load processor for multimodal models if needed
processor = None processor = None
@@ -95,6 +100,12 @@ def evaluate(
eval_dataset = dataset_meta.eval_dataset eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer # Set up trainer
trainer = setup_trainer( trainer = setup_trainer(
cfg, cfg,

View File

@@ -43,12 +43,10 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args() input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = [] plugin_classes = []
dynamic_input = "" dynamic_input = ""
for plugin_args in input_args: for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1) plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n" dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls) plugin_classes.append(plugin_cls)
if dynamic_input: if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -64,5 +62,4 @@ def merge_input_args():
"AxolotlConfigWCapabilities" "AxolotlConfigWCapabilities"
] ]
return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -22,6 +22,13 @@ import inspect
import logging import logging
import sys import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only from ...utils.distributed import zero_only
@@ -39,13 +46,6 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs" return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn) liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save") LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -8,7 +8,7 @@ import logging
from transformers import LlamaForCausalLM, Trainer from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum") LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

View File

@@ -1,7 +1,9 @@
"""module for patching with unsloth optimizations""" """module for patching with unsloth optimizations"""
import inspect import inspect
import re
import types import types
from typing import Tuple
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -9,8 +11,6 @@ from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """ ORIGINAL_QKV_CODE = """
@@ -93,6 +93,15 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type") raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,8 +1,7 @@
""" """
Shared utils for the monkeypatches Shared utils for the monkeypatches
""" """
import re from typing import Optional
from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -224,12 +223,3 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype), mask_2d_to_4d(attention_mask, dtype=dtype),
*args, *args,
) )
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -128,7 +128,6 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text" text_column: Optional[str] = "text"
type: Optional[str] = "pretrain" type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
class UserDefinedPrompterType(BaseModel): class UserDefinedPrompterType(BaseModel):

View File

@@ -88,7 +88,6 @@ def prepare_dataset(cfg, tokenizer, processor=None):
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
split = "train" split = "train"
name = None name = None
data_files = None
if isinstance(cfg.pretraining_dataset, list) and isinstance( if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict cfg.pretraining_dataset[0], dict
): ):
@@ -97,8 +96,6 @@ def prepare_dataset(cfg, tokenizer, processor=None):
if "split" in cfg.pretraining_dataset[0]: if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"] split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
@@ -108,9 +105,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
) )
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
load_dataset( load_dataset(path, streaming=True, split=split, name=name),
path, streaming=True, split=split, name=name, data_files=data_files
),
tokenizer, tokenizer,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,

View File

@@ -713,45 +713,19 @@ class ModelLoader:
if self.cfg.flash_attention: if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
pass pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
if self.cfg.diff_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_flash_attention_2"
)
else:
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
)
else:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
)
else:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager" "flash_attention_2"
)
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
) )
if self.cfg.low_cpu_mem_usage: if self.cfg.low_cpu_mem_usage:
@@ -842,7 +816,6 @@ class ModelLoader:
if self.cfg.is_multimodal: if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained( self.model = self.AutoModelLoader.from_pretrained(
self.base_model, self.base_model,
config=self.model_config, config=self.model_config,

View File

@@ -196,7 +196,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type in ["falcon", "mistral"]: if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists") LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names: if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids") train_dataset = train_dataset.remove_columns("token_type_ids")

View File

@@ -1,157 +0,0 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -1,5 +1,4 @@
"""Shared pytest fixtures for cli module.""" """Shared pytest fixtures for cli module."""
import pytest import pytest
from click.testing import CliRunner from click.testing import CliRunner

View File

@@ -43,12 +43,14 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)]) result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called assert mock.called
assert mock.call_args.args[0][:5] == [ assert mock.call_args.args[0] == [
"accelerate", "accelerate",
"launch", "launch",
"-m", "-m",
f"axolotl.cli.{command}", f"axolotl.cli.{command}",
str(config_path), str(config_path),
"--debug-num-examples",
"0",
] ]
assert mock.call_args.kwargs == {"check": True} assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0 assert result.exit_code == 0

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI fetch command.""" """pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import fetch from axolotl.cli.main import fetch

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI inference command.""" """pytest tests for axolotl CLI inference command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,5 +1,4 @@
"""General pytest tests for axolotl.cli.main interface.""" """General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli from axolotl.cli.main import build_command, cli
@@ -23,7 +22,6 @@ def test_build_command():
"--batch-size", "--batch-size",
"8", "8",
"--debug", "--debug",
"--nouse-fp16",
] ]

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI merge_lora command.""" """pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,6 +1,5 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" """pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI preprocess command.""" """pytest tests for axolotl CLI preprocess command."""
import shutil import shutil
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch

View File

@@ -1,6 +1,5 @@
"""pytest tests for axolotl CLI shard command.""" """pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli
@@ -12,12 +11,14 @@ def test_shard_with_accelerate(cli_runner, config_path):
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"]) result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called assert mock.called
assert mock.call_args.args[0][:5] == [ assert mock.call_args.args[0] == [
"accelerate", "accelerate",
"launch", "launch",
"-m", "-m",
"axolotl.cli.shard", "axolotl.cli.shard",
str(config_path), str(config_path),
"--debug-num-examples",
"0",
] ]
assert mock.call_args.kwargs == {"check": True} assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0 assert result.exit_code == 0

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI --version""" """pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,6 +1,5 @@
"""pytest tests for axolotl CLI utils.""" """pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import json import json
from unittest.mock import Mock, patch from unittest.mock import Mock, patch

View File

@@ -37,7 +37,8 @@ def retry_on_request_exceptions(max_retries=3, delay=1):
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
def snapshot_download_w_retry(*args, **kwargs): def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs) url = snapshot_download(*args, **kwargs)
raise f"{args[0]}: {url}"
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@@ -120,12 +121,13 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
from transformers import Trainer from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2, from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM, LlamaForCausalLM,
) )
# original_fa2_forward = LlamaFlashAttention2.forward original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = ( original_trainer_inner_training_loop = (
@@ -135,7 +137,7 @@ def cleanup_monkeypatches():
# monkey patches can happen inside the tests # monkey patches can happen inside the tests
yield yield
# Reset LlamaFlashAttention2 forward # Reset LlamaFlashAttention2 forward
# LlamaFlashAttention2.forward = original_fa2_forward LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access Trainer._inner_training_loop = ( # pylint: disable=protected-access
@@ -148,10 +150,7 @@ def cleanup_monkeypatches():
("transformers.models.llama",), ("transformers.models.llama",),
( (
"transformers.models.llama.modeling_llama", "transformers.models.llama.modeling_llama",
[ ["LlamaFlashAttention2", "LlamaAttention"],
# "LlamaFlashAttention2",
"LlamaAttention",
],
), ),
("transformers.trainer",), ("transformers.trainer",),
("transformers", ["Trainer"]), ("transformers", ["Trainer"]),

View File

@@ -1,40 +1,43 @@
""" """
Simple end-to-end test for Liger integration Simple end-to-end test for Liger integration
""" """
import unittest
from pathlib import Path from pathlib import Path
from e2e.utils import require_torch_2_4_1
from axolotl.cli import load_datasets from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
class LigerIntegrationTestCase:
class LigerIntegrationTestCase(unittest.TestCase):
""" """
e2e tests for liger integration with Axolotl e2e tests for liger integration with Axolotl
""" """
@require_torch_2_4_1 @with_temp_dir
def test_llama_wo_flce(self, temp_dir): def test_llama_wo_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [ "plugins": [
"axolotl.integrations.liger.LigerPlugin", "axolotl.integrations.liger.LigerPlugin",
], ],
"liger_rope": True, "liger_rope": True,
"liger_rms_norm": True, "liger_rms_norm": True,
"liger_glu_activation": True, "liger_swiglu": True,
"liger_cross_entropy": True, "liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False, "liger_fused_linear_cross_entropy": False,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.05, "val_set_size": 0.1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"datasets": [ "datasets": [
{ {
@@ -43,15 +46,15 @@ class LigerIntegrationTestCase:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 2, "micro_batch_size": 8,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True, "save_safetensors": True,
"bf16": "auto", "bf16": "auto",
"max_steps": 5, "max_steps": 10,
} }
) )
prepare_plugins(cfg) prepare_plugins(cfg)
@@ -62,24 +65,26 @@ class LigerIntegrationTestCase:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists() assert (Path(temp_dir) / "model.safetensors").exists()
@require_torch_2_4_1 @with_temp_dir
def test_llama_w_flce(self, temp_dir): def test_llama_w_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [ "plugins": [
"axolotl.integrations.liger.LigerPlugin", "axolotl.integrations.liger.LigerPlugin",
], ],
"liger_rope": True, "liger_rope": True,
"liger_rms_norm": True, "liger_rms_norm": True,
"liger_glu_activation": True, "liger_swiglu": True,
"liger_cross_entropy": False, "liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True, "liger_fused_linear_cross_entropy": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.05, "val_set_size": 0.1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"datasets": [ "datasets": [
{ {
@@ -88,15 +93,15 @@ class LigerIntegrationTestCase:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 2, "micro_batch_size": 8,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True, "save_safetensors": True,
"bf16": "auto", "bf16": "auto",
"max_steps": 5, "max_steps": 10,
} }
) )
prepare_plugins(cfg) prepare_plugins(cfg)

View File

@@ -1,14 +1,9 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest import unittest
import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothIntegration(unittest.TestCase): class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests.""" """Unsloth monkeypatch integration tests."""

View File

@@ -20,9 +20,6 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothQLoRA: class TestUnslothQLoRA:
""" """
Test class for Unsloth QLoRA Llama models Test class for Unsloth QLoRA Llama models

View File

@@ -113,7 +113,6 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir @with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir): def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -49,19 +49,7 @@ def require_torch_2_3_1(test_case):
torch_version = version.parse(torch.__version__) torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.3.1") return torch_version >= version.parse("2.3.1")
return unittest.skipUnless(is_min_2_3_1(), "test requires torch>=2.3.1")(test_case) return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
def require_torch_2_4_1(test_case):
"""
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_4_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.4.1")
return unittest.skipUnless(is_min_2_4_1(), "test requires torch>=2.4.1")(test_case)
def require_torch_2_5_1(test_case): def require_torch_2_5_1(test_case):
@@ -73,7 +61,7 @@ def require_torch_2_5_1(test_case):
torch_version = version.parse(torch.__version__) torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.5.1") return torch_version >= version.parse("2.5.1")
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case) return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
def is_hopper(): def is_hopper():

View File

@@ -7,11 +7,11 @@ from typing import Optional
import pytest import pytest
from axolotl.utils.config import prepare_plugins, validate_config from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_liger_cfg") @pytest.fixture(name="minimal_base_cfg")
def fixture_cfg(): def fixture_cfg():
return DictDefault( return DictDefault(
{ {
@@ -25,57 +25,56 @@ def fixture_cfg():
], ],
"micro_batch_size": 1, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
} }
) )
# pylint: disable=too-many-public-methods class BaseValidation:
class TestValidation:
""" """
Test the validation module for liger Base validation module to setup the log capture
""" """
_caplog: Optional[pytest.LogCaptureFixture] = None _caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def inject_fixtures(self, caplog): def inject_fixtures(self, caplog):
caplog.set_level(logging.WARNING)
self._caplog = caplog self._caplog = caplog
def test_deprecated_swiglu(self, minimal_liger_cfg):
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
test_cfg = DictDefault( test_cfg = DictDefault(
{ {
"liger_swiglu": False, "liger_swiglu": False,
} }
| minimal_liger_cfg | minimal_cfg
) )
with self._caplog.at_level( with self._caplog.at_level(logging.WARNING):
logging.WARNING, logger="axolotl.integrations.liger.args"
):
prepare_plugins(test_cfg)
updated_cfg = validate_config(test_cfg) updated_cfg = validate_config(test_cfg)
# TODO this test is brittle in CI assert (
# assert ( "The 'liger_swiglu' argument is deprecated"
# "The 'liger_swiglu' argument is deprecated" in self._caplog.records[0].message
# in self._caplog.records[0].message )
# )
assert updated_cfg.liger_swiglu is None assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activation is False assert updated_cfg.liger_glu_activations is False
def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg): def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
test_cfg = DictDefault( test_cfg = DictDefault(
{ {
"liger_swiglu": False, "liger_swiglu": False,
"liger_glu_activation": True, "liger_glu_activations": True,
} }
| minimal_liger_cfg | minimal_cfg
) )
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*", match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
): ):
prepare_plugins(test_cfg)
validate_config(test_cfg) validate_config(test_cfg)

View File

@@ -4,7 +4,9 @@ import json
import logging import logging
import unittest import unittest
from pathlib import Path from pathlib import Path
from typing import Optional
import pytest
from datasets import load_dataset from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
@@ -63,6 +65,12 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies. Test class for prompt tokenization strategies.
""" """
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def setUp(self) -> None: def setUp(self) -> None:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")