Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
1dd7f087b3 support for custom lr groups for non-embedding modules
invert name check for group modules
include lr_groups in training args
additional conditional for creating optimizer
fix regular params as w weight decay
fix lookup and add docs
2024-12-25 22:36:59 -05:00
62 changed files with 466 additions and 524 deletions

View File

@@ -1,7 +1,6 @@
name: lint
on:
# check on PRs, and manual triggers
merge_group:
pull_request:
paths:
- '**.py'

View File

@@ -25,6 +25,7 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -35,7 +36,6 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -92,6 +92,7 @@ jobs:
python_version: "3.11"
pytorch: 2.3.1
axolotl_extras:
is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
@@ -102,7 +103,6 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -129,7 +129,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -1,7 +1,6 @@
name: Tests
on:
# check on push/merge to main, PRs, and manual triggers
merge_group:
push:
branches:
- "main"
@@ -61,15 +60,6 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -110,15 +100,6 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
@@ -134,15 +115,6 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -184,15 +156,6 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
@@ -220,7 +183,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -266,7 +229,7 @@ jobs:
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -23,7 +23,7 @@ repos:
hooks:
- id: flake8
- repo: https://github.com/PyCQA/pylint
rev: v3.3.0
rev: v2.17.4
hooks:
- id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy

View File

@@ -1,5 +1,5 @@
[MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())"
init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
[TYPECHECK]
@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -8,7 +8,6 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -28,7 +28,6 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -49,12 +48,6 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
@@ -74,7 +67,6 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")

View File

@@ -29,7 +29,6 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
}
dockerfile_contents = df_template.render(**df_args)
@@ -51,12 +50,6 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
@@ -76,7 +69,6 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60,
cpu=8.0,
memory=131072,
volumes=VOLUME_CONFIG,
)
def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl")

View File

@@ -19,14 +19,7 @@ For pretraining, there is no prompt template or roles. The only required field
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```{.yaml filename="config.yaml"}
pretraining_dataset:
- name:
path:
split:
text_column: # column in dataset with the data, usually `text`
type: pretrain
trust_remote_code:
skip: # number of rows of data to skip over from the beginning
pretraining_dataset: # hf path only
...
```

29
docs/lr_groups.qmd Normal file
View File

@@ -0,0 +1,29 @@
---
title: Learning Rate Groups
description: "Setting different learning rates by module name"
---
## Background
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
modules in a model.
## Example
```yaml
lr_groups:
- name: o_proj
modules:
- self_attn.o_proj.weight
lr: 1e-6
- name: q_proj
modules:
- model.layers.2.self_attn.q_proj.weight
lr: 1e-5
learning_rate: 2e-5
```
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.

View File

@@ -2,7 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0
triton>=3.0.0
triton>=2.3.0
mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2
xformers>=0.0.23.post1
@@ -14,11 +14,11 @@ packaging==23.2
peft==0.14.0
transformers==4.47.1
tokenizers>=0.21.0
tokenizers>=0.20.1
accelerate==1.2.1
datasets==3.2.0
datasets==3.1.0
deepspeed==0.16.1
trl==0.13.0
trl==0.12.1
optimum==1.16.2
hf_transfer
@@ -53,7 +53,7 @@ zstandard==0.22.0
fastcore
# lm eval harness
lm_eval==0.4.7
lm_eval==0.4.4
langdetect==1.0.9
immutabledict==4.2.0
antlr4-python3-runtime==4.13.2
@@ -61,4 +61,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0
schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3
axolotl-contribs-lgpl==0.0.1b2

View File

