Compare commits

..

64 Commits

Author SHA1 Message Date
Wing Lian
8c4f89745a fix softmax class check 2025-01-15 23:23:13 -05:00
Wing Lian
36b71f34d7 register rala 2025-01-15 23:21:22 -05:00
Wing Lian
d28fee7609 use autoconfig w rala 2025-01-15 23:14:47 -05:00
Wing Lian
c196776996 option to not concatenate during pretraining 2025-01-15 22:45:02 -05:00
Wing Lian
79ae776102 fixup logging layer 2025-01-15 21:36:14 -05:00
Wing Lian
145664d82c more fixups 2025-01-15 21:27:12 -05:00
Dan Saunders
28694219a5 inline comment change 2025-01-14 16:59:43 +00:00
Dan Saunders
fd8ad6fcbf fixing negative component mixing 2025-01-13 19:21:55 +00:00
Dan Saunders
661d71a14b adding diff attn negative component warmup (in progress) 2025-01-10 21:57:31 +00:00
Dan Saunders
6dd47edcb8 fire CLI fixes 2025-01-10 18:24:16 +00:00
Dan Saunders
7aca08ff60 adding guard statements 2025-01-10 16:39:21 +00:00
Dan Saunders
4f804f6d88 adding diff attn callback, adding documentation 2025-01-10 16:28:51 +00:00
Dan Saunders
443327c585 CLI build_command bugfix 2025-01-10 16:28:51 +00:00
Dan Saunders
70c4e6fbe6 updates and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
2a7f139ad2 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
332ce0ae85 fixes and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
e5fa842ff8 update 2025-01-10 16:28:51 +00:00
Dan Saunders
78e0ec0aa5 changes 2025-01-10 16:28:51 +00:00
Dan Saunders
3bc568eb27 adding registration function 2025-01-10 16:28:51 +00:00
Dan Saunders
eb6611d55f progress on modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
4ff3328e66 updated custom modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
a3fd5074a9 fix duplicate-code warnings 2025-01-10 16:28:51 +00:00
Dan Saunders
5b90da0be3 added modeling code; cleanup + refactor 2025-01-10 16:28:51 +00:00
Dan Saunders
fcbfa86373 refactor and fixing test isolation issues 2025-01-10 16:28:51 +00:00
Dan Saunders
0d56582090 adding yaml dumper preserving input config format 2025-01-10 16:28:51 +00:00
Dan Saunders
390cb5742e removing extra pytest xdist args 2025-01-10 16:28:51 +00:00
Dan Saunders
1d935f65c3 moving tests around for flash_attn install 2025-01-10 16:28:51 +00:00
Dan Saunders
66176b3e07 adding split_heads argument for retaining original (Q, K) dimensionanlity 2025-01-10 16:28:51 +00:00
Dan Saunders
505321ac95 isolating problematic test 2025-01-10 16:28:51 +00:00
Dan Saunders
0b382c88da fixes post-rebase 2025-01-10 16:28:51 +00:00
Dan Saunders
ea07a7086e plugin implementation 2025-01-10 16:28:51 +00:00
Dan Saunders
d22e1136bc convert-differential-transformer test coverage 2025-01-10 16:28:51 +00:00
Dan Saunders
63b8e42c6b duplicate code ignore 2025-01-10 16:28:51 +00:00
Dan Saunders
bda1eed59e differential flash attention 2; cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
41ebd93158 moving monkeypatch 2025-01-10 16:28:51 +00:00
Dan Saunders
4c050ce807 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
6665acf63d fix model save / load logic 2025-01-10 16:28:51 +00:00
Dan Saunders
2f9fa4c465 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
849bc94112 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
e484ec778d training fixes, patching, minor cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
df1504ae14 adding CLI command for convert-diff-transformer 2025-01-10 16:28:51 +00:00
Dan Saunders
7be0d7496c Adding script for doing conversion; fixes and updates 2025-01-10 16:28:51 +00:00
Dan Saunders
13cdffa91f initial diff attn layer / model conversion implementation (support for llama arch) 2025-01-10 16:28:51 +00:00
Dan Saunders
7a4b296f60 Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-01-10 16:28:51 +00:00
Wing Lian
d8b4027200 use 2.5.1 docker images as latest tag as it seems stable (#2198) 2025-01-10 08:35:25 -05:00
Wing Lian
fb3352e21c rename liger test so it properly runs in ci (#2246) 2025-01-09 17:31:43 -05:00
NanoCode012
ed77e7001e feat: add support for data_files in pretraining (#2238) 2025-01-09 21:04:13 +00:00
Wing Lian
7669a03fb4 update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts:

* bump datasets, tokenizer, trl

* remove log workarounds in trl

* bump lm-eval

* remove unsloth_ import from critical path

* remove llama fa2 from conftest

* unsloth breaks with latest upstream
2025-01-09 21:01:59 +00:00
Vincenzo di Cicco
6553683170 Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235) 2025-01-09 21:01:22 +00:00
Wing Lian
5e0124e2ab update modal version for ci (#2242) 2025-01-09 21:01:02 +00:00
NanoCode012
2e8d7c1adb fix: mistral nemo does not recognize token_type_ids in forward (#2233) 2025-01-09 21:00:36 +00:00
Wing Lian
3c1921e400 add hf cache caching for GHA (#2247)
* add hf cache caching for GHA

* use modal volume to cache hf data

* make sure to update the cache as we add new fixtures in conftest
2025-01-09 20:59:54 +00:00
Wing Lian
7faf2b6e8e Merge group queue (#2248)
* add support for merge groups

* also lint merge groups
2025-01-09 15:49:00 -05:00
salman
c1b920f291 Fixing OSX installation (#2231)
* bumping version, removing non-osx compatible deps

* updating pylintrc

* fixing linters

* reverting changes
2025-01-07 13:42:01 +00:00
Wing Lian
3915abee4c make sure padding is labeled as -100 for pretraining (#2227) 2024-12-31 15:22:18 -05:00
NJordan72
7a38dbe674 fix: allow trainer builder to use custom jinja chat template (#2219)
* fix: allow trainer builder to use custom jinja chat template

* chore: use get_chat_template_from_config

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>

* fix: swap imports

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
2024-12-24 16:18:50 -05:00
Wing Lian
e0a2eb2ebd fix untrained tokens if specified explicitly from a list (#2210) 2024-12-23 09:08:28 -05:00
Wing Lian
d852d7af7a inference - don't default w accelerate, fix base model (#2216) [skip ci] 2024-12-23 07:48:41 -05:00
Wing Lian
3742deb1de add deepspeed example with torch compile enabled (#2212) [skip ci] 2024-12-22 12:11:39 -05:00
Wing Lian
2312caaa98 GC every n steps (#2209) 2024-12-21 17:38:33 -05:00
Wing Lian
307cf7c685 move the dataset loading from remote/disk to a shared function so we can re-use for RL (#2204) 2024-12-20 21:43:52 -05:00
Dan Saunders
70541145f1 adding test_datasets compat with pretraining_dataset (streaming) (#2206) [skip ci] 2024-12-20 21:43:33 -05:00
Wing Lian
42bd32a233 add outputs (symlink) to gitignore [skip ci] (#2205) 2024-12-19 20:14:43 -05:00
Dan Saunders
5b8fb5e939 remove cicd pytest xdist args (#2201)
* remove cicd pytest xdist args

* Delete outputs
2024-12-19 11:44:53 -05:00
79 changed files with 3753 additions and 497 deletions

View File

@@ -1,6 +1,7 @@
name: lint
on:
# check on PRs, and manual triggers
merge_group:
pull_request:
paths:
- '**.py'

View File

@@ -25,7 +25,6 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -36,6 +35,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -92,7 +92,6 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -103,6 +102,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -129,7 +129,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -1,6 +1,7 @@
name: Tests
on:
# check on push/merge to main, PRs, and manual triggers
merge_group:
push:
branches:
- "main"
@@ -60,6 +61,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -100,6 +110,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -115,6 +134,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -156,6 +184,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
@@ -183,7 +220,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -229,7 +266,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

4
.gitignore vendored
View File

@@ -1,6 +1,7 @@
**/axolotl.egg-info
configs
last_run_prepared/
outputs
.vscode
_site/
@@ -185,3 +186,6 @@ out/
# vim
*.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v2.17.4
rev: v3.3.0
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy

View File

@@ -1,5 +1,5 @@
[MASTER]
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -8,6 +8,7 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -4,7 +4,6 @@ set -e
python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -1,6 +1,6 @@
"""
modal application to run axolotl gpu tests in Modal
"""
modal application to run axolotl gpu tests in Modal
"""
# pylint: disable=duplicate-code
import os
@@ -28,6 +28,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -48,6 +49,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
@@ -67,6 +74,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")

View File

@@ -29,6 +29,7 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -50,6 +51,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
@@ -69,6 +76,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")

View File

@@ -0,0 +1,27 @@
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"compile": {
"disable": false,
"backend": "inductor"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -2,7 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
@@ -14,11 +14,11 @@ packaging==23.2
peft==0.14.0
transformers==4.47.1
tokenizers>=0.20.1
tokenizers>=0.21.0
accelerate==1.2.1
datasets==3.1.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.12.1
trl==0.13.0
optimum==1.16.2
hf_transfer
@@ -53,7 +53,7 @@ zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
lm_eval==0.4.7
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.1b2
axolotl-contribs-lgpl==0.0.3

View File

@@ -1,4 +1,5 @@
"""setup.py for axolotl"""
import ast
import os
import platform
@@ -29,15 +30,30 @@ def parse_requirements():
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
@@ -73,6 +89,8 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")

View File

@@ -202,7 +202,7 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
from typing import Dict, Union
import fire
from dotenv import load_dotenv
@@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate
LOG = logging.getLogger("axolotl.cli.evaluate")
def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg, cli_args) -> Dict[str, float]:
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
@@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

View File

@@ -0,0 +1,208 @@
"""CLI to convert a transformers model's attention layers to differential attention layers."""
import logging
import warnings
from pathlib import Path
from time import time
from typing import Union
import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
)
from axolotl.utils.yaml import dump_yaml_preserved_order
LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
def convert_diff_transformer(cfg, cli_args, config_path):
assert not (
cli_args.split_heads and cli_args.zero_init
), "Both `split_heads` and `zero_init` cannot be `True`"
assert not (
cli_args.zero_init and cli_args.mirror_weights
), "Both `zero_init` and `mirror_weights` cannot be `True`"
debug_info = {}
# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)
# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)
# Test original model
if cli_args.debug:
LOG.info("Testing original model...")
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
model, tokenizer
)
try:
# Convert attention
LOG.info("Converting to differential attention...")
config = LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
mirror_weights=cli_args.mirror_weights,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise
# Test converted model
if cli_args.debug:
LOG.info("Testing converted model...")
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
model, tokenizer
)
# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)
with open(config_path, "r", encoding="utf-8") as file:
modified_cfg = yaml.safe_load(file) or {}
modified_cfg["base_model"] = cfg.output_dir
modified_cfg["diff_attention"] = True
plugin_class = (
"axolotl.integrations.diff_transformer.DifferentialTransformerPlugin"
)
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]
# Write out the updated axolotl config while preserving original ordering / formatting
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
+ Fore.RESET
)
if debug_info["orig_text"] == debug_info["conv_text"]:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
debug_info["generations_match"] = True
else:
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['conv_text']}\n"
+ "*" * 50
+ "\n"
)
debug_info["generations_match"] = False
if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
debug_info["match_expected"] = True
else:
LOG.info(
Fore.YELLOW
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
debug_info["match_expected"] = False
return model, debug_info
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_diff_transformer(cfg, cli_args, config)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -0,0 +1,198 @@
"""CLI to convert a transformers model's attns to rala attns."""
import logging
import warnings
from pathlib import Path
from time import time
from typing import Union
import fire
import torch
import yaml
from colorama import Fore
from dotenv import load_dotenv
from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.rala.convert import convert_to_rala
from axolotl.utils.yaml import dump_yaml_preserved_order
LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
try:
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
except Exception as exc:
LOG.error("Inference failed: %s", str(exc))
raise
def convert_rala(cfg, cli_args, config_path):
debug_info = {}
# Load model and tokenizer
with warnings.catch_warnings():
warnings.simplefilter("ignore")
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model.to(cfg.device, dtype=cfg.torch_dtype)
# Log original model info
LOG.info(
"Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d",
model.config.hidden_size,
model.config.num_attention_heads,
)
# Test original model
if cli_args.debug:
LOG.info("attention layers to RALA attention")
debug_info["orig_time"], debug_info["orig_text"] = test_inference(
model, tokenizer
)
# Convert attention
try:
model = convert_to_rala(
model=model,
zero_init=cli_args.zero_init,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
model.config.model_type = "llama-rala"
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise
# Test converted model
if cli_args.debug:
LOG.info("Testing converted model...")
debug_info["conv_time"], debug_info["conv_text"] = test_inference(
model, tokenizer
)
# Save if requested
if cfg.output_dir:
# Save model and tokenizer
LOG.info("Saving converted model to %s", cfg.output_dir)
model.save_pretrained(cfg.output_dir)
tokenizer.save_pretrained(cfg.output_dir)
# Modify config to reflect new path / differential attention
output_config_path = Path(cfg.output_dir) / "axolotl_config.yml"
LOG.info("Saving updated config to %s", output_config_path)
with open(config_path, "r", encoding="utf-8") as file:
modified_cfg = yaml.safe_load(file) or {}
modified_cfg["base_model"] = cfg.output_dir
modified_cfg["rala_attention"] = True
plugin_class = "axolotl.integrations.rala.RalaPlugin"
if "plugins" in modified_cfg:
modified_cfg["plugins"].append(plugin_class)
else:
modified_cfg["plugins"] = [plugin_class]
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,
output_path=output_config_path,
)
else:
LOG.info("Not saving converted model to disk")
LOG.info("Pass --output-dir path/to/save to save model")
if cli_args.debug:
LOG.info(
Fore.GREEN
+ "Conversion successful!\n"
+ f"Original generation time: {debug_info['orig_time']:.2f}s\n"
+ f"Converted generation time: {debug_info['conv_time']:.2f}s"
+ Fore.RESET
)
if debug_info["orig_text"] == debug_info["conv_text"]:
LOG.info(
Fore.GREEN
+ "Generations match!\n"
+ "Model generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ Fore.RESET
)
debug_info["generations_match"] = True
else:
message = (
"Generations do not match.\n"
+ "Original generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['orig_text']}\n"
+ "*" * 50
+ "\n"
+ "Converted generation:\n"
+ "*" * 50
+ "\n"
+ f"{debug_info['conv_text']}\n"
+ "*" * 50
+ "\n"
)
debug_info["generations_match"] = False
if cli_args.zero_init and not cli_args.sublayer_norm:
LOG.info(Fore.RED + message + Fore.RESET)
debug_info["match_expected"] = True
else:
LOG.info(
Fore.YELLOW
+ message
+ "However, this is expected since --zero-init"
+ " and --no-sublayer-norm were not passed."
+ Fore.RESET
)
debug_info["match_expected"] = False
return model, debug_info
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
print_axolotl_text_art()
cfg = load_cfg(config, **kwargs)
if cfg.rala_attention:
cfg.rala_attention = False
parser = HfArgumentParser(ConvertDiffTransformerCliArgs)
cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True)
convert_rala(cfg, cli_args, config)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -12,7 +12,12 @@ from axolotl.cli.utils import (
build_command,
fetch_from_github,
)
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.common.cli import (
ConvertDiffTransformerCliArgs,
EvaluateCliArgs,
PreprocessCliArgs,
TrainerCliArgs,
)
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -77,6 +82,9 @@ def evaluate(config: str, accelerate: bool, **kwargs):
"""Evaluate a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config:
@@ -93,7 +101,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
default=False,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
@@ -124,7 +132,7 @@ def inference(
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["output_dir"] = base_model
kwargs["base_model"] = base_model
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
@@ -240,6 +248,32 @@ def merge_lora(
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_diff_transformer(config: str, **kwargs):
"""Convert model attention layers to differential attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.integrations.convert_diff_transformer import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(ConvertDiffTransformerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def convert_rala(config: str, **kwargs):
"""Convert model attention layers to RALA attention layers."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.integrations.convert_rala import do_cli
do_cli(config=config, **kwargs)
@cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory")

View File

@@ -22,7 +22,6 @@ def add_options_from_dataclass(config_class: Type[Any]):
# Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)):
field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
@@ -44,6 +43,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
default=field.default,
help=field.metadata.get("description"),
)(function)
return function
return decorator
@@ -55,7 +55,14 @@ def add_options_from_config(config_class: Type[BaseModel]):
def decorator(function):
# Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()):
if field.annotation == bool:
field_type = field.annotation
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
# NOTE: defaults are handled by the pydantic model config classes.
if field_type == bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"
function = click.option(
@@ -66,6 +73,7 @@ def add_options_from_config(config_class: Type[BaseModel]):
function = click.option(
option_name, default=None, help=field.description
)(function)
return function
return decorator
@@ -84,6 +92,8 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else:
cmd.extend([f"--{key}", str(value)])

View File

@@ -4,7 +4,7 @@ shared module for cli specific things
import logging
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, Union
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
@@ -12,14 +12,12 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
LOG = logging.getLogger(__name__)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
"""dataclass with arguments for preprocessing only"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -30,9 +28,7 @@ class PreprocessCliArgs:
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
"""dataclass with various non-training arguments"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -45,19 +41,28 @@ class TrainerCliArgs:
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
"""
"""dataclass with various evaluation arguments"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
@dataclass
class ConvertDiffTransformerCliArgs:
"""dataclass with arguments for convert-diff-transformer CLI"""
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)
mirror_weights: bool = field(default=False)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
cli_args: Union[TrainerCliArgs, EvaluateCliArgs, ConvertDiffTransformerCliArgs],
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -56,6 +55,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GCCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
@@ -67,7 +67,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -293,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments):
"""
Training arguments for Causal trainer
This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value
This code is duplicated due to HF TrainingArguments not setting output_dir with a default value
so it can't be used as a mixin.
"""
@@ -481,7 +481,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"to_weight_decay": {}, # LayerNorm except bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
@@ -607,8 +607,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
@@ -977,12 +983,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1166,22 +1167,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1190,22 +1175,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
@@ -1214,49 +1183,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
@@ -1265,22 +1191,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
@@ -1289,15 +1199,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC):
"""
@@ -1452,6 +1353,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())
return callbacks
@@ -1831,8 +1734,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
@@ -1974,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -9,12 +9,11 @@ from typing import Dict, Optional
import torch
from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.common.cli import EvaluateCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_processor, load_tokenizer
from axolotl.utils.models import load_processor
from axolotl.utils.trainer import setup_trainer
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
@@ -62,8 +61,9 @@ def evaluate_dataset(
return metrics
# pylint: disable=duplicate-code
def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
*, cfg: DictDefault, cli_args: EvaluateCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
"""
Evaluate a model on training and validation datasets
@@ -79,16 +79,11 @@ def evaluate(
- The tokenizer
- Dictionary of evaluation metrics
"""
# pylint: disable=duplicate-code
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
# Load model
LOG.debug("loading model for evaluation...")
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
# Load processor for multimodal models if needed
processor = None
@@ -100,12 +95,6 @@ def evaluate(
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
# Load model
LOG.debug("loading model for evaluation...")
model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer
trainer = setup_trainer(
cfg,

View File

@@ -48,12 +48,12 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg): # pylint: disable=unused-argument
def register(self): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
Parameters:
cfg (dict): The configuration for the plugin.
None
Returns:
None
@@ -75,6 +75,19 @@ class BasePlugin:
None
"""
def set_attn_config(
self, cfg, model_kwargs, model_config
): # pylint: disable=unused-argument
"""
Sets attention configuration for the model.
Parameters:
cfg (dict): The configuration for the plugin.
model_kwargs (dict): The model kwargs for the plugin.
model_config (object): The model configuration.
Returns:
None
"""
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -274,6 +287,7 @@ class PluginManager:
try:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
plugin.register()
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -304,6 +318,17 @@ class PluginManager:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def set_attn_config(self, cfg, model_kwargs, model_config):
"""
modifies the attention configuration of the model kwargs for loading
Parameters:
cfg (dict): The configuration for the plugins.
model_kwargs (dict): The model's kwargs for construction the model
model_config (dict): The model's configuration.
"""
for plugin in self.plugins.values():
plugin.set_attn_config(cfg, model_kwargs, model_config)
def post_model_load(self, cfg, model):
"""
Calls the post_model_load method of all registered plugins.

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""
for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -0,0 +1,12 @@
# Differential Transformer
### Usage
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.
```yaml
plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
diff_attention: true
```

View File

@@ -0,0 +1,67 @@
"""Definition of differential transformer plugin."""
import logging
from typing import List
from transformers import PreTrainedModel, TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.diff_attn import (
DifferentialAttentionMixingCallback,
DifferentialAttentionMonitorCallback,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""Plugin for differential transformer integration with Axolotl."""
def __init__(self) -> None:
"""
Constructor for differential transformers plugin. Calls `register_diff_attn`
to register differential attention custom modeling implementation to `AutoConfig`
and `AutoModel`.
"""
from .modeling_diff_attn import register_diff_attn
register_diff_attn()
def get_input_args(self) -> str:
"""Returns module path to diff transformer plugin args for `axolotl` config."""
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> List[TrainerCallback]:
"""
Returns `DifferentialAttentionMonitorCallback` to be added to the list of
callbacks for the `axolotl` trainer if wandb usage is enabled.
Parameters:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The loaded mfodel.
Returns:
A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`.
"""
callbacks = []
if cfg.use_wandb:
callbacks.append(
DifferentialAttentionMonitorCallback(
log_every=cfg.diff_attn_log_every,
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
warmup_steps=cfg.diff_attn_warmup_steps,
)
)
if cfg.diff_attn_warmup_steps:
callbacks.append(
DifferentialAttentionMixingCallback(
warmup_steps=cfg.diff_attn_warmup_steps
)
)
return callbacks

View File

@@ -0,0 +1,27 @@
"""Module for handling differential transfomer input arguments."""
import logging
from typing import Optional
from pydantic import BaseModel
LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel):
"""
Input args for differential transformer.
Attributes:
diff_attention: Whether to use differential attention layers.
diff_attn_log_every: How often to log differential attention statistics.
diff_attn_num_monitor_layers: Number of layers to monitor for attention stats.
diff_attn_warmup_steps: Number of steps to linearly increase negative attention
mixing weight from 0 to 1. If specified, will reach full mixing at this
step. If `None`, negative attention has full weight from the start.
"""
diff_attention: Optional[bool] = None
diff_attn_log_every: Optional[int] = 100
diff_attn_num_monitor_layers: Optional[int] = 3
diff_attn_warmup_steps: Optional[int] = None

View File

@@ -0,0 +1,694 @@
"""Re-implemention of differential attention from the Differential Transformer paper
(https://arxiv.org/abs/2410.05258)."""
# pylint: disable=invalid-name
import logging
import math
from typing import Any
import torch
import torch.nn.functional as F
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
)
logging.basicConfig(level=logging.INFO)
LOG = logging.getLogger(__name__)
try:
from flash_attn.flash_attn_interface import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""
Repeats key/value heads to match the number of query heads in multi-head attention.
Args:
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
n_rep: Number of times to repeat each head.
Returns:
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
seq_len, head_dim)`.
If `n_rep` is 1, returns the input tensor unchanged.
"""
batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
return (
x[:, :, None, :, :]
.expand(batch_size, n_kv_heads, n_rep, slen, head_dim)
.reshape(batch_size, n_kv_heads * n_rep, slen, head_dim)
)
def lambda_init_fn(depth: int) -> float:
"""
Lambda mixing parameter init function from the "Differential Transformer" paper.
Args:
depth: Index of layer to init lambda parameter.
Returns:
Lambda initialization value (decreasing with `depth`).
"""
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class LlamaDifferentialAttentionBase(nn.Module):
"""
Base class for differential attention implementations.
This class implements the core differential attention mechanism used in Llama models.
It supports both split heads and double projection modes for attention computation.
"""
def __init__(self, config: Any, layer_idx: int):
"""
Initializes the differential attention module.
Args:
config: Model configuration object containing hyperparameters, including:
- hidden_size: The size of hidden states.
- num_attention_heads: Number of attention heads.
- num_key_value_heads: Number of key/value heads.
- attention_bias: Whether to use bias in attention projections.
- split_heads: Whether to use split heads mode.
- rms_norm_eps: Epsilon for RMS normalization.
layer_idx: The index of this layer in the model.
Note:
The initialization process consists of four steps:
1. Configuration initialization (`_init_config`)
2. Projection layers initialization (`_init_projections`)
3. Differential parameters initialization (`_init_differential_params`)
4. Normalization layers initialization (`_init_normalization`)
"""
super().__init__()
self.config = config
self._init_config(layer_idx)
self._init_projections()
self._init_differential_params()
self._init_normalization()
# For logging
self.attn1 = None
self.attn2 = None
self.lambda_full = None
def _init_config(self, layer_idx: int) -> None:
"""
Initializes configuration parameters for the attention layer. Sets up various
dimension sizes and head counts based on the provided config. Handles both
split heads and double projection modes.
In split heads mode, the number of heads is divided by 2 (rounding down), which
differs from the original implementation that required an even number.
Args:
layer_idx: Index of the current layer.
"""
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads
self.base_num_kv_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
self.layer_idx = layer_idx
if self.config.split_heads:
self.heads_per_component = self.base_num_heads // 2
self.kv_heads_per_component = self.base_num_kv_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
self.heads_per_component = self.base_num_heads
self.kv_heads_per_component = self.base_num_kv_heads
self.value_head_dim = self.head_dim
def _init_projections(self) -> None:
"""
Initializes the query, key, value, and output projection layers.
Creates linear transformations for Q, K, V projections with dimensions
depending on whether split heads or double projection mode is used.
The output projection combines the attention heads back to model dimension.
"""
if self.config.split_heads:
q_out_dim = self.config.hidden_size
k_out_dim = self.head_dim * self.base_num_kv_heads
else:
q_out_dim = self.config.hidden_size * 2
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
self.q_proj = nn.Linear(
self.config.hidden_size, q_out_dim, bias=self.config.attention_bias
)
self.k_proj = nn.Linear(
self.config.hidden_size, k_out_dim, bias=self.config.attention_bias
)
self.v_proj = nn.Linear(
self.config.hidden_size,
self.head_dim * self.base_num_kv_heads,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(
self.base_num_heads * self.head_dim,
self.config.hidden_size,
bias=self.config.attention_bias,
)
def _init_differential_params(self) -> None:
"""
Initializes parameters specific to differential attention.
Creates learnable parameters for the differential attention mechanism:
- Mixing parameter for negative attention component warmup phase.
- Lambda parameters for queries and keys.
- Initial lambda value based on layer index.
- Rotary position embedding layer.
"""
self.diff_attn_mix = 1.0 # Default to full mixing
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
)
self.lambda_q1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k1 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_q2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self) -> None:
"""
Initializes normalization layers for the attention mechanism.
Sets up either RMS normalization or identity transformation based on config.
The normalization is applied to the sublayer output if enabled.
"""
sublayer_norm = getattr(self.config, "sublayer_norm", True)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
else:
self.subln = nn.Identity()
def _prepare_attention_inputs(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepares input tensors for attention computation.
Projects input hidden states to query, key, and value spaces, then reshapes
them for multi-head attention processing.
Args:
hidden_states: Input tensor of shape `(batch_size, seq_len,
hidden_size)`.
Returns:
tuple: Tuple containing:
- q1: Positive attention query component
- q2: Negative attention query component
- k1: Positive attention key component
- k2: Negative attention key component
- v: Value tensor
"""
bsz, q_len, _ = hidden_states.size()
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
return q1, q2, k1, k2, v
def _apply_rotary_embeddings(
self,
q1: torch.Tensor,
q2: torch.Tensor,
k1: torch.Tensor,
k2: torch.Tensor,
position_ids: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Applies rotary positional embeddings to queries and keys.
Args:
q1: Positive attention query component.
q2: Negative attention query component.
k1: Positive attention key component.
k2: Negative attention key component.
position_ids: Token position indices.
position_embeddings: Pre-computed rotary embeddings (cos, sin).
Returns:
tuple: Tuple containing:
- q1: Positive attention query with positional encoding.
- q2: Negative attention query with positional encoding.
- k1: Positive attention key with positional encoding.
- k2: Negative attention key with positional encoding.
- cos: Cosine part of rotary embeddings.
- sin: Sine part of rotary embeddings.
"""
if position_embeddings is None:
LOG.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
return q1, q2, k1, k2, cos, sin
def _handle_cache(
self,
k1: torch.Tensor,
k2: torch.Tensor,
v: torch.Tensor,
past_key_value: Cache | None,
cache_kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Handles key-value caching for autoregressive generation and the repetition of
key-value heads to match the number of query heads.
Args:
k1: Positive attention key component.
k2: Negative attention key component.
v: Value tensor.
past_key_value: Cache object for storing previous key-value pairs.
cache_kwargs: Additional arguments for cache handling.
Returns:
tuple: Tuple containing:
- k1: Processed positive attention key component.
- k2: Processed negative attention key component.
- v: Processed value tensor.
"""
if past_key_value is not None:
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
if self.config.split_heads:
v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1)
return k1, k2, v
def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
"""
Computes lambda values for differential attention.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
from the learned parameters. `diff_attn_mix` is multiplied through the result
for negative attention component warmup phase (if applicable).
Args:
q1: Positive attention query component, used for type casting.
Returns:
Computed lambda value for differential attention.
"""
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
lambda_full = lambda_1 - lambda_2 + self.lambda_init
return self.diff_attn_mix * lambda_full
def _process_attention_output(
self, attn: torch.Tensor, bsz: int, q_len: int
) -> torch.Tensor:
"""
Processes and projects the attention output. Applies sublayer normalization,
scales by (1 - λ_init), and projects back to model dimension.
Args:
attn: Raw attention output.
bsz: Batch size.
q_len: Query sequence length.
Returns:
Processed attention output of shape (batch_size, seq_len, hidden_size)
"""
attn = self.subln(attn)
# NOTE: this may need to be added back in, but doesn't interact well with
# `diff_attn_mix`, and doesn't allow us to preserve the original model output.
# attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
"""
Standard implementation of differential attention.
This class implements the standard differential attention mechanism using
explicit matrix multiplications for the attention computation.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using standard matrix multiplication operations.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- Attention weights if output_attentions is True, else None.
- Updated key-value cache if use_cache is True, else None.
"""
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Standard attention computation
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
if attention_mask is not None:
causal_mask = attention_mask[:, :, :, : k1.shape[-2]]
attn1 = attn1 + causal_mask
attn2 = attn2 + causal_mask
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
dropout_p = self.config.attention_dropout if self.training else 0.0
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
lambda_full = self._compute_lambda(q1)
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
if output_attentions:
attn_weights = attn1 - lambda_full * attn2
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
return attn, attn_weights, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
"""
SDPA-based implementation of differential attention.
This class implements differential attention using PyTorch's scaled_dot_product_attention
for improved performance on supported hardware.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using PyTorch's scaled dot product attention.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (SDPA doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# SDPA-specific attention computation
causal_mask = (
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
)
is_causal = attention_mask is None and q_len > 1
dropout_p = self.config.attention_dropout if self.training else 0.0
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
k1, k2 = k1.contiguous(), k2.contiguous()
v = v.contiguous()
attn1 = F.scaled_dot_product_attention(
q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
attn2 = F.scaled_dot_product_attention(
q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal
)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
"""
Flash Attention 2-based implementation of differential attention.
This class implements differential attention using Flash Attention 2 for maximum
performance on supported hardware.
"""
def __init__(self, *args, **kwargs):
"""
Initializes the Flash Attention 2 differential attention module.
Args:
*args: Positional arguments passed to parent class.
**kwargs: Keyword arguments passed to parent class.
Raises:
ImportError: If flash-attn library is not installed.
"""
if not FLASH_ATTENTION_AVAILABLE:
raise ImportError(
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
"Please install with `pip install flash-attn --no-build-isolation`"
)
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using Flash Attention 2.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (Flash Attention doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
+ "flash attenion does not support `output_attentions=True`. Falling back "
+ "to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
attention_mask,
position_ids,
past_key_value,
output_attentions,
use_cache,
cache_position,
position_embeddings,
)
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
q1, q2, k1, k2, position_ids, position_embeddings
)
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs)
# Flash Attention specific processing
q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2)
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
v = v.transpose(1, 2)
dropout_p = self.config.attention_dropout if self.training else 0.0
if self.config.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True)
attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True)
attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True)
attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2)
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value

View File

@@ -0,0 +1,401 @@
"""
Modeling for differential transformers.
This module implements differential attention variants of the LLaMA model,
providing various attention implementations for improved performance.
"""
import logging
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig):
"""
Configuration class for Differential LLaMA model.
Extends the base LLaMA configuration with additional parameters for differential
attention mechanisms.
"""
model_type = "llama-differential"
def __init__(
self,
split_heads: bool = False,
sublayer_norm: bool = True,
zero_init: bool = False,
mirror_weights: bool = False,
**kwargs,
):
"""
Initialize differential LLaMA configuration.
Args:
split_heads: Whether to use split heads mode for attention computation.
sublayer_norm: Whether to apply normalization to sublayers.
zero_init: Whether to initialize new weights to zero.
mirror_weights: Whether to copy the positive attention component weights to
the negative attention component.
**kwargs: Additional arguments passed to LlamaConfig.
"""
super().__init__(**kwargs)
self.split_heads = split_heads
self.sublayer_norm = sublayer_norm
self.zero_init = zero_init
self.mirror_weights = mirror_weights
self.architectures = ["LlamaDifferentialModel"]
self._attn_implementations = {
"eager": "differential_eager",
"sdpa": "differential_sdpa",
"flash_attention_2": "differential_flash_attention_2",
}
class LlamaDifferentialModel(LlamaModel):
"""
LlamaModel with differential attention.
This class extends the base LLaMA model by replacing standard attention with
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model.
Args:
config: Configuration object for the model.
Raises:
ValueError: If specified attention implementation is not supported.
"""
super().__init__(config)
# Handle attention implementation
attn_impl = config._attn_implementation or "eager"
if attn_impl in config._attn_implementations:
attn_impl = config._attn_implementations[attn_impl]
# Validate attention implementation
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_impl not in valid_impls:
raise ValueError(f"Invalid attention implementation: {attn_impl}")
# Replace standard attention with differential attention in each layer
attn_classes = {
"differential_eager": LlamaDifferentialAttention,
"differential_sdpa": LlamaDifferentialSdpaAttention,
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
}
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
for idx, layer in enumerate(self.layers):
layer.self_attn = attn_class(config, idx)
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls,
model: LlamaModel | LlamaForCausalLM,
config: LlamaDifferentialConfig | None = None,
) -> "LlamaDifferentialModel":
"""
Convert a `LlamaModel` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
model = model.model
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
logger.debug(f"Created config: {config}")
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
# Copy all weights except attention
logger.debug("Copying embeddings and norm")
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict())
logger.debug("Copying layer weights")
for layer_idx, (new_layer, old_layer) in enumerate(
zip(new_model.layers, model.layers)
):
# Copy everything except attention weights
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
new_layer.input_layernorm.load_state_dict(
old_layer.input_layernorm.state_dict()
)
new_layer.post_attention_layernorm.load_state_dict(
old_layer.post_attention_layernorm.state_dict()
)
# Handle attention weights
new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict()
)
new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict()
)
# Get the original projection sizes
old_q_size = old_layer.self_attn.q_proj.weight.size(0)
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
if not config.split_heads:
logger.debug(
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
)
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
old_layer.self_attn.k_proj.weight.data
)
if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing")
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
new_layer.self_attn.lambda_q1.zero_()
new_layer.self_attn.lambda_k1.zero_()
new_layer.self_attn.lambda_q2.zero_()
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
elif config.mirror_weights:
# Mirror weights for second component
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
old_layer.self_attn.k_proj.weight.data
)
logger.info("Conversion complete")
return new_model
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
"""
`LlamaForCausalLM` with differential attention.
This class extends the base LLaMA causal language model by incorporating
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model for causal language modeling.
Args:
config: Configuration object for the model.
"""
super().__init__(config)
self.model = LlamaDifferentialModel(config)
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
) -> "LlamaDifferentialForCausalLM":
"""
Convert a `LlamaForCausalLM` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
return new_model
def register_diff_attn() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
# Register models
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2

View File

@@ -22,13 +22,6 @@ import inspect
import logging
import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
@@ -46,6 +39,13 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -0,0 +1,21 @@
"""Definition of RALA plugin."""
import logging
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rala.auto.llama.modeling_rala import register_rala_model
LOG = logging.getLogger(__name__)
class RalaPlugin(BasePlugin):
"""
Plugin for Rala integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.rala.args.RalaArgs"
def register(self):
LOG.info("Registering RALA model with AutoConfig & AutoModel")
register_rala_model()

View File

@@ -0,0 +1,14 @@
"""Module for handling RALA input arguments."""
import logging
from typing import Optional
from pydantic import BaseModel
LOG = logging.getLogger(__name__)
class RalaArgs(BaseModel):
"""Input args for RALA."""
rala_attention: Optional[bool] = None

View File

@@ -0,0 +1,13 @@
"""
Rala config class
"""
from transformers import LlamaConfig
class LlamaRalaConfig(LlamaConfig):
"""
Configuration for LlamaRala model
"""
model_type = "llama-rala"
softmax_every: int = 6 # every N-th layer applies softmax

View File

@@ -0,0 +1,623 @@
# Copyright 2024-2025 Axolotl AI. All rights reserved.
#
# This software may be used and distributed according to
# the terms of the Apache License 2.0 (the "License");
# you may not use this file except in compliance with the License.
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations under
# the License.
"""
Custom modeling code for RALA Llama
"""
from typing import List, Optional, Tuple, Union, Unpack
import torch
import torch.nn.functional as F
from torch import nn
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
Cache,
GenerationMixin,
LlamaModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LLAMA_ATTENTION_CLASSES,
KwargsForCausalLM,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP,
LlamaPreTrainedModel,
LlamaRMSNorm,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
from .configuration_rala import LlamaRalaConfig
def kappa(x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
"""
The paper uses κ(x) = ELU(x) + 1.
x is assumed to be [batch, n_heads, seq_len, head_dim].
"""
return F.elu(x) + 1
class LlamaRALAAttention(nn.Module):
"""
LlamaAttention replaced with Rank-Augmented Linear Attention (RALA).
Adapted from the standard LlamaAttention for demonstration.
**Not** a fully drop-in replacement if you need caching/TP.
"""
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# Same Q, K, V, output projections
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size, self.hidden_size, bias=config.attention_bias
)
# We will preserve rope usage
self._init_rope()
# A simple φ-projection for RALA:
# The paper uses φ(x) as a linear transform or identity. We'll do a linear:
self.phi = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def _init_rope(self):
# Standard Llama rope logic
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
"""
RALA forward pass.
This version omits incremental decoding with `past_key_value` for simplicity
(linear attention caching is non-trivial).
"""
bsz, q_len, _ = hidden_states.size()
# Standard Q, K, V
query_states = self.q_proj(hidden_states) # [b, seq, n_heads*dim]
key_states = self.k_proj(hidden_states) # [b, seq, n_kv_heads*dim]
value_states = self.v_proj(hidden_states) # [b, seq, n_kv_heads*dim]
# Reshape to [b, n_heads, seq_len, head_dim]
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# Apply RoPE (rotary embeddings) just as in standard Llama
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
# 4. If we have a past_key_value (Cache object), let it update / append
if past_key_value is not None:
# This is the normal Llama pattern
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
# The .update() method returns updated (key_states, value_states)
# and typically updates internal buffers. It may also store `layer_idx` data.
key_states, value_states = past_key_value.update(
key_states, value_states, self.layer_idx, cache_kwargs
)
# If you still want to handle the repeated KV for multi-group setups:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Now we apply RALA.
# 1) Apply κ(.) to Q,K: shape [b, n_heads, seq_len, head_dim]
Q_kappa = kappa(query_states) # pylint: disable=invalid-name
K_kappa = kappa(key_states) # pylint: disable=invalid-name
# 2) Compute global query Q_g = average of Q_kappa across seq_len => [b, n_heads, head_dim]
# The paper denotes Q_g = (1/N) Σ_i Q_kappa_i
seq_len_float = float(q_len) # for scaling
Q_g = Q_kappa.mean( # pylint: disable=invalid-name
dim=2
) # [b, n_heads, head_dim]
# 3) Compute alpha_j for each token j in [0..seq_len-1]
# alpha_j = N * softmax( Q_g · K_kappa_j^T ), shape => [b, n_heads, seq_len]
# Dot product over head_dim
# K_kappa is [b, n_heads, seq_len, head_dim], Q_g is [b, n_heads, head_dim]
# We'll do an einsum or transpose to produce logits [b, n_heads, seq_len]
# Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len]
# logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len]
logits = (Q_g.unsqueeze(2) * K_kappa).sum(
dim=-1
) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work
# 4) Incorporate causal or padding mask if provided.
# In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar.
# For RALA, we only do a single softmax over "j" dimension. We can add the mask to logits.
# Caution: This might not replicate strict causal linear attention. It's a best-effort approach.
if attention_mask is not None:
# Usually Llama's causal mask is [b, 1, q_len, kv_len] with 0 or -inf
# We want shape [b, n_heads, seq_len], so we can broadcast accordingly:
# e.g., attention_mask: [b, 1, q_len, seq_len]
# We pick the slice that corresponds to q_len vs. kv_len.
# Typically the last two dims are (q_len, kv_len). We want the kv_len dimension to be `seq_len`.
# We'll do something like:
if attention_mask.dim() == 4:
# attention_mask: [b, 1, q_len, kv_len]
# if q_len == kv_len, we can do attention_mask[:, :, :, :seq_len], then squeeze dims
mask_2d = attention_mask[:, 0, :, :q_len] # [b, q_len, seq_len]
# we only want [b, n_heads, seq_len], so we must broadcast over q_len if needed
# but in this snippet, we do a single alpha_j for each j *per head*,
# ignoring per-token Q_i. So there's a mismatch.
# A simpler approach is to apply the mask for the entire sequence if a token j is invalid for ANY i.
# That is approximate. We'll just pick the first row of q_len, or do min across i dimension...
# For demonstration, let's sum or min across i dimension to see if j is valid for ANY i.
# Or we do a "causal" approach: all tokens j>i get masked. But there's no direct i index here in alpha_j.
# We'll just do a rough approach, e.g. mask = min across the q_len dimension:
mask_1d = torch.min(mask_2d, dim=1)[
0
] # [b, seq_len], picking the worst mask across query positions
# broadcast for n_heads
mask_1d = mask_1d.unsqueeze(1).expand(
-1, self.num_heads, -1
) # [b, n_heads, seq_len]
logits = logits + mask_1d
else:
# Possibly it's [b, seq_len]. Then we just broadcast to [b,n_heads,seq_len].
mask_1d = attention_mask # [b, seq_len]
mask_1d = mask_1d.unsqueeze(1).expand(-1, self.num_heads, -1)
logits = logits + mask_1d
alpha = F.softmax(logits, dim=-1) # [b, n_heads, seq_len]
# multiply by seq_len per the formula
alpha = alpha * seq_len_float
# 5) Construct the outer-sum: Σ_j alpha_j * (K_kappa_j^T V_j)
# The paper shows a d×d matrix formed per head.
# K_kappa: [b, n_heads, seq_len, head_dim], V: [b, n_heads, seq_len, head_dim]
# For each j, do outer product K_kappa_j (d×1) × V_j^T (1×d) => d×d
# Then multiply by alpha_j and sum over j.
# We'll do an einsum for that: [b,n_heads,seq_len,d] outer [b,n_heads,seq_len,d] => [b,n_heads,d,d]
# alpha: [b, n_heads, seq_len].
value_states_ = value_states # [b, n_heads, seq_len, head_dim]
outer_sum = torch.einsum("bns,bnsd,bnsf->bndf", alpha, K_kappa, value_states_)
# Explanation:
# - 'bnhs' is alpha (batch, n_heads, seq_len)
# - 'bnhsd' is K_kappa (b,n_heads,seq_len, d)
# - 'bnhsf' is V (b,n_heads,seq_len, d)
# We want [b,n_heads,d,f], which is the d×d matrix per head.
# Actually we need an outer product (K_kappa_j^T × V_j). That is [d, d].
# The call above is not quite correct if we want K_kappa_j^T × V_j as [d,d].
# Let's do a simpler approach:
# outer_sum = sum_j alpha_j * (K_kappa_j^T outer V_j).
# = "bnhs,bnhsd,bnhsf -> bnhdf"
# means: alpha has shape (b,n,h,s), K_kappa has shape (b,n,h,s,d), V has shape (b,n,h,s,d)
# We want to produce (b,n,h,d,d).
# So the correct einsum string is 'bnhs,bnhsd,bnhsf->bnhdf':
# alpha indexes b,n,h,s
# K_kappa indexes b,n,h,s,d => K_kappa_j
# V indexes b,n,h,s,f => V_j
# The resulting shape is (b,n,h,d,f). Great.
# 6) For each token i, Y_i = φ(X_i) ∘ [ κ(Q_i) × outer_sum ]
# Here κ(Q_i) is shape [b,n,h,d], outer_sum is shape [b,n,h,d,d].
# We'll do a batch matmul: result_attn = Q_kappa_i × outer_sum => [b,n,h,d]
# Then multiply elementwise by φ(X_i).
# But φ(X_i) is a single [b,seq_len,d_model], so we reshape to [b,seq_len,n,h_dim].
# We'll do per-token i in a loop or broadcast. Let's do it in a single operation with einsum:
# first, compute φ(X):
# X is the original hidden_states: [b, seq_len, d_model]
X_phi = self.phi( # pylint: disable=invalid-name
hidden_states
) # [b, seq_len, d_model]
X_phi = X_phi.view( # pylint: disable=invalid-name
bsz, q_len, self.num_heads, self.head_dim
) # [b, s, n, d]
X_phi = X_phi.transpose(1, 2) # [b, n, s, d] # pylint: disable=invalid-name
# Now for each i in [0..q_len-1], we do a matrix multiply:
# result_attn_i = Q_kappa_i [b,n,s,d] × outer_sum [b,n,d,d] => we want [b,n,s,d].
# We'll do:
result_attn = torch.einsum("bnsd,bndf->bnsf", Q_kappa, outer_sum) # [b,n,s,d]
# Then elementwise multiply by φ(X_i):
context_layer = X_phi * result_attn # [b,n,s,d]
# Finally, reorder to [b, s, n, d] -> [b, s, n*d]
context_layer = context_layer.transpose(1, 2).contiguous() # [b, s, n, d]
context_layer = context_layer.view(bsz, q_len, self.hidden_size)
# One last linear projection:
attn_output = self.o_proj(context_layer)
if output_attentions:
# alpha => [b, n_heads, (past_len + q_len)]
attn_weights = alpha
else:
attn_weights = None
# Return 3-tuple: (attn_output, attn_weights, past_key_value)
return attn_output, attn_weights, past_key_value
class LlamaRalaDecoderLayer(nn.Module):
"""
LlamaDecoderLayer with RALA support
"""
def __init__(self, config: LlamaRalaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
config.num_hidden_layers, layer_idx, config.softmax_every
):
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx
)
# self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
else:
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
@classmethod
def is_layer_idx_softmax(
cls, num_hidden_layers: int, layer_idx: int, softmax_every: int
) -> bool:
inner_layers = num_hidden_layers - 2
if 1 + softmax_every * (inner_layers // softmax_every) == inner_layers:
softmax_start_idx = 1
elif 1 + softmax_every * (inner_layers // softmax_every) > inner_layers:
layer_group_size = 1 + softmax_every * ((inner_layers // softmax_every) - 1)
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
elif 1 + softmax_every * (inner_layers // softmax_every) < inner_layers:
layer_group_size = 1 + softmax_every * (inner_layers // softmax_every)
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
softmax_layers = set(range(softmax_start_idx, num_hidden_layers, softmax_every))
softmax_layers.add(0)
softmax_layers.add(num_hidden_layers - 1)
return layer_idx in softmax_layers
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[
Tuple[torch.Tensor, torch.Tensor]
] = None, # will become mandatory in v4.46
**kwargs,
) -> Tuple[
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*):
attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1,
query_sequence_length, key_sequence_length)` if default attention is used.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
Indices depicting the position of the input sequence tokens in the sequence
position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
with `head_dim` being the embedding dimension of each attention head.
kwargs (`dict`, *optional*):
Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
into the model
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
position_embeddings=position_embeddings,
**kwargs,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,) # type: ignore
if use_cache:
outputs += (present_key_value,) # type: ignore
return outputs # type: ignore
class LlamaRalaModel(LlamaModel):
"""
LlamaModel with RALA support
"""
config_class = LlamaRalaConfig
def __init__(self, config: LlamaRalaConfig):
LlamaPreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, self.padding_idx
)
self.layers = nn.ModuleList(
[
LlamaRalaDecoderLayer(config, layer_idx)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.rotary_emb = LlamaRotaryEmbedding(config=config)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
"""
LlamaForCausalLM with RALA support
"""
config_class = LlamaRalaConfig
_no_split_modules = ["LlamaRalaDecoderLayer"]
_tied_weights_keys = ["lm_head.weight"]
_tp_plan = {"lm_head": "colwise_rep"}
def __init__(self, config):
super().__init__(config)
self.model = LlamaRalaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
def get_input_embeddings(self):
return self.model.embed_tokens
def set_input_embeddings(self, value):
self.model.embed_tokens = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
num_logits_to_keep: int = 0,
**kwargs: Unpack[KwargsForCausalLM], # type: ignore
) -> Union[Tuple, CausalLMOutputWithPast]:
r"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
```"""
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(
logits=logits,
labels=labels,
vocab_size=self.config.vocab_size,
**kwargs,
)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def register_rala_model() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-rala", LlamaRalaConfig)
# Register models
AutoModel.register(LlamaRalaConfig, LlamaRalaModel)
AutoModelForCausalLM.register(LlamaRalaConfig, LlamaRalaForCausalLM)
LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention

View File

@@ -0,0 +1,106 @@
"""
conversion for llama models to use RALA attention
"""
import logging
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.integrations.rala.auto.llama.modeling_rala import (
LlamaRALAAttention,
LlamaRalaDecoderLayer,
)
logger = logging.getLogger(__name__)
ATTENTION_MAPPING = {
LlamaAttention: LlamaRALAAttention,
}
def copy_attention_weights(
old_attn,
new_attn,
zero_init: bool = False,
) -> None:
"""
Copy weights from old attention layer to new RALA layer.
Copies q, k, v, o
"""
new_attn.q_proj.weight.data.copy_(old_attn.q_proj.weight.data)
new_attn.k_proj.weight.data.copy_(old_attn.k_proj.weight.data)
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
# Zero out lambda parameters for exact equivalence
if zero_init:
nn.init.zeros_(new_attn.phi.weight)
else:
nn.init.normal_(new_attn.phi.weight)
if new_attn.phi.bias:
nn.init.normal_(new_attn.phi.bias)
logger.debug(
"Copied positive attention weights from %s to %s",
type(old_attn).__name__,
type(new_attn).__name__,
)
def convert_to_rala(
model: PreTrainedModel, zero_init: bool = False, softmax_every_n: int = 6
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0
def convert_module(module, softmax_every, num_hidden_layers):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
decoder_layer_idx = child.layer_idx
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
num_hidden_layers, decoder_layer_idx, softmax_every
):
continue
# Choose appropriate differential attention class
# pylint: disable=duplicate-code
attention_class = ATTENTION_MAPPING[type(child)]
layer_type = type(child).__name__
logger.info(
f"Converting attention layer {decoder_layer_idx}: {layer_type} to {attention_class.__name__}"
)
# Create new diff attn layer
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
)
# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer
setattr(module, name, new_attention)
layer_idx += 1
elif len(list(child.children())) > 0:
convert_module(child, softmax_every, num_hidden_layers)
model.config.softmax_every = softmax_every_n
convert_module(model, softmax_every_n, model.config.num_hidden_layers)
logger.info(f"Converted {layer_idx} attention layers to RALA attention")
model.config.architectures = [
"LlamaRalaForCausalLM",
]
model.config.model_type = "llama-rala"
# model.config.auto_map = {
# "AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
# "AutoModel": "llama.modeling_rala.LlamaRalaModel",
# "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
# }
return model

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -8,7 +8,7 @@ import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

View File

@@ -1,9 +1,7 @@
"""module for patching with unsloth optimizations"""
import inspect
import re
import types
from typing import Tuple
import torch
from accelerate.logging import get_logger
@@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """
@@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,7 +1,8 @@
"""
Shared utils for the monkeypatches
"""
from typing import Optional
import re
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import inspect
import os
import signal
import sys
@@ -126,7 +127,20 @@ def train(
)
if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset)
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import gc
import logging
import math
import os
@@ -842,3 +843,17 @@ class SaveModelCallback(TrainerCallback):
):
control.should_save = True
return control
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""
def __init__(self, gc_steps=None):
self.gc_steps = gc_steps
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()

View File

@@ -0,0 +1,234 @@
"""
Monitor and log differential attention components during training.
This module provides a callback for tracking the behavior of differential attention
mechanisms, including lambda parameters and attention statistics.
"""
from typing import Any
import torch
import wandb
from torch import nn
from transformers import TrainerCallback
from axolotl.utils.distributed import is_main_process
class DifferentialAttentionMonitorCallback(TrainerCallback):
"""
Callback to monitor differential attention components and lambda parameters.
This callback tracks attention statistics across all layers and provides detailed
monitoring for a specified number of layers evenly spaced through the model.
"""
def __init__(
self,
log_every: int = 250,
num_monitor_layers: int = 3,
warmup_steps: int | None = None,
):
"""
Initialize the differential attention monitor.
Args:
log_every: Number of steps between logging events.
num_monitor_layers: Number of individual layers to monitor in detail.
warmup_steps: Optional parameter for negative attention component warmup.
"""
self.log_every = log_every
self.num_monitor_layers = num_monitor_layers
self.warmup_steps = warmup_steps
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""
Set up layer monitoring at the start of training.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if is_main_process():
num_layers = len(model.model.layers)
self.num_monitor_layers = min(self.num_monitor_layers, num_layers)
stride = (
(num_layers - 1) / (self.num_monitor_layers - 1)
if self.num_monitor_layers > 1
else 0
)
self.monitor_layers = [
round(i * stride) for i in range(self.num_monitor_layers)
]
print(f"Monitoring layers {self.monitor_layers} in detail")
# pylint: disable=unused-argument
def on_step_end(
self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs
) -> None:
"""
Log attention metrics at the end of each step.
Collects and logs:
- Lambda parameter norms and values.
- Attention statistics (mean and std).
- Both per-layer and aggregate metrics.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if not is_main_process() or state.global_step % self.log_every != 0:
return
assert self.monitor_layers is not None
# Aggregate stats across all layers
all_q1_norms = []
all_q2_norms = []
all_k1_norms = []
all_k2_norms = []
all_lambda1 = []
all_lambda2 = []
all_lambda_full = []
metrics = {}
for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn
# Collect stats for aggregation
all_q1_norms.append(attn.lambda_q1.norm().item())
all_q2_norms.append(attn.lambda_q2.norm().item())
all_k1_norms.append(attn.lambda_k1.norm().item())
all_k2_norms.append(attn.lambda_k2.norm().item())
lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item()
lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item()
all_lambda1.append(lambda1)
all_lambda2.append(lambda2)
all_lambda_full.append(attn.lambda_full)
# Log detailed metrics for monitored layers
if layer_idx in self.monitor_layers:
metrics.update(
{
f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(),
f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(),
f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(),
f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(),
f"layer_{layer_idx}/lambda1": lambda1,
f"layer_{layer_idx}/lambda2": lambda2,
f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(),
f"layer_{layer_idx}/lambda_full": lambda1
- lambda2
+ attn.lambda_init.item(),
f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(),
f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(),
f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(),
f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(),
}
)
# Add aggregate metrics
metrics.update(
{
"aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms)
.mean()
.item(),
"aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(),
"aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms)
.mean()
.item(),
"aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(),
"aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms)
.mean()
.item(),
"aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(),
"aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms)
.mean()
.item(),
"aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(),
"aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(),
"aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(),
"aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(),
"aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(),
"aggregate/lambda_full_mean": torch.tensor(all_lambda_full)
.mean()
.item(),
"aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(),
}
)
if self.warmup_steps:
metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix
wandb.log(metrics, step=state.global_step)
class DifferentialAttentionMixingCallback(TrainerCallback):
"""
Callback to gradually increase the weight of negative attention components during
training.
"""
def __init__(self, warmup_steps: int):
"""
Args:
warmup_steps: Number of steps to linearly increase negative attention
weight from 0 to 1. If `None`, negative attention has full weight from
start.
"""
self.warmup_steps = warmup_steps
self.diff_attention_layers: list[nn.Module] | None = None
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""Cache the differential attention layers at the start of training."""
if model is not None:
# Get the actual model if it's wrapped
if hasattr(model, "module"):
model = model.module
# Cache all differential attention layers
self.diff_attention_layers = [
module for module in model.modules() if hasattr(module, "diff_attn_mix")
]
def on_step_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module = None,
**kwargs,
) -> None:
if self.diff_attention_layers and self.warmup_steps:
# Calculate mixing parameter (0 to 1)
mix = min(1.0, state.global_step / self.warmup_steps)
# Update cached layers
for layer in self.diff_attention_layers:
layer.diff_attn_mix = mix

View File

@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
getattr, self.layers_attribute.split("."), self.trainer.model
)
LOG.info(
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
)
def freeze_all_layers(self):

View File

@@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
class UserDefinedPrompterType(BaseModel):
@@ -666,6 +667,8 @@ class AxolotlInputConfig(
loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None
gc_steps: Optional[int] = None
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
@@ -695,6 +698,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None
@@ -792,7 +801,7 @@ class AxolotlInputConfig(
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None

View File

@@ -18,7 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
concatenate: bool = True,
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
@@ -28,8 +31,17 @@ def encode_pretraining(
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = []
new_labels = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
@@ -40,22 +52,34 @@ def encode_pretraining(
),
dim=0,
)
targets[i] = torch.cat(
(
targets[i],
torch.tensor([tokenizer.eos_token_id, -100]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask):
for ids, labels, mask in zip(input_ids, targets, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
@@ -69,6 +93,17 @@ def encode_pretraining(
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
@@ -81,11 +116,14 @@ def encode_pretraining(
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
@@ -101,6 +139,17 @@ def encode_pretraining(
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
@@ -113,11 +162,12 @@ def encode_pretraining(
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_labels],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
@@ -155,6 +205,10 @@ def wrap_pretraining_dataset(
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
elif cfg.pretraining_sample_concatenation is False:
encode = functools.partial(
encode_pretraining, tokenizer, max_tokens, concatenate=False
)
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

View File

@@ -3,7 +3,7 @@
import functools
import logging
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Tuple, Union
from datasets import (
Dataset,
@@ -12,8 +12,6 @@ from datasets import (
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError
from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
@@ -42,6 +40,7 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
md5,
@@ -85,9 +84,11 @@ def prepare_dataset(cfg, tokenizer, processor=None):
processor=processor,
)
else:
# Load streaming dataset if pretraining_dataset is given
path = cfg.pretraining_dataset
split = "train"
name = None
data_files = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
@@ -96,6 +97,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
@@ -105,7 +108,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name),
load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
),
tokenizer,
cfg,
ds_wrapper_partial,
@@ -116,7 +121,18 @@ def prepare_dataset(cfg, tokenizer, processor=None):
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
# Load eval dataset (non-streaming) if specified
eval_dataset = None
if cfg.test_datasets:
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
)
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")
@@ -243,195 +259,9 @@ def load_tokenized_prepared_datasets(
# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith(
"gs://"
) or config_dataset.path.startswith("gcs://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(
config_dataset.path
):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
else:
try:
ds = load_from_disk(config_dataset.path)
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError(
"data_files must be either a string or list of strings"
)
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
)
if not ds:
raise ValueError("unhandled dataset load")
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token
)
d_base_type = d_prompt_style = None
d_type = config_dataset.type
@@ -501,24 +331,6 @@ def load_tokenized_prepared_datasets(
return dataset, prompters
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,

View File

@@ -0,0 +1,222 @@
"""
dataset loading shared utils
"""
from pathlib import Path
from typing import Optional, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HFValidationError
from axolotl.utils.dict import DictDefault
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_dataset_w_config(config_dataset, auth_token):
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
"gcs://"
):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(config_dataset.path):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
else:
try:
ds = load_from_disk(
config_dataset.path
) # pylint: disable=invalid-name
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
data_files=config_dataset.data_files,
token=auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError("data_files must be either a string or list of strings")
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
)
if not ds:
raise ValueError("unhandled dataset load")
return ds

View File

@@ -270,7 +270,7 @@ def load_sharded_model_quant(
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
print(f"Loaded model weights in {time.time() - start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()

View File

@@ -48,6 +48,7 @@ from transformers.integrations.deepspeed import (
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.base import PluginManager
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -375,8 +376,6 @@ class ModelLoader:
def apply_patches(self) -> None:
# load any patches from plugins
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg)
@@ -713,24 +712,53 @@ class ModelLoader:
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
if self.cfg.diff_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_flash_attention_2"
)
else:
self.model_kwargs["attn_implementation"] = "flash_attention_2"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
)
elif self.cfg.sdp_attention:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
)
else:
self.model_kwargs["attn_implementation"] = "sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
if self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
)
else:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.diff_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
"differential_eager"
)
if self.cfg.low_cpu_mem_usage:
self.model_kwargs["low_cpu_mem_usage"] = True
plugin_manager = PluginManager.get_instance()
plugin_manager.set_attn_config(self.cfg, self.model_kwargs, self.model_config)
def build_model(self, qlora_fsdp) -> bool:
def _configure_zero3_memory_efficient_loading():
"""
@@ -816,6 +844,7 @@ class ModelLoader:
if self.cfg.is_multimodal:
self.model_config.text_config = self.text_model_config
self.model = self.AutoModelLoader.from_pretrained(
self.base_model,
config=self.model_config,

View File

@@ -196,7 +196,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "falcon":
if cfg.model_config_type in ["falcon", "mistral"]:
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")

157
src/axolotl/utils/yaml.py Normal file
View File

@@ -0,0 +1,157 @@
"""Utilities for YAML files."""
from collections import OrderedDict
from typing import Any, Dict, List, Set, Tuple, Union
import yaml
class YAMLOrderTracker:
"""Tracks the order of keys and section breaks in YAML files."""
def __init__(self, yaml_path: str):
self.yaml_path = yaml_path
self.structure, self.needs_break = self._parse_yaml_structure()
def _get_indentation_level(self, line: str) -> int:
"""Get the indentation level of a line."""
return len(line) - len(line.lstrip())
def _parse_yaml_structure(
self,
) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]:
"""Parse the YAML file to extract structure and identify section breaks."""
with open(self.yaml_path, "r", encoding="utf-8") as file:
contents = file.readlines()
structure: OrderedDict = OrderedDict()
needs_break = set() # Track which keys should have a break before them
current_path = []
last_indentation = -1
had_empty_line = False
for line in contents:
# Track empty lines and comments
if not line.strip() or line.strip().startswith("#"):
had_empty_line = True
continue
# Get indentation level and content
indentation = self._get_indentation_level(line)
content = line.strip()
# Skip lines that don't define keys
if ":" not in content:
continue
# Extract key
key = content.split(":")[0].strip()
# If this is a top-level key and we had an empty line, mark it
if indentation == 0:
if had_empty_line:
needs_break.add(key)
had_empty_line = False
# Handle indentation changes
if indentation > last_indentation:
current_path.append(key)
elif indentation < last_indentation:
levels_up = (last_indentation - indentation) // 2
current_path = current_path[:-levels_up]
current_path[-1] = key
else:
if current_path:
current_path[-1] = key
# Update structure
current_dict = structure
for path_key in current_path[:-1]:
if path_key not in current_dict:
current_dict[path_key] = OrderedDict()
current_dict = current_dict[path_key]
if current_path:
if current_path[-1] not in current_dict:
current_dict[current_path[-1]] = OrderedDict()
last_indentation = indentation
return structure, needs_break
class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict:
"""Reorder a dictionary based on a reference structure."""
ordered = OrderedDict()
# First add keys that are in the reference order
for key in reference_structure:
if key in data:
if isinstance(reference_structure[key], dict) and isinstance(
data[key], dict
):
ordered[key] = reorder_dict(data[key], reference_structure[key])
else:
ordered[key] = data[key]
# Then add any remaining keys that weren't in the reference
for key in data:
if key not in ordered:
ordered[key] = data[key]
return ordered
def dump_yaml_preserved_order(
data: Dict, reference_yaml_path: str, output_path: str
) -> None:
"""Dump YAML file while preserving nested order and normalized spacing."""
# Get reference structure and spacing
tracker = YAMLOrderTracker(reference_yaml_path)
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)
# First dump to string
yaml_str = yaml.dump(
ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False
)
# Add spacing according to reference
lines = yaml_str.split("\n")
result_lines: List[str] = []
current_line = 0
while current_line < len(lines):
line = lines[current_line]
if line.strip() and ":" in line and not line.startswith(" "): # Top-level key
key = line.split(":")[0].strip()
if key in tracker.needs_break:
# Add single empty line before this key
if result_lines and result_lines[-1] != "":
result_lines.append("")
result_lines.append(line)
current_line += 1
# Write the final result
with open(output_path, "w", encoding="utf-8") as file:
file.write("\n".join(result_lines))

View File

@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module."""
import pytest
from click.testing import CliRunner

View File

@@ -43,14 +43,12 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch
from axolotl.cli.main import fetch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli
@@ -22,6 +23,7 @@ def test_build_command():
"--batch-size",
"8",
"--debug",
"--nouse-fp16",
]

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command."""
import shutil
from pathlib import Path
from unittest.mock import patch

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -11,14 +12,12 @@ def test_shard_with_accelerate(cli_runner, config_path):
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch

View File

@@ -120,13 +120,12 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import (
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
original_fa2_forward = LlamaFlashAttention2.forward
# original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
@@ -136,7 +135,7 @@ def cleanup_monkeypatches():
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
# LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
@@ -149,7 +148,10 @@ def cleanup_monkeypatches():
("transformers.models.llama",),
(
"transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"],
[
# "LlamaFlashAttention2",
"LlamaAttention",
],
),
("transformers.trainer",),
("transformers", ["Trainer"]),

View File

@@ -0,0 +1,31 @@
"""Shared fixtures for differential transformer conversion tests."""
import pytest
from click.testing import CliRunner
@pytest.fixture(scope="class")
def base_config():
"""Basic config for testing."""
return {
"base_model": "HuggingFaceTB/SmolLM2-135M",
"datasets": [
{
"path": "axolotl-ai-co/alpaca_100_test",
"type": "alpaca",
},
],
"gradient_accumulation_steps": 1,
"learning_rate": 1e-4,
"val_set_size": 0.1,
"micro_batch_size": 1,
"sequence_len": 2048,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
}
@pytest.fixture(scope="class")
def cli_runner():
return CliRunner()

View File

@@ -0,0 +1,51 @@
"""End-to-end tests for differential transformer conversion and evaluation."""
# pylint: disable=duplicate-code
from pathlib import Path
import yaml
from pytest import approx
from axolotl.cli import load_cfg
from axolotl.cli.evaluate import do_evaluate
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
eval_cfg = load_cfg(str(output_dir))
eval_cli_args = EvaluateCliArgs()
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
assert list(all_metrics.keys()) == [
"train_loss",
"train_model_preparation_time",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"eval_loss",
"eval_model_preparation_time",
"eval_runtime",
"eval_samples_per_second",
"eval_steps_per_second",
]
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)

View File

@@ -0,0 +1,150 @@
"""End-to-end tests for differential transformer conversion."""
# pylint: disable=redefined-outer-name
# pylint: disable=duplicate-code
from pathlib import Path
from typing import Optional
from unittest.mock import patch
import pytest
import yaml
from axolotl.cli import load_cfg
from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer
from axolotl.cli.main import cli
from axolotl.common.cli import ConvertDiffTransformerCliArgs
def test_cli_validation(cli_runner):
# Test missing config file
result = cli_runner.invoke(cli, ["convert-diff-transformer"])
assert result.exit_code != 0
assert "Error: Missing argument 'CONFIG'." in result.output
# Test non-existent config file
result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"])
assert result.exit_code != 0
assert "Error: Invalid value for 'CONFIG'" in result.output
def test_basic_execution(cli_runner, tmp_path: Path, base_config):
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
with patch(
"axolotl.cli.integrations.convert_diff_transformer.do_cli"
) as mock_do_cli:
result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)])
assert result.exit_code == 0
mock_do_cli.assert_called_once()
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
def test_conversion_cli_basic(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs()
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_debug(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info["generations_match"]
assert not debug_info["match_expected"]
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_cli_reproduce(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_repoduce_attentions(
tmp_path: Path, base_config, attention: Optional[str]
):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
# Smallest model with an even number of attention heads
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()

View File

@@ -1,43 +1,40 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path
from e2e.utils import require_torch_2_4_1
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
class LigerIntegrationTestCase(unittest.TestCase):
class LigerIntegrationTestCase:
"""
e2e tests for liger integration with Axolotl
"""
@with_temp_dir
@require_torch_2_4_1
def test_llama_wo_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_glu_activation": True,
"liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -46,15 +43,15 @@ class LigerIntegrationTestCase(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
"max_steps": 5,
}
)
prepare_plugins(cfg)
@@ -65,26 +62,24 @@ class LigerIntegrationTestCase(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@with_temp_dir
@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_glu_activation": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -93,15 +88,15 @@ class LigerIntegrationTestCase(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
"max_steps": 5,
}
)
prepare_plugins(cfg)

View File

@@ -1,9 +1,14 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""

View File

@@ -20,6 +20,9 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models

View File

@@ -113,6 +113,7 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -49,7 +49,19 @@ def require_torch_2_3_1(test_case):
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.3.1")
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
return unittest.skipUnless(is_min_2_3_1(), "test requires torch>=2.3.1")(test_case)
def require_torch_2_4_1(test_case):
"""
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_4_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.4.1")
return unittest.skipUnless(is_min_2_4_1(), "test requires torch>=2.4.1")(test_case)
def require_torch_2_5_1(test_case):
@@ -61,7 +73,7 @@ def require_torch_2_5_1(test_case):
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.5.1")
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
def is_hopper():

View File

@@ -7,11 +7,11 @@ from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_base_cfg")
@pytest.fixture(name="minimal_liger_cfg")
def fixture_cfg():
return DictDefault(
{
@@ -25,56 +25,57 @@ def fixture_cfg():
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
}
)
class BaseValidation:
# pylint: disable=too-many-public-methods
class TestValidation:
"""
Base validation module to setup the log capture
Test the validation module for liger
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
caplog.set_level(logging.WARNING)
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
def test_deprecated_swiglu(self, minimal_liger_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_cfg
| minimal_liger_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level(
logging.WARNING, logger="axolotl.integrations.liger.args"
):
prepare_plugins(test_cfg)
updated_cfg = validate_config(test_cfg)
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
# TODO this test is brittle in CI
# assert (
# "The 'liger_swiglu' argument is deprecated"
# in self._caplog.records[0].message
# )
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activations is False
assert updated_cfg.liger_glu_activation is False
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activations": True,
"liger_glu_activation": True,
}
| minimal_cfg
| minimal_liger_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
prepare_plugins(test_cfg)
validate_config(test_cfg)

View File

@@ -4,9 +4,7 @@ import json
import logging
import unittest
from pathlib import Path
from typing import Optional
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
@@ -65,12 +63,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies.
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")