Compare commits

..

13 Commits

Author SHA1 Message Date
Wing Lian
ee20600b9a use alternate math-hard repo 2025-01-13 08:46:35 -05:00
Wing Lian
fd91de3ea6 apply chat template as arg 2025-01-12 17:38:32 -05:00
Wing Lian
530bf77cf9 revision support 2025-01-12 05:17:03 -05:00
Wing Lian
bfc91a91ca use chat template 2025-01-11 23:18:27 -05:00
Wing Lian
5c226b600d pr feedback 2025-01-08 08:38:06 -05:00
Wing Lian
af66f7c274 update link in README to include utm 2025-01-07 15:13:18 -05:00
Wing Lian
079f94ee99 include modal in requirements 2025-01-07 08:48:25 -05:00
Wing Lian
981ad965d0 allow minimal yaml for lm eval 2025-01-06 17:41:10 -05:00
Wing Lian
7ba701a355 cache bust when using branch, grab sha of latest image tag, update lm-eval dep 2025-01-06 16:19:08 -05:00
Wing Lian
0390bce7aa lm_eval option to not post eval, and append not extend 2025-01-06 11:52:07 -05:00
Wing Lian
2741d8de23 Fix the sub call to lm-eval 2025-01-06 11:44:55 -05:00
Wing Lian
27a88f37cd do lm_eval in cloud too 2025-01-06 11:17:14 -05:00
Wing Lian
6da8abc01f native support for modal cloud from CLI 2025-01-05 21:49:53 -05:00
116 changed files with 2600 additions and 2071 deletions

View File

@@ -1,7 +1,6 @@
name: lint name: lint
on: on:
# check on PRs, and manual triggers # check on PRs, and manual triggers
merge_group:
pull_request: pull_request:
paths: paths:
- '**.py' - '**.py'

View File

@@ -25,6 +25,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: mamba-ssm axolotl_extras: mamba-ssm
is_latest: true
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -35,7 +36,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -92,6 +92,7 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.3.1 pytorch: 2.3.1
axolotl_extras: axolotl_extras:
is_latest: true
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
@@ -102,7 +103,6 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -52,7 +52,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -129,7 +129,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -1,7 +1,6 @@
name: Tests name: Tests
on: on:
# check on push/merge to main, PRs, and manual triggers # check on push/merge to main, PRs, and manual triggers
merge_group:
push: push:
branches: branches:
- "main" - "main"
@@ -61,15 +60,6 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@@ -110,15 +100,6 @@ jobs:
run: | run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist: pytest-sdist:
name: PyTest from Source Dist name: PyTest from Source Dist
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -134,15 +115,6 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-${{ hashFiles('**/conftest.py') }}
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@@ -184,15 +156,6 @@ jobs:
run: | run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st: docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
@@ -220,7 +183,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
@@ -266,7 +229,7 @@ jobs:
- name: Install Modal - name: Install Modal
run: | run: |
python -m pip install --upgrade pip python -m pip install --upgrade pip
pip install modal==0.71.8 jinja2 pip install modal==0.63.64 jinja2
- name: Update env vars - name: Update env vars
run: | run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV

View File

@@ -23,7 +23,7 @@ repos:
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/PyCQA/pylint - repo: https://github.com/PyCQA/pylint
rev: v3.3.0 rev: v2.17.4
hooks: hooks:
- id: pylint - id: pylint
- repo: https://github.com/pre-commit/mirrors-mypy - repo: https://github.com/pre-commit/mirrors-mypy

View File

@@ -1,5 +1,5 @@
[MASTER] [MASTER]
init-hook="from pylint.config import find_default_config_files; import sys; sys.path.append(next(find_default_config_files()).parent.as_posix())" init-hook="from pylint.config import find_pylintrc; import os, sys; sys.path.append(os.path.dirname(find_pylintrc()))"
[TYPECHECK] [TYPECHECK]
@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-positional-arguments, possibly-used-before-assignment

View File

@@ -217,7 +217,7 @@ If you love axolotl, consider sponsoring the project by reaching out directly to
--- ---
- [Modal](https://modal.com/) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more. - [Modal](https://www.modal.com?utm_source=github&utm_medium=github&utm_campaign=axolotl) Modal lets you run data/AI jobs in the cloud, by just writing a few lines of Python. Customers use Modal to deploy Gen AI models at large scale, fine-tune LLM models, run protein folding simulations, and much more.
--- ---
@@ -519,8 +519,8 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
train_on_split: validation train_on_split: validation
# loading from s3 or gcs # loading from s3 or gcs
# s3 creds will be loaded from the system default / gcs will attempt to load from gcloud creds, google metadata service, or anon # s3 creds will be loaded from the system default and gcs only supports public access
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above - path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
... ...
# Loading Data From a Public URL # Loading Data From a Public URL

View File

@@ -8,7 +8,6 @@ ENV PYTORCH_VERSION="{{ PYTORCH_VERSION }}"
ENV GITHUB_REF="{{ GITHUB_REF }}" ENV GITHUB_REF="{{ GITHUB_REF }}"
ENV GITHUB_SHA="{{ GITHUB_SHA }}" ENV GITHUB_SHA="{{ GITHUB_SHA }}"
ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}" ENV NIGHTLY_BUILD="{{ NIGHTLY_BUILD }}"
ENV HF_HOME="{{ HF_HOME }}"
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev

View File

@@ -6,6 +6,5 @@ python -c "import torch; assert '$PYTORCH_VERSION' in torch.__version__"
pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/ pytest -v --durations=10 -n8 --ignore=tests/e2e/ --ignore=tests/patched/ /workspace/axolotl/tests/
# pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/ # pytest -v --durations=10 -n8 --dist loadfile /workspace/axolotl/tests/patched/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/patched/
pytest -v --durations=10 -n1 /workspace/axolotl/tests/e2e/solo/
pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/ pytest -v --durations=10 /workspace/axolotl/tests/e2e/integrations/
pytest -v --durations=10 --ignore=tests/e2e/solo/ --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/ pytest -v --durations=10 --ignore=tests/e2e/patched/ --ignore=tests/e2e/multigpu/ --ignore=tests/e2e/integrations/ /workspace/axolotl/tests/e2e/

View File

@@ -28,7 +28,6 @@ df_args = {
"CUDA": os.environ.get("CUDA", "121"), "CUDA": os.environ.get("CUDA", "121"),
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
} }
dockerfile_contents = df_template.render(**df_args) dockerfile_contents = df_template.render(**df_args)
@@ -49,12 +48,6 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[]) app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 2)) N_GPUS = int(os.environ.get("N_GPUS", 2))
GPU_CONFIG = modal.gpu.H100(count=N_GPUS) GPU_CONFIG = modal.gpu.H100(count=N_GPUS)
@@ -74,7 +67,6 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60, timeout=60 * 60,
cpu=8.0, cpu=8.0,
memory=131072 * N_GPUS, memory=131072 * N_GPUS,
volumes=VOLUME_CONFIG,
) )
def cicd_pytest(): def cicd_pytest():
run_cmd("./cicd/multigpu.sh", "/workspace/axolotl") run_cmd("./cicd/multigpu.sh", "/workspace/axolotl")

View File

@@ -29,7 +29,6 @@ df_args = {
"GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"), "GITHUB_REF": os.environ.get("GITHUB_REF", "refs/heads/main"),
"GITHUB_SHA": os.environ.get("GITHUB_SHA", ""), "GITHUB_SHA": os.environ.get("GITHUB_SHA", ""),
"NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""), "NIGHTLY_BUILD": os.environ.get("NIGHTLY_BUILD", ""),
"HF_HOME": "/workspace/data/huggingface-cache/hub",
} }
dockerfile_contents = df_template.render(**df_args) dockerfile_contents = df_template.render(**df_args)
@@ -51,12 +50,6 @@ cicd_image = (
app = App("Axolotl CI/CD", secrets=[]) app = App("Axolotl CI/CD", secrets=[])
hf_cache_volume = modal.Volume.from_name(
"axolotl-ci-hf-hub-cache", create_if_missing=True
)
VOLUME_CONFIG = {
"/workspace/data/huggingface-cache/hub": hf_cache_volume,
}
N_GPUS = int(os.environ.get("N_GPUS", 1)) N_GPUS = int(os.environ.get("N_GPUS", 1))
GPU_CONFIG = modal.gpu.A10G(count=N_GPUS) GPU_CONFIG = modal.gpu.A10G(count=N_GPUS)
@@ -76,7 +69,6 @@ def run_cmd(cmd: str, run_folder: str):
timeout=60 * 60, timeout=60 * 60,
cpu=8.0, cpu=8.0,
memory=131072, memory=131072,
volumes=VOLUME_CONFIG,
) )
def cicd_pytest(): def cicd_pytest():
run_cmd("./cicd/cicd.sh", "/workspace/axolotl") run_cmd("./cicd/cicd.sh", "/workspace/axolotl")

View File

@@ -20,8 +20,7 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \ printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \ printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \ chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh && \ chmod +x /root/cloud-entrypoint.sh
echo 'set-option -g history-limit 5000' >> ~/.tmux.conf
ENTRYPOINT ["/root/cloud-entrypoint.sh"] ENTRYPOINT ["/root/cloud-entrypoint.sh"]
CMD ["sleep", "infinity"] CMD ["sleep", "infinity"]

View File

@@ -244,8 +244,6 @@ total_num_tokens:
sample_packing_group_size: 100000 sample_packing_group_size: 100000
# The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples. # The number of samples which can be packed into one sequence. Increase if using a large sequence_len with many short samples.
sample_packing_bin_size: 200 sample_packing_bin_size: 200
# whether to concatenate samples during pretraining
pretraining_sample_concatenation:
# Use batch flattening for speedups when not using sample_packing # Use batch flattening for speedups when not using sample_packing
batch_flattening: batch_flattening:
@@ -360,11 +358,10 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
eval_steps: # Leave empty to eval at each epoch, integer for every N steps. float for fraction of total steps eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
eval_strategy: # Set to `"no"` to skip evaluation, `"epoch"` at end of each epoch, leave empty to infer from `eval_steps`. save_strategy: # Set to `"no"` to skip checkpoint saves
save_strategy: # Set to `"no"` to skip checkpoint saves, `"epoch"` at end of each epoch, `"best"` when better result is achieved, leave empty to infer from `save_steps`. save_steps: # Leave empty to save at each epoch
save_steps: # Leave empty to save at each epoch, integer for every N steps. float for fraction of total steps
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
save_total_limit: # Checkpoints saved at a time save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that # Maximum number of iterations to train for. It precedes num_epochs which means that

View File

@@ -19,14 +19,7 @@ For pretraining, there is no prompt template or roles. The only required field
Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming: Axolotl usually loads the entire dataset into memory. This will be challenging for large datasets. Use the following config to enable streaming:
```{.yaml filename="config.yaml"} ```{.yaml filename="config.yaml"}
pretraining_dataset: pretraining_dataset: # hf path only
- name:
path:
split:
text_column: # column in dataset with the data, usually `text`
type: pretrain
trust_remote_code:
skip: # number of rows of data to skip over from the beginning
... ...
``` ```

View File

@@ -1,29 +0,0 @@
---
title: Learning Rate Groups
description: "Setting different learning rates by module name"
---
## Background
Inspired by LoRA+, Axolotl allows practitioners to specify separate learning rates for each module or groups of
modules in a model.
## Example
```yaml
lr_groups:
- name: o_proj
modules:
- self_attn.o_proj.weight
lr: 1e-6
- name: q_proj
modules:
- model.layers.2.self_attn.q_proj.weight
lr: 1e-5
learning_rate: 2e-5
```
In this example, we have a default learning rate of 2e-5 across the entire model, but we have a separate learning rate
of 1e-6 for all the self attention `o_proj` modules across all layers, and a learning are of 1e-5 to the 3rd layer's
self attention `q_proj` module.

15
examples/cloud/modal.yaml Normal file
View File

@@ -0,0 +1,15 @@
volumes:
- name: axolotl-data
mount: /workspace/data
- name: axolotl-artifacts
mount: /workspace/artifacts
secrets:
- HF_TOKEN
- WANDB_API_KEY
branch: cli-cloud-modal
gpu: h100
gpu_count: 1
memory: 128
timeout: 86400
timeout_preprocess: 14400
memory_preprocess: 32

11
lm_eval-kd.yaml Normal file
View File

@@ -0,0 +1,11 @@
lm_eval_model: axolotl-ai-co/numina-8b-ep1-exp1
lm_eval_tasks:
- leaderboard_math_hard
lm_eval_batch_size: 64
apply_chat_template: false
wandb_project: numina-kd-experiment
wandb_entity: axolotl-ai
bf16: true
flash_attention: true
output_dir: ./outputs/model-evals-out

View File

@@ -2,7 +2,7 @@
# START section of dependencies that don't install on Darwin/MacOS # START section of dependencies that don't install on Darwin/MacOS
bitsandbytes==0.45.0 bitsandbytes==0.45.0
triton>=3.0.0 triton>=2.3.0
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1
flash-attn==2.7.0.post2 flash-attn==2.7.0.post2
xformers>=0.0.23.post1 xformers>=0.0.23.post1
@@ -13,18 +13,19 @@ liger-kernel==0.5.2
packaging==23.2 packaging==23.2
peft==0.14.0 peft==0.14.0
transformers==4.48.1 transformers==4.47.1
tokenizers>=0.21.0 tokenizers>=0.20.1
accelerate==1.3.0 accelerate==1.2.1
datasets==3.2.0 datasets==3.1.0
deepspeed==0.16.1 deepspeed==0.16.1
trl==0.13.0 trl==0.12.1
optimum==1.16.2 optimum==1.16.2
hf_transfer hf_transfer
sentencepiece sentencepiece
gradio==3.50.2 gradio==3.50.2
modal==0.70.5
pydantic==2.6.3 pydantic==2.6.3
addict addict
fire fire
@@ -61,4 +62,4 @@ antlr4-python3-runtime==4.13.2
torchao==0.7.0 torchao==0.7.0
schedulefree==1.3.0 schedulefree==1.3.0
axolotl-contribs-lgpl==0.0.3 axolotl-contribs-lgpl==0.0.2

View File

@@ -30,7 +30,7 @@ def parse_dataset(dataset=None, split="train"):
) )
ds_cfg["field_messages"] = field_messages ds_cfg["field_messages"] = field_messages
message_fields = features[field_messages][0].keys() message_fields = features["conversations"][0].keys()
message_field_role = None message_field_role = None
for key in ["from", "role"]: for key in ["from", "role"]:
if key in message_fields: if key in message_fields:

52
scripts/finetune.py Normal file
View File