@@ -1,5 +1,4 @@
"""setup.py for axolotl"""
import ast
import os
import platform
@@ -30,30 +29,15 @@ def parse_requirements():
elif not is_extras and line and line[0] != "#":
# Handle standard packages
_install_requires.append(line)
try:
xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system():
# skip packages not compatible with OSX
skip_packages = [
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
# don't install xformers on MacOS
_install_requires.pop(_install_requires.index(xformers_version))
else:
# detect the version of torch already installed
# and set it so dependencies don't clobber the torch version
@@ -89,8 +73,6 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1")

View File

@@ -93,7 +93,7 @@ def evaluate(config: str, accelerate: bool, **kwargs):
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
default=True,
help="Use accelerate launch for multi-GPU inference",
)
@click.option(
@@ -124,7 +124,7 @@ def inference(
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["base_model"] = base_model
kwargs["output_dir"] = base_model
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]

View File

@@ -27,6 +27,7 @@ def add_options_from_dataclass(config_class: Type[Any]):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
if field_type == bool:
field_name = field.name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch
import transformers
from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -67,7 +68,7 @@ from axolotl.utils.callbacks import (
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForSeq2Seq,
@@ -243,6 +244,10 @@ class AxolotlTrainingMixins:
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
@@ -461,11 +466,96 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
)
return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
if lr_groups_lookup and any(
group_modules in name for group_modules in lr_groups_lookup
):
lr_group_module = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
][0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self):
if (
self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer
not in [
"optimi_adamw",
@@ -479,59 +569,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs
)
if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -548,6 +592,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif (
self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None
or self.args.lr_groups is not None
):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
@@ -607,14 +652,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.state.train_batch_size or self.args.per_device_train_batch_size
)
batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler(
sampler,
RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
@@ -983,7 +1022,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
return super().log(logs, start_time)
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1167,6 +1211,22 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache()
return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
"""
@@ -1175,6 +1235,22 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
"""
@@ -1183,6 +1259,49 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
"""
@@ -1191,6 +1310,22 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
"""
@@ -1199,6 +1334,15 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC):
"""
@@ -1664,6 +1808,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -1734,8 +1879,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template_from_config(
cfg=self.cfg,
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template,
tokenizer=self.tokenizer,
)

View File

@@ -22,6 +22,13 @@ import inspect
import logging
import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only
@@ -39,13 +46,6 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -8,7 +8,7 @@ import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.utils import detab_code
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")

View File

@@ -1,7 +1,9 @@
"""module for patching with unsloth optimizations"""
import inspect
import re
import types
from typing import Tuple
import torch
from accelerate.logging import get_logger
@@ -9,8 +11,6 @@ from peft import PeftModelForCausalLM
from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """
@@ -93,6 +93,15 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,8 +1,7 @@
"""
Shared utils for the monkeypatches
"""
import re
from typing import Optional, Tuple
from typing import Optional
import torch
import torch.nn.functional as F
@@ -224,12 +223,3 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -1,6 +1,5 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import inspect
import os
import signal
import sys
@@ -127,20 +126,7 @@ def train(
)
if cfg.fix_untrained_tokens:
# check if the `token_ids_to_fix` kwarg exists in the fix_untrained_tokens args
sig = inspect.signature(fix_untrained_tokens)
# if the function has the `token_ids_to_fix` arg, and fix_untrained_tokens is a list
if "token_ids_to_fix" in sig.parameters and isinstance(
cfg.fix_untrained_tokens, list
):
fix_untrained_tokens(
model,
tokenizer,
train_dataset,
token_ids_to_fix=cfg.fix_untrained_tokens,
)
else:
fix_untrained_tokens(model, tokenizer, train_dataset)
fix_untrained_tokens(model, tokenizer, train_dataset)
if cfg.local_rank == 0:
model.save_pretrained(
str(Path(cfg.output_dir)), safe_serialization=safe_serialization

View File

@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
getattr, self.layers_attribute.split("."), self.trainer.model
)
LOG.info(
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps"
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
)
def freeze_all_layers(self):

View File

@@ -128,8 +128,6 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text"
type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
skip: Optional[int] = None
class UserDefinedPrompterType(BaseModel):
@@ -147,6 +145,14 @@ class UserDefinedPrompterType(BaseModel):
field: Optional[str] = None
class LrGroup(BaseModel):
"""Custom learning rate group configuration"""
name: str
modules: List[str]
lr: float
class SFTDataset(BaseModel):
"""SFT configuration subset"""
@@ -368,13 +374,6 @@ class LoraConfig(BaseModel):
loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding
@model_validator(mode="before")
@classmethod
def validate_lora_dropout(cls, data):
if data.get("adapter") is not None and data.get("lora_dropout") is None:
data["lora_dropout"] = 0.0
return data
class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset"""
@@ -475,6 +474,7 @@ class HyperparametersConfig(BaseModel):
cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None
lr_groups: Optional[List[LrGroup]] = None
adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None
@@ -803,7 +803,7 @@ class AxolotlInputConfig(
chat_template_jinja: Optional[str] = None
default_system_message: Optional[str] = None
fix_untrained_tokens: Optional[Union[int, List[int]]] = None
fix_untrained_tokens: Optional[bool] = None
# INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None

View File

@@ -28,10 +28,8 @@ def encode_pretraining(
)
# Convert to PyTorch tensors
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
new_input_ids = []
new_labels = []
new_attention_mask = []
# Append EOS and PAD tokens to input_ids, and correct attention_mask
for i, _ in enumerate(input_ids):
@@ -42,34 +40,22 @@ def encode_pretraining(
),
dim=0,
)
targets[i] = torch.cat(
(
targets[i],
torch.tensor([tokenizer.eos_token_id, -100]),
),
dim=0,
)
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
# Concatenate tokens so that their lengths are less than max_tokens
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
for ids, labels, mask in zip(input_ids, targets, attention_mask):
for ids, mask in zip(input_ids, attention_mask):
if buffer_input_ids.numel() == max_tokens:
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
else:
buffer_input_ids = torch.cat(
@@ -83,17 +69,6 @@ def encode_pretraining(
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
@@ -106,14 +81,11 @@ def encode_pretraining(
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
buffer_input_ids = torch.tensor([], dtype=torch.long)
buffer_labels = torch.tensor([], dtype=torch.long)
buffer_attention_mask = torch.tensor([], dtype=torch.long)
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
buffer_labels = torch.cat((buffer_labels, labels), dim=0)
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
if buffer_input_ids.numel() > 0: # for any leftover tokens
@@ -129,17 +101,6 @@ def encode_pretraining(
),
dim=0,
)
buffer_labels = torch.cat(
(
buffer_labels,
torch.full(
(max_tokens - buffer_labels.numel(),),
-100,
dtype=torch.long,
),
),
dim=0,
)
buffer_attention_mask = torch.cat(
(
buffer_attention_mask,
@@ -152,12 +113,11 @@ def encode_pretraining(
dim=0,
)
new_input_ids.append(buffer_input_ids)
new_labels.append(buffer_labels)
new_attention_mask.append(buffer_attention_mask)
ret = {
"input_ids": [seq.tolist() for seq in new_input_ids],
"labels": [seq.tolist() for seq in new_labels],
"labels": [seq.tolist() for seq in new_input_ids],
"attention_mask": [seq.tolist() for seq in new_attention_mask],
}

View File

@@ -88,19 +88,14 @@ def prepare_dataset(cfg, tokenizer, processor=None):
path = cfg.pretraining_dataset
split = "train"
name = None
data_files = None
skip = 0
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
skip = cfg.pretraining_dataset[0]["skip"]
if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
@@ -109,14 +104,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
train_dataset = wrap_pretraining_dataset(
iter_ds,
load_dataset(path, streaming=True, split=split, name=name),
tokenizer,
cfg,
ds_wrapper_partial,

View File

@@ -270,7 +270,7 @@ def load_sharded_model_quant(
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time() - start:.3f} seconds")
print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache()

View File

@@ -196,7 +196,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type in ["falcon", "mistral"]:
if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids")

View File

@@ -120,12 +120,13 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches():
from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2,
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM,
)
# original_fa2_forward = LlamaFlashAttention2.forward
original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
@@ -135,7 +136,7 @@ def cleanup_monkeypatches():
# monkey patches can happen inside the tests
yield
# Reset LlamaFlashAttention2 forward
# LlamaFlashAttention2.forward = original_fa2_forward
LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
@@ -148,10 +149,7 @@ def cleanup_monkeypatches():
("transformers.models.llama",),
(
"transformers.models.llama.modeling_llama",
[
# "LlamaFlashAttention2",
"LlamaAttention",
],
["LlamaFlashAttention2", "LlamaAttention"],
),
("transformers.trainer",),
("transformers", ["Trainer"]),

View File

@@ -1,8 +1,8 @@
"""
Simple end-to-end test for Liger integration
"""
from e2e.utils import require_torch_2_4_1
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -10,32 +10,34 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
from ..utils import with_temp_dir
class LigerIntegrationTestCase:
class LigerIntegrationTestCase(unittest.TestCase):
"""
e2e tests for liger integration with Axolotl
"""
@require_torch_2_4_1
@with_temp_dir
def test_llama_wo_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_glu_activation": True,
"liger_swiglu": True,
"liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False,
"sequence_len": 1024,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
@@ -44,15 +46,15 @@ class LigerIntegrationTestCase:
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"max_steps": 10,
}
)
prepare_plugins(cfg)
@@ -61,26 +63,28 @@ class LigerIntegrationTestCase:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()
@require_torch_2_4_1
@with_temp_dir
def test_llama_w_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [
"axolotl.integrations.liger.LigerPlugin",
],
"liger_rope": True,
"liger_rms_norm": True,
"liger_glu_activation": True,
"liger_swiglu": True,
"liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True,
"sequence_len": 1024,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
@@ -89,15 +93,15 @@ class LigerIntegrationTestCase:
},
],
"num_epochs": 1,
"micro_batch_size": 2,
"gradient_accumulation_steps": 2,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"save_safetensors": True,
"bf16": "auto",
"max_steps": 5,
"max_steps": 10,
}
)
prepare_plugins(cfg)
@@ -106,4 +110,4 @@ class LigerIntegrationTestCase:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -2,6 +2,8 @@
Simple end-to-end test for Cut Cross Entropy integration
"""
from pathlib import Path
import pytest
from axolotl.cli import load_datasets
@@ -11,8 +13,6 @@ from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
# pylint: disable=duplicate-code
@@ -67,7 +67,7 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()
@pytest.mark.parametrize(
"attention_type",
@@ -95,4 +95,4 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else:
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for multipack fft llama using 4d attention masks
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir
from ..utils import require_torch_2_3_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -66,7 +67,7 @@ class Test4dMultipackLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
@@ -110,4 +111,4 @@ class Test4dMultipackLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -4,6 +4,7 @@ E2E tests for lora llama
import logging
import os
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
@@ -14,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -81,7 +82,7 @@ class TestFAXentropyLlama:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"

View File

@@ -5,6 +5,7 @@ E2E tests for falcon
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,7 +69,7 @@ class TestFalconPatched(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_ft(self, temp_dir):
@@ -108,4 +109,4 @@ class TestFalconPatched(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
import pytest
from transformers.utils import is_torch_bf16_gpu_available
@@ -15,7 +16,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -72,4 +73,4 @@ class TestFusedLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for llama w/ S2 attn
import logging
import os
import unittest
from pathlib import Path
import pytest
@@ -14,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -70,7 +71,7 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_fft_s2_attn(self, temp_dir):
@@ -110,4 +111,4 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
@@ -15,7 +16,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -75,7 +76,7 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
@@ -125,4 +126,4 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,7 +69,7 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_ft_packing(self, temp_dir):
@@ -109,4 +110,4 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for mixtral
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -65,7 +66,7 @@ class TestMixtral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_ft(self, temp_dir):
@@ -107,4 +108,4 @@ class TestMixtral(unittest.TestCase):
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir
from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,7 +69,7 @@ class TestPhiMultipack(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
@with_temp_dir
def test_qlora_packed(self, temp_dir):
@@ -119,4 +120,4 @@ class TestPhiMultipack(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -6,6 +6,7 @@ import logging
import os
import re
import subprocess
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
@@ -15,7 +16,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -82,7 +83,7 @@ class TestResumeLlama:
cli_args = TrainerCliArgs()
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"

View File

@@ -1,14 +1,9 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
import pytest
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""

View File

@@ -3,6 +3,7 @@ e2e tests for unsloth qlora
"""
import logging
import os
from pathlib import Path
import pytest
@@ -12,16 +13,13 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothQLoRA:
"""
Test class for Unsloth QLoRA Llama models
@@ -76,7 +74,7 @@ class TestUnslothQLoRA:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
@@ -126,7 +124,7 @@ class TestUnslothQLoRA:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
@@ -181,7 +179,7 @@ class TestUnslothQLoRA:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"

View File

@@ -15,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,7 +68,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_dpo_nll_lora(self, temp_dir):
@@ -113,7 +113,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_dpo_use_weighting(self, temp_dir):
@@ -158,7 +158,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
@@ -203,7 +203,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_ipo_lora(self, temp_dir):
@@ -247,7 +247,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_orpo_lora(self, temp_dir):
@@ -294,7 +294,7 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip(reason="Fix the implementation")
@with_temp_dir
@@ -358,4 +358,4 @@ class TestDPOLlamaLora(unittest.TestCase):
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for llama pretrain
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -61,7 +62,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
@@ -105,7 +106,7 @@ class TestEmbeddingsLrScale(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"

View File

@@ -5,6 +5,7 @@ E2E tests for falcon
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -70,7 +71,7 @@ class TestFalcon(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
@@ -123,7 +124,7 @@ class TestFalcon(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_ft(self, temp_dir):
@@ -162,4 +163,4 @@ class TestFalcon(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -4,8 +4,7 @@ E2E tests for llama
import logging
import os
from e2e.utils import check_model_output_exists
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -61,7 +60,7 @@ class TestLlama:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
@@ -104,7 +103,7 @@ class TestLlama:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()
def test_batch_flattening(self, temp_dir):
# pylint: disable=duplicate-code
@@ -143,4 +142,4 @@ class TestLlama:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for llama pretrain
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -63,4 +64,4 @@ class TestPretrainLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -67,7 +68,7 @@ class TestLlamaVision(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@with_temp_dir
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
@@ -112,4 +113,4 @@ class TestLlamaVision(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -64,4 +65,4 @@ class TestLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
import pytest
@@ -14,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -64,4 +65,4 @@ class TestMamba(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
@@ -14,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -68,7 +69,7 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_ft(self, temp_dir):
@@ -111,4 +112,4 @@ class TestMistral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for mixtral
import logging
import os
import unittest
from pathlib import Path
import torch
from transformers.utils import is_torch_bf16_gpu_available
@@ -15,7 +16,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -78,7 +79,7 @@ class TestMixtral(unittest.TestCase):
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_qlora_wo_fa2(self, temp_dir):
@@ -132,7 +133,7 @@ class TestMixtral(unittest.TestCase):
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_16bit_lora_w_fa2(self, temp_dir):
@@ -189,7 +190,7 @@ class TestMixtral(unittest.TestCase):
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_16bit_lora_wo_fa2(self, temp_dir):
@@ -246,7 +247,7 @@ class TestMixtral(unittest.TestCase):
model.base_model.model.model.layers[0].block_sparse_moe.gate.weight.dtype
== torch.float32
)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_ft(self, temp_dir):
@@ -286,4 +287,4 @@ class TestMixtral(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for custom optimizers using Llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, require_torch_2_5_1, with_temp_dir
from .utils import require_torch_2_5_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -64,7 +65,7 @@ class TestCustomOptimizers(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
@require_torch_2_5_1
@@ -108,11 +109,10 @@ class TestCustomOptimizers(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -144,4 +144,4 @@ class TestCustomOptimizers(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -5,6 +5,7 @@ E2E tests for lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -66,7 +67,7 @@ class TestPhi(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "pytorch_model.bin").exists()
@with_temp_dir
def test_phi_qlora(self, temp_dir):
@@ -115,4 +116,4 @@ class TestPhi(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir
from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -78,10 +78,10 @@ class TestReLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
assert (
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
).exists(), "Relora model checkpoint not found"
Path(temp_dir) / "checkpoint-100/adapter/adapter_model.safetensors"
).exists()
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"

View File

@@ -5,6 +5,7 @@ E2E tests for reward model lora llama
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
@@ -12,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -70,4 +71,4 @@ class TestRewardModelLoraLlama(unittest.TestCase):
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -14,8 +14,6 @@ import torch
from packaging import version
from tbparse import SummaryReader
from axolotl.utils.dict import DictDefault
def with_temp_dir(test_func):
@wraps(test_func)
@@ -51,19 +49,7 @@ def require_torch_2_3_1(test_case):
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.3.1")
return unittest.skipUnless(is_min_2_3_1(), "test requires torch>=2.3.1")(test_case)
def require_torch_2_4_1(test_case):
"""
Decorator marking a test that requires torch >= 2.5.1
"""
def is_min_2_4_1():
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.4.1")
return unittest.skipUnless(is_min_2_4_1(), "test requires torch>=2.4.1")(test_case)
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
def require_torch_2_5_1(test_case):
@@ -75,7 +61,7 @@ def require_torch_2_5_1(test_case):
torch_version = version.parse(torch.__version__)
return torch_version >= version.parse("2.5.1")
return unittest.skipUnless(is_min_2_5_1(), "test requires torch>=2.5.1")(test_case)
return unittest.skipUnless(is_min_2_5_1(), "test torch 2.5.1")(test_case)
def is_hopper():
@@ -95,27 +81,3 @@ def check_tensorboard(
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
assert df.value.values[-1] < lt_val, assertion_err
def check_model_output_exists(temp_dir: str, cfg: DictDefault) -> None:
"""
helper function to check if a model output file exists after training
checks based on adapter or not and if safetensors saves are enabled or not
"""
if cfg.save_safetensors:
if not cfg.adapter:
assert (Path(temp_dir) / "model.safetensors").exists()
else:
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
else:
# check for both, b/c in trl, it often defaults to saving safetensors
if not cfg.adapter:
assert (Path(temp_dir) / "pytorch_model.bin").exists() or (
Path(temp_dir) / "model.safetensors"
).exists()
else:
assert (Path(temp_dir) / "adapter_model.bin").exists() or (
Path(temp_dir) / "adapter_model.safetensors"
).exists()

View File

@@ -7,11 +7,11 @@ from typing import Optional
import pytest
from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.config import validate_config
from axolotl.utils.dict import DictDefault
@pytest.fixture(name="minimal_liger_cfg")
@pytest.fixture(name="minimal_base_cfg")
def fixture_cfg():
return DictDefault(
{
@@ -25,57 +25,56 @@ def fixture_cfg():
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"plugins": ["axolotl.integrations.liger.LigerPlugin"],
}
)
# pylint: disable=too-many-public-methods
class TestValidation:
class BaseValidation:
"""
Test the validation module for liger
Base validation module to setup the log capture
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
caplog.set_level(logging.WARNING)
self._caplog = caplog
def test_deprecated_swiglu(self, minimal_liger_cfg):
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module for liger
"""
def test_deprecated_swiglu(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
}
| minimal_liger_cfg
| minimal_cfg
)
with self._caplog.at_level(
logging.WARNING, logger="axolotl.integrations.liger.args"
):
prepare_plugins(test_cfg)
with self._caplog.at_level(logging.WARNING):
updated_cfg = validate_config(test_cfg)
# TODO this test is brittle in CI
# assert (
# "The 'liger_swiglu' argument is deprecated"
# in self._caplog.records[0].message
# )
assert (
"The 'liger_swiglu' argument is deprecated"
in self._caplog.records[0].message
)
assert updated_cfg.liger_swiglu is None
assert updated_cfg.liger_glu_activation is False
assert updated_cfg.liger_glu_activations is False
def test_conflict_swiglu_ligergluactivation(self, minimal_liger_cfg):
def test_conflict_swiglu_ligergluactivation(self, minimal_cfg):
test_cfg = DictDefault(
{
"liger_swiglu": False,
"liger_glu_activation": True,
"liger_glu_activations": True,
}
| minimal_liger_cfg
| minimal_cfg
)
with pytest.raises(
ValueError,
match=r".*You cannot have both `liger_swiglu` and `liger_glu_activation` set.*",
):
prepare_plugins(test_cfg)
validate_config(test_cfg)

View File

@@ -1,69 +0,0 @@
"""
tests for loading loras
"""
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
# pylint: disable=duplicate-code
minimal_config = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"learning_rate": 0.000001,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
}
],
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
class TestLoRALoad:
"""
Test class for loading LoRA weights
"""
def test_load_lora_weights(self):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.0,
"lora_target_linear": True,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"sequence_len": 1024,
}
| minimal_config
)
cfg = validate_config(cfg)
normalize_config(cfg)
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)
def test_load_lora_weights_empty_dropout(self):
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": None,
"lora_target_linear": True,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"sequence_len": 1024,
}
| minimal_config
)
cfg = validate_config(cfg)
normalize_config(cfg)
assert cfg.lora_dropout == 0.0
tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer)

View File

@@ -4,7 +4,9 @@ import json
import logging
import unittest
from pathlib import Path
from typing import Optional
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
@@ -63,6 +65,12 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
Test class for prompt tokenization strategies.
"""
_caplog: Optional[pytest.LogCaptureFixture] = None
@pytest.fixture(autouse=True)
def inject_fixtures(self, caplog):
self._caplog = caplog
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")