Compare commits

..

62 Commits

Author SHA1 Message Date
Wing Lian
8c4f89745a fix softmax class check 2025-01-15 23:23:13 -05:00
Wing Lian
36b71f34d7 register rala 2025-01-15 23:21:22 -05:00
Wing Lian
d28fee7609 use autoconfig w rala 2025-01-15 23:14:47 -05:00
Wing Lian
c196776996 option to not concatenate during pretraining 2025-01-15 22:45:02 -05:00
Wing Lian
79ae776102 fixup logging layer 2025-01-15 21:36:14 -05:00
Wing Lian
145664d82c more fixups 2025-01-15 21:27:12 -05:00
Dan Saunders
28694219a5 inline comment change 2025-01-14 16:59:43 +00:00
Dan Saunders
fd8ad6fcbf fixing negative component mixing 2025-01-13 19:21:55 +00:00
Dan Saunders
661d71a14b adding diff attn negative component warmup (in progress) 2025-01-10 21:57:31 +00:00
Dan Saunders
6dd47edcb8 fire CLI fixes 2025-01-10 18:24:16 +00:00
Dan Saunders
7aca08ff60 adding guard statements 2025-01-10 16:39:21 +00:00
Dan Saunders
4f804f6d88 adding diff attn callback, adding documentation 2025-01-10 16:28:51 +00:00
Dan Saunders
443327c585 CLI build_command bugfix 2025-01-10 16:28:51 +00:00
Dan Saunders
70c4e6fbe6 updates and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
2a7f139ad2 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
332ce0ae85 fixes and cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
e5fa842ff8 update 2025-01-10 16:28:51 +00:00
Dan Saunders
78e0ec0aa5 changes 2025-01-10 16:28:51 +00:00
Dan Saunders
3bc568eb27 adding registration function 2025-01-10 16:28:51 +00:00
Dan Saunders
eb6611d55f progress on modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
4ff3328e66 updated custom modeling code 2025-01-10 16:28:51 +00:00
Dan Saunders
a3fd5074a9 fix duplicate-code warnings 2025-01-10 16:28:51 +00:00
Dan Saunders
5b90da0be3 added modeling code; cleanup + refactor 2025-01-10 16:28:51 +00:00
Dan Saunders
fcbfa86373 refactor and fixing test isolation issues 2025-01-10 16:28:51 +00:00
Dan Saunders
0d56582090 adding yaml dumper preserving input config format 2025-01-10 16:28:51 +00:00
Dan Saunders
390cb5742e removing extra pytest xdist args 2025-01-10 16:28:51 +00:00
Dan Saunders
1d935f65c3 moving tests around for flash_attn install 2025-01-10 16:28:51 +00:00
Dan Saunders
66176b3e07 adding split_heads argument for retaining original (Q, K) dimensionanlity 2025-01-10 16:28:51 +00:00
Dan Saunders
505321ac95 isolating problematic test 2025-01-10 16:28:51 +00:00
Dan Saunders
0b382c88da fixes post-rebase 2025-01-10 16:28:51 +00:00
Dan Saunders
ea07a7086e plugin implementation 2025-01-10 16:28:51 +00:00
Dan Saunders
d22e1136bc convert-differential-transformer test coverage 2025-01-10 16:28:51 +00:00
Dan Saunders
63b8e42c6b duplicate code ignore 2025-01-10 16:28:51 +00:00
Dan Saunders
bda1eed59e differential flash attention 2; cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
41ebd93158 moving monkeypatch 2025-01-10 16:28:51 +00:00
Dan Saunders
4c050ce807 pre-commit fix 2025-01-10 16:28:51 +00:00
Dan Saunders
6665acf63d fix model save / load logic 2025-01-10 16:28:51 +00:00
Dan Saunders
2f9fa4c465 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
849bc94112 various improvemnents 2025-01-10 16:28:51 +00:00
Dan Saunders
e484ec778d training fixes, patching, minor cleanup 2025-01-10 16:28:51 +00:00
Dan Saunders
df1504ae14 adding CLI command for convert-diff-transformer 2025-01-10 16:28:51 +00:00
Dan Saunders
7be0d7496c Adding script for doing conversion; fixes and updates 2025-01-10 16:28:51 +00:00
Dan Saunders
13cdffa91f initial diff attn layer / model conversion implementation (support for llama arch) 2025-01-10 16:28:51 +00:00
Dan Saunders
7a4b296f60 Basic evaluate CLI command / codepath (#2188)
* basic evaluate CLI command / codepath

* tests for evaluate CLI command

* fixes and cleanup

* review comments; slightly DRYing up things

---------

Co-authored-by: Dan Saunders <danjsaund@gmail.com>
2025-01-10 16:28:51 +00:00
Wing Lian
d8b4027200 use 2.5.1 docker images as latest tag as it seems stable (#2198) 2025-01-10 08:35:25 -05:00
Wing Lian
fb3352e21c rename liger test so it properly runs in ci (#2246) 2025-01-09 17:31:43 -05:00
NanoCode012
ed77e7001e feat: add support for data_files in pretraining (#2238) 2025-01-09 21:04:13 +00:00
Wing Lian
7669a03fb4 update upstream HF deps (#2239)
* bump axolotl contribs for upstream main conflicts:

* bump datasets, tokenizer, trl

* remove log workarounds in trl

* bump lm-eval

* remove unsloth_ import from critical path

* remove llama fa2 from conftest

* unsloth breaks with latest upstream
2025-01-09 21:01:59 +00:00
Vincenzo di Cicco
6553683170 Use SequentialSampler if curriculum_sampling is enabled with sample_packing (#2235) 2025-01-09 21:01:22 +00:00
Wing Lian
5e0124e2ab update modal version for ci (#2242) 2025-01-09 21:01:02 +00:00
NanoCode012
2e8d7c1adb fix: mistral nemo does not recognize token_type_ids in forward (#2233) 2025-01-09 21:00:36 +00:00
Wing Lian
3c1921e400 add hf cache caching for GHA (#2247)
* add hf cache caching for GHA

* use modal volume to cache hf data

* make sure to update the cache as we add new fixtures in conftest
2025-01-09 20:59:54 +00:00
Wing Lian
7faf2b6e8e Merge group queue (#2248)
* add support for merge groups

* also lint merge groups
2025-01-09 15:49:00 -05:00
salman
c1b920f291 Fixing OSX installation (#2231)
* bumping version, removing non-osx compatible deps

* updating pylintrc

* fixing linters

* reverting changes
2025-01-07 13:42:01 +00:00
Wing Lian
3915abee4c make sure padding is labeled as -100 for pretraining (#2227) 2024-12-31 15:22:18 -05:00
NJordan72
7a38dbe674 fix: allow trainer builder to use custom jinja chat template (#2219)
* fix: allow trainer builder to use custom jinja chat template

* chore: use get_chat_template_from_config

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>

* fix: swap imports

---------

Co-authored-by: Chirag Jain <jain.chirag925@gmail.com>
2024-12-24 16:18:50 -05:00
Wing Lian
e0a2eb2ebd fix untrained tokens if specified explicitly from a list (#2210) 2024-12-23 09:08:28 -05:00
Wing Lian
d852d7af7a inference - don't default w accelerate, fix base model (#2216) [skip ci] 2024-12-23 07:48:41 -05:00
Wing Lian
3742deb1de add deepspeed example with torch compile enabled (#2212) [skip ci] 2024-12-22 12:11:39 -05:00
Wing Lian
2312caaa98 GC every n steps (#2209) 2024-12-21 17:38:33 -05:00
Wing Lian
307cf7c685 move the dataset loading from remote/disk to a shared function so we can re-use for RL (#2204) 2024-12-20 21:43:52 -05:00
Dan Saunders
70541145f1 adding test_datasets compat with pretraining_dataset (streaming) (#2206) [skip ci] 2024-12-20 21:43:33 -05:00
63 changed files with 1842 additions and 1129 deletions

View File

@@ -1,6 +1,7 @@
name: lint
on:
# check on PRs, and manual triggers
merge_group:
pull_request:
paths:
- '**.py'

View File

@@ -25,7 +25,6 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -36,6 +35,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -92,7 +92,6 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -103,6 +102,7 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -129,7 +129,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -1,6 +1,7 @@
name: Tests
on:
# check on push/merge to main, PRs, and manual triggers
merge_group:
push:
branches:
- "main"
@@ -60,6 +61,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -100,6 +110,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -115,6 +134,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -156,6 +184,15 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
@@ -183,7 +220,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -229,7 +266,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
pip install modal==0.71.8 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v2.17.4
rev: v3.3.0
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy

View File

@@ -1,5 +1,5 @@
[MASTER]
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
[TYPECHECK]
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -8,6 +8,7 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -28,6 +28,7 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -48,6 +49,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
@@ -67,6 +74,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")

View File

@@ -29,6 +29,7 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -50,6 +51,12 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
@@ -69,6 +76,7 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")

View File

@@ -0,0 +1,27 @@
{
"zero_optimization": {
"stage": 1,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"compile": {
"disable": false,
"backend": "inductor"
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -2,7 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=2.3.0
triton>=3.0.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
@@ -14,11 +14,11 @@ packaging==23.2
peft==0.14.0
transformers==4.47.1
tokenizers>=0.20.1
tokenizers>=0.21.0
accelerate==1.2.1
datasets==3.1.0
datasets==3.2.0
deepspeed==0.16.1
trl==0.12.1
trl==0.13.0
optimum==1.16.2
hf_transfer
@@ -53,7 +53,7 @@ zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.4
lm_eval==0.4.7
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.1b2
axolotl-contribs-lgpl==0.0.3

View File

@@ -1,4 +1,5 @@
"""setup.py for axolotl"""
import ast
import os
import platform
@@ -29,15 +30,30 @@ def parse_requirements():
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
@@ -73,6 +89,8 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")

View File

@@ -202,7 +202,7 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
elif cfg.datasets and cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)

View File

@@ -1,4 +1,5 @@
"""CLI to convert a transformers model's attns to diff attns."""
"""CLI to convert a transformers model's attention layers to differential attention layers."""
import logging
import warnings
from pathlib import Path
@@ -14,7 +15,10 @@ from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
)
from axolotl.utils.yaml import dump_yaml_preserved_order
LOG = logging.getLogger(__name__)
@@ -22,37 +26,37 @@ LOG = logging.getLogger(__name__)
def test_inference(model, tokenizer, prompt="The quick brown fox"):
"""Run test inference and return generation time"""
try:
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {
k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()
}
inputs = tokenizer(prompt, return_tensors="pt")
inputs = {k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items()}
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
start = time()
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=20,
num_beams=1,
do_sample=False,
pad_token_id=tokenizer.pad_token_id,
use_cache=False,
)
elapsed = time() - start
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
LOG.info("Prompt: %s", prompt)
LOG.info("Generated: %s", generated_text)
LOG.info("Generation time: %.2fs", elapsed)
return elapsed, generated_text
except Exception as exc:
LOG.error("Inference failed: %s", str(exc))
raise
return elapsed, generated_text
def convert_diff_transformer(cfg, cli_args, config_path):
assert not (
cli_args.split_heads and cli_args.zero_init
), "Both `split_heads` and `zero_init` cannot be `True`"
assert not (
cli_args.zero_init and cli_args.mirror_weights
), "Both `zero_init` and `mirror_weights` cannot be `True`"
debug_info = {}
# Load model and tokenizer
@@ -75,22 +79,18 @@ def convert_diff_transformer(cfg, cli_args, config_path):
model, tokenizer
)
# Convert attention
LOG.info("Converting to differential attention...")
if cli_args.split_heads and cli_args.zero_init:
LOG.warning(
Fore.YELLOW
+ "Warning: Using split_heads with zero_init is not recommended; "
+ "split_heads will preclude the effects of zero_init"
+ Fore.RESET
)
try:
model = convert_to_diff_attn(
model=model,
# Convert attention
LOG.info("Converting to differential attention...")
config = LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
mirror_weights=cli_args.mirror_weights,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
@@ -127,6 +127,7 @@ def convert_diff_transformer(cfg, cli_args, config_path):
else:
modified_cfg["plugins"] = [plugin_class]
# Write out the updated axolotl config while preserving original ordering / formatting
dump_yaml_preserved_order(
data=modified_cfg,
reference_yaml_path=config_path,

View File

@@ -82,6 +82,7 @@ def convert_rala(cfg, cli_args, config_path):
zero_init=cli_args.zero_init,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
model.config.model_type = "llama-rala"
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise

View File

@@ -101,7 +101,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
default=False,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
@@ -132,7 +132,7 @@ def inference(
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["output_dir"] = base_model
kwargs["base_model"] = base_model
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]

View File

@@ -92,6 +92,8 @@ def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
if isinstance(value, bool):
if value:
cmd.append(f"--{key}")
else:
cmd.append(f"--no{key}")
else:
cmd.extend([f"--{key}", str(value)])

View File

@@ -12,14 +12,12 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
LOG = logging.getLogger(__name__)
@dataclass
class PreprocessCliArgs:
"""
dataclass with arguments for preprocessing only
"""
"""dataclass with arguments for preprocessing only"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -30,9 +28,7 @@ class PreprocessCliArgs:
@dataclass
class TrainerCliArgs:
"""
dataclass with various non-training arguments
"""
"""dataclass with various non-training arguments"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -45,9 +41,7 @@ class TrainerCliArgs:
@dataclass
class EvaluateCliArgs:
"""
dataclass with various evaluation arguments
"""
"""dataclass with various evaluation arguments"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
@@ -56,14 +50,13 @@ class EvaluateCliArgs:
@dataclass
class ConvertDiffTransformerCliArgs:
"""
dataclass with arguments for convert-diff-transformer CLI
"""
"""dataclass with arguments for convert-diff-transformer CLI"""
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)
mirror_weights: bool = field(default=False)
def load_model_and_tokenizer(

View File

@@ -22,7 +22,6 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -56,6 +55,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GCCallback,
GPUStatsCallback,
LossWatchDogCallback,
SaveAxolotlConfigtoWandBCallback,
@@ -67,7 +67,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -607,8 +607,14 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
sampler,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
@@ -977,12 +983,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
return super().log(logs, start_time)
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1166,22 +1167,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1190,22 +1175,6 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
@@ -1214,49 +1183,6 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
@@ -1265,22 +1191,6 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
@@ -1289,15 +1199,6 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC):
"""
@@ -1452,6 +1353,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
callbacks.append(SaveModelCallback())
return callbacks
@@ -1831,8 +1734,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
tokenizer=self.tokenizer,
)
@@ -1974,6 +1877,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None
if self.cfg.model_config_type == "mamba":

View File

@@ -48,12 +48,12 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg): # pylint: disable=unused-argument
def register(self): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
Parameters:
cfg (dict): The configuration for the plugin.
None
Returns:
None
@@ -80,12 +80,10 @@ class BasePlugin:
): # pylint: disable=unused-argument
"""
Sets attention configuration for the model.
Parameters:
cfg (dict): The configuration for the plugin.
model_kwargs (dict): The model kwargs for the plugin.
model_config (object): The model configuration.
Returns:
None
"""
@@ -289,6 +287,7 @@ class PluginManager:
try:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
plugin.register()
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -322,7 +321,6 @@ class PluginManager:
def set_attn_config(self, cfg, model_kwargs, model_config):
"""
modifies the attention configuration of the model kwargs for loading
Parameters:
cfg (dict): The configuration for the plugins.
model_kwargs (dict): The model's kwargs for construction the model

View File

@@ -2,6 +2,8 @@
### Usage
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.
```yaml
plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin

View File

@@ -1,25 +1,67 @@
"""Definition of differential transformer plugin."""
import logging
from typing import List
from transformers import PreTrainedModel, TrainerCallback
from axolotl.integrations.base import BasePlugin
from axolotl.utils.callbacks.diff_attn import (
DifferentialAttentionMixingCallback,
DifferentialAttentionMonitorCallback,
)
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""
Plugin for differential transformer integration with Axolotl.
"""
"""Plugin for differential transformer integration with Axolotl."""
def get_input_args(self):
def __init__(self) -> None:
"""
Constructor for differential transformers plugin. Calls `register_diff_attn`
to register differential attention custom modeling implementation to `AutoConfig`
and `AutoModel`.
"""
from .modeling_diff_attn import register_diff_attn
register_diff_attn()
def get_input_args(self) -> str:
"""Returns module path to diff transformer plugin args for `axolotl` config."""
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.diff_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
# pylint: disable=unused-argument
def add_callbacks_pre_trainer(
self, cfg: DictDefault, model: PreTrainedModel
) -> List[TrainerCallback]:
"""
Returns `DifferentialAttentionMonitorCallback` to be added to the list of
callbacks for the `axolotl` trainer if wandb usage is enabled.
Parameters:
cfg: Dictionary mapping `axolotl` config keys to values.
model: The loaded mfodel.
Returns:
A list (possibly) containing an instantiated `DifferentialAttentionMonitorCallback`.
"""
callbacks = []
if cfg.use_wandb:
callbacks.append(
DifferentialAttentionMonitorCallback(
log_every=cfg.diff_attn_log_every,
num_monitor_layers=cfg.diff_attn_num_monitor_layers,
warmup_steps=cfg.diff_attn_warmup_steps,
)
)
patch_llama_attention_classes()
if cfg.diff_attn_warmup_steps:
callbacks.append(
DifferentialAttentionMixingCallback(
warmup_steps=cfg.diff_attn_warmup_steps
)
)
return callbacks

View File

@@ -9,6 +9,19 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
"""
Input args for differential transformer.
Attributes:
diff_attention: Whether to use differential attention layers.
diff_attn_log_every: How often to log differential attention statistics.
diff_attn_num_monitor_layers: Number of layers to monitor for attention stats.
diff_attn_warmup_steps: Number of steps to linearly increase negative attention
mixing weight from 0 to 1. If specified, will reach full mixing at this
step. If `None`, negative attention has full weight from the start.
"""
diff_attention: Optional[bool] = None
diff_attn_log_every: Optional[int] = 100
diff_attn_num_monitor_layers: Optional[int] = 3
diff_attn_warmup_steps: Optional[int] = None

View File

@@ -1,130 +0,0 @@
"""Differential attention conversion logic for a huggingface pre-trained model."""
import logging
from typing import Union
import torch
from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaSdpaAttention,
)
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
ATTENTION_MAPPING = {
LlamaAttention: LlamaDifferentialAttention,
LlamaSdpaAttention: LlamaDifferentialSdpaAttention,
LlamaFlashAttention2: LlamaDifferentialFlashAttention2,
}
def copy_attention_weights(
old_attn: Union[LlamaAttention, LlamaSdpaAttention, LlamaFlashAttention2],
new_attn: Union[
LlamaDifferentialAttention,
LlamaDifferentialSdpaAttention,
LlamaDifferentialFlashAttention2,
],
zero_init: bool = False,
) -> None:
"""
Copy weights from old attention layer to new differential attention layer.
Copies old weights to Q1 and K1, zeros out Q2 and K2 for exact equivalence
to original attention mechanism.
"""
# For Q projection (Q1 and Q2)
new_q = torch.empty_like(new_attn.q_proj.weight.data)
new_q[: new_attn.hidden_size] = old_attn.q_proj.weight.data # Q1
if zero_init:
new_q[new_attn.hidden_size :] = 0
else:
nn.init.normal_(new_q[new_attn.hidden_size :], mean=0, std=0.1)
new_attn.q_proj.weight.data.copy_(new_q)
# For K projection (K1 and K2)
old_kv_size = old_attn.k_proj.weight.data.size(0) # Size for 3 heads
new_k = torch.empty_like(new_attn.k_proj.weight.data)
new_k[:old_kv_size] = old_attn.k_proj.weight.data # K1
if zero_init:
new_k[old_kv_size:] = 0
else:
nn.init.normal_(new_k[old_kv_size:], mean=0, std=0.1)
new_attn.k_proj.weight.data.copy_(new_k)
# For V projection (single V)
new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data)
# Output projection remains the same
new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data)
# Zero out lambda parameters for exact equivalence
if zero_init:
nn.init.zeros_(new_attn.lambda_q1)
nn.init.zeros_(new_attn.lambda_k1)
nn.init.zeros_(new_attn.lambda_q2)
nn.init.zeros_(new_attn.lambda_k2)
nn.init.zeros_(new_attn.lambda_init)
logger.debug(
"Copied positive attention weights from %s to %s",
type(old_attn).__name__,
type(new_attn).__name__,
)
def convert_to_diff_attn(
model: PreTrainedModel,
zero_init: bool = False,
sublayer_norm: bool = True,
split_heads: bool = True,
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0
# Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm
model.config.split_heads = split_heads
def convert_module(module):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
# Choose appropriate differential attention class
attention_class = ATTENTION_MAPPING[type(child)]
layer_type = type(child).__name__
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
)
# Create new diff attn layer
new_attention = attention_class(
config=module.config if hasattr(module, "config") else model.config,
layer_idx=layer_idx,
)
# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
if not split_heads:
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer
setattr(module, name, new_attention)
layer_idx += 1
elif len(list(child.children())) > 0:
convert_module(child)
convert_module(model)
logger.info(f"Converted {layer_idx} attention layers to differential attention")
return model

View File

@@ -1,13 +1,13 @@
"""Re-implemention of differential attention."""
"""Re-implemention of differential attention from the Differential Transformer paper
(https://arxiv.org/abs/2410.05258)."""
# pylint: disable=invalid-name
import logging
import math
from typing import Any, Optional, Tuple
from typing import Any
import torch
import torch.nn.functional as F
from flash_attn.flash_attn_interface import flash_attn_func
from torch import nn
from transformers.cache_utils import Cache
from transformers.models.llama.modeling_llama import (
@@ -17,11 +17,29 @@ from transformers.models.llama.modeling_llama import (
)
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
LOG = logging.getLogger(__name__)
try:
from flash_attn.flash_attn_interface import flash_attn_func
FLASH_ATTENTION_AVAILABLE = True
except ImportError:
FLASH_ATTENTION_AVAILABLE = False
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
"""torch.repeat_interleave(x, dim=1, repeats=n_rep)"""
"""
Repeats key/value heads to match the number of query heads in multi-head attention.
Args:
x: Input tensor of shape `(batch_size, num_kv_heads, seq_len, head_dim)`.
n_rep: Number of times to repeat each head.
Returns:
Tensor with repeated heads of shape `(batch_size, num_kv_heads * n_rep,
seq_len, head_dim)`.
If `n_rep` is 1, returns the input tensor unchanged.
"""
batch_size, n_kv_heads, slen, head_dim = x.shape
if n_rep == 1:
return x
@@ -32,69 +50,132 @@ def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
)
def lambda_init_fn(depth):
def lambda_init_fn(depth: int) -> float:
"""
Lambda mixing parameter init function from the "Differential Transformer" paper.
Args:
depth: Index of layer to init lambda parameter.
Returns:
Lambda initialization value (decreasing with `depth`).
"""
return 0.8 - 0.6 * math.exp(-0.3 * depth)
class DifferentialAttentionBase(nn.Module):
"""Base class for differential attention implementations."""
class LlamaDifferentialAttentionBase(nn.Module):
"""
Base class for differential attention implementations.
This class implements the core differential attention mechanism used in Llama models.
It supports both split heads and double projection modes for attention computation.
"""
def __init__(self, config: Any, layer_idx: int):
"""
Initializes the differential attention module.
Args:
config: Model configuration object containing hyperparameters, including:
- hidden_size: The size of hidden states.
- num_attention_heads: Number of attention heads.
- num_key_value_heads: Number of key/value heads.
- attention_bias: Whether to use bias in attention projections.
- split_heads: Whether to use split heads mode.
- rms_norm_eps: Epsilon for RMS normalization.
layer_idx: The index of this layer in the model.
Note:
The initialization process consists of four steps:
1. Configuration initialization (`_init_config`)
2. Projection layers initialization (`_init_projections`)
3. Differential parameters initialization (`_init_differential_params`)
4. Normalization layers initialization (`_init_normalization`)
"""
super().__init__()
self._init_config(config, layer_idx)
self.config = config
self._init_config(layer_idx)
self._init_projections()
self._init_differential_params()
self._init_normalization(config)
self._init_normalization()
def _init_config(self, config: Any, layer_idx: int):
"""Initialize configuration parameters."""
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
# For logging
self.attn1 = None
self.attn2 = None
self.lambda_full = None
def _init_config(self, layer_idx: int) -> None:
"""
Initializes configuration parameters for the attention layer. Sets up various
dimension sizes and head counts based on the provided config. Handles both
split heads and double projection modes.
In split heads mode, the number of heads is divided by 2 (rounding down), which
differs from the original implementation that required an even number.
Args:
layer_idx: Index of the current layer.
"""
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads
self.base_num_kv_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.split_heads = config.split_heads
if config.split_heads:
# Split heads mode - single projections
self.head_dim = config.hidden_size // config.num_attention_heads // 2
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even.
if self.config.split_heads:
self.heads_per_component = self.base_num_heads // 2
self.kv_heads_per_component = self.base_num_kv_heads // 2
self.value_head_dim = 2 * self.head_dim
else:
# Double projection mode
self.head_dim = config.hidden_size // config.num_attention_heads
self.heads_per_component = self.base_num_heads
self.kv_heads_per_component = self.base_num_kv_heads
self.value_head_dim = self.head_dim
def _init_projections(self):
"""Initialize Q, K, V projections."""
if self.split_heads:
# Split heads mode - single projections
q_out_dim = self.hidden_size
k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads
def _init_projections(self) -> None:
"""
Initializes the query, key, value, and output projection layers.
Creates linear transformations for Q, K, V projections with dimensions
depending on whether split heads or double projection mode is used.
The output projection combines the attention heads back to model dimension.
"""
if self.config.split_heads:
q_out_dim = self.config.hidden_size
k_out_dim = self.head_dim * self.base_num_kv_heads
else:
# Double projection mode
q_out_dim = self.hidden_size * 2
k_out_dim = (
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
)
q_out_dim = self.config.hidden_size * 2
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
self.q_proj = nn.Linear(
self.config.hidden_size, q_out_dim, bias=self.config.attention_bias
)
self.k_proj = nn.Linear(
self.config.hidden_size, k_out_dim, bias=self.config.attention_bias
)
self.v_proj = nn.Linear(
self.config.hidden_size,
self.head_dim * self.base_num_kv_heads,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(
self.base_num_heads * self.head_dim,
self.config.hidden_size,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
def _init_differential_params(self):
"""Initialize differential attention parameters."""
def _init_differential_params(self) -> None:
"""
Initializes parameters specific to differential attention.
Creates learnable parameters for the differential attention mechanism:
- Mixing parameter for negative attention component warmup phase.
- Lambda parameters for queries and keys.
- Initial lambda value based on layer index.
- Rotary position embedding layer.
"""
self.diff_attn_mix = 1.0 # Default to full mixing
self.lambda_init = nn.Parameter(
torch.full((), lambda_init_fn(self.layer_idx)),
requires_grad=False,
@@ -111,106 +192,245 @@ class DifferentialAttentionBase(nn.Module):
self.lambda_k2 = nn.Parameter(
torch.zeros(self.head_dim).normal_(mean=0, std=0.1)
)
self.rotary_emb = LlamaRotaryEmbedding(
self.max_position_embeddings, self.head_dim, self.rope_theta
)
def _init_normalization(self, config):
"""Initialize normalization layers."""
sublayer_norm = getattr(config, "sublayer_norm", True)
self.subln = (
LlamaRMSNorm(self.value_head_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _prepare_attention_inputs(self, hidden_states: torch.Tensor):
"""Prepare inputs for attention computation."""
def _init_normalization(self) -> None:
"""
Initializes normalization layers for the attention mechanism.
Sets up either RMS normalization or identity transformation based on config.
The normalization is applied to the sublayer output if enabled.
"""
sublayer_norm = getattr(self.config, "sublayer_norm", True)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
else:
self.subln = nn.Identity()
def _prepare_attention_inputs(
self, hidden_states: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Prepares input tensors for attention computation.
Projects input hidden states to query, key, and value spaces, then reshapes
them for multi-head attention processing.
Args:
hidden_states: Input tensor of shape `(batch_size, seq_len,
hidden_size)`.
Returns:
tuple: Tuple containing:
- q1: Positive attention query component
- q2: Negative attention query component
- k1: Positive attention key component
- k2: Negative attention key component
- v: Value tensor
"""
bsz, q_len, _ = hidden_states.size()
# Project and split
qp = self.q_proj(hidden_states)
kp = self.k_proj(hidden_states)
q = self.q_proj(hidden_states)
k = self.k_proj(hidden_states)
v = self.v_proj(hidden_states)
q1, q2 = qp.chunk(2, dim=-1)
k1, k2 = kp.chunk(2, dim=-1)
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)
# Reshape
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2)
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
q2 = q2.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
k1 = k1.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
return q1, q2, k1, k2, v
def _apply_rotary_embeddings(
self, q1, q2, k1, k2, position_ids, position_embeddings
):
"""Apply rotary embeddings to queries and keys."""
self,
q1: torch.Tensor,
q2: torch.Tensor,
k1: torch.Tensor,
k2: torch.Tensor,
position_ids: torch.Tensor,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None,
) -> tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
"""
Applies rotary positional embeddings to queries and keys.
Args:
q1: Positive attention query component.
q2: Negative attention query component.
k1: Positive attention key component.
k2: Negative attention key component.
position_ids: Token position indices.
position_embeddings: Pre-computed rotary embeddings (cos, sin).
Returns:
tuple: Tuple containing:
- q1: Positive attention query with positional encoding.
- q2: Negative attention query with positional encoding.
- k1: Positive attention key with positional encoding.
- k2: Negative attention key with positional encoding.
- cos: Cosine part of rotary embeddings.
- sin: Sine part of rotary embeddings.
"""
if position_embeddings is None:
if position_ids is None:
position_ids = torch.arange(q1.size(-2), device=q1.device)
LOG.warning(
"The attention layers in this model are transitioning from computing the RoPE embeddings internally "
"through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
"`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
"removed and `position_embeddings` will be mandatory."
)
cos, sin = self.rotary_emb(q1, position_ids)
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
return q1, q2, k1, k2, cos, sin
def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs):
"""Handle caching for autoregressive generation."""
def _handle_cache(
self,
k1: torch.Tensor,
k2: torch.Tensor,
v: torch.Tensor,
past_key_value: Cache | None,
cache_kwargs: dict,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Handles key-value caching for autoregressive generation and the repetition of
key-value heads to match the number of query heads.
Args:
k1: Positive attention key component.
k2: Negative attention key component.
v: Value tensor.
past_key_value: Cache object for storing previous key-value pairs.
cache_kwargs: Additional arguments for cache handling.
Returns:
tuple: Tuple containing:
- k1: Processed positive attention key component.
- k2: Processed negative attention key component.
- v: Processed value tensor.
"""
if past_key_value is not None:
k = torch.stack([k1, k2], dim=1)
k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs)
k1, k2 = k.unbind(dim=1)
# Repeat KV heads
k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads)
k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads)
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
if self.config.split_heads:
v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1)
return k1, k2, v
def _compute_lambda(self, q1):
"""Compute lambda values for differential attention."""
def _compute_lambda(self, q1: torch.Tensor) -> torch.Tensor:
"""
Computes lambda values for differential attention.
The lambda value is computed as λ₁ - λ₂ + λ_init, where λ₁ and λ₂ are computed
from the learned parameters. `diff_attn_mix` is multiplied through the result
for negative attention component warmup phase (if applicable).
Args:
q1: Positive attention query component, used for type casting.
Returns:
Computed lambda value for differential attention.
"""
lambda_1 = torch.exp(
torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float()
).type_as(q1)
lambda_2 = torch.exp(
torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float()
).type_as(q1)
return lambda_1 - lambda_2 + self.lambda_init
lambda_full = lambda_1 - lambda_2 + self.lambda_init
def _process_attention_output(self, attn, bsz, q_len):
"""Process and project attention output."""
return self.diff_attn_mix * lambda_full
def _process_attention_output(
self, attn: torch.Tensor, bsz: int, q_len: int
) -> torch.Tensor:
"""
Processes and projects the attention output. Applies sublayer normalization,
scales by (1 - λ_init), and projects back to model dimension.
Args:
attn: Raw attention output.
bsz: Batch size.
q_len: Query sequence length.
Returns:
Processed attention output of shape (batch_size, seq_len, hidden_size)
"""
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
# NOTE: this may need to be added back in, but doesn't interact well with
# `diff_attn_mix`, and doesn't allow us to preserve the original model output.
# attn = attn * self.diff_attn_mix * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)
class LlamaDifferentialAttention(DifferentialAttentionBase):
"""Standard implementation of differential attention."""
class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
"""
Standard implementation of differential attention.
This class implements the standard differential attention mechanism using
explicit matrix multiplications for the attention computation.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using standard matrix multiplication operations.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- Attention weights if output_attentions is True, else None.
- Updated key-value cache if use_cache is True, else None.
"""
bsz, q_len, _ = hidden_states.size()
q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states)
q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings(
@@ -232,36 +452,74 @@ class LlamaDifferentialAttention(DifferentialAttentionBase):
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
lambda_full = self._compute_lambda(q1)
attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v)
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
if output_attentions:
return attn, attn1 - lambda_full * attn2, past_key_value
attn_weights = attn1 - lambda_full * attn2
attn_weights = attn_weights.view(bsz, self.heads_per_component, q_len, -1)
return attn, attn_weights, past_key_value
return attn, None, past_key_value
class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
"""SDPA-based implementation of differential attention."""
class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
"""
SDPA-based implementation of differential attention.
This class implements differential attention using PyTorch's scaled_dot_product_attention
for improved performance on supported hardware.
"""
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using PyTorch's scaled dot product attention.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (SDPA doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialSdpaAttention, but "
+ "`torch.nn.functional.scaled_dot_product_attention` does not support "
+ "`output_attentions=True`. Falling back to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
@@ -288,7 +546,7 @@ class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
)
is_causal = attention_mask is None and q_len > 1
dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
@@ -304,27 +562,83 @@ class LlamaDifferentialSdpaAttention(DifferentialAttentionBase):
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value
class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
"""Flash Attention 2-based implementation of differential attention."""
class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
"""
Flash Attention 2-based implementation of differential attention.
This class implements differential attention using Flash Attention 2 for maximum
performance on supported hardware.
"""
def __init__(self, *args, **kwargs):
"""
Initializes the Flash Attention 2 differential attention module.
Args:
*args: Positional arguments passed to parent class.
**kwargs: Keyword arguments passed to parent class.
Raises:
ImportError: If flash-attn library is not installed.
"""
if not FLASH_ATTENTION_AVAILABLE:
raise ImportError(
"LlamaDifferentialFlashAttention2 requires flash-attn library. "
"Please install with `pip install flash-attn --no-build-isolation`"
)
super().__init__(*args, **kwargs)
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_value: Cache | None = None,
output_attentions: bool = False,
use_cache: bool = False,
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
cache_position: torch.LongTensor | None = None,
position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
**kwargs, # pylint: disable=unused-argument
):
"""
Computes differential attention using Flash Attention 2.
Args:
hidden_states: Input tensor containing sequence to attend to.
attention_mask: Mask to avoid attention on padding tokens.
position_ids: Indices of positions for positional embeddings.
past_key_value: Cached key and value tensors for autoregressive decoding.
output_attentions: Whether to return attention weights.
use_cache: Whether to use cached key/value states.
cache_position: Position indices for cached states.
position_embeddings: Pre-computed positional embeddings.
**kwargs: Additional arguments passed to the forward call.
Returns:
tuple containing:
- Output tensor after attention computation.
- None for attention weights (Flash Attention doesn't support output_attentions).
- Updated key-value cache if use_cache is True, else None.
"""
if output_attentions:
LOG.warning(
"LlamaDifferentialModel is using LlamaDifferentialFlashAttention2, but "
+ "flash attenion does not support `output_attentions=True`. Falling back "
+ "to the eager attention implementation."
)
# pylint: disable=duplicate-code
return LlamaDifferentialAttention.forward(
self,
hidden_states,
@@ -351,9 +665,9 @@ class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
v = v.transpose(1, 2)
dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0
if self.split_heads:
if self.config.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
@@ -370,6 +684,11 @@ class LlamaDifferentialFlashAttention2(DifferentialAttentionBase):
lambda_full = self._compute_lambda(q1)
attn = attn1 - lambda_full * attn2
attn = self._process_attention_output(attn, bsz, q_len)
# Save for logging
self.attn1 = attn1
self.attn2 = attn2
self.lambda_full = lambda_full
return attn, None, past_key_value

View File

@@ -0,0 +1,401 @@
"""
Modeling for differential transformers.
This module implements differential attention variants of the LLaMA model,
providing various attention implementations for improved performance.
"""
import logging
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
logger = logging.getLogger(__name__)
class LlamaDifferentialConfig(LlamaConfig):
"""
Configuration class for Differential LLaMA model.
Extends the base LLaMA configuration with additional parameters for differential
attention mechanisms.
"""
model_type = "llama-differential"
def __init__(
self,
split_heads: bool = False,
sublayer_norm: bool = True,
zero_init: bool = False,
mirror_weights: bool = False,
**kwargs,
):
"""
Initialize differential LLaMA configuration.
Args:
split_heads: Whether to use split heads mode for attention computation.
sublayer_norm: Whether to apply normalization to sublayers.
zero_init: Whether to initialize new weights to zero.
mirror_weights: Whether to copy the positive attention component weights to
the negative attention component.
**kwargs: Additional arguments passed to LlamaConfig.
"""
super().__init__(**kwargs)
self.split_heads = split_heads
self.sublayer_norm = sublayer_norm
self.zero_init = zero_init
self.mirror_weights = mirror_weights
self.architectures = ["LlamaDifferentialModel"]
self._attn_implementations = {
"eager": "differential_eager",
"sdpa": "differential_sdpa",
"flash_attention_2": "differential_flash_attention_2",
}
class LlamaDifferentialModel(LlamaModel):
"""
LlamaModel with differential attention.
This class extends the base LLaMA model by replacing standard attention with
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model.
Args:
config: Configuration object for the model.
Raises:
ValueError: If specified attention implementation is not supported.
"""
super().__init__(config)
# Handle attention implementation
attn_impl = config._attn_implementation or "eager"
if attn_impl in config._attn_implementations:
attn_impl = config._attn_implementations[attn_impl]
# Validate attention implementation
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_impl not in valid_impls:
raise ValueError(f"Invalid attention implementation: {attn_impl}")
# Replace standard attention with differential attention in each layer
attn_classes = {
"differential_eager": LlamaDifferentialAttention,
"differential_sdpa": LlamaDifferentialSdpaAttention,
"differential_flash_attention_2": LlamaDifferentialFlashAttention2,
}
attn_class = attn_classes.get(attn_impl, LlamaDifferentialAttention)
for idx, layer in enumerate(self.layers):
layer.self_attn = attn_class(config, idx)
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls,
model: LlamaModel | LlamaForCausalLM,
config: LlamaDifferentialConfig | None = None,
) -> "LlamaDifferentialModel":
"""
Convert a `LlamaModel` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
logger.info(f"Converting {type(model).__name__} to {cls.__name__}")
# Handle LlamaForCausalLM
if isinstance(model, LlamaForCausalLM):
model = model.model
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
logger.debug(f"Created config: {config}")
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
# Copy all weights except attention
logger.debug("Copying embeddings and norm")
new_model.embed_tokens.load_state_dict(model.embed_tokens.state_dict())
new_model.norm.load_state_dict(model.norm.state_dict())
logger.debug("Copying layer weights")
for layer_idx, (new_layer, old_layer) in enumerate(
zip(new_model.layers, model.layers)
):
# Copy everything except attention weights
new_layer.mlp.load_state_dict(old_layer.mlp.state_dict())
new_layer.input_layernorm.load_state_dict(
old_layer.input_layernorm.state_dict()
)
new_layer.post_attention_layernorm.load_state_dict(
old_layer.post_attention_layernorm.state_dict()
)
# Handle attention weights
new_layer.self_attn.v_proj.load_state_dict(
old_layer.self_attn.v_proj.state_dict()
)
new_layer.self_attn.o_proj.load_state_dict(
old_layer.self_attn.o_proj.state_dict()
)
# Get the original projection sizes
old_q_size = old_layer.self_attn.q_proj.weight.size(0)
old_k_size = old_layer.self_attn.k_proj.weight.size(0)
if not config.split_heads:
logger.debug(
f"Layer {layer_idx}: Copying Q/K projections with sizes {old_q_size}, {old_k_size}"
)
new_layer.self_attn.q_proj.weight.data[:old_q_size].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[:old_k_size].copy_(
old_layer.self_attn.k_proj.weight.data
)
if config.zero_init:
logger.debug(f"Layer {layer_idx}: Zero initializing")
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].zero_()
new_layer.self_attn.k_proj.weight.data[old_k_size:].zero_()
new_layer.self_attn.lambda_q1.zero_()
new_layer.self_attn.lambda_k1.zero_()
new_layer.self_attn.lambda_q2.zero_()
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
elif config.mirror_weights:
# Mirror weights for second component
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
old_layer.self_attn.q_proj.weight.data
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].copy_(
old_layer.self_attn.k_proj.weight.data
)
logger.info("Conversion complete")
return new_model
class LlamaDifferentialForCausalLM(LlamaForCausalLM):
"""
`LlamaForCausalLM` with differential attention.
This class extends the base LLaMA causal language model by incorporating
differential attention mechanisms.
"""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def __init__(self, config: LlamaDifferentialConfig):
"""
Initialize a differential LLaMA model for causal language modeling.
Args:
config: Configuration object for the model.
"""
super().__init__(config)
self.model = LlamaDifferentialModel(config)
@classmethod
# pylint: disable=protected-access
def _autoset_attn_implementation(
cls,
config: LlamaDifferentialConfig,
**kwargs, # pylint: disable=unused-argument
) -> LlamaDifferentialConfig:
"""
Automatically set the attention implementation based on config.
Args:
config: Model configuration object.
**kwargs: Additional arguments (unused).
Returns:
Updated configuration object.
Raises:
ValueError: If specified attention implementation is not supported.
"""
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls, model: LlamaForCausalLM, config: LlamaDifferentialConfig | None = None
) -> "LlamaDifferentialForCausalLM":
"""
Convert a `LlamaForCausalLM` to use differential attention.
Args:
model: Base LLaMA model to convert.
config: Configuration for differential attention. If `None`, created from
base model config.
Returns:
Converted model with differential attention.
Raises:
ValueError: If number of heads is not even when using `split_heads` mode.
"""
if config is None:
config = LlamaDifferentialConfig(**model.config.__dict__)
# Validate head counts if using split heads mode
if config.split_heads:
if config.num_attention_heads % 2 != 0:
raise ValueError(
f"Number of attention heads ({config.num_attention_heads}) must be even "
"when using split_heads=True"
)
if config.num_key_value_heads % 2 != 0:
raise ValueError(
f"Number of key/value heads ({config.num_key_value_heads}) must be even "
"when using split_heads=True"
)
new_model = cls(config)
new_model.model = LlamaDifferentialModel.from_llama(model.model, config)
new_model.lm_head.load_state_dict(model.lm_head.state_dict())
return new_model
def register_diff_attn() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-differential", LlamaDifferentialConfig)
# Register models
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2

View File

@@ -22,13 +22,6 @@ import inspect
import logging
import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
@@ -46,6 +39,13 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -2,10 +2,8 @@
import logging
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRALAAttention
from axolotl.integrations.rala.auto.llama.modeling_rala import register_rala_model
LOG = logging.getLogger(__name__)
@@ -18,17 +16,6 @@ class RalaPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.rala.args.RalaArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.rala_attention:
LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()
def set_attn_config(self, cfg, model_kwargs, model_config):
if cfg.rala_attention:
model_kwargs["attn_implementation"] = "rala"
def register(self):
LOG.info("Registering RALA model with AutoConfig & AutoModel")
register_rala_model()

View File

@@ -9,4 +9,5 @@ class LlamaRalaConfig(LlamaConfig):
Configuration for LlamaRala model
"""
softmax_every: int = 6 # every 8th layer applies softmax
model_type = "llama-rala"
softmax_every: int = 6 # every N-th layer applies softmax

View File

@@ -19,9 +19,17 @@ from typing import List, Optional, Tuple, Union, Unpack
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Cache, GenerationMixin, LlamaModel
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
Cache,
GenerationMixin,
LlamaModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LLAMA_ATTENTION_CLASSES,
KwargsForCausalLM,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
@@ -325,7 +333,15 @@ class LlamaRalaDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
config.num_hidden_layers, layer_idx, config.softmax_every
):
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx
)
# self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
else:
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -518,25 +534,18 @@ class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
num_logits_to_keep (`int`, *optional*):
Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
>>> prompt = "Hey, are you conscious? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
@@ -595,3 +604,20 @@ class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def register_rala_model() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-rala", LlamaRalaConfig)
# Register models
AutoModel.register(LlamaRalaConfig, LlamaRalaModel)
AutoModelForCausalLM.register(LlamaRalaConfig, LlamaRalaForCausalLM)
LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention

View File

@@ -7,8 +7,10 @@ from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.integrations.rala import LlamaRALAAttention
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer
from axolotl.integrations.rala.auto.llama.modeling_rala import (
LlamaRALAAttention,
LlamaRalaDecoderLayer,
)
logger = logging.getLogger(__name__)
@@ -69,7 +71,7 @@ def convert_to_rala(
layer_type = type(child).__name__
logger.info(
f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}"
f"Converting attention layer {decoder_layer_idx}: {layer_type} to {attention_class.__name__}"
)
# Create new diff attn layer
@@ -95,10 +97,10 @@ def convert_to_rala(
model.config.architectures = [
"LlamaRalaForCausalLM",
]
model.config.model_type = "llama_rala"
model.config.auto_map = {
"AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
"AutoModel": "llama.modeling_rala.LlamaRalaModel",
"AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
}
model.config.model_type = "llama-rala"
# model.config.auto_map = {
# "AutoConfig": "llama.configuration_rala.LlamaRalaConfig",
# "AutoModel": "llama.modeling_rala.LlamaRalaModel",
# "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM",
# }
return model

View File

@@ -1,280 +0,0 @@
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Cache
from transformers.models.llama.modeling_llama import (
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaRotaryEmbedding,
apply_rotary_pos_emb,
repeat_kv,
)
def kappa(x: torch.Tensor) -> torch.Tensor:
"""
The paper uses κ(x) = ELU(x) + 1.
x is assumed to be [batch, n_heads, seq_len, head_dim].
"""
return F.elu(x) + 1
class LlamaRALAAttention(nn.Module):
"""
LlamaAttention replaced with Rank-Augmented Linear Attention (RALA).
Adapted from the standard LlamaAttention for demonstration.
**Not** a fully drop-in replacement if you need caching/TP.
"""
def __init__(self, config, layer_idx: Optional[int] = None):
super().__init__()
self.config = config
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
# Same Q, K, V, output projections
self.q_proj = nn.Linear(
self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.v_proj = nn.Linear(
self.hidden_size,
self.num_key_value_heads * self.head_dim,
bias=config.attention_bias,
)
self.o_proj = nn.Linear(
self.hidden_size, self.hidden_size, bias=config.attention_bias
)
# We will preserve rope usage
self._init_rope()
# A simple φ-projection for RALA:
# The paper uses φ(x) as a linear transform or identity. We'll do a linear:
self.phi = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
def _init_rope(self):
# Standard Llama rope logic
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False, # pylint: disable=unused-argument
cache_position: Optional[torch.LongTensor] = None,
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
**kwargs, # pylint: disable=unused-argument
):
"""
RALA forward pass.
This version omits incremental decoding with `past_key_value` for simplicity
(linear attention caching is non-trivial).
"""
bsz, q_len, _ = hidden_states.size()
# Standard Q, K, V
query_states = self.q_proj(hidden_states) # [b, seq, n_heads*dim]
key_states = self.k_proj(hidden_states) # [b, seq, n_kv_heads*dim]
value_states = self.v_proj(hidden_states) # [b, seq, n_kv_heads*dim]
# Reshape to [b, n_heads, seq_len, head_dim]
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# Apply RoPE (rotary embeddings) just as in standard Llama
cos, sin = self.rotary_emb(value_states, position_ids)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin
)
# If you still want to handle the repeated KV for multi-group setups:
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
# Now we apply RALA.
# 1) Apply κ(.) to Q,K: shape [b, n_heads, seq_len, head_dim]
Q_kappa = kappa(query_states)
K_kappa = kappa(key_states)
# 2) Compute global query Q_g = average of Q_kappa across seq_len => [b, n_heads, head_dim]
# The paper denotes Q_g = (1/N) Σ_i Q_kappa_i
seq_len_float = float(q_len) # for scaling
Q_g = Q_kappa.mean(dim=2) # [b, n_heads, head_dim]
# 3) Compute alpha_j for each token j in [0..seq_len-1]
# alpha_j = N * softmax( Q_g · K_kappa_j^T ), shape => [b, n_heads, seq_len]
# Dot product over head_dim
# K_kappa is [b, n_heads, seq_len, head_dim], Q_g is [b, n_heads, head_dim]
# We'll do an einsum or transpose to produce logits [b, n_heads, seq_len]
# Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len]
# logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len]
logits = (Q_g.unsqueeze(2) * K_kappa).sum(
dim=-1
) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work
# 4) Incorporate causal or padding mask if provided.
# In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar.
# For RALA, we only do a single softmax over "j" dimension. We can add the mask to logits.
# Caution: This might not replicate strict causal linear attention. It's a best-effort approach.
if attention_mask is not None:
# Usually Llama's causal mask is [b, 1, q_len, kv_len] with 0 or -inf
# We want shape [b, n_heads, seq_len], so we can broadcast accordingly:
# e.g., attention_mask: [b, 1, q_len, seq_len]
# We pick the slice that corresponds to q_len vs. kv_len.
# Typically the last two dims are (q_len, kv_len). We want the kv_len dimension to be `seq_len`.
# We'll do something like:
if attention_mask.dim() == 4:
# attention_mask: [b, 1, q_len, kv_len]
# if q_len == kv_len, we can do attention_mask[:, :, :, :seq_len], then squeeze dims
mask_2d = attention_mask[:, 0, :, :q_len] # [b, q_len, seq_len]
# we only want [b, n_heads, seq_len], so we must broadcast over q_len if needed
# but in this snippet, we do a single alpha_j for each j *per head*,
# ignoring per-token Q_i. So there's a mismatch.
# A simpler approach is to apply the mask for the entire sequence if a token j is invalid for ANY i.
# That is approximate. We'll just pick the first row of q_len, or do min across i dimension...
# For demonstration, let's sum or min across i dimension to see if j is valid for ANY i.
# Or we do a "causal" approach: all tokens j>i get masked. But there's no direct i index here in alpha_j.
# We'll just do a rough approach, e.g. mask = min across the q_len dimension:
mask_1d = torch.min(mask_2d, dim=1)[
0
] # [b, seq_len], picking the worst mask across query positions
# broadcast for n_heads
mask_1d = mask_1d.unsqueeze(1).expand(
-1, self.num_heads, -1
) # [b, n_heads, seq_len]
logits = logits + mask_1d
else:
# Possibly it's [b, seq_len]. Then we just broadcast to [b,n_heads,seq_len].
mask_1d = attention_mask # [b, seq_len]
mask_1d = mask_1d.unsqueeze(1).expand(-1, self.num_heads, -1)
logits = logits + mask_1d
alpha = F.softmax(logits, dim=-1) # [b, n_heads, seq_len]
# multiply by seq_len per the formula
alpha = alpha * seq_len_float
# 5) Construct the outer-sum: Σ_j alpha_j * (K_kappa_j^T V_j)
# The paper shows a d×d matrix formed per head.
# K_kappa: [b, n_heads, seq_len, head_dim], V: [b, n_heads, seq_len, head_dim]
# For each j, do outer product K_kappa_j (d×1) × V_j^T (1×d) => d×d
# Then multiply by alpha_j and sum over j.
# We'll do an einsum for that: [b,n_heads,seq_len,d] outer [b,n_heads,seq_len,d] => [b,n_heads,d,d]
# alpha: [b, n_heads, seq_len].
value_states_ = value_states # [b, n_heads, seq_len, head_dim]
outer_sum = torch.einsum("bns,bnsd,bnsf->bndf", alpha, K_kappa, value_states_)
# Explanation:
# - 'bnhs' is alpha (batch, n_heads, seq_len)
# - 'bnhsd' is K_kappa (b,n_heads,seq_len, d)
# - 'bnhsf' is V (b,n_heads,seq_len, d)
# We want [b,n_heads,d,f], which is the d×d matrix per head.
# Actually we need an outer product (K_kappa_j^T × V_j). That is [d, d].
# The call above is not quite correct if we want K_kappa_j^T × V_j as [d,d].
# Let's do a simpler approach:
# outer_sum = sum_j alpha_j * (K_kappa_j^T outer V_j).
# = "bnhs,bnhsd,bnhsf -> bnhdf"
# means: alpha has shape (b,n,h,s), K_kappa has shape (b,n,h,s,d), V has shape (b,n,h,s,d)
# We want to produce (b,n,h,d,d).
# So the correct einsum string is 'bnhs,bnhsd,bnhsf->bnhdf':
# alpha indexes b,n,h,s
# K_kappa indexes b,n,h,s,d => K_kappa_j
# V indexes b,n,h,s,f => V_j
# The resulting shape is (b,n,h,d,f). Great.
# 6) For each token i, Y_i = φ(X_i) ∘ [ κ(Q_i) × outer_sum ]
# Here κ(Q_i) is shape [b,n,h,d], outer_sum is shape [b,n,h,d,d].
# We'll do a batch matmul: result_attn = Q_kappa_i × outer_sum => [b,n,h,d]
# Then multiply elementwise by φ(X_i).
# But φ(X_i) is a single [b,seq_len,d_model], so we reshape to [b,seq_len,n,h_dim].
# We'll do per-token i in a loop or broadcast. Let's do it in a single operation with einsum:
# first, compute φ(X):
# X is the original hidden_states: [b, seq_len, d_model]
X_phi = self.phi(hidden_states) # [b, seq_len, d_model]
X_phi = X_phi.view(bsz, q_len, self.num_heads, self.head_dim) # [b, s, n, d]
X_phi = X_phi.transpose(1, 2) # [b, n, s, d]
# Now for each i in [0..q_len-1], we do a matrix multiply:
# result_attn_i = Q_kappa_i [b,n,s,d] × outer_sum [b,n,d,d] => we want [b,n,s,d].
# We'll do:
result_attn = torch.einsum("bnsd,bndf->bnsf", Q_kappa, outer_sum) # [b,n,s,d]
# Then elementwise multiply by φ(X_i):
context_layer = X_phi * result_attn # [b,n,s,d]
# Finally, reorder to [b, s, n, d] -> [b, s, n*d]
context_layer = context_layer.transpose(1, 2).contiguous() # [b, s, n, d]
context_layer = context_layer.view(bsz, q_len, self.hidden_size)
# One last linear projection:
attn_output = self.o_proj(context_layer)
# Not returning a standard attn_weights.
# If you want to return alpha as "attention," we can do so:
if output_attentions:
# alpha: [b, n_heads, seq_len], but note it's only the "global" weighting of each key,
# not a (q_len x kv_len) map like standard attention.
attn_weights = alpha
else:
attn_weights = None
# We omit cache / past_key_value returns to keep it simpler.
return attn_output, attn_weights, None

View File

@@ -1,49 +0,0 @@
"""Patches related to differential transformers implementation."""
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from axolotl.integrations.diff_transformer.diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
def patch_llama_attention_classes():
"""Patch transformers to support differential attention"""
# Add our attention class to the registry
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2
@classmethod
def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument
config._attn_implementation_autoset = True # pylint: disable=protected-access
attn_implementation = getattr(config, "_attn_implementation", None)
valid_impls = [
None,
"eager",
"sdpa",
"flash_attention_2",
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
"rala",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message + ".")
return config
# Apply patch
PreTrainedModel._autoset_attn_implementation = ( # pylint: disable=protected-access
new_autoset
)

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -8,7 +8,7 @@ import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.unsloth_ import detab_code
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

View File

@@ -1,9 +1,7 @@
"""module for patching with unsloth optimizations"""
import inspect
import re
import types
from typing import Tuple
import torch
from accelerate.logging import get_logger
@@ -11,6 +9,8 @@ from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """
@@ -93,15 +93,6 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,7 +1,8 @@
"""
Shared utils for the monkeypatches
"""
from typing import Optional
import re
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
@@ -223,3 +224,12 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -1,5 +1,6 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import inspect
import os
import signal
import sys
@@ -126,7 +127,20 @@ def train(
)
if cfg.fix_untrained_tokens:
fix_untrained_tokens(model, tokenizer, train_dataset)
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
import gc
import logging
import math
import os
@@ -842,3 +843,17 @@ class SaveModelCallback(TrainerCallback):
):
control.should_save = True
return control
class GCCallback(TrainerCallback):
"""Callback to garbage collect torch cache"""
def __init__(self, gc_steps=None):
self.gc_steps = gc_steps
def on_step_end(
self, args, state, control, **kwargs # pylint: disable=unused-argument
):
if state.global_step % self.gc_steps == 0:
torch.cuda.empty_cache()
gc.collect()

View File

@@ -0,0 +1,234 @@
"""
Monitor and log differential attention components during training.
This module provides a callback for tracking the behavior of differential attention
mechanisms, including lambda parameters and attention statistics.
"""
from typing import Any
import torch
import wandb
from torch import nn
from transformers import TrainerCallback
from axolotl.utils.distributed import is_main_process
class DifferentialAttentionMonitorCallback(TrainerCallback):
"""
Callback to monitor differential attention components and lambda parameters.
This callback tracks attention statistics across all layers and provides detailed
monitoring for a specified number of layers evenly spaced through the model.
"""
def __init__(
self,
log_every: int = 250,
num_monitor_layers: int = 3,
warmup_steps: int | None = None,
):
"""
Initialize the differential attention monitor.
Args:
log_every: Number of steps between logging events.
num_monitor_layers: Number of individual layers to monitor in detail.
warmup_steps: Optional parameter for negative attention component warmup.
"""
self.log_every = log_every
self.num_monitor_layers = num_monitor_layers
self.warmup_steps = warmup_steps
self.monitor_layers: list[int] | None = None # Will be set in on_train_begin
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""
Set up layer monitoring at the start of training.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if is_main_process():
num_layers = len(model.model.layers)
self.num_monitor_layers = min(self.num_monitor_layers, num_layers)
stride = (
(num_layers - 1) / (self.num_monitor_layers - 1)
if self.num_monitor_layers > 1
else 0
)
self.monitor_layers = [
round(i * stride) for i in range(self.num_monitor_layers)
]
print(f"Monitoring layers {self.monitor_layers} in detail")
# pylint: disable=unused-argument
def on_step_end(
self, args: Any, state: Any, control: Any, model: torch.nn.Module, **kwargs
) -> None:
"""
Log attention metrics at the end of each step.
Collects and logs:
- Lambda parameter norms and values.
- Attention statistics (mean and std).
- Both per-layer and aggregate metrics.
Args:
args: Training arguments.
state: Training state.
control: Training control object.
model: The model being trained.
**kwargs: Additional arguments passed by the trainer.
"""
if not is_main_process() or state.global_step % self.log_every != 0:
return
assert self.monitor_layers is not None
# Aggregate stats across all layers
all_q1_norms = []
all_q2_norms = []
all_k1_norms = []
all_k2_norms = []
all_lambda1 = []
all_lambda2 = []
all_lambda_full = []
metrics = {}
for layer_idx, layer in enumerate(model.model.layers):
attn = layer.self_attn
# Collect stats for aggregation
all_q1_norms.append(attn.lambda_q1.norm().item())
all_q2_norms.append(attn.lambda_q2.norm().item())
all_k1_norms.append(attn.lambda_k1.norm().item())
all_k2_norms.append(attn.lambda_k2.norm().item())
lambda1 = torch.exp(torch.sum(attn.lambda_q1 * attn.lambda_k1)).item()
lambda2 = torch.exp(torch.sum(attn.lambda_q2 * attn.lambda_k2)).item()
all_lambda1.append(lambda1)
all_lambda2.append(lambda2)
all_lambda_full.append(attn.lambda_full)
# Log detailed metrics for monitored layers
if layer_idx in self.monitor_layers:
metrics.update(
{
f"layer_{layer_idx}/lambda_q1_norm": attn.lambda_q1.norm().item(),
f"layer_{layer_idx}/lambda_k1_norm": attn.lambda_k1.norm().item(),
f"layer_{layer_idx}/lambda_q2_norm": attn.lambda_q2.norm().item(),
f"layer_{layer_idx}/lambda_k2_norm": attn.lambda_k2.norm().item(),
f"layer_{layer_idx}/lambda1": lambda1,
f"layer_{layer_idx}/lambda2": lambda2,
f"layer_{layer_idx}/lambda_init": attn.lambda_init.item(),
f"layer_{layer_idx}/lambda_full": lambda1
- lambda2
+ attn.lambda_init.item(),
f"layer_{layer_idx}/attn1_mean": attn.attn1.mean().item(),
f"layer_{layer_idx}/attn2_mean": attn.attn2.mean().item(),
f"layer_{layer_idx}/attn1_std": attn.attn1.std().item(),
f"layer_{layer_idx}/attn2_std": attn.attn2.std().item(),
}
)
# Add aggregate metrics
metrics.update(
{
"aggregate/lambda_q1_norm_mean": torch.tensor(all_q1_norms)
.mean()
.item(),
"aggregate/lambda_q1_norm_std": torch.tensor(all_q1_norms).std().item(),
"aggregate/lambda_q2_norm_mean": torch.tensor(all_q2_norms)
.mean()
.item(),
"aggregate/lambda_q2_norm_std": torch.tensor(all_q2_norms).std().item(),
"aggregate/lambda_k1_norm_mean": torch.tensor(all_k1_norms)
.mean()
.item(),
"aggregate/lambda_k1_norm_std": torch.tensor(all_k1_norms).std().item(),
"aggregate/lambda_k2_norm_mean": torch.tensor(all_k2_norms)
.mean()
.item(),
"aggregate/lambda_k2_norm_std": torch.tensor(all_k2_norms).std().item(),
"aggregate/lambda1_mean": torch.tensor(all_lambda1).mean().item(),
"aggregate/lambda1_std": torch.tensor(all_lambda1).std().item(),
"aggregate/lambda2_mean": torch.tensor(all_lambda2).mean().item(),
"aggregate/lambda2_std": torch.tensor(all_lambda2).std().item(),
"aggregate/lambda_full_mean": torch.tensor(all_lambda_full)
.mean()
.item(),
"aggregate/lambda_full_std": torch.tensor(all_lambda_full).std().item(),
}
)
if self.warmup_steps:
metrics["aggregate/diff_attn_mix"] = attn.diff_attn_mix
wandb.log(metrics, step=state.global_step)
class DifferentialAttentionMixingCallback(TrainerCallback):
"""
Callback to gradually increase the weight of negative attention components during
training.
"""
def __init__(self, warmup_steps: int):
"""
Args:
warmup_steps: Number of steps to linearly increase negative attention
weight from 0 to 1. If `None`, negative attention has full weight from
start.
"""
self.warmup_steps = warmup_steps
self.diff_attention_layers: list[nn.Module] | None = None
# pylint: disable=unused-argument
def on_train_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module,
**kwargs,
) -> None:
"""Cache the differential attention layers at the start of training."""
if model is not None:
# Get the actual model if it's wrapped
if hasattr(model, "module"):
model = model.module
# Cache all differential attention layers
self.diff_attention_layers = [
module for module in model.modules() if hasattr(module, "diff_attn_mix")
]
def on_step_begin(
self,
args: Any,
state: Any,
control: Any,
model: torch.nn.Module = None,
**kwargs,
) -> None:
if self.diff_attention_layers and self.warmup_steps:
# Calculate mixing parameter (0 to 1)
mix = min(1.0, state.global_step / self.warmup_steps)
# Update cached layers
for layer in self.diff_attention_layers:
layer.diff_attn_mix = mix

View File

@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
getattr, self.layers_attribute.split("."), self.trainer.model
)
LOG.info(
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
)
def freeze_all_layers(self):

View File

@@ -128,6 +128,7 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
class UserDefinedPrompterType(BaseModel):
@@ -666,6 +667,8 @@ class AxolotlInputConfig(
loss_watchdog_threshold: Optional[float] = None
loss_watchdog_patience: Optional[int] = None
gc_steps: Optional[int] = None
bf16: Optional[Union[Literal["auto"], bool]] = "auto"
fp16: Optional[bool] = None
bfloat16: Optional[bool] = None # for non-AMP cases
@@ -695,6 +698,12 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None
@@ -792,7 +801,7 @@ class AxolotlInputConfig(
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[bool] = None
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None

View File

@@ -18,7 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
tokenizer: PreTrainedTokenizerBase,
max_tokens: int,
examples: Dict[str, List],
concatenate: bool = True,
) -> Dict[str, List]:
res = tokenizer(
examples["text"],
@@ -28,8 +31,17 @@ def encode_pretraining(
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = []
new_labels = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
@@ -40,22 +52,34 @@ def encode_pretraining(
),
dim=0,
)
targets[i] = torch.cat(
(
targets[i],
torch.tensor([tokenizer.eos_token_id, -100]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, mask in zip(input_ids, attention_mask):
for ids, labels, mask in zip(input_ids, targets, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
@@ -69,6 +93,17 @@ def encode_pretraining(
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
@@ -81,11 +116,14 @@ def encode_pretraining(
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
@@ -101,6 +139,17 @@ def encode_pretraining(
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
@@ -113,11 +162,12 @@ def encode_pretraining(
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_labels],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}
@@ -155,6 +205,10 @@ def wrap_pretraining_dataset(
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
elif cfg.pretraining_sample_concatenation is False:
encode = functools.partial(
encode_pretraining, tokenizer, max_tokens, concatenate=False
)
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)

View File

@@ -3,7 +3,7 @@
import functools
import logging
from pathlib import Path
from typing import List, Optional, Tuple, Union
from typing import List, Tuple, Union
from datasets import (
Dataset,
@@ -12,8 +12,6 @@ from datasets import (
load_dataset,
load_from_disk,
)
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError
from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
@@ -42,6 +40,7 @@ from axolotl.prompters import (
UnsupportedPrompter,
)
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
from axolotl.utils.data.shared import load_dataset_w_config
from axolotl.utils.data.utils import (
deduplicate_and_log_datasets,
md5,
@@ -85,9 +84,11 @@ def prepare_dataset(cfg, tokenizer, processor=None):
processor=processor,
)
else:
# Load streaming dataset if pretraining_dataset is given
path = cfg.pretraining_dataset
split = "train"
name = None
data_files = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
@@ -96,6 +97,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
@@ -105,7 +108,9 @@ def prepare_dataset(cfg, tokenizer, processor=None):
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split=split, name=name),
load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
),
tokenizer,
cfg,
ds_wrapper_partial,
@@ -116,7 +121,18 @@ def prepare_dataset(cfg, tokenizer, processor=None):
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
# Load eval dataset (non-streaming) if specified
eval_dataset = None
if cfg.test_datasets:
_, eval_dataset, _ = load_prepare_datasets(
tokenizer,
cfg,
DEFAULT_DATASET_PREPARED_PATH,
split="test",
processor=processor,
)
if cfg.dataset_exact_deduplication:
LOG.info("Deduplication not available for pretrained datasets")
@@ -243,195 +259,9 @@ def load_tokenized_prepared_datasets(
# pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Optional[Union[Dataset, DatasetDict]] = None
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith(
"gs://"
) or config_dataset.path.startswith("gcs://"):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(
config_dataset.path
):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
else:
try:
ds = load_from_disk(config_dataset.path)
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
data_files=config_dataset.data_files,
token=use_auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError(
"data_files must be either a string or list of strings"
)
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
)
if not ds:
raise ValueError("unhandled dataset load")
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token
)
d_base_type = d_prompt_style = None
d_type = config_dataset.type
@@ -501,24 +331,6 @@ def load_tokenized_prepared_datasets(
return dataset, prompters
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase,
cfg,

View File

@@ -0,0 +1,222 @@
"""
dataset loading shared utils
"""
from pathlib import Path
from typing import Optional, Union
from datasets import Dataset, DatasetDict, load_dataset, load_from_disk
from huggingface_hub import hf_hub_download
from huggingface_hub.errors import HFValidationError
from axolotl.utils.dict import DictDefault
def get_ds_type(config_dataset: DictDefault):
"""
Get the dataset type from the path if it's not specified
"""
ds_type = "json"
if config_dataset.ds_type:
ds_type = config_dataset.ds_type
elif ".parquet" in config_dataset.path:
ds_type = "parquet"
elif ".arrow" in config_dataset.path:
ds_type = "arrow"
elif ".csv" in config_dataset.path:
ds_type = "csv"
elif ".txt" in config_dataset.path:
ds_type = "text"
return ds_type
def load_dataset_w_config(config_dataset, auth_token):
# pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False
ds_trust_remote_code = config_dataset.trust_remote_code
try:
# this is just a basic check to see if the path is a
# valid HF dataset that's loadable
load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=True,
token=auth_token,
revision=config_dataset.revision,
trust_remote_code=ds_trust_remote_code,
)
ds_from_hub = True
except (FileNotFoundError, ConnectionError, HFValidationError, ValueError):
pass
ds_from_cloud = False
storage_options = {}
remote_file_system = None
if config_dataset.path.startswith("s3://"):
try:
import aiobotocore.session # type: ignore
import s3fs # type: ignore
except ImportError as exc:
raise ImportError(
"s3:// paths require aiobotocore and s3fs to be installed"
) from exc
# Takes credentials from ~/.aws/credentials for default profile
s3_session = aiobotocore.session.AioSession(profile="default")
storage_options = {"session": s3_session}
remote_file_system = s3fs.S3FileSystem(**storage_options)
elif config_dataset.path.startswith("gs://") or config_dataset.path.startswith(
"gcs://"
):
try:
import gcsfs # type: ignore
except ImportError as exc:
raise ImportError(
"gs:// or gcs:// paths require gcsfs to be installed"
) from exc
# gcsfs will use default credentials from the environment else anon
# https://gcsfs.readthedocs.io/en/latest/#credentials
storage_options = {"token": None}
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
# TODO: Figure out how to get auth creds passed
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
# try:
# import adlfs
# except ImportError as exc:
# raise ImportError(
# "adl:// or abfs:// paths require adlfs to be installed"
# ) from exc
# # Gen 1
# storage_options = {
# "tenant_id": TENANT_ID,
# "client_id": CLIENT_ID,
# "client_secret": CLIENT_SECRET,
# }
# # Gen 2
# storage_options = {
# "account_name": ACCOUNT_NAME,
# "account_key": ACCOUNT_KEY,
# }
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
try:
if remote_file_system and remote_file_system.exists(config_dataset.path):
ds_from_cloud = True
except (FileNotFoundError, ConnectionError):
pass
# prefer local dataset, even if hub exists
local_path = Path(config_dataset.path)
if local_path.exists():
if local_path.is_dir():
if config_dataset.data_files:
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.data_files,
streaming=False,
split=None,
)
else:
try:
ds = load_from_disk(
config_dataset.path
) # pylint: disable=invalid-name
except FileNotFoundError:
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
split=None,
)
elif local_path.is_file():
ds_type = get_ds_type(config_dataset)
ds = load_dataset( # pylint: disable=invalid-name
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset(
config_dataset.path,
name=config_dataset.name,
streaming=False,
data_files=config_dataset.data_files,
token=auth_token,
revision=config_dataset.revision,
trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=config_dataset.data_files,
revision=config_dataset.revision,
)
elif isinstance(config_dataset.data_files, list):
fp = []
for file in config_dataset.data_files:
fp.append(
hf_hub_download(
repo_id=config_dataset.path,
repo_type="dataset",
filename=file,
revision=config_dataset.revision,
)
)
else:
raise ValueError("data_files must be either a string or list of strings")
ds = load_dataset(
"json",
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
)
if not ds:
raise ValueError("unhandled dataset load")
return ds

View File

@@ -270,7 +270,7 @@ def load_sharded_model_quant(
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
print(f"Loaded model weights in {time.time() - start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()

View File

@@ -713,7 +713,7 @@ class ModelLoader:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
if self.cfg.differentiaion:
if self.cfg.diff_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"

View File

@@ -196,7 +196,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type == "falcon":
if cfg.model_config_type in ["falcon", "mistral"]:
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")

View File

@@ -84,6 +84,11 @@ class OrderedDumper(yaml.SafeDumper):
"""Custom YAML dumper that maintains dictionary order."""
def represent_none(self, _):
"""Represent None values as empty fields."""
return self.represent_scalar("tag:yaml.org,2002:null", "")
def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any:
"""Custom representer for dictionaries that maintains order."""
return dumper.represent_mapping("tag:yaml.org,2002:map", data.items())
@@ -121,7 +126,8 @@ def dump_yaml_preserved_order(
# Reorder the data
ordered_data = reorder_dict(data, tracker.structure)
# Register the custom representer
# Register the custom representers
OrderedDumper.add_representer(type(None), represent_none)
OrderedDumper.add_representer(dict, ordered_dict_representer)
OrderedDumper.add_representer(OrderedDict, ordered_dict_representer)

View File

@@ -43,14 +43,12 @@ class BaseCliTest:
result = cli_runner.invoke(cli, [command, str(config_path)])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
f"axolotl.cli.{command}",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -23,6 +23,7 @@ def test_build_command():
"--batch-size",
"8",
"--debug",
"--nouse-fp16",
]

View File

@@ -12,14 +12,12 @@ def test_shard_with_accelerate(cli_runner, config_path):
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
assert mock.call_args.args[0][:5] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0

View File

@@ -120,13 +120,12 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import (
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
original_fa2_forward = LlamaFlashAttention2.forward
# original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
@@ -136,7 +135,7 @@ def cleanup_monkeypatches():
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
LlamaFlashAttention2.forward = original_fa2_forward
# LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
@@ -149,7 +148,10 @@ def cleanup_monkeypatches():
("transformers.models.llama",),
(
"transformers.models.llama.modeling_llama",
["LlamaFlashAttention2", "LlamaAttention"],
[
# "LlamaFlashAttention2",
"LlamaAttention",
],
),
("transformers.trainer",),
("transformers", ["Trainer"]),

View File

@@ -4,7 +4,7 @@ import pytest
from click.testing import CliRunner
@pytest.fixture()
@pytest.fixture(scope="class")
def base_config():
"""Basic config for testing."""
return {
@@ -26,6 +26,6 @@ def base_config():
}
@pytest.fixture
@pytest.fixture(scope="class")
def cli_runner():
return CliRunner()

View File

@@ -130,6 +130,9 @@ def test_conversion_cli_repoduce_attentions(
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
# Smallest model with an even number of attention heads
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True

View File

@@ -1,43 +1,40 @@
"""
Simple end-to-end test for Liger integration
"""
import unittest
from pathlib import Path
from e2e.utils import require_torch_2_4_1
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
class LigerIntegrationTestCase(unittest.TestCase):
class LigerIntegrationTestCase:
"""
e2e tests for liger integration with Axolotl
"""
@with_temp_dir
@require_torch_2_4_1
def test_llama_wo_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_glu_activation": True,
"liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -46,15 +43,15 @@ class LigerIntegrationTestCase(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
"max_steps": 5,
}
)
prepare_plugins(cfg)
@@ -65,26 +62,24 @@ class LigerIntegrationTestCase(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
@with_temp_dir
@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_swiglu": True,
"liger_glu_activation": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{
@@ -93,15 +88,15 @@ class LigerIntegrationTestCase(unittest.TestCase):
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 10,
"max_steps": 5,
}
)
prepare_plugins(cfg)

View File

@@ -1,9 +1,14 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""

View File

@@ -20,6 +20,9 @@ os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models

View File

@@ -113,6 +113,7 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -49,7 +49,19 @@ def require_torch_2_3_1(test_case):
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.3.1")
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
return unittest.skipUnless(is_min_2_3_1(), "test requires torch>=2.3.1")(test_case)
def require_torch_2_4_1(test_case):
"""
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_4_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.4.1")
return unittest.skipUnless(is_min_2_4_1(), "test requires torch>=2.4.1")(test_case)
def require_torch_2_5_1(test_case):
@@ -61,7 +73,7 @@ def require_torch_2_5_1(test_case):
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.5.1")
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
def is_hopper():

View File

@@ -7,11 +7,11 @@ from typing import Optional
import pytest
from axolotl.utils.config import validate_config
from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_base_cfg")
@pytest.fixture(name="minimal_liger_cfg")
def fixture_cfg():
return DictDefault(
{
@@ -25,56 +25,57 @@ def fixture_cfg():
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
}
)
class BaseValidation:
# pylint: disable=too-many-public-methods
class TestValidation:
"""
Base validation module to setup the log capture
Test the validation module for liger
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
caplog.set_level(logging.WARNING)
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
def test_deprecated_swiglu(self, minimal_liger_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_cfg
| minimal_liger_cfg
)
with self._caplog.at_level(logging.WARNING):
with self._caplog.at_level(
logging.WARNING, logger="axolotl.integrations.liger.args"
):
prepare_plugins(test_cfg)
updated_cfg = validate_config(test_cfg)
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
# TODO this test is brittle in CI
# assert (
# "The 'liger_swiglu' argument is deprecated"
# in self._caplog.records[0].message
# )
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activations is False
assert updated_cfg.liger_glu_activation is False
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activations": True,
"liger_glu_activation": True,
}
| minimal_cfg
| minimal_liger_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
prepare_plugins(test_cfg)
validate_config(test_cfg)

View File

@@ -4,9 +4,7 @@ import json
import logging
import unittest
from pathlib import Path
from typing import Optional
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
@@ -65,12 +63,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies.
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")