@@ -0,0 +1,52 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import logging
from pathlib import Path
import fire
import transformers
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
do_inference,
do_merge_lora,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.cli.shard import shard
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
LOG = logging.getLogger("axolotl.scripts.finetune")
def do_cli(config: Path = Path("examples/"), **kwargs):
print_axolotl_text_art()
LOG.warning(
str(
PendingDeprecationWarning(
"scripts/finetune.py will be replaced with calling axolotl.cli.train"
)
)
)
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if parsed_cli_args.inference:
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.merge_lora:
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
elif parsed_cli_args.shard:
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -1,10 +1,15 @@
dP dP dP #@@ #@@ @@# @@#
88 88 88 @@ @@ @@ @@ =@@# @@ #@ =@@#.
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88 @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
88' `88 `8bd8' 88' `88 88 88' `88 88 88 #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
88. .88 .d88b. 88. .88 88 88. .88 88 88 @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
`88888P8 dP' `dP `88888P' dP `88888P' dP dP @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands: Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:

View File

@@ -1,5 +1,4 @@
"""setup.py for axolotl""" """setup.py for axolotl"""
import ast import ast
import os import os
import platform import platform
@@ -30,30 +29,15 @@ def parse_requirements():
elif not is_extras and line and line[0] != "#": elif not is_extras and line and line[0] != "#":
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)
try: try:
xformers_version = [req for req in _install_requires if "xformers" in req][0] xformers_version = [req for req in _install_requires if "xformers" in req][0]
triton_version = [req for req in _install_requires if "triton" in req][0]
torchao_version = [req for req in _install_requires if "torchao" in req][0] torchao_version = [req for req in _install_requires if "torchao" in req][0]
autoawq_version = [req for req in _install_requires if "autoawq" in req][0] autoawq_version = [req for req in _install_requires if "autoawq" in req][0]
if "Darwin" in platform.system(): if "Darwin" in platform.system():
# skip packages not compatible with OSX # don't install xformers on MacOS
skip_packages = [ _install_requires.pop(_install_requires.index(xformers_version))
"bitsandbytes",
"triton",
"mamba-ssm",
"flash-attn",
"xformers",
"autoawq",
"liger-kernel",
]
_install_requires = [
req
for req in _install_requires
if re.split(r"[>=<]", req)[0].strip() not in skip_packages
]
print(
_install_requires, [req in skip_packages for req in _install_requires]
)
else: else:
# detect the version of torch already installed # detect the version of torch already installed
# and set it so dependencies don't clobber the torch version # and set it so dependencies don't clobber the torch version
@@ -89,8 +73,6 @@ def parse_requirements():
_install_requires.append("xformers==0.0.28.post1") _install_requires.append("xformers==0.0.28.post1")
elif (major, minor) >= (2, 3): elif (major, minor) >= (2, 3):
_install_requires.pop(_install_requires.index(torchao_version)) _install_requires.pop(_install_requires.index(torchao_version))
_install_requires.pop(_install_requires.index(triton_version))
_install_requires.append("triton>=2.3.1")
if patch == 0: if patch == 0:
_install_requires.pop(_install_requires.index(xformers_version)) _install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append("xformers>=0.0.26.post1") _install_requires.append("xformers>=0.0.26.post1")

View File

@@ -1,5 +1,568 @@
"""Axolotl CLI module initialization.""" """Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import json
import logging
import math
import os import os
import random
import sys
import tempfile
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import requests
import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate.commands.config import config_args
from art import text2art
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import _is_package_available
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
prepare_plugins,
validate_config,
)
from axolotl.utils.data import load_prepare_dpo_datasets, prepare_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
configure_logging()
LOG = logging.getLogger("axolotl.scripts")
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_legacy_axolotl_text_art(suffix=None):
font = "nancyj"
ascii_text = " axolotl"
if suffix:
ascii_text += f" x {suffix}"
ascii_art = text2art(ascii_text, font=font)
if is_main_process():
print(ascii_art)
print_dep_versions()
def print_axolotl_text_art(
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
print(AXOLOTL_LOGO)
def print_dep_versions():
packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"]
max_len = max(len(pkg) for pkg in packages)
if is_main_process():
print("*" * 40)
print("**** Axolotl Dependency Versions *****")
for pkg in packages:
pkg_version = _is_package_available(pkg, return_version=True)
print(f"{pkg: >{max_len}}: {pkg_version[1]: <15}")
print("*" * 40)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
# instruction = pathlib.Path("/proc/self/fd/0").read_text()
return instruction
def do_merge_lora(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.info("running merge of LoRA with base model")
model = model.merge_and_unload(progressbar=True)
try:
model.to(dtype=cfg.torch_dtype)
except RuntimeError:
pass
model.generation_config.do_sample = True
if cfg.local_rank == 0:
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_inference(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def choose_config(path: Path):
yaml_files = list(path.glob("*.yml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0],
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg
def load_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def load_rl_datasets(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs, # pylint: disable=unused-argument
) -> TrainDatasetMeta:
train_dataset, eval_dataset = load_prepare_dpo_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
check_dataset_labels(
train_dataset.select(
[
random.randrange(0, len(train_dataset) - 1) # nosec
for _ in range(cli_args.debug_num_examples)
]
),
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def check_accelerate_default_config():
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token():
# Skip check if HF_HUB_OFFLINE is set to True
if os.getenv("HF_HUB_OFFLINE") == "1":
LOG.info(
"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used."
)
return True
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

View File

@@ -1,49 +0,0 @@
"""Module for axolotl CLI command arguments."""
from dataclasses import dataclass, field
from typing import Optional
@dataclass
class PreprocessCliArgs:
"""Dataclass with CLI arguments for `axolotl preprocess` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(
default=None,
metadata={
"help": "Use IterableDataset for streaming processing of large datasets"
},
)
@dataclass
class TrainerCliArgs:
"""Dataclass with CLI arguments for `axolotl train` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@dataclass
class EvaluateCliArgs:
"""Dataclass with CLI arguments for `axolotl evaluate` command."""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
@dataclass
class InferenceCliArgs:
"""Dataclass with CLI arguments for `axolotl inference` command."""
prompter: Optional[str] = field(default=None)

View File

@@ -1,23 +0,0 @@
"""Axolotl ASCII logo utils."""
from axolotl.utils.distributed import is_main_process
AXOLOTL_LOGO = """
#@@ #@@ @@# @@#
@@ @@ @@ @@ =@@# @@ #@ =@@#.
@@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@
#@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@
@@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@
@@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@
=@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@
@@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@
=@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@
@@@@ @@@@@@@@@@@@@@@@
"""
def print_axolotl_text_art():
"""Prints axolotl ASCII art."""
if is_main_process():
print(AXOLOTL_LOGO)

View File

@@ -1,50 +0,0 @@
"""Various checks for Axolotl CLI."""
import logging
import os
from pathlib import Path
from accelerate.commands.config import config_args
from huggingface_hub import HfApi
from huggingface_hub.utils import LocalTokenNotFoundError
from axolotl.logging_config import configure_logging
configure_logging()
LOG = logging.getLogger(__name__)
def check_accelerate_default_config() -> None:
"""Logs at warning level if no accelerate config file is found."""
if Path(config_args.default_yaml_config_file).exists():
LOG.warning(
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
)
def check_user_token() -> bool:
"""Checks for HF user info. Check is skipped if HF_HUB_OFFLINE=1.
Returns:
Boolean indicating successful check (i.e., HF_HUB_OFFLINE=1 or HF user info is retrieved).
Raises:
LocalTokenNotFoundError: If HF user info can't be retrieved.
"""
# Skip check if HF_HUB_OFFLINE is set to True
if os.getenv("HF_HUB_OFFLINE") == "1":
LOG.info(
"Skipping HuggingFace token verification because HF_HUB_OFFLINE is set to True. Only local files will be used."
)
return True
# Verify if token is valid
api = HfApi()
try:
user_info = api.whoami()
return bool(user_info)
except LocalTokenNotFoundError:
LOG.warning(
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
)
return False

View File

@@ -0,0 +1,56 @@
"""
launch axolotl in supported cloud platforms
"""
from pathlib import Path
from typing import Union
import yaml
from axolotl.cli import print_axolotl_text_art
from axolotl.cli.cloud.modal_ import ModalCloud
from axolotl.utils.dict import DictDefault
def load_cloud_cfg(cloud_config: Union[Path, str]) -> DictDefault:
"""Load and validate cloud configuration."""
# Load cloud configuration.
with open(cloud_config, encoding="utf-8") as file:
cloud_cfg: DictDefault = DictDefault(yaml.safe_load(file))
return cloud_cfg
def do_cli_preprocess(
cloud_config: Union[Path, str],
config: Union[Path, str] = Path("examples/"),
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.preprocess(config_yaml)
def do_cli_train(
cloud_config: Union[Path, str],
config: Union[Path, str] = Path("examples/"),
accelerate: bool = True,
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.train(config_yaml, accelerate=accelerate)
def do_cli_lm_eval(
cloud_config: Union[Path, str],
config: Union[Path, str] = Path("examples/"),
) -> None:
print_axolotl_text_art()
cloud_cfg = load_cloud_cfg(cloud_config)
cloud = ModalCloud(cloud_cfg)
with open(config, "r", encoding="utf-8") as file:
config_yaml = file.read()
cloud.lm_eval(config_yaml)

View File

@@ -0,0 +1,18 @@
"""
base class for cloud platforms from cli
"""
from abc import ABC, abstractmethod
class Cloud(ABC):
"""
Abstract base class for cloud platforms.
"""
@abstractmethod
def preprocess(self, config_yaml: str, *args, **kwargs) -> None:
pass
@abstractmethod
def train(self, config_yaml: str, accelerate: bool = True) -> str:
pass

View File

@@ -0,0 +1,272 @@
"""
Modal Cloud support from CLI
"""
import copy
import json
import os
import subprocess # nosec B404
from pathlib import Path
from random import randint
import modal
from axolotl.cli.cloud.base import Cloud
def run_cmd(cmd: str, run_folder: str, volumes=None):
"""Run a command inside a folder, with Modal Volume reloading before and commit on success."""
# Ensure volumes contain latest files.
if volumes:
for _, vol in volumes.items():
vol.reload()
# modal workaround so it doesn't use the automounted axolotl
new_env = copy.deepcopy(os.environ)
if "PYTHONPATH" in new_env:
del new_env["PYTHONPATH"]
# Propagate errors from subprocess.
if exit_code := subprocess.call( # nosec B603
cmd.split(), cwd=run_folder, env=new_env
):
exit(exit_code) # pylint: disable=consider-using-sys-exit
# Commit writes to volume.
if volumes:
for _, vol in volumes.items():
vol.commit()
class ModalCloud(Cloud):
"""
Modal Cloud implementation.
"""
def __init__(self, config, app=None):
self.config = config
if not app:
app = modal.App()
self.app = app
self.volumes = {}
if config.volumes:
for volume_config in config.volumes:
_, mount, vol = self.create_volume(volume_config)
self.volumes[mount] = (vol, volume_config)
def get_env(self):
res = {
"HF_DATASETS_CACHE": "/workspace/data/huggingface-cache/datasets",
"HF_HUB_CACHE": "/workspace/data/huggingface-cache/hub",
}
for key in self.config.get("env", []):
if isinstance(key, str):
if val := os.environ.get(key, ""):
res[key] = val
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res[key_] = val
return res
def get_image(self):
docker_tag = "main-py3.11-cu124-2.5.1"
if self.config.docker_tag:
docker_tag = self.config.docker_tag
docker_image = f"axolotlai/axolotl:{docker_tag}"
# grab the sha256 hash from docker hub for this image+tag
# this ensures that we always get the latest image for this tag, even if it's already cached
try:
manifest = subprocess.check_output( # nosec B602
f"docker manifest inspect {docker_image}",
shell=True,
).decode("utf-8")
sha256_hash = json.loads(manifest)["manifests"][0]["digest"]
except subprocess.CalledProcessError:
sha256_hash = None
# create the image
if sha256_hash:
image = modal.Image.from_registry(f"axolotlai/axolotl@{sha256_hash}")
else:
image = modal.Image.from_registry(docker_image)
# branch
if self.config.branch:
image = image.dockerfile_commands(
[
# Random id for cache busting of branch commits
f"RUN echo '{str(randint(0, 1000000))}'", # nosec B311
f"RUN cd /workspace/axolotl && git fetch && git checkout {self.config.branch}",
"RUN cd /workspace/ && git clone https://github.com/winglian/lm-evaluation-harness.git && cd lm-evaluation-harness && pip install -e .[math]",
]
)
if env := self.get_env():
image = image.env(env)
image = image.pip_install("fastapi==0.110.0", "pydantic==2.6.3")
return image
def get_secrets(self):
res = []
if self.config.secrets:
for key in self.config.get("secrets", []):
# pylint: disable=duplicate-code
if isinstance(key, str):
if val := os.environ.get(key, ""):
res.append(modal.Secret.from_dict({key: val}))
elif isinstance(key, dict):
(key_, val) = list(key.items())[0]
res.append(modal.Secret.from_dict({key_: val}))
return res
def create_volume(self, volume_config):
name = volume_config.name
mount = volume_config.mount
return name, mount, modal.Volume.from_name(name, create_if_missing=True)
def get_ephemeral_disk_size(self):
return 1000 * 525 # 1 TiB
def get_preprocess_timeout(self):
if self.config.timeout_preprocess:
return int(self.config.timeout_preprocess)
return 60 * 60 * 3 # 3 hours
def get_preprocess_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
if self.config.memory_preprocess:
memory = int(self.config.memory_preprocess)
return 1024 * memory
def get_preprocess_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=8.0,
ephemeral_disk=self.get_ephemeral_disk_size(),
memory=self.get_preprocess_memory(),
timeout=self.get_preprocess_timeout(),
secrets=self.get_secrets(),
)
def preprocess(self, config_yaml: str, *args, **kwargs):
modal_fn = self.get_preprocess_env()(_preprocess)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
*args,
**kwargs,
)
def get_train_timeout(self):
if self.config.timeout:
return int(self.config.timeout)
return 60 * 60 * 24 # 24 hours
def get_train_gpu(self): # pylint: disable=too-many-return-statements
count = self.config.gpu_count or 1
family = self.config.gpu.lower() or "l40s"
if family == "l40s":
return modal.gpu.L40S(count=count)
if family == "a100":
return modal.gpu.A100(count=count, size="40GB")
if family == "a100-80gb":
return modal.gpu.A100(count=count, size="80GB")
if family in ["a10", "a10g"]:
return modal.gpu.A10G(count=count)
if family == "h100":
return modal.gpu.H100(count=count)
if family == "t4":
return modal.gpu.T4(count=count)
if family == "l4":
return modal.gpu.L4(count=count)
raise ValueError(f"Unsupported GPU family: {family}")
def get_train_memory(self):
memory = 128 # default to 128GiB
if self.config.memory:
memory = int(self.config.memory)
return 1024 * memory
def get_train_env(self):
return self.app.function(
image=self.get_image(),
volumes={k: v[0] for k, v in self.volumes.items()},
cpu=16.0,
gpu=self.get_train_gpu(),
memory=self.get_train_memory(),
timeout=self.get_train_timeout(),
secrets=self.get_secrets(),
)
def train(self, config_yaml: str, accelerate: bool = True):
modal_fn = self.get_train_env()(_train)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
accelerate=accelerate,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def lm_eval(self, config_yaml: str):
modal_fn = self.get_train_env()(_lm_eval)
with modal.enable_output():
with self.app.run(detach=True):
modal_fn.remote(
config_yaml,
volumes={k: v[0] for k, v in self.volumes.items()},
)
def _preprocess(config_yaml: str, volumes=None):
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl preprocess /workspace/artifacts/axolotl/config.yaml --dataset-processes=8",
run_folder,
volumes,
)
def _train(config_yaml: str, accelerate: bool = True, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
if accelerate:
accelerate_args = "--accelerate"
else:
accelerate_args = "--no-accelerate"
run_cmd(
f"axolotl train {accelerate_args} /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)
def _lm_eval(config_yaml: str, volumes=None):
with open(
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
) as f_out:
f_out.write(config_yaml)
run_folder = "/workspace/artifacts/axolotl"
run_cmd(
"axolotl lm-eval /workspace/artifacts/axolotl/config.yaml",
run_folder,
volumes,
)

View File

@@ -1,217 +0,0 @@
"""Configuration loading and processing."""
import json
import logging
import os
import tempfile
from pathlib import Path
from typing import Union
from urllib.parse import urlparse
import requests
import torch
import yaml
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.integrations.base import PluginManager
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
normalize_config,
validate_config,
)
from axolotl.utils.dict import DictDefault
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
from axolotl.utils.trainer import prepare_opinionated_env, prepare_optim_env
from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = logging.getLogger(__name__)
def check_remote_config(config: Union[str, Path]) -> Union[str, Path]:
"""
First, determines if the passed config is a valid HTTPS URL. Then, attempts to query
for it and parse its content, first as JSON, then as YAML (YAML is preferred).
Finally, the parsed content is written to a local file and its path is returned.
Args:
config: HTTPS URL to a YAML or JSON file.
Returns:
Either the original `config` if it's not a valid HTTPS URL, or the path to the
downloaded remote config.
Raises:
ValueError: If the remote configuration is neither valid JSON or YAML.
RuntimeError: If some request-related exception occurs from the file download.
Exception: Catch-all for any other exception.
"""
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly
# considered YAML.
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML.
# This can happen when you forget to point to a raw GitHub link.
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}: \n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def choose_config(path: Path) -> str:
"""
Helper method for choosing a `axolotl` config YAML file (considering only files
ending with `.yml` or `.yaml`). If more than one config file exists in the passed
`path`, the user is prompted to choose one.
Args:
path: Directory in which config file(s) are stored.
Returns:
Path to either (1) the sole YAML file, or (2) if more than one YAML files exist,
the user-selected YAML file.
Raises:
ValueError: If no YAML files are found in the given `path`.
"""
yaml_files = list(path.glob("*.yml")) + list(path.glob("*.yaml"))
if not yaml_files:
raise ValueError(
"No YAML config files found in the specified directory. Are you using a .yml extension?"
)
if len(yaml_files) == 1:
print(f"Using default YAML file '{yaml_files[0]}'")
return str(yaml_files[0])
print("Choose a YAML file:")
for idx, file in enumerate(yaml_files):
print(f"{idx + 1}. {file}")
chosen_file = None
while chosen_file is None:
try:
choice = int(input("Enter the number of your choice: "))
if 1 <= choice <= len(yaml_files):
chosen_file = str(yaml_files[choice - 1])
else:
print("Invalid choice. Please choose a number from the list.")
except ValueError:
print("Invalid input. Please enter a number.")
return chosen_file
def prepare_plugins(cfg: DictDefault):
"""
Registers the plugins for the given configuration.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
if cfg.get("plugins"):
plugin_manager = PluginManager.get_instance()
for plugin_name in cfg["plugins"]:
plugin_manager.register(plugin_name)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs) -> DictDefault:
"""
Loads the `axolotl` configuration stored at `config`, validates it, and performs
various setup.
Args:
config: Path (local or remote) to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Returns:
`DictDefault` mapping configuration keys to values.
"""
config = check_remote_config(config)
if Path(config).is_dir():
config = choose_config(Path(config))
# Load the config from the yaml file
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
for k, _ in kwargs.items():
# if not strict, allow writing to cfg even if it's not in the yml already
if k in cfg_keys or not cfg.strict:
# handle booleans
if isinstance(cfg[k], bool):
cfg[k] = bool(kwargs[k])
else:
cfg[k] = kwargs[k]
cfg.axolotl_config_path = config
try:
device_props = torch.cuda.get_device_properties("cuda")
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
except: # pylint: disable=bare-except # noqa: E722
gpu_version = None
prepare_plugins(cfg)
cfg = validate_config(
cfg,
capabilities={
"bf16": is_torch_bf16_gpu_available(),
"n_gpu": int(os.environ.get("WORLD_SIZE", 1)),
"compute_capability": gpu_version,
},
env_capabilities={
"torch_version": str(torch.__version__).split("+", maxsplit=1)[0]
},
)
prepare_optim_env(cfg)
prepare_opinionated_env(cfg)
normalize_config(cfg)
normalize_cfg_datasets(cfg)
setup_wandb_env_vars(cfg)
setup_mlflow_env_vars(cfg)
setup_comet_env_vars(cfg)
return cfg

View File

@@ -1,5 +1,6 @@
"""CLI to run evaluation on a model.""" """
CLI to run training on a model
"""
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -8,48 +9,35 @@ import fire
from dotenv import load_dotenv from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import (
from axolotl.cli.art import print_axolotl_text_art check_accelerate_default_config,
from axolotl.cli.checks import check_accelerate_default_config, check_user_token check_user_token,
from axolotl.cli.config import load_cfg load_cfg,
from axolotl.common.datasets import load_datasets, load_preference_datasets load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.evaluate import evaluate from axolotl.evaluate import evaluate
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger("axolotl.cli.evaluate")
def do_evaluate(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: def do_evaluate(cfg, cli_args) -> None:
"""
Evaluates a `transformers` model by first loading the dataset(s) specified in the
`axolotl` config, and then calling `axolotl.evaluate.evaluate`, which computes
evaluation metrics on the given dataset(s) and writes them to disk.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: CLI arguments.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
if cfg.rl: if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, dataset_meta=dataset_meta) evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_evaluate`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs) parser = HfArgumentParser(TrainerCliArgs)

View File

@@ -1,267 +1,32 @@
"""CLI to run inference on a trained model.""" """
CLI to run inference on a trained model
import importlib """
import logging
import sys
from pathlib import Path from pathlib import Path
from threading import Thread
from typing import Union from typing import Union
import fire import fire
import torch
import transformers import transformers
from dotenv import load_dotenv from dotenv import load_dotenv
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
from axolotl.cli.args import InferenceCliArgs from axolotl.cli import (
from axolotl.cli.art import print_axolotl_text_art do_inference,
from axolotl.cli.config import load_cfg do_inference_gradio,
from axolotl.cli.utils import load_model_and_tokenizer load_cfg,
from axolotl.utils.chat_templates import ( print_axolotl_text_art,
get_chat_template,
get_chat_template_from_config,
) )
from axolotl.utils.dict import DictDefault from axolotl.common.cli import TrainerCliArgs
LOG = logging.getLogger(__name__)
def get_multi_line_input() -> str: def do_cli(config: Union[Path, str] = Path("examples/"), gradio=False, **kwargs):
"""
Gets multi-line input from terminal.
Returns:
Possibly multi-line, possibly empty stdin input as a string.
"""
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
for line in sys.stdin:
instruction += line # pylint: disable=consider-using-join
return instruction
def do_inference(
*,
cfg: DictDefault,
cli_args: InferenceCliArgs,
):
"""
Runs inference on the command line in a loop. User input is accepted, a chat template
is (optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Inference-specific CLI arguments.
"""
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
while True:
print("=" * 80)
# support for multiline inputs
instruction = get_multi_line_input()
if not instruction:
return
if prompter_module:
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
print("=" * 40)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=1024,
temperature=0.9,
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextStreamer(tokenizer)
generated = model.generate(
inputs=batch["input_ids"].to(cfg.device),
generation_config=generation_config,
streamer=streamer,
)
print("=" * 40)
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
def do_inference_gradio(
*,
cfg: DictDefault,
cli_args: InferenceCliArgs,
):
"""
Runs inference in a Gradio interface. User input is accepted, a chat template is
(optionally) applied, and the model specified in the `axolotl` config is used to
generate completions according to a default generation config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Inference-specific CLI arguments.
"""
import gradio as gr
model, tokenizer = load_model_and_tokenizer(cfg=cfg, inference=True)
prompter = cli_args.prompter
prompter_module = None
chat_template_str = None
if prompter:
prompter_module = getattr(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)
def generate(instruction):
if not instruction:
return
if prompter_module:
# pylint: disable=stop-iteration-return
prompt: str = next(
prompter_module().build_prompt(instruction=instruction.strip("\n"))
)
else:
prompt = instruction.strip()
if chat_template_str:
batch = tokenizer.apply_chat_template(
[
{
"role": "user",
"content": prompt,
}
],
return_tensors="pt",
add_special_tokens=True,
add_generation_prompt=True,
chat_template=chat_template_str,
tokenize=True,
return_dict=True,
)
else:
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
model.eval()
with torch.no_grad():
generation_config = GenerationConfig(
repetition_penalty=1.1,
max_new_tokens=cfg.get("gradio_max_new_tokens", 1024),
temperature=cfg.get("gradio_temperature", 0.9),
top_p=0.95,
top_k=40,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
do_sample=True,
use_cache=True,
return_dict_in_generate=True,
output_attentions=False,
output_hidden_states=False,
output_scores=False,
)
streamer = TextIteratorStreamer(tokenizer)
generation_kwargs = {
"inputs": batch["input_ids"].to(cfg.device),
"attention_mask": batch["attention_mask"].to(cfg.device),
"generation_config": generation_config,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
all_text = ""
for new_text in streamer:
all_text += new_text
yield all_text
demo = gr.Interface(
fn=generate,
inputs="textbox",
outputs="text",
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
)
demo.queue().launch(
show_api=False,
share=cfg.get("gradio_share", True),
server_name=cfg.get("gradio_server_name", "127.0.0.1"),
server_port=cfg.get("gradio_server_port", None),
)
def do_cli(
config: Union[Path, str] = Path("examples/"), gradio: bool = False, **kwargs
) -> None:
"""
Parses axolotl config, CLI args, and calls `do_inference` or `do_inference_gradio`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, inference=True, **kwargs) parsed_cfg = load_cfg(config, inference=True, **kwargs)
parsed_cfg.sample_packing = False parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser(InferenceCliArgs) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
parsed_cli_args.inference = True
if gradio: if gradio:
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args) do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)

View File

@@ -1,20 +1,19 @@
"""Click CLI definitions for various axolotl commands.""" """CLI definition for various axolotl commands."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import subprocess # nosec B404 import subprocess # nosec B404
from typing import Optional from typing import Optional
import click import click
import axolotl import axolotl
from axolotl.cli.args import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.cli.utils import ( from axolotl.cli.utils import (
add_options_from_config, add_options_from_config,
add_options_from_dataclass, add_options_from_dataclass,
build_command, build_command,
fetch_from_github, fetch_from_github,
filter_none_kwargs,
) )
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.integrations.lm_eval.cli import lm_eval
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
@@ -27,58 +26,59 @@ def cli():
@cli.command() @cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs) @add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs def preprocess(config: str, cloud: Optional[str] = None, **kwargs):
def preprocess(config: str, **kwargs) -> None: """Preprocess datasets before training."""
""" kwargs = {k: v for k, v in kwargs.items() if v is not None}
Preprocess datasets before training.
Args: if cloud:
config: Path to `axolotl` config YAML file. from axolotl.cli.cloud import do_cli_preprocess
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs) do_cli_preprocess(cloud_config=cloud, config=config)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(config: str, accelerate: bool, **kwargs) -> None:
"""
Train or fine-tune a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else: else:
from axolotl.cli.train import do_cli from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs) do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=True,
help="Use accelerate launch for multi-GPU training",
)
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def train(config: str, accelerate: bool, cloud: Optional[str], **kwargs):
"""Train or fine-tune a model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf()
from axolotl.cli.cloud import do_cli_train
if accelerate:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=True)
else:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.train"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
if cloud:
do_cli_train(cloud_config=cloud, config=config, accelerate=False)
else:
from axolotl.cli.train import do_cli
do_cli(config=config, **kwargs)
@cli.command() @cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option( @click.option(
@@ -88,17 +88,10 @@ def train(config: str, accelerate: bool, **kwargs) -> None:
) )
@add_options_from_dataclass(EvaluateCliArgs) @add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs def evaluate(config: str, accelerate: bool, **kwargs):
def evaluate(config: str, accelerate: bool, **kwargs) -> None: """Evaluate a model."""
""" kwargs = {k: v for k, v in kwargs.items() if v is not None}
Evaluate a model.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate: if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config: if config:
@@ -118,33 +111,81 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
default=False, default=False,
help="Use accelerate launch for multi-GPU inference", help="Use accelerate launch for multi-GPU inference",
) )
@click.option(
"--lora-model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing LoRA model",
)
@click.option(
"--base-model",
type=click.Path(exists=True, path_type=str),
help="Path to base model for non-LoRA models",
)
@click.option("--gradio", is_flag=True, help="Launch Gradio interface") @click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@click.option("--load-in-8bit", is_flag=True, help="Load model in 8-bit mode")
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs def inference(
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: config: str,
""" accelerate: bool,
Run inference with a trained model. lora_model_dir: Optional[str] = None,
base_model: Optional[str] = None,
**kwargs,
):
"""Run inference with a trained model."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
del kwargs["inference"] # interferes with inference.do_cli
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if base_model:
kwargs["base_model"] = base_model
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
gradio: Whether to use Gradio browser interface or command line for inference.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate: if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"] base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.inference"]
if config: if config:
base_cmd.append(config) base_cmd.append(config)
if gradio:
base_cmd.append("--gradio")
cmd = build_command(base_cmd, kwargs) cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603 subprocess.run(cmd, check=True) # nosec B603
else: else:
from axolotl.cli.inference import do_cli from axolotl.cli.inference import do_cli
do_cli(config=config, gradio=gradio, **kwargs) do_cli(config=config, **kwargs)
@cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option(
"--accelerate/--no-accelerate",
default=False,
help="Use accelerate launch for multi-GPU operations",
)
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing model weights to shard",
)
@click.option(
"--save-dir",
type=click.Path(path_type=str),
help="Directory to save sharded weights",
)
@add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig)
def shard(config: str, accelerate: bool, **kwargs):
"""Shard model weights."""
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.shard"]
if config:
base_cmd.append(config)
cmd = build_command(base_cmd, kwargs)
subprocess.run(cmd, check=True) # nosec B603
else:
from axolotl.cli.shard import do_cli
do_cli(config=config, **kwargs)
@cli.command() @cli.command()
@@ -154,19 +195,20 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
default=True, default=True,
help="Use accelerate launch for weight merging", help="Use accelerate launch for weight merging",
) )
@click.option(
"--model-dir",
type=click.Path(exists=True, path_type=str),
help="Directory containing sharded weights",
)
@click.option(
"--save-path", type=click.Path(path_type=str), help="Path to save merged weights"
)
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs):
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: """Merge sharded FSDP model weights."""
""" kwargs = {k: v for k, v in kwargs.items() if v is not None}
Merge sharded FSDP model weights.
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
if accelerate: if accelerate:
base_cmd = [ base_cmd = [
"accelerate", "accelerate",
@@ -186,19 +228,28 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
@cli.command() @cli.command()
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(TrainerCliArgs) @click.option(
@add_options_from_config(AxolotlInputConfig) "--lora-model-dir",
@filter_none_kwargs type=click.Path(exists=True, path_type=str),
def merge_lora(config: str, **kwargs) -> None: help="Directory containing the LoRA model to merge",
""" )
Merge trained LoRA adapters into a base model. @click.option(
"--output-dir",
type=click.Path(path_type=str),
help="Directory to save the merged model",
)
def merge_lora(
config: str,
lora_model_dir: Optional[str] = None,
output_dir: Optional[str] = None,
):
"""Merge a trained LoRA into a base model"""
kwargs = {}
if lora_model_dir:
kwargs["lora_model_dir"] = lora_model_dir
if output_dir:
kwargs["output_dir"] = output_dir
Args:
config: Path to `axolotl` config YAML file.
accelerate: Whether to use `accelerate` launcher.
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options.
"""
from axolotl.cli.merge_lora import do_cli from axolotl.cli.merge_lora import do_cli
do_cli(config=config, **kwargs) do_cli(config=config, **kwargs)
@@ -207,21 +258,20 @@ def merge_lora(config: str, **kwargs) -> None:
@cli.command() @cli.command()
@click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"]))
@click.option("--dest", help="Destination directory") @click.option("--dest", help="Destination directory")
def fetch(directory: str, dest: Optional[str]) -> None: def fetch(directory: str, dest: Optional[str]):
""" """
Fetch example configs or other resources. Fetch example configs or other resources.
Available directories: Available directories:
- examples: Example configuration files - examples: Example configuration files
- deepspeed_configs: DeepSpeed configuration files - deepspeed_configs: DeepSpeed configuration files
Args:
directory: One of `examples`, `deepspeed_configs`.
dest: Optional destination directory.
""" """
fetch_from_github(f"{directory}/", dest) fetch_from_github(f"{directory}/", dest)
cli.add_command(lm_eval)
def main(): def main():
cli() cli()

View File

@@ -1,6 +1,6 @@
"""CLI to merge a trained LoRA into a base model.""" """
CLI to run merge a trained LoRA into a base model
import logging """
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -8,58 +8,14 @@ import fire
import transformers import transformers
from dotenv import load_dotenv from dotenv import load_dotenv
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import do_merge_lora, load_cfg, print_axolotl_text_art
from axolotl.cli.art import print_axolotl_text_art from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.config import load_cfg
from axolotl.cli.utils import load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__)
def do_merge_lora(*, cfg: DictDefault) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
Calls `transformers`' `merge_and_unload` on the model given in the `axolotl` config
along with the LoRA adapters to combine them into a single base model.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
"""
print_axolotl_text_art()
model, tokenizer = load_model_and_tokenizer(cfg=cfg)
safe_serialization = cfg.save_safetensors is True
LOG.info("Running merge of LoRA with base model...")
model = model.merge_and_unload(progressbar=True)
model.to(dtype=cfg.torch_dtype)
model.generation_config.do_sample = True
if cfg.local_rank == 0:
LOG.info(f"Saving merged model to: {str(Path(cfg.output_dir) / 'merged')}...")
model.save_pretrained(
str(Path(cfg.output_dir) / "merged"),
safe_serialization=safe_serialization,
progressbar=True,
)
tokenizer.save_pretrained(str(Path(cfg.output_dir) / "merged"))
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_merge_lora`. Note that various
config values will be overwritten to allow the LoRA merge logic to work as expected
(`load_in_8bit=False`, `load_in4bit=False`, `flash_attention=False`, etc.).
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
Raises:
ValueError: If target directory for LoRA merged model does not exist.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
parser = transformers.HfArgumentParser(TrainerCliArgs) print_axolotl_text_art()
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
@@ -90,7 +46,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
parsed_cfg.fsdp = None parsed_cfg.fsdp = None
parsed_cfg.fsdp_config = None parsed_cfg.fsdp_config = None
do_merge_lora(cfg=parsed_cfg) do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -1,5 +1,6 @@
"""CLI to merge sharded FSDP model checkpoints into a single combined checkpoint.""" """
This module provides a CLI to merge sharded FSDP model checkpoints into a single combined checkpoint
"""
import json import json
import logging import logging
import os import os
@@ -24,15 +25,16 @@ from huggingface_hub import split_torch_state_dict_into_shards
from safetensors.torch import save_file as safe_save_file from safetensors.torch import save_file as safe_save_file
from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner from torch.distributed.checkpoint.format_utils import _EmptyStateDictLoadPlanner
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.cli.art import print_axolotl_text_art from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.config import load_cfg
LOG = logging.getLogger(__name__) LOG = logging.getLogger("axolotl.cli.merge_sharded_fsdp_weights")
class BFloat16CastPlanner(_EmptyStateDictLoadPlanner): class BFloat16CastPlanner(_EmptyStateDictLoadPlanner):
"""A custom planner to cast tensors to bfloat16 on the fly during loading.""" """
A custom planner to cast tensors to bfloat16 on the fly during loading.
"""
def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument def commit_tensor(self, read_item, tensor): # pylint: disable=unused-argument
tensor.copy_(tensor.to(torch.bfloat16)) tensor.copy_(tensor.to(torch.bfloat16))
@@ -43,19 +45,11 @@ def _distributed_checkpoint_to_merged_weights(
save_path: str, save_path: str,
safe_serialization: bool = False, safe_serialization: bool = False,
max_shard_size: str = "5GB", max_shard_size: str = "5GB",
) -> Path: ):
""" """
Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`. Will Passthrough to `torch.distributed.checkpoint.format_utils.dcp_to_torch_save`
save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
Args: Will save under `save_path` as either `model.safetensors` or `pytorch_model.bin`.
checkpoint_dir: Directory where distributed checkpoint is saved.
save_path: Path to save model to.
safe_serialization: Whether to save in safetensors format.
max_shard_size: Max size of model shards to save.
Returns:
Path where model is saved.
""" """
state_dict: Dict = {} state_dict: Dict = {}
@@ -85,7 +79,6 @@ def _distributed_checkpoint_to_merged_weights(
state_dict_split = split_torch_state_dict_into_shards( state_dict_split = split_torch_state_dict_into_shards(
state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
) )
# Save index if sharded # Save index if sharded
index = None index = None
if state_dict_split.is_sharded: if state_dict_split.is_sharded:
@@ -142,9 +135,6 @@ def merge_fsdp_weights(
Whether to save the merged weights with safetensors (recommended). Whether to save the merged weights with safetensors (recommended).
remove_checkpoint_dir (`bool`, *optional*, defaults to `False`): remove_checkpoint_dir (`bool`, *optional*, defaults to `False`):
Whether to remove the checkpoint directory after merging. Whether to remove the checkpoint directory after merging.
Raises:
ValueError: If torch version < 2.3.0, or if `checkpoint_dir` does not exist.
""" """
checkpoint_dir_ = Path(checkpoint_dir) checkpoint_dir_ = Path(checkpoint_dir)
from accelerate.state import PartialState from accelerate.state import PartialState
@@ -188,21 +178,18 @@ def merge_fsdp_weights(
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
"""
Parses `axolotl` config, CLI args, and calls `merge_fsdp_weights`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parser = transformers.HfArgumentParser(TrainerCliArgs) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
parsed_cli_args.merge_lora = True parsed_cli_args.merge_lora = True
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg = load_cfg(
config,
**kwargs,
)
fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0" fsdp_dir = Path(parsed_cfg.output_dir) / "pytorch_model_fsdp_0"
merge_fsdp_weights( merge_fsdp_weights(

View File

@@ -1,5 +1,6 @@
"""CLI to run preprocessing of a dataset.""" """
CLI to run training on a model
"""
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
@@ -12,31 +13,34 @@ from colorama import Fore
from dotenv import load_dotenv from dotenv import load_dotenv
from transformers import AutoModelForCausalLM from transformers import AutoModelForCausalLM
from axolotl.cli.args import PreprocessCliArgs from axolotl.cli import (
from axolotl.cli.art import print_axolotl_text_art check_accelerate_default_config,
from axolotl.cli.checks import check_accelerate_default_config, check_user_token check_user_token,
from axolotl.cli.config import load_cfg load_cfg,
load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching from axolotl.utils.trainer import disable_datasets_caching
LOG = logging.getLogger(__name__) LOG = logging.getLogger("axolotl.cli.preprocess")
def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
""" # pylint: disable=duplicate-code
Preprocesses dataset specified in axolotl config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Preprocessing-specific CLI arguments.
"""
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not cfg.dataset_prepared_path: if not parsed_cfg.dataset_prepared_path:
msg = ( msg = (
Fore.RED Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, " + "preprocess CLI called without dataset_prepared_path set, "
@@ -44,16 +48,16 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
+ Fore.RESET + Fore.RESET
) )
LOG.warning(msg) LOG.warning(msg)
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching(): with disable_datasets_caching():
if cfg.rl: if parsed_cfg.rl: # and parsed_cfg.rl != "orpo":
load_preference_datasets(cfg=cfg, cli_args=cli_args) load_rl_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
load_datasets(cfg=cfg, cli_args=cli_args) load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if cli_args.download: if parsed_cli_args.download:
model_name = cfg.base_model model_name = parsed_cfg.base_model
with warnings.catch_warnings(): with warnings.catch_warnings():
# there are a bunch of useless UserWarnings about # there are a bunch of useless UserWarnings about
# "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model" # "copying from a non-meta parameter in the checkpoint to a meta parameter in the current model"
@@ -70,30 +74,11 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
LOG.info( LOG.info(
Fore.GREEN Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {cfg.dataset_prepared_path}`" + f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ Fore.RESET + Fore.RESET
) )
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_preprocess`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.is_preprocess = True
parser = transformers.HfArgumentParser(PreprocessCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_preprocess(parsed_cfg, parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
fire.Fire(do_cli) fire.Fire(do_cli)

45
src/axolotl/cli/shard.py Normal file
View File

@@ -0,0 +1,45 @@
"""
CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
from dotenv import load_dotenv
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.scripts")
def shard(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
model, _ = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
safe_serialization = cfg.save_safetensors is True
LOG.debug("Re-saving model w/ sharding")
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
parsed_cli_args.shard = True
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
if __name__ == "__main__":
load_dotenv()
fire.Fire(do_cli)

View File

@@ -1,5 +1,6 @@
"""CLI to run training on a model.""" """
CLI to run training on a model
"""
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
@@ -8,38 +9,42 @@ import fire
from dotenv import load_dotenv from dotenv import load_dotenv
from transformers.hf_argparser import HfArgumentParser from transformers.hf_argparser import HfArgumentParser
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import (
from axolotl.cli.art import print_axolotl_text_art check_accelerate_default_config,
from axolotl.cli.checks import check_accelerate_default_config, check_user_token check_user_token,
from axolotl.cli.config import load_cfg load_cfg,
from axolotl.common.datasets import load_datasets, load_preference_datasets load_datasets,
load_rl_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import TrainerCliArgs
from axolotl.integrations.base import PluginManager from axolotl.integrations.base import PluginManager
from axolotl.train import train from axolotl.train import train
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger(__name__) LOG = logging.getLogger("axolotl.cli.train")
def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
""" # pylint: disable=duplicate-code
Trains a `transformers` model by first loading the dataset(s) specified in the parsed_cfg = load_cfg(config, **kwargs)
`axolotl` config, and then calling `axolotl.train.train`. Also runs the plugin parser = HfArgumentParser((TrainerCliArgs))
manager's `post_train_unload` once training completes. parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
return do_train(parsed_cfg, parsed_cli_args)
Args:
cfg: Dictionary mapping `axolotl` config keys to values. def do_train(cfg, cli_args) -> None:
cli_args: Training-specific CLI arguments.
"""
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
if cfg.rl: if cfg.rl: # and cfg.rl != "orpo":
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer = train(cfg=cfg, dataset_meta=dataset_meta) model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
del model del model
@@ -48,24 +53,6 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs) -> None:
plugin_manager.post_train_unload(cfg) plugin_manager.post_train_unload(cfg)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
"""
Parses `axolotl` config, CLI args, and calls `do_train`.
Args:
config: Path to `axolotl` config YAML file.
kwargs: Additional keyword arguments to override config file values.
"""
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser(TrainerCliArgs)
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
do_train(parsed_cfg, parsed_cli_args)
if __name__ == "__main__": if __name__ == "__main__":
load_dotenv() load_dotenv()
fire.Fire(do_cli) fire.Fire(do_cli)

View File

@@ -1,84 +1,32 @@
"""Utility methods for axolotl CLI.""" """Utility methods for axoltl CLI."""
import concurrent.futures import concurrent.futures
import dataclasses import dataclasses
import hashlib import hashlib
import json import json
import logging import logging
import typing
from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
from typing import Any, Callable, Type, Union, get_args, get_origin from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
import click import click
import requests import requests
from pydantic import BaseModel from pydantic import BaseModel
from transformers import PreTrainedModel, PreTrainedTokenizer, PreTrainedTokenizerFast
from axolotl.logging_config import configure_logging LOG = logging.getLogger("axolotl.cli.utils")
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger(__name__)
def strip_optional_type(field_type: type | typing._SpecialForm | None): def add_options_from_dataclass(config_class: Type[Any]):
""" """Create Click options from the fields of a dataclass."""
Extracts the non-`None` type from an `Optional` / `Union` type.
Args: def decorator(function):
field_type: Type of field for Axolotl CLI command.
Returns:
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
return field_type
def filter_none_kwargs(func: Callable) -> Callable:
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> Callable:
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
"""
Create Click options from the fields of a dataclass.
Args:
config_class: Dataclass with fields to parse from the CLI.
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process dataclass fields in reverse order for correct option ordering # Process dataclass fields in reverse order for correct option ordering
for field in reversed(dataclasses.fields(config_class)): for field in reversed(dataclasses.fields(config_class)):
field_type = strip_optional_type(field.type) field_type = field.type
if get_origin(field_type) is Union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
if field_type == bool: if field_type == bool:
field_name = field.name.replace("_", "-") field_name = field.name.replace("_", "-")
@@ -96,29 +44,18 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
default=field.default, default=field.default,
help=field.metadata.get("description"), help=field.metadata.get("description"),
)(function) )(function)
return function return function
return decorator return decorator
def add_options_from_config(config_class: Type[BaseModel]) -> Callable: def add_options_from_config(config_class: Type[BaseModel]):
""" """Create Click options from the fields of a Pydantic model."""
Create Click options from the fields of a Pydantic model.
Args: def decorator(function):
config_class: PyDantic model with fields to parse from the CLI
Returns:
Function decorator for Axolotl CLI command.
"""
def decorator(function: Callable) -> Callable:
# Process model fields in reverse order for correct option ordering # Process model fields in reverse order for correct option ordering
for name, field in reversed(config_class.model_fields.items()): for name, field in reversed(config_class.model_fields.items()):
field_type = strip_optional_type(field.annotation) if field.annotation == bool:
if field_type == bool:
field_name = name.replace("_", "-") field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}" option_name = f"--{field_name}/--no-{field_name}"
function = click.option( function = click.option(
@@ -129,23 +66,13 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
function = click.option( function = click.option(
option_name, default=None, help=field.description option_name, default=None, help=field.description
)(function) )(function)
return function return function
return decorator return decorator
def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]: def build_command(base_cmd: List[str], options: Dict[str, Any]) -> List[str]:
""" """Build command list from base command and options."""
Build command list from base command and options.
Args:
base_cmd: Command without options.
options: Options to parse and append to base command.
Returns:
List of strings giving shell command.
"""
cmd = base_cmd.copy() cmd = base_cmd.copy()
for key, value in options.items(): for key, value in options.items():
@@ -165,18 +92,18 @@ def build_command(base_cmd: list[str], options: dict[str, Any]) -> list[str]:
def download_file( def download_file(
file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str file_info: tuple, raw_base_url: str, dest_path: Path, dir_prefix: str
) -> tuple[str, str]: ) -> Tuple[str, str]:
""" """
Download a single file and return its processing status. Download a single file and return its processing status.
Args: Args:
file_info: Tuple of (file_path, remote_sha). file_info: Tuple of (file_path, remote_sha)
raw_base_url: Base URL for raw GitHub content. raw_base_url: Base URL for raw GitHub content
dest_path: Local destination directory. dest_path: Local destination directory
dir_prefix: Directory prefix to filter files. dir_prefix: Directory prefix to filter files
Returns: Returns:
Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'. Tuple of (file_path, status) where status is 'new', 'updated', or 'unchanged'
""" """
file_path, remote_sha = file_info file_path, remote_sha = file_info
raw_url = f"{raw_base_url}/{file_path}" raw_url = f"{raw_base_url}/{file_path}"
@@ -218,17 +145,16 @@ def download_file(
def fetch_from_github( def fetch_from_github(
dir_prefix: str, dest_dir: str | None = None, max_workers: int = 5 dir_prefix: str, dest_dir: Optional[str] = None, max_workers: int = 5
) -> None: ) -> None:
""" """
Sync files from a specific directory in the GitHub repository. Sync files from a specific directory in the GitHub repository.
Only downloads files that don't exist locally or have changed. Only downloads files that don't exist locally or have changed.
Args: Args:
dir_prefix: Directory prefix to filter files (e.g., 'examples/', dir_prefix: Directory prefix to filter files (e.g., 'examples/', 'deepspeed_configs/')
'deepspeed_configs/'). dest_dir: Local destination directory
dest_dir: Local destination directory. max_workers: Maximum number of concurrent downloads
max_workers: Maximum number of concurrent downloads.
""" """
api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1" api_url = "https://api.github.com/repos/axolotl-ai-cloud/axolotl/git/trees/main?recursive=1"
raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main" raw_base_url = "https://raw.githubusercontent.com/axolotl-ai-cloud/axolotl/main"
@@ -253,7 +179,7 @@ def fetch_from_github(
dest_path = Path(dest_dir) if dest_dir else default_dest dest_path = Path(dest_dir) if dest_dir else default_dest
# Keep track of processed files for summary # Keep track of processed files for summary
files_processed: dict[str, list[str]] = { files_processed: Dict[str, List[str]] = {
"new": [], "new": [],
"updated": [], "updated": [],
"unchanged": [], "unchanged": [],
@@ -290,28 +216,3 @@ def fetch_from_github(
LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}") LOG.info(f"Unchanged files: {len(files_processed['unchanged'])}")
if files_processed["error"]: if files_processed["error"]:
LOG.info(f"Failed files: {len(files_processed['error'])}") LOG.info(f"Failed files: {len(files_processed['error'])}")
def load_model_and_tokenizer(
*,
cfg: DictDefault,
inference: bool = False,
) -> tuple[PreTrainedModel, PreTrainedTokenizer | PreTrainedTokenizerFast | Any]:
"""
Helper function for loading a model and tokenizer specified in the given `axolotl`
config.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
inference: Boolean denoting inference mode.
Returns:
`transformers` model and tokenizer.
"""
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model...")
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

69
src/axolotl/common/cli.py Normal file
View File

@@ -0,0 +1,69 @@
"""
shared module for cli specific things
"""
import logging
from dataclasses import dataclass, field
from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
configure_logging()
LOG = logging.getLogger("axolotl.common.cli")
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True)
@dataclass
class TrainerCliArgs:
"""
dataclass representing the various non-training arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
inference: bool = field(default=False)
merge_lora: bool = field(default=False)
prompter: Optional[str] = field(default=None)
shard: bool = field(default=False)
@dataclass
class EvaluateCliArgs:
"""
dataclass representing the various evaluation arguments
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=0)
def load_model_and_tokenizer(
*,
cfg: DictDefault,
cli_args: TrainerCliArgs,
):
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
tokenizer = load_tokenizer(cfg)
LOG.info("loading model and (optionally) peft_config...")
inference = getattr(cli_args, "inference", False)
model, _ = load_model(cfg, tokenizer, inference=inference)
return model, tokenizer

View File

@@ -1,140 +0,0 @@
"""Dataset loading utilities."""
import logging
import math
import random
from dataclasses import dataclass
from typing import Optional, Union
from datasets import Dataset
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.cli.args import PreprocessCliArgs, TrainerCliArgs
from axolotl.utils.data import prepare_dataset
from axolotl.utils.data.rl import load_prepare_preference_datasets
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_processor, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
LOG = logging.getLogger(__name__)
@dataclass
class TrainDatasetMeta:
"""Dataclass with fields for training and validation datasets and metadata."""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def sample_dataset(dataset: Dataset, num_samples: int) -> Dataset:
"""
Randomly sample `num_samples` samples from `dataset`.
Args:
dataset: Dataset.
num_samples: Number of samples to return.
Returns:
Random sample (with replacement) of examples in `dataset`.
"""
return dataset.select(
[random.randrange(0, len(dataset) - 1) for _ in range(num_samples)] # nosec
)
def load_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets, calling
`axolotl.utils.data.prepare_dataset`. Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg,
tokenizer,
processor=processor,
)
if (
cli_args.debug
or cfg.debug
or cli_args.debug_text_only
or int(cli_args.debug_num_examples) > 0
):
LOG.info("check_dataset_labels...")
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
)
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)
def load_preference_datasets(
*,
cfg: DictDefault,
cli_args: Union[PreprocessCliArgs, TrainerCliArgs],
) -> TrainDatasetMeta:
"""
Loads one or more training or evaluation datasets for RL training using paired
preference data, calling `axolotl.utils.data.rl.load_prepare_preference_datasets`.
Optionally, logs out debug information.
Args:
cfg: Dictionary mapping `axolotl` config keys to values.
cli_args: Command-specific CLI arguments.
Returns:
Dataclass with fields for training and evaluation datasets and the computed
`total_num_steps`.
"""
train_dataset, eval_dataset = load_prepare_preference_datasets(cfg)
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...")
tokenizer = load_tokenizer(cfg)
train_samples = sample_dataset(train_dataset, cli_args.debug_num_examples)
check_dataset_labels(
train_samples,
tokenizer,
num_examples=cli_args.debug_num_examples,
text_only=cli_args.debug_text_only,
rl_mode=True,
)
return TrainDatasetMeta(
train_dataset=train_dataset,
eval_dataset=eval_dataset,
total_num_steps=total_num_steps,
)

View File

@@ -22,6 +22,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from packaging import version
from peft.optimizers import create_loraplus_optimizer from peft.optimizers import create_loraplus_optimizer
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
@@ -243,10 +244,6 @@ class AxolotlTrainingMixins:
default=None, default=None,
metadata={"help": "Scale the learning rate for the embedding layers."}, metadata={"help": "Scale the learning rate for the embedding layers."},
) )
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field( embedding_lr: Optional[float] = field(
default=None, default=None,
metadata={"help": "absolute learning rate for the embedding layers."}, metadata={"help": "absolute learning rate for the embedding layers."},
@@ -465,95 +462,11 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
) )
return super()._wrap_model(model, training=training, dataloader=dataloader) return super()._wrap_model(model, training=training, dataloader=dataloader)
def create_optimizer_grouped_parameters(self, opt_model, optimizer_kwargs):
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
lr_groups_lookup = {}
lr_groups_learning_rates = {}
if self.args.lr_groups:
for lr_group in self.args.lr_groups:
group_name = lr_group["name"]
group_modules = lr_group["modules"]
for module in group_modules:
lr_groups_lookup[module] = group_name
lr_groups_learning_rates[group_name] = lr_group["lr"]
params[f"to_weight_decay_{group_name}"] = {}
for name, param in opt_model.named_parameters():
if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
lr_group_modules = [
group_modules
for group_modules in lr_groups_lookup
if group_modules in name
]
if lr_groups_lookup and any(lr_group_modules):
lr_group_module = lr_group_modules[0]
group_name = lr_groups_lookup[lr_group_module]
params[f"to_weight_decay_{group_name}"][name] = param
else:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
for group_name, group_lr in lr_groups_learning_rates.items():
if params[f"to_weight_decay_{group_name}"]:
optimizer_grouped_parameters.append(
{
"params": list(
params[f"to_weight_decay_{group_name}"].values()
),
"weight_decay": self.args.weight_decay,
"lr": group_lr,
}
)
return optimizer_grouped_parameters
def create_optimizer(self): def create_optimizer(self):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.embedding_lr_scale is None and self.args.embedding_lr_scale is None
and self.args.embedding_lr is None and self.args.embedding_lr is None
and self.args.lr_groups is None
and self.args.alternate_optimizer and self.args.alternate_optimizer
not in [ not in [
"optimi_adamw", "optimi_adamw",
@@ -567,13 +480,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition if self.optimizer is None: # pylint: disable=access-member-before-definition
decay_parameters = self.get_decay_parameter_names(opt_model)
params = {
"to_weight_decay": {}, # LayerNorm and bias
"embeddings": {}, # lm_head, embed_tokens,
"no_weight_decay": {},
}
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args, self.args,
opt_model, opt_model,
) )
optimizer_grouped_parameters = self.create_optimizer_grouped_parameters(
opt_model, optimizer_kwargs for name, param in opt_model.named_parameters():
) if not param.requires_grad:
continue
if name.endswith("modules_to_save.default.weight") or any(
embed_name in name for embed_name in ["embed_tokens", "lm_head"]
):
params["embeddings"][name] = param
elif name in decay_parameters:
params["to_weight_decay"][name] = param
else:
params["no_weight_decay"][name] = param
optimizer_grouped_parameters = []
if params["to_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["to_weight_decay"].values()),
"weight_decay": self.args.weight_decay,
"lr": optimizer_kwargs["lr"],
}
)
if params["embeddings"]:
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
if self.args.embedding_lr_scale:
lr *= self.args.embedding_lr_scale # pylint: disable=invalid-name
elif self.args.embedding_lr:
lr = self.args.embedding_lr # pylint: disable=invalid-name
optimizer_grouped_parameters.append(
{
"params": list(params["embeddings"].values()),
"weight_decay": 0.0,
"lr": lr,
}
)
if params["no_weight_decay"]:
optimizer_grouped_parameters.append(
{
"params": list(params["no_weight_decay"].values()),
"weight_decay": 0.0,
"lr": optimizer_kwargs["lr"],
}
)
if self.args.loraplus_lr_ratio is not None: if self.args.loraplus_lr_ratio is not None:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
@@ -590,7 +549,6 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif ( elif (
self.args.embedding_lr_scale is not None self.args.embedding_lr_scale is not None
or self.args.embedding_lr is not None or self.args.embedding_lr is not None
or self.args.lr_groups is not None
): ):
self.optimizer = ( # pylint: disable=attribute-defined-outside-init self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs) optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
@@ -650,14 +608,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
self.state.train_batch_size or self.args.per_device_train_batch_size self.state.train_batch_size or self.args.per_device_train_batch_size
) )
batch_max_len = train_batch_size * self.args.max_seq_length batch_max_len = train_batch_size * self.args.max_seq_length
if self.args.curriculum_sampling:
sampler = SequentialSampler(self.train_dataset)
else:
sampler = RandomSampler(self.train_dataset)
return MultipackBatchSampler( return MultipackBatchSampler(
sampler, RandomSampler(self.train_dataset),
lengths=get_dataset_lengths(self.train_dataset), lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len, batch_max_len=batch_max_len,
@@ -1026,7 +978,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
logs[key] = torch.tensor(metrics).mean().item() logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval] del self._stored_metrics[train_eval]
return super().log(logs, start_time) if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
try:
return super().log(logs, start_time)
except TypeError:
return super().log(logs) # transformers<=4.46
return super().log(logs) # transformers<=4.46
def store_metrics( def store_metrics(
self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train" self, metrics: Dict[str, float], train_eval: Literal["train", "eval"] = "train"
@@ -1122,7 +1079,6 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.dataset_tags = dataset_tags self.dataset_tags = dataset_tags
self.optimizer = None self.optimizer = None
self.model_accepts_loss_kwargs = False
def create_optimizer(self): def create_optimizer(self):
if self.args.loraplus_lr_ratio is None: if self.args.loraplus_lr_ratio is None:
@@ -1211,6 +1167,22 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
torch.cuda.empty_cache() torch.cuda.empty_cache()
return loss return loss
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(DPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(DPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
""" """
@@ -1219,6 +1191,22 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
tag_names = ["axolotl", "orpo"] tag_names = ["axolotl", "orpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(ORPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(ORPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
""" """
@@ -1227,6 +1215,49 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# train metrics should have no prefix, eval should have 'eval_'
prefix = "eval_" if train_eval == "eval" else ""
# accumulate average metrics from sums and lengths
for split in ["chosen", "rejected"]:
if f"count/{split}" in self._stored_metrics[train_eval]:
count_sum = (
torch.Tensor(self._stored_metrics[train_eval][f"count/{split}"])
.sum()
.item()
)
for metric in ["rewards", "logps", "logits"]:
logs[f"{prefix}{metric}/{split}"] = (
torch.Tensor(
self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
)
.sum()
.item()
/ count_sum
)
# delete obsolete metric
del self._stored_metrics[train_eval][f"{metric}/{split}_sum"]
del self._stored_metrics[train_eval][f"count/{split}"]
# calculate reward margin
if f"{prefix}rewards/chosen" in logs and f"{prefix}rewards/rejected" in logs:
logs[f"{prefix}rewards/margins"] = (
logs[f"{prefix}rewards/chosen"] - logs[f"{prefix}rewards/rejected"]
)
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[f"{prefix}{key}"] = torch.Tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(KTOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(KTOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
""" """
@@ -1235,6 +1266,22 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
tag_names = ["axolotl", "cpo"] tag_names = ["axolotl", "cpo"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
# logs either has 'loss' or 'eval_loss'
train_eval = "train" if "loss" in logs else "eval"
# Add averaged stored metrics to logs
for key, metrics in self._stored_metrics[train_eval].items():
logs[key] = torch.tensor(metrics).mean().item()
del self._stored_metrics[train_eval]
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(CPOTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(CPOTrainer, self).log(logs) # pylint: disable=bad-super-call
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
""" """
@@ -1243,6 +1290,15 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"] tag_names = ["axolotl", "reward"]
def log(self, logs: Dict[str, float], start_time: Optional[float] = None) -> None:
# TODO remove once trl supports the updated to the Trainer.log method
if version.parse(transformers.__version__) >= version.parse("4.47.0.dev0"):
return super(RewardTrainer, self).log( # pylint: disable=bad-super-call
logs, start_time
)
# transformers<=4.46
return super(RewardTrainer, self).log(logs) # pylint: disable=bad-super-call
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
@@ -1708,7 +1764,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding ] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
training_arguments_kwargs["lr_groups"] = self.cfg.lr_groups
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"
@@ -1922,10 +1977,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
if self.cfg.micro_batch_size > 1:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
return None return None
if self.cfg.model_config_type == "mamba": if self.cfg.model_config_type == "mamba":

View File

@@ -9,6 +9,7 @@ from typing import Dict, Optional
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta from axolotl.train import TrainDatasetMeta
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
@@ -61,13 +62,16 @@ def evaluate_dataset(
return metrics return metrics
def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, float]: def evaluate(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Dict[str, float]:
""" """
Evaluate a model on training and validation datasets Evaluate a model on training and validation datasets
Args: Args:
cfg: Dictionary mapping `axolotl` config keys to values. cfg: Configuration dictionary
dataset_meta: Dataset metadata containing training and evaluation datasets. cli_args: Command line arguments
dataset_meta: Dataset metadata containing training and evaluation datasets
Returns: Returns:
Tuple containing: Tuple containing:
@@ -98,7 +102,9 @@ def evaluate(*, cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> Dict[str, f
# Load model # Load model
LOG.debug("loading model for evaluation...") LOG.debug("loading model for evaluation...")
model, _ = load_model(cfg, tokenizer, processor=processor) model, _ = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
# Set up trainer # Set up trainer
trainer = setup_trainer( trainer = setup_trainer(

View File

@@ -22,6 +22,13 @@ import inspect
import logging import logging
import sys import sys
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from ...utils.distributed import zero_only from ...utils.distributed import zero_only
@@ -39,13 +46,6 @@ class LigerPlugin(BasePlugin):
return "axolotl.integrations.liger.LigerArgs" return "axolotl.integrations.liger.LigerArgs"
def pre_model_load(self, cfg): def pre_model_load(self, cfg):
from liger_kernel.transformers.cross_entropy import LigerCrossEntropyLoss
from liger_kernel.transformers.functional import liger_cross_entropy
from liger_kernel.transformers.monkey_patch import MODEL_TYPE_TO_APPLY_LIGER_FN
from liger_kernel.transformers.rms_norm import LigerRMSNorm
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from liger_kernel.transformers.swiglu import LigerSwiGLUMLP
if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN: if cfg.model_config_type in MODEL_TYPE_TO_APPLY_LIGER_FN:
apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type] apply_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN[cfg.model_config_type]
liger_fn_sig = inspect.signature(apply_liger_fn) liger_fn_sig = inspect.signature(apply_liger_fn)

View File

@@ -2,9 +2,9 @@
Module for the Plugin for LM Eval Harness Module for the Plugin for LM Eval Harness
""" """
import subprocess # nosec import subprocess # nosec
from datetime import datetime
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401 from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
@@ -18,25 +18,19 @@ class LMEvalPlugin(BasePlugin):
return "axolotl.integrations.lm_eval.LMEvalArgs" return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg): def post_train_unload(self, cfg):
tasks = ",".join(cfg.lm_eval_tasks) if cfg.lm_eval_post_train:
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else "" # pylint: disable=duplicate-code
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16" for lm_eval_args in build_lm_eval_command(
output_path = cfg.output_dir cfg.lm_eval_tasks,
output_path += "" if cfg.output_dir.endswith("/") else "/" bfloat16=cfg.bfloat16 or cfg.bf16,
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") flash_attention=cfg.flash_attention,
subprocess.run( # nosec output_dir=cfg.output_dir,
[ batch_size=cfg.lm_eval_batch_size,
"lm_eval", wandb_project=cfg.wandb_project,
"--model", wandb_entity=cfg.wandb_entity,
"hf", model=cfg.lm_eval_model or cfg.hub_model_id,
"--model_args", ):
f"pretrained={cfg.output_dir}{fa2}{dtype}", subprocess.run( # nosec
"--tasks", lm_eval_args,
tasks, check=True,
"--batch_size", )
str(cfg.lm_eval_batch_size),
"--output_path",
output_path,
],
check=True,
)

View File

@@ -13,3 +13,5 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = [] lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8 lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True
lm_eval_model: Optional[str] = None

View File

@@ -0,0 +1,113 @@
"""
axolotl CLI for running lm_eval tasks
"""
import subprocess # nosec
from collections import defaultdict
from datetime import datetime
from typing import Optional
import click
import yaml
from axolotl.utils.dict import DictDefault
def build_lm_eval_command(
tasks: list[str],
bfloat16=True,
flash_attention=False,
output_dir="./",
batch_size=8,
wandb_project=None,
wandb_entity=None,
model=None,
revision=None,
apply_chat_template=None,
fewshot_as_multiturn=None,
):
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
for task in tasks:
num_fewshot = "-1"
task_parts = task.split(":")
task_name = task_parts[0]
if len(task_parts) == 2:
task_name, num_fewshot = task_parts
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
tasks_str = ",".join(tasks_list)
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
pretrained = "pretrained="
pretrained += model if model else output_dir
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
revision = f",revision={revision}" if revision else ""
output_path = output_dir
output_path += "" if output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
lm_eval_args = [
"lm_eval",
"--model",
"hf",
"--model_args",
f"{pretrained}{fa2}{dtype}{revision}",
"--tasks",
tasks_str,
"--batch_size",
str(batch_size),
"--output_path",
output_path,
]
wandb_args = []
if wandb_project:
wandb_args.append(f"project={wandb_project}")
if wandb_entity:
wandb_args.append(f"entity={wandb_entity}")
if wandb_args:
lm_eval_args.append("--wandb_args")
lm_eval_args.append(",".join(wandb_args))
if apply_chat_template:
lm_eval_args.append("--apply_chat_template")
if num_fewshot_val:
lm_eval_args.append("--num_fewshot")
lm_eval_args.append(str(num_fewshot_val))
if apply_chat_template and fewshot_as_multiturn:
lm_eval_args.append("--fewshot_as_multiturn")
yield lm_eval_args
@click.command()
@click.argument("config", type=click.Path(exists=True, path_type=str))
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
def lm_eval(config: str, cloud: Optional[str] = None):
"""
use lm eval to evaluate a trained language model
"""
if cloud:
from axolotl.cli.cloud import do_cli_lm_eval
do_cli_lm_eval(cloud_config=cloud, config=config)
else:
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
apply_chat_template=cfg.apply_chat_template,
fewshot_as_multiturn=cfg.fewshot_as_multiturn,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)

View File

@@ -6,7 +6,7 @@ import logging
from transformers import Trainer from transformers import Trainer
from axolotl.monkeypatch.utils import detab_code from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save") LOG = logging.getLogger("axolotl.monkeypatch.trainer_fsdp_save")

View File

@@ -0,0 +1,308 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/35128
"""
import inspect
import logging
from transformers import LlamaForCausalLM, Trainer
from transformers.modeling_flash_attention_utils import _flash_attention_forward
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = logging.getLogger("axolotl.monkeypatch.trainer_grad_accum")
ORIGINAL_CONTEXT_CODE = """
with self.compute_loss_context_manager():
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
"""
PATCHED_CONTEXT_CODE = """
with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
else:
loss = self.compute_loss(model, inputs)
"""
ORIGINAL_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs)
"""
PATCHED_LLAMA_FCLM_CODE = """
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# remove num_items_in_batch otherwise self.model attempts to pass it to flash_attention
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs[0]
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
loss = None
if labels is not None:
loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, num_items_in_batch=num_items_in_batch, **kwargs)
"""
def get_training_step_code() -> str:
training_step = inspect.getsource(
Trainer.training_step # pylint: disable=protected-access
)
return training_step
def check_training_step_is_patchable() -> bool:
training_step = get_training_step_code()
training_step, _ = detab_code(training_step)
return ORIGINAL_CONTEXT_CODE in training_step
def patch_training_step_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
training_step = get_training_step_code()
except OSError:
return
Trainer._original_training_step = training_step # pylint: disable=protected-access
training_step, _ = detab_code(training_step)
if ORIGINAL_CONTEXT_CODE not in training_step:
return
# assert (
# ORIGINAL_CONTEXT_CODE in training_step
# ), "Original training_step code not found"
training_step = training_step.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
training_step = training_step.replace(
"def training_step(",
"def _fixed_training_step(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_step:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_step, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching training_step")
Trainer.training_step = ( # pylint: disable=protected-access
_fixed_training_step # pylint: disable=undefined-variable # noqa: F821
)
def get_model_forward_code() -> str:
forward = inspect.getsource(
LlamaForCausalLM.forward # pylint: disable=protected-access
)
return forward
def check_forward_is_patchable() -> bool:
forward = get_model_forward_code()
forward, _ = detab_code(forward)
return ORIGINAL_LLAMA_FCLM_CODE in forward
def patch_forward_for_ga():
"""
monkeypatch for fixing the training loop for gradient accumulation
"""
try:
forward = get_model_forward_code()
except OSError:
return
LlamaForCausalLM._original_forward = forward # pylint: disable=protected-access
forward, _ = detab_code(forward)
if ORIGINAL_LLAMA_FCLM_CODE not in forward:
return
# assert ORIGINAL_LLAMA_FCLM_CODE in forward, "Original forward code not found"
forward = forward.replace(ORIGINAL_LLAMA_FCLM_CODE, PATCHED_LLAMA_FCLM_CODE)
forward = forward.replace(
"def forward(",
"def _fixed_forward(",
1,
)
# load imports necessary
import transformers.models.llama.modeling_llama
items_to_import = []
for item in dir(transformers.models.llama.modeling_llama):
if item in forward:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.models.llama.modeling_llama import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(forward, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching forward")
LlamaForCausalLM.forward = ( # pylint: disable=protected-access
_fixed_forward # pylint: disable=undefined-variable # noqa: F821
)
ORIGINAL_TRAINER_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
PATCHED_TRAINER_CODE = """
disable_deepspeed_no_sync = (
self.accelerator.distributed_type == DistributedType.DEEPSPEED
# and self.accelerator.deepspeed_engine_wrapped.engine.zero_optimization_partition_gradients()
)
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1 and not disable_deepspeed_no_sync
else contextlib.nullcontext
)
with context():
tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
training_loop = get_training_loop_code()
training_loop, _ = detab_code(training_loop)
return ORIGINAL_TRAINER_CODE in training_loop
def patch_training_loop_for_deepspeed_0_16_x():
"""
monkeypatch for fixing the training loop for deepspeed GA
see https://github.com/huggingface/transformers/pull/35157
"""
try:
training_loop = get_training_loop_code()
except OSError:
return
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
training_loop
)
training_loop, _ = detab_code(training_loop)
if ORIGINAL_TRAINER_CODE not in training_loop:
return
training_loop = training_loop.replace(ORIGINAL_TRAINER_CODE, PATCHED_TRAINER_CODE)
training_loop = training_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in training_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(training_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop for fsdp optimizer save")
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)
def patch_flash_attention_forward():
"""
monkeypatch for fixing the forward pass for flash attention to ignore num_items_in_batch
"""
import transformers.modeling_flash_attention_utils
def proxy_flash_attention_forward(*args, **kwargs):
kwargs.pop("num_items_in_batch", None)
return _flash_attention_forward(*args, **kwargs)
transformers.modeling_flash_attention_utils._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)
transformers.models.llama.modeling_llama._flash_attention_forward = ( # pylint: disable=protected-access
proxy_flash_attention_forward
)

View File

@@ -1,67 +0,0 @@
"""
see https://github.com/huggingface/transformers/pull/35834
"""
import logging
from functools import partial
from typing import Optional
import torch
logger = logging.getLogger(__name__)
def fixed_fa_peft_integration_check(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
target_dtype: Optional[torch.dtype] = None,
preferred_dtype: Optional[torch.dtype] = None,
):
"""
PEFT usually casts the layer norms in float32 for training stability reasons
therefore the input hidden states gets silently casted in float32. Hence, we need
cast them back in float16 / bfloat16 just to be sure everything works as expected.
This might slowdown training & inference so it is recommended to not cast the LayerNorms!
Args:
query (`torch.Tensor`):
Input query states to be passed to Flash Attention API
key (`torch.Tensor`):
Input key states to be passed to Flash Attention API
value (`torch.Tensor`):
Input value states to be passed to Flash Attention API
target_dtype (`torch.dtype`, *optional*):
The dtype to convert the attention tensors to. Conversion can be ignored by
not providing the target dtype.
preferred_dtype (`torch.dtype`, *optional*):
The preferred dtype to convert the attention tensors to regardless of the
target dtype.
"""
if target_dtype is None and preferred_dtype is None:
return query, key, value
if preferred_dtype and target_dtype != preferred_dtype:
target_dtype = preferred_dtype
# check if any of query, key, or value are in float32. If so, cast them back to target dtype.
if any(module.dtype == torch.float32 for module in [query, key, value]):
logger.warning_once(
f"The input hidden states seems to be silently casted in float32, this might be related to"
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in"
f" {target_dtype}."
)
query = query.to(target_dtype)
key = key.to(target_dtype)
value = value.to(target_dtype)
return query, key, value
def patch_fa_peft_integration():
import transformers.modeling_flash_attention_utils
transformers.modeling_flash_attention_utils.fa_peft_integration_check = partial(
fixed_fa_peft_integration_check, preferred_dtype=None
)

View File

@@ -1,7 +1,9 @@
"""module for patching with unsloth optimizations""" """module for patching with unsloth optimizations"""
import inspect import inspect
import re
import types import types
from typing import Tuple
import torch import torch
from accelerate.logging import get_logger from accelerate.logging import get_logger
@@ -9,8 +11,6 @@ from peft import PeftModelForCausalLM
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import LlamaFlashAttention2 from transformers.models.llama.modeling_llama import LlamaFlashAttention2
from axolotl.monkeypatch.utils import detab_code
LOG = get_logger("axolotl.monkeypatch.unsloth") LOG = get_logger("axolotl.monkeypatch.unsloth")
ORIGINAL_QKV_CODE = """ ORIGINAL_QKV_CODE = """
@@ -93,6 +93,15 @@ def integrate_cross_entropy_loss_patch(model_type: str = "llama") -> None:
raise ValueError("Unsupported model type") raise ValueError("Unsupported model type")
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces
self_attn_lora_patched = False # pylint: disable=invalid-name self_attn_lora_patched = False # pylint: disable=invalid-name

View File

@@ -1,8 +1,7 @@
""" """
Shared utils for the monkeypatches Shared utils for the monkeypatches
""" """
import re from typing import Optional
from typing import Optional, Tuple
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@@ -224,12 +223,3 @@ def patched_prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype), mask_2d_to_4d(attention_mask, dtype=dtype),
*args, *args,
) )
def detab_code(code: str) -> Tuple[str, str]:
try:
spaces = re.match(r"([\s\t]{1,})", code).group(0)
code = re.sub(r"^" + spaces, "", code, flags=re.MULTILINE)
except AttributeError:
return code, ""
return code, spaces

View File

@@ -5,19 +5,21 @@ import os
import signal import signal
import sys import sys
import weakref import weakref
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Tuple, Union from typing import Optional, Tuple, Union
import torch import torch
import transformers.modelcard import transformers.modelcard
from accelerate.logging import get_logger from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model from accelerate.utils import save_fsdp_model
from datasets import Dataset
from peft import PeftModel from peft import PeftModel
from pkg_resources import get_distribution # type: ignore from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer from transformers import PreTrainedModel, PreTrainedTokenizer
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.datasets import TrainDatasetMeta from axolotl.common.cli import TrainerCliArgs
from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module from axolotl.contribs.lgpl.unsloth import ( # pylint: disable = no-name-in-module
fix_untrained_tokens, fix_untrained_tokens,
) )
@@ -37,11 +39,22 @@ src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir) sys.path.insert(0, src_dir)
configure_logging() configure_logging()
LOG = get_logger(__name__) LOG = get_logger("axolotl.train")
@dataclass
class TrainDatasetMeta:
"""
dataclass to capture the dataset specific options for training
"""
train_dataset: Dataset
eval_dataset: Optional[Dataset] = None
total_num_steps: Optional[int] = None
def train( def train(
*, cfg: DictDefault, dataset_meta: TrainDatasetMeta *, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]: ) -> Tuple[Union[PeftModel, PreTrainedModel], PreTrainedTokenizer]:
# Load tokenizer # Load tokenizer
LOG.debug( LOG.debug(
@@ -80,7 +93,9 @@ def train(
if cfg.adapter: if cfg.adapter:
msg += " and peft_config..." msg += " and peft_config..."
LOG.debug(msg) LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, processor=processor) model, peft_config = load_model(
cfg, tokenizer, processor=processor, inference=cli_args.inference
)
if model.generation_config is not None: if model.generation_config is not None:
model.generation_config.do_sample = True model.generation_config.do_sample = True
@@ -92,7 +107,9 @@ def train(
model_ref = None # explicit setting to None model_ref = None # explicit setting to None
else: else:
# load the model again for model_ref/baseline # load the model again for model_ref/baseline
model_ref, _ = load_model(cfg, tokenizer, reference_model=True) model_ref, _ = load_model(
cfg, tokenizer, inference=cli_args.inference, reference_model=True
)
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True

View File

@@ -43,7 +43,7 @@ def lisa_callback_factory(trainer: "AxolotlTrainer"):
getattr, self.layers_attribute.split("."), self.trainer.model getattr, self.layers_attribute.split("."), self.trainer.model
) )
LOG.info( LOG.info(
f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers * 100 / len(layers)}%) every {self.step_interval} steps" f"LISA will activate {self.n_layers}/{len(layers)} layers ({self.n_layers*100/len(layers)}%) every {self.step_interval} steps"
) )
def freeze_all_layers(self): def freeze_all_layers(self):

View File

@@ -128,8 +128,6 @@ class PretrainingDataset(BaseModel):
text_column: Optional[str] = "text" text_column: Optional[str] = "text"
type: Optional[str] = "pretrain" type: Optional[str] = "pretrain"
trust_remote_code: Optional[bool] = False trust_remote_code: Optional[bool] = False
data_files: Optional[str] = None
skip: Optional[int] = None
class UserDefinedPrompterType(BaseModel): class UserDefinedPrompterType(BaseModel):
@@ -147,14 +145,6 @@ class UserDefinedPrompterType(BaseModel):
field: Optional[str] = None field: Optional[str] = None
class LrGroup(BaseModel):
"""Custom learning rate group configuration"""
name: str
modules: List[str]
lr: float
class SFTDataset(BaseModel): class SFTDataset(BaseModel):
"""SFT configuration subset""" """SFT configuration subset"""
@@ -376,13 +366,6 @@ class LoraConfig(BaseModel):
loraplus_lr_embedding = float(loraplus_lr_embedding) loraplus_lr_embedding = float(loraplus_lr_embedding)
return loraplus_lr_embedding return loraplus_lr_embedding
@model_validator(mode="before")
@classmethod
def validate_lora_dropout(cls, data):
if data.get("adapter") is not None and data.get("lora_dropout") is None:
data["lora_dropout"] = 0.0
return data
class ReLoRAConfig(BaseModel): class ReLoRAConfig(BaseModel):
"""ReLoRA configuration subset""" """ReLoRA configuration subset"""
@@ -483,7 +466,6 @@ class HyperparametersConfig(BaseModel):
cosine_min_lr_ratio: Optional[float] = None cosine_min_lr_ratio: Optional[float] = None
cosine_constant_lr_ratio: Optional[float] = None cosine_constant_lr_ratio: Optional[float] = None
lr_div_factor: Optional[float] = None lr_div_factor: Optional[float] = None
lr_groups: Optional[List[LrGroup]] = None
adam_epsilon: Optional[float] = None adam_epsilon: Optional[float] = None
adam_beta1: Optional[float] = None adam_beta1: Optional[float] = None
@@ -715,12 +697,6 @@ class AxolotlInputConfig(
pad_to_sequence_len: Optional[bool] = None pad_to_sequence_len: Optional[bool] = None
curriculum_sampling: Optional[bool] = None curriculum_sampling: Optional[bool] = None
multipack_real_batches: Optional[bool] = None multipack_real_batches: Optional[bool] = None
pretraining_sample_concatenation: Optional[bool] = Field(
default=None,
json_schema_extra={
"description": "whether to soft pack/concatenate samples during pretraining",
},
)
batch_flattening: Optional[Union[Literal["auto"], bool]] = None batch_flattening: Optional[Union[Literal["auto"], bool]] = None

View File

@@ -5,7 +5,7 @@ from axolotl.utils.data.pretraining import ( # noqa: F401
encode_pretraining, encode_pretraining,
wrap_pretraining_dataset, wrap_pretraining_dataset,
) )
from axolotl.utils.data.rl import load_prepare_preference_datasets # noqa: F401 from axolotl.utils.data.rl import load_prepare_dpo_datasets # noqa: F401
from axolotl.utils.data.sft import ( # noqa: F401 from axolotl.utils.data.sft import ( # noqa: F401
get_dataset_wrapper, get_dataset_wrapper,
load_prepare_datasets, load_prepare_datasets,

View File

@@ -18,14 +18,10 @@ LOG = logging.getLogger("axolotl")
def encode_pretraining( def encode_pretraining(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: Dict[str, List]
max_tokens: int,
examples: Dict[str, List],
text_column: str = "text",
concatenate: bool = True,
) -> Dict[str, List]: ) -> Dict[str, List]:
res = tokenizer( res = tokenizer(
examples[text_column], examples["text"],
truncation=True, truncation=True,
max_length=max_tokens - 2, max_length=max_tokens - 2,
add_special_tokens=True, add_special_tokens=True,
@@ -34,13 +30,6 @@ def encode_pretraining(
input_ids = [torch.tensor(seq) for seq in res["input_ids"]] input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
targets = [torch.tensor(seq) for seq in res["input_ids"]] targets = [torch.tensor(seq) for seq in res["input_ids"]]
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
if not concatenate:
return {
"input_ids": [seq.tolist() for seq in input_ids],
"labels": [seq.tolist() for seq in targets],
"attention_mask": [seq.tolist() for seq in attention_mask],
}
new_input_ids = [] new_input_ids = []
new_labels = [] new_labels = []
new_attention_mask = [] new_attention_mask = []
@@ -191,7 +180,7 @@ def wrap_pretraining_dataset(
tokenizer, tokenizer,
return_tensors="pt", return_tensors="pt",
padding=True, padding=True,
pad_to_multiple_of=max_tokens, pad_to_multiple_of=max_tokens * batch_size,
multipack_attn=cfg.pretrain_multipack_attn, multipack_attn=cfg.pretrain_multipack_attn,
) )
encode = functools.partial( encode = functools.partial(
@@ -201,17 +190,13 @@ def wrap_pretraining_dataset(
max_seq_length=max_tokens, max_seq_length=max_tokens,
batch_size=batch_size, batch_size=batch_size,
multipack_attn=cfg.pretrain_multipack_attn, multipack_attn=cfg.pretrain_multipack_attn,
group_size=cfg.sample_packing_group_size,
bin_size=cfg.sample_packing_bin_size,
) )
# set this to 1 so downstream data_loader doesn't try to increase the batch again # set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1 cfg.micro_batch_size = 1
else: else:
encode = functools.partial( encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
encode_pretraining,
tokenizer,
max_tokens,
text_column=cfg.pretraining_dataset[0].text_column or "text",
concatenate=cfg.pretraining_sample_concatenation is True,
)
if cfg.shuffle_merged_datasets: if cfg.shuffle_merged_datasets:
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size) dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
@@ -245,7 +230,9 @@ def encode_packed_pretraining(
examples: Dict[str, List], examples: Dict[str, List],
max_seq_length: int = 2048, max_seq_length: int = 2048,
batch_size: int = 4, batch_size: int = 4,
multipack_attn: Optional[bool] = True, multipack_attn: Optional[bool] = False,
group_size: int = 100000,
bin_size: int = 200,
) -> Dict[str, List]: ) -> Dict[str, List]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
# tokenize all the examples # tokenize all the examples
@@ -256,9 +243,6 @@ def encode_packed_pretraining(
train_dataset, train_dataset,
max_seq_length, max_seq_length,
skip_position_ids=not multipack_attn, skip_position_ids=not multipack_attn,
# FIXME using attention mask unpad/pad with trainer and packed pretraining is broken atm
# workaround by using the position id logic for now in trainer
drop_attention_mask=multipack_attn,
) )
sampler = MultipackBatchSampler( sampler = MultipackBatchSampler(
@@ -266,6 +250,8 @@ def encode_packed_pretraining(
lengths=get_dataset_lengths(train_dataset), lengths=get_dataset_lengths(train_dataset),
batch_size=1, batch_size=1,
batch_max_len=batch_size * max_seq_length, batch_max_len=batch_size * max_seq_length,
group_size=group_size,
bin_size=bin_size,
drop_last=True, drop_last=True,
) )

View File

@@ -115,7 +115,7 @@ def drop_long_rl_seq(
raise ValueError("Unknown RL type") raise ValueError("Unknown RL type")
def load_prepare_preference_datasets(cfg): def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg): def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = [] split_datasets: List[Any] = []
for i, ds_cfg in enumerate(dataset_cfgs): for i, ds_cfg in enumerate(dataset_cfgs):

View File

@@ -88,19 +88,14 @@ def prepare_dataset(cfg, tokenizer, processor=None):
path = cfg.pretraining_dataset path = cfg.pretraining_dataset
split = "train" split = "train"
name = None name = None
data_files = None
skip = 0
if isinstance(cfg.pretraining_dataset, list) and isinstance( if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict cfg.pretraining_dataset[0], dict
): ):
path = cfg.pretraining_dataset[0]["path"] path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"] name = cfg.pretraining_dataset[0]["name"]
skip = cfg.pretraining_dataset[0]["skip"]
if "split" in cfg.pretraining_dataset[0]: if "split" in cfg.pretraining_dataset[0]:
split = cfg.pretraining_dataset[0]["split"] split = cfg.pretraining_dataset[0]["split"]
data_files = cfg.pretraining_dataset[0].get("data_files")
ds_wrapper_partial = functools.partial( ds_wrapper_partial = functools.partial(
get_dataset_wrapper, get_dataset_wrapper,
cfg.pretraining_dataset[0], cfg.pretraining_dataset[0],
@@ -109,14 +104,8 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg.pretraining_dataset[0]["type"] or "pretrain", cfg.pretraining_dataset[0]["type"] or "pretrain",
) )
iter_ds = load_dataset(
path, streaming=True, split=split, name=name, data_files=data_files
)
if skip:
LOG.info(f"Skipping {skip} samples from the dataset")
iter_ds = iter_ds.skip(skip)
train_dataset = wrap_pretraining_dataset( train_dataset = wrap_pretraining_dataset(
iter_ds, load_dataset(path, streaming=True, split=split, name=name),
tokenizer, tokenizer,
cfg, cfg,
ds_wrapper_partial, ds_wrapper_partial,

View File

@@ -107,13 +107,6 @@ def load_dataset_w_config(config_dataset, auth_token):
except (FileNotFoundError, ConnectionError): except (FileNotFoundError, ConnectionError):
pass pass
# gather extra args from the config
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
else:
load_ds_kwargs["split"] = None
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(config_dataset.path) local_path = Path(config_dataset.path)
if local_path.exists(): if local_path.exists():
@@ -125,7 +118,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
streaming=False, streaming=False,
**load_ds_kwargs, split=None,
) )
else: else:
try: try:
@@ -137,7 +130,7 @@ def load_dataset_w_config(config_dataset, auth_token):
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,
streaming=False, streaming=False,
**load_ds_kwargs, split=None,
) )
elif local_path.is_file(): elif local_path.is_file():
ds_type = get_ds_type(config_dataset) ds_type = get_ds_type(config_dataset)
@@ -147,13 +140,16 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
**load_ds_kwargs, split=None,
) )
else: else:
raise ValueError( raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file" "unhandled dataset load: local path exists, but is neither a directory or a file"
) )
elif ds_from_hub: elif ds_from_hub:
load_ds_kwargs = {}
if config_dataset.split:
load_ds_kwargs["split"] = config_dataset.split
ds = load_dataset( ds = load_dataset(
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,
@@ -177,9 +173,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None,
storage_options=storage_options, storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
) )
elif config_dataset.path.startswith("https://"): elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset) ds_type = get_ds_type(config_dataset)
@@ -188,9 +184,9 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=False,
split=None,
storage_options=storage_options, storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
**load_ds_kwargs,
) )
else: else:
if isinstance(config_dataset.data_files, str): if isinstance(config_dataset.data_files, str):
@@ -218,7 +214,7 @@ def load_dataset_w_config(config_dataset, auth_token):
name=config_dataset.name, name=config_dataset.name,
data_files=fp, data_files=fp,
streaming=False, streaming=False,
**load_ds_kwargs, split=None,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")

View File

@@ -270,7 +270,7 @@ def load_sharded_model_quant(
model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config) model.hf_quantizer = AutoHfQuantizer.from_config(quantization_config)
if cfg.local_rank == 0 and verbose: if cfg.local_rank == 0 and verbose:
print(f"Loaded model weights in {time.time() - start:.3f} seconds") print(f"Loaded model weights in {time.time()-start:.3f} seconds")
# cleanup any extra memory usage from parallel loading # cleanup any extra memory usage from parallel loading
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@@ -380,19 +380,23 @@ class ModelLoader:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
plugin_manager.pre_model_load(self.cfg) plugin_manager.pre_model_load(self.cfg)
if self.cfg.adapter:
from axolotl.monkeypatch.transformers_fa_utils import (
patch_fa_peft_integration,
)
patch_fa_peft_integration()
if self.cfg.gradient_checkpointing == "unsloth": if self.cfg.gradient_checkpointing == "unsloth":
transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper transformers.modeling_utils.checkpoint = hf_grad_checkpoint_unsloth_wrapper
if self.cfg.flash_attention: if self.cfg.flash_attention:
self.patch_attention() self.patch_attention()
if self.cfg.model_config_type == "llama":
from axolotl.monkeypatch.trainer_grad_accum import (
patch_flash_attention_forward,
patch_forward_for_ga,
patch_training_step_for_ga,
)
patch_flash_attention_forward()
patch_forward_for_ga()
patch_training_step_for_ga()
if self.cfg.sample_packing and self.cfg.s2_attention: if self.cfg.sample_packing and self.cfg.s2_attention:
raise ValueError( raise ValueError(
"Received `sample_packing=true` and `s2_attention=true`; however, \ "Received `sample_packing=true` and `s2_attention=true`; however, \
@@ -1053,7 +1057,7 @@ class ModelLoader:
) )
if ( if (
hasattr(self.model, "get_input_embeddings") hasattr(self.model, "get_input_embeddings")
and self.model.get_input_embeddings().num_embeddings != embeddings_len and self.model.get_input_embeddings().num_embeddings < embeddings_len
): ):
resize_kwargs = {} resize_kwargs = {}
if self.cfg.mean_resizing_embeddings is not None: if self.cfg.mean_resizing_embeddings is not None:

View File

@@ -196,7 +196,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")
if cfg.model_config_type in ["falcon", "mistral"]: if cfg.model_config_type == "falcon":
LOG.info("dropping token_type_ids column if it exists") LOG.info("dropping token_type_ids column if it exists")
if "token_type_ids" in train_dataset.column_names: if "token_type_ids" in train_dataset.column_names:
train_dataset = train_dataset.remove_columns("token_type_ids") train_dataset = train_dataset.remove_columns("token_type_ids")
@@ -310,22 +310,19 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
def process_pretraining_datasets_for_packing( def process_pretraining_datasets_for_packing(
train_dataset, sequence_len, skip_position_ids=True, drop_attention_mask=False train_dataset, sequence_len, skip_position_ids=True
): ):
drop_long = partial(drop_long_seq, sequence_len=sequence_len) drop_long = partial(drop_long_seq, sequence_len=sequence_len)
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_long, drop_long,
desc="Dropping Long Sequences", desc="Dropping Long Sequences",
load_from_cache_file=False,
) )
if not skip_position_ids: if skip_position_ids:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,
desc="Add position_id column (Pretraining Sample Packing)", desc="Add position_id column (Pretraining Sample Packing)",
) )
if drop_attention_mask:
train_dataset = train_dataset.remove_columns("attention_mask")
return train_dataset return train_dataset

View File

@@ -1,5 +1,4 @@
"""Shared pytest fixtures for cli module.""" """Shared pytest fixtures for cli module."""
import pytest import pytest
from click.testing import CliRunner from click.testing import CliRunner

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI fetch command.""" """pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import fetch from axolotl.cli.main import fetch

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI inference command.""" """pytest tests for axolotl CLI inference command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,5 +1,4 @@
"""General pytest tests for axolotl.cli.main interface.""" """General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli from axolotl.cli.main import build_command, cli

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI merge_lora command.""" """pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,6 +1,5 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" """pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli
@@ -16,3 +15,46 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
assert mock.called assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path) assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0 assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
"""Test merge_sharded_fsdp_weights command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command with save_path option"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--save-path",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI preprocess command.""" """pytest tests for axolotl CLI preprocess command."""
import shutil import shutil
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch

View File

@@ -0,0 +1,76 @@
"""pytest tests for axolotl CLI shard command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
def test_shard_with_accelerate(cli_runner, config_path):
"""Test shard command with accelerate"""
with patch("subprocess.run") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--accelerate"])
assert mock.called
assert mock.call_args.args[0] == [
"accelerate",
"launch",
"-m",
"axolotl.cli.shard",
str(config_path),
"--debug-num-examples",
"0",
]
assert mock.call_args.kwargs == {"check": True}
assert result.exit_code == 0
def test_shard_no_accelerate(cli_runner, config_path):
"""Test shard command without accelerate"""
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(cli, ["shard", str(config_path), "--no-accelerate"])
assert mock.called
assert result.exit_code == 0
def test_shard_with_model_dir(cli_runner, config_path, tmp_path):
"""Test shard command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
catch_exceptions=False,
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_shard_with_save_dir(cli_runner, config_path):
with patch("axolotl.cli.shard.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"shard",
str(config_path),
"--no-accelerate",
"--save-dir",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_dir"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,5 +1,4 @@
"""pytest tests for axolotl CLI --version""" """pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,6 +1,5 @@
"""pytest tests for axolotl CLI utils.""" """pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import json import json
from unittest.mock import Mock, patch from unittest.mock import Mock, patch

View File

@@ -120,12 +120,13 @@ def temp_dir():
@pytest.fixture(scope="function", autouse=True) @pytest.fixture(scope="function", autouse=True)
def cleanup_monkeypatches(): def cleanup_monkeypatches():
from transformers import Trainer from transformers import Trainer
from transformers.models.llama.modeling_llama import ( # LlamaFlashAttention2, from transformers.models.llama.modeling_llama import (
LlamaAttention, LlamaAttention,
LlamaFlashAttention2,
LlamaForCausalLM, LlamaForCausalLM,
) )
# original_fa2_forward = LlamaFlashAttention2.forward original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = ( original_trainer_inner_training_loop = (
@@ -135,7 +136,7 @@ def cleanup_monkeypatches():
# monkey patches can happen inside the tests # monkey patches can happen inside the tests
yield yield
# Reset LlamaFlashAttention2 forward # Reset LlamaFlashAttention2 forward
# LlamaFlashAttention2.forward = original_fa2_forward LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access Trainer._inner_training_loop = ( # pylint: disable=protected-access
@@ -148,10 +149,7 @@ def cleanup_monkeypatches():
("transformers.models.llama",), ("transformers.models.llama",),
( (
"transformers.models.llama.modeling_llama", "transformers.models.llama.modeling_llama",
[ ["LlamaFlashAttention2", "LlamaAttention"],
# "LlamaFlashAttention2",
"LlamaAttention",
],
), ),
("transformers.trainer",), ("transformers.trainer",),
("transformers", ["Trainer"]), ("transformers", ["Trainer"]),

View File

@@ -1,41 +1,43 @@
""" """
Simple end-to-end test for Liger integration Simple end-to-end test for Liger integration
""" """
import unittest
from pathlib import Path
from e2e.utils import require_torch_2_4_1 from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists from ..utils import with_temp_dir
class LigerIntegrationTestCase: class LigerIntegrationTestCase(unittest.TestCase):
""" """
e2e tests for liger integration with Axolotl e2e tests for liger integration with Axolotl
""" """
@require_torch_2_4_1 @with_temp_dir
def test_llama_wo_flce(self, temp_dir): def test_llama_wo_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [ "plugins": [
"axolotl.integrations.liger.LigerPlugin", "axolotl.integrations.liger.LigerPlugin",
], ],
"liger_rope": True, "liger_rope": True,
"liger_rms_norm": True, "liger_rms_norm": True,
"liger_glu_activation": True, "liger_swiglu": True,
"liger_cross_entropy": True, "liger_cross_entropy": True,
"liger_fused_linear_cross_entropy": False, "liger_fused_linear_cross_entropy": False,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.05, "val_set_size": 0.1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"datasets": [ "datasets": [
{ {
@@ -44,15 +46,15 @@ class LigerIntegrationTestCase:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 2, "micro_batch_size": 8,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True, "save_safetensors": True,
"bf16": "auto", "bf16": "auto",
"max_steps": 5, "max_steps": 10,
} }
) )
prepare_plugins(cfg) prepare_plugins(cfg)
@@ -60,27 +62,29 @@ class LigerIntegrationTestCase:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()
@require_torch_2_4_1 @with_temp_dir
def test_llama_w_flce(self, temp_dir): def test_llama_w_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"plugins": [ "plugins": [
"axolotl.integrations.liger.LigerPlugin", "axolotl.integrations.liger.LigerPlugin",
], ],
"liger_rope": True, "liger_rope": True,
"liger_rms_norm": True, "liger_rms_norm": True,
"liger_glu_activation": True, "liger_swiglu": True,
"liger_cross_entropy": False, "liger_cross_entropy": False,
"liger_fused_linear_cross_entropy": True, "liger_fused_linear_cross_entropy": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.05, "val_set_size": 0.1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"datasets": [ "datasets": [
{ {
@@ -89,15 +93,15 @@ class LigerIntegrationTestCase:
}, },
], ],
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 2, "micro_batch_size": 8,
"gradient_accumulation_steps": 2, "gradient_accumulation_steps": 1,
"output_dir": temp_dir, "output_dir": temp_dir,
"learning_rate": 0.00001, "learning_rate": 0.00001,
"optimizer": "adamw_torch", "optimizer": "adamw_torch",
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True, "save_safetensors": True,
"bf16": "auto", "bf16": "auto",
"max_steps": 5, "max_steps": 10,
} }
) )
prepare_plugins(cfg) prepare_plugins(cfg)
@@ -105,5 +109,5 @@ class LigerIntegrationTestCase:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -2,17 +2,17 @@
Simple end-to-end test for Cut Cross Entropy integration Simple end-to-end test for Cut Cross Entropy integration
""" """
from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils import get_pytorch_version from axolotl.utils import get_pytorch_version
from axolotl.utils.config import normalize_config, prepare_plugins from axolotl.utils.config import normalize_config, prepare_plugins
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -64,10 +64,10 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version() major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4): if (major, minor) < (2, 4):
with pytest.raises(ImportError): with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else: else:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()
@pytest.mark.parametrize( @pytest.mark.parametrize(
"attention_type", "attention_type",
@@ -92,7 +92,7 @@ class TestCutCrossEntropyIntegration:
major, minor, _ = get_pytorch_version() major, minor, _ = get_pytorch_version()
if (major, minor) < (2, 4): if (major, minor) < (2, 4):
with pytest.raises(ImportError): with pytest.raises(ImportError):
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
else: else:
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -63,7 +63,6 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True,
} }
) )
@@ -128,7 +127,6 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True,
} }
) )
@@ -203,7 +201,6 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True,
} }
) )
@@ -226,12 +223,8 @@ class TestMultiGPULlama:
] ]
) )
loss_threshold = 2.3
check_tensorboard( check_tensorboard(
temp_dir + "/runs", temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
"train/train_loss",
loss_threshold,
"Train Loss is too high",
) )
def test_dpo_qlora_ddp(self, temp_dir): def test_dpo_qlora_ddp(self, temp_dir):
@@ -282,7 +275,6 @@ class TestMultiGPULlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"flash_attention": True, "flash_attention": True,
"use_tensorboard": True, "use_tensorboard": True,
"bf16": True,
} }
) )
@@ -305,12 +297,8 @@ class TestMultiGPULlama:
] ]
) )
loss_threshold = 2.3
check_tensorboard( check_tensorboard(
temp_dir + "/runs", temp_dir + "/runs", "train/train_loss", 2.3, "Train Loss is too high"
"train/train_loss",
loss_threshold,
"Train Loss is too high",
) )
@pytest.mark.parametrize( @pytest.mark.parametrize(

View File

@@ -5,14 +5,15 @@ E2E tests for multipack fft llama using 4d attention masks
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, require_torch_2_3_1, with_temp_dir from ..utils import require_torch_2_3_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -65,8 +66,8 @@ class Test4dMultipackLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir @with_temp_dir
def test_torch_lora_packing(self, temp_dir): def test_torch_lora_packing(self, temp_dir):
@@ -109,5 +110,5 @@ class Test4dMultipackLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -5,7 +5,7 @@ from pathlib import Path
import yaml import yaml
from axolotl.cli.config import load_cfg from axolotl.cli import load_cfg
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault

View File

@@ -4,17 +4,18 @@ E2E tests for lora llama
import logging import logging
import os import os
from pathlib import Path
import pytest import pytest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -80,8 +81,8 @@ class TestFAXentropyLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 1.5, "Train Loss is too high"

View File

@@ -5,14 +5,15 @@ E2E tests for falcon
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -67,8 +68,8 @@ class TestFalconPatched(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir @with_temp_dir
def test_ft(self, temp_dir): def test_ft(self, temp_dir):
@@ -107,5 +108,5 @@ class TestFalconPatched(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,17 +5,18 @@ E2E tests for lora llama
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
import pytest import pytest
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -71,5 +72,5 @@ class TestFusedLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,16 +5,17 @@ E2E tests for llama w/ S2 attn
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -69,8 +70,8 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir @with_temp_dir
def test_fft_s2_attn(self, temp_dir): def test_fft_s2_attn(self, temp_dir):
@@ -109,5 +110,5 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,17 +5,18 @@ E2E tests for lora llama
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
import pytest import pytest
from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available from transformers.utils import is_auto_gptq_available, is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -74,8 +75,8 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available") @pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir @with_temp_dir
@@ -124,5 +125,5 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -5,14 +5,15 @@ E2E tests for lora llama
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -67,8 +68,8 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir @with_temp_dir
def test_ft_packing(self, temp_dir): def test_ft_packing(self, temp_dir):
@@ -108,5 +109,5 @@ class TestMistral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -5,14 +5,15 @@ E2E tests for mixtral
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -64,8 +65,8 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir @with_temp_dir
def test_ft(self, temp_dir): def test_ft(self, temp_dir):
@@ -102,5 +103,9 @@ class TestMixtral(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) model, _ = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -6,6 +6,7 @@ import unittest
import transformers import transformers
from axolotl.common.cli import TrainerCliArgs
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
@@ -48,8 +49,14 @@ class TestModelPatches(unittest.TestCase):
} }
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False) model, _ = load_model(cfg, tokenizer, inference=cli_args.inference)
assert (
"MixtralFlashAttention2"
in model.model.layers[0].self_attn.__class__.__name__
)
@with_temp_dir @with_temp_dir
def test_mistral_multipack(self, temp_dir): def test_mistral_multipack(self, temp_dir):
@@ -80,8 +87,9 @@ class TestModelPatches(unittest.TestCase):
} }
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs()
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
load_model(cfg, tokenizer, inference=False) load_model(cfg, tokenizer, inference=cli_args.inference)
assert ( assert (
"torch.jit" "torch.jit"

View File

@@ -5,14 +5,15 @@ E2E tests for lora llama
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, with_temp_dir from ..utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -67,8 +68,8 @@ class TestPhiMultipack(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "pytorch_model.bin").exists()
@with_temp_dir @with_temp_dir
def test_qlora_packed(self, temp_dir): def test_qlora_packed(self, temp_dir):
@@ -118,5 +119,5 @@ class TestPhiMultipack(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -6,16 +6,17 @@ import logging
import os import os
import re import re
import subprocess import subprocess
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir from ..utils import most_recent_subdir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -71,7 +72,7 @@ class TestResumeLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
resume_cfg = cfg | DictDefault( resume_cfg = cfg | DictDefault(
{ {
@@ -81,8 +82,8 @@ class TestResumeLlama:
normalize_config(resume_cfg) normalize_config(resume_cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
train(cfg=resume_cfg, dataset_meta=dataset_meta) train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs") tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}" cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"

View File

@@ -1,18 +1,13 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected.""" """Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest import unittest
import pytest from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothIntegration(unittest.TestCase): class TestUnslothIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests.""" """Unsloth monkeypatch integration tests."""
def test_is_self_attn_patchable(self): def test_is_self_attn_patchable(self):
from axolotl.monkeypatch.unsloth_ import check_self_attn_is_patchable
# ensures the current version of transformers has loss code that matches our patching code # ensures the current version of transformers has loss code that matches our patching code
self.assertTrue( self.assertTrue(
check_self_attn_is_patchable(), check_self_attn_is_patchable(),

View File

@@ -3,25 +3,23 @@ e2e tests for unsloth qlora
""" """
import logging import logging
import os import os
from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard from ..utils import check_tensorboard
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)
class TestUnslothQLoRA: class TestUnslothQLoRA:
""" """
Test class for Unsloth QLoRA Llama models Test class for Unsloth QLoRA Llama models
@@ -75,8 +73,8 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
@@ -125,8 +123,8 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"
@@ -180,8 +178,8 @@ class TestUnslothQLoRA:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Train Loss is too high"

View File

@@ -9,13 +9,13 @@ from pathlib import Path
import pytest import pytest
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_rl_datasets
from axolotl.common.datasets import load_preference_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -65,10 +65,10 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir @with_temp_dir
def test_dpo_nll_lora(self, temp_dir): def test_dpo_nll_lora(self, temp_dir):
@@ -110,10 +110,10 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir @with_temp_dir
def test_dpo_use_weighting(self, temp_dir): def test_dpo_use_weighting(self, temp_dir):
@@ -155,10 +155,10 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip("kto_pair no longer supported in trl") @pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir @with_temp_dir
@@ -200,10 +200,10 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir @with_temp_dir
def test_ipo_lora(self, temp_dir): def test_ipo_lora(self, temp_dir):
@@ -244,10 +244,10 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir @with_temp_dir
def test_orpo_lora(self, temp_dir): def test_orpo_lora(self, temp_dir):
@@ -291,10 +291,10 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip(reason="Fix the implementation") @pytest.mark.skip(reason="Fix the implementation")
@with_temp_dir @with_temp_dir
@@ -355,7 +355,7 @@ class TestDPOLlamaLora(unittest.TestCase):
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-20", cfg) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()

View File

@@ -5,14 +5,15 @@ E2E tests for llama pretrain
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard, with_temp_dir from .utils import check_tensorboard, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -60,8 +61,8 @@ class TestEmbeddingsLrScale(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"
@@ -104,8 +105,8 @@ class TestEmbeddingsLrScale(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard( check_tensorboard(
temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high" temp_dir + "/runs", "train/train_loss", 2.0, "Loss is too high"

View File

@@ -5,14 +5,15 @@ E2E tests for falcon
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -69,8 +70,8 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir @with_temp_dir
def test_lora_added_vocab(self, temp_dir): def test_lora_added_vocab(self, temp_dir):
@@ -122,8 +123,8 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir @with_temp_dir
def test_ft(self, temp_dir): def test_ft(self, temp_dir):
@@ -161,5 +162,5 @@ class TestFalcon(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "pytorch_model.bin").exists()

View File

@@ -4,11 +4,10 @@ E2E tests for llama
import logging import logging
import os import os
from pathlib import Path
from e2e.utils import check_model_output_exists from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
@@ -60,8 +59,8 @@ class TestLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()
def test_fix_untrained_tokens(self, temp_dir): def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -103,8 +102,8 @@ class TestLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()
def test_batch_flattening(self, temp_dir): def test_batch_flattening(self, temp_dir):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -142,5 +141,5 @@ class TestLlama:
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -4,49 +4,40 @@ E2E tests for llama pretrain
import logging import logging
import os import os
import unittest
from pathlib import Path
import pytest from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, check_tensorboard from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
class TestPretrainLlama: class TestPretrainLlama(unittest.TestCase):
""" """
Test case for Llama models w pretraining Test case for Llama models w pretraining
""" """
@pytest.mark.parametrize( @with_temp_dir
"sample_packing", def test_pretrain_w_sample_packing(self, temp_dir):
[True, False],
)
@pytest.mark.parametrize(
"pretrain_multipack_attn",
[True, False],
)
def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
if not sample_packing and pretrain_multipack_attn:
return
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": sample_packing, "sample_packing": True,
"pretrain_multipack_attn": pretrain_multipack_attn,
"dataset_processes": 1,
"special_tokens": { "special_tokens": {
"pad_token": "<|endoftext|>", "unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
}, },
"pretraining_dataset": [ "pretraining_dataset": [
{ {
@@ -57,7 +48,7 @@ class TestPretrainLlama:
], ],
"max_steps": 5, "max_steps": 5,
"num_epochs": 1, "num_epochs": 1,
"micro_batch_size": 2, "micro_batch_size": 1,
"gradient_accumulation_steps": 1, "gradient_accumulation_steps": 1,
"val_set_size": 0.0, "val_set_size": 0.0,
"output_dir": temp_dir, "output_dir": temp_dir,
@@ -66,21 +57,11 @@ class TestPretrainLlama:
"lr_scheduler": "cosine", "lr_scheduler": "cosine",
"save_safetensors": True, "save_safetensors": True,
"bf16": "auto", "bf16": "auto",
"use_tensorboard": True,
} }
) )
normalize_config(cfg) normalize_config(cfg)
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "model.safetensors").exists()
loss_threshold = 3.5
if sample_packing and not pretrain_multipack_attn:
loss_threshold = 6.5
check_tensorboard(
temp_dir + "/runs",
"train/train_loss",
loss_threshold,
"Train Loss is too high",
)

View File

@@ -5,14 +5,15 @@ E2E tests for lora llama
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -66,8 +67,8 @@ class TestLlamaVision(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.safetensors").exists()
@with_temp_dir @with_temp_dir
def test_lora_llama_vision_multimodal_dataset(self, temp_dir): def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
@@ -111,5 +112,5 @@ class TestLlamaVision(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.safetensors").exists()

View File

@@ -5,14 +5,15 @@ E2E tests for lora llama
import logging import logging
import os import os
import unittest import unittest
from pathlib import Path
from axolotl.cli.args import TrainerCliArgs from axolotl.cli import load_datasets
from axolotl.common.datasets import load_datasets from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
from axolotl.utils.config import normalize_config from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists, with_temp_dir from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e") LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true" os.environ["WANDB_DISABLED"] = "true"
@@ -63,5 +64,5 @@ class TestLoraLlama(unittest.TestCase):
cli_args = TrainerCliArgs() cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg) assert (Path(temp_dir) / "adapter_model.bin").exists()

Some files were not shown because too many files have changed in this diff Show More