Compare commits

..

6 Commits

Author SHA1 Message Date
Wing Lian
afb8218c67 fix the monkeypatch 2024-11-19 02:12:33 -05:00
Wing Lian
1ff78d6347 remove temp_dir decorator as we're using fixtures now 2024-11-19 01:28:27 -05:00
Wing Lian
613a217142 monkeypatch for zero3 w 8bit lora 2024-11-19 00:45:20 -05:00
Wing Lian
127953af4e zero3 can'y use 8bit optimizer 2024-11-19 00:45:20 -05:00
Wing Lian
920ea77bdf reduce number of steps 2024-11-19 00:45:20 -05:00
Wing Lian
ef60e3e851 bi-weekly 8bit lora zero3 check 2024-11-19 00:45:20 -05:00
36 changed files with 471 additions and 1815 deletions

View File

@@ -10,7 +10,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -49,7 +49,7 @@ jobs:
axolotlai/axolotl
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
type=semver,pattern={{version}}
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
@@ -77,7 +77,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -116,7 +116,7 @@ jobs:
axolotlai/axolotl-cloud
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
type=semver,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
@@ -140,7 +140,7 @@ jobs:
build-axolotl-cloud-no-tmux:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:
@@ -163,7 +163,7 @@ jobs:
axolotlai/axolotl-cloud-term
tags: |
type=ref,event=branch
type=pep440,pattern={{version}}
type=semver,pattern={{version}}
- name: Login to Docker Hub
uses: docker/login-action@v3
with:

View File

@@ -15,7 +15,7 @@ concurrency:
jobs:
test-axolotl-multigpu:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:

View File

@@ -7,7 +7,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
strategy:
fail-fast: false
matrix:
@@ -71,7 +71,7 @@ jobs:
build-axolotl-cloud:
needs: build-axolotl
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
matrix:

View File

@@ -10,13 +10,20 @@ jobs:
setup_release:
name: Create Release
runs-on: ubuntu-latest
permissions:
contents: write
steps:
- name: Create release
- name: Get the tag version
id: extract_branch
run: echo ::set-output name=branch::${GITHUB_REF#refs/tags/}
shell: bash
- name: Create Release
id: create_release
uses: actions/create-release@v1
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: gh release create "$GITHUB_REF_NAME" # GITHUB_REF_NAME is the tag name in `on.push.tags` workflows
with:
tag_name: ${{ steps.extract_branch.outputs.branch }}
release_name: ${{ steps.extract_branch.outputs.branch }}
pypi-publish:
name: Upload release to PyPI
runs-on: ubuntu-latest

View File

@@ -77,56 +77,12 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.4.1", "2.5.1"]
timeout-minutes: 20
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging setuptools wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
python3 setup.py sdist
pip3 install dist/axolotl*.tar.gz
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Run tests
run: |
pytest -n8 --ignore=tests/e2e/ tests/
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, pytest-sdist]
needs: [pre-commit, pytest]
strategy:
fail-fast: false

3
.gitignore vendored
View File

@@ -182,6 +182,3 @@ submit.sh
typings/
out/
# vim
*.swp

View File

@@ -1,4 +0,0 @@
include requirements.txt
include README.md
include LICENSE
recursive-include axolotl *.py

View File

@@ -11,10 +11,12 @@ standard industry baselines.
### Installation
The following will install the correct unsloth and extras from source.
The following will install unsloth from source and downgrade xformers as unsloth is incompatible with the most up
to date libraries.
```bash
python scripts/unsloth_install.py | sh
pip install --no-deps "unsloth @ git+https://github.com/unslothai/unsloth.git"
pip install --no-deps --force-reinstall xformers==0.0.26.post1
```
### Using unsloth w Axolotl

View File

@@ -2,15 +2,19 @@
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "AKjdG7tbTb-n"
},
"source": [
"## Setting up"
"# Example notebook for running Axolotl on google colab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"id": "RcbNpOgWRcii"
},
"outputs": [],
"source": [
"import torch\n",
@@ -18,76 +22,82 @@
"assert (torch.cuda.is_available()==True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h3nLav8oTRA5"
},
"source": [
"## Install Axolotl and dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3c3yGAwnOIdi",
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
},
"outputs": [],
"source": [
"!pip install axolotl[deepspeed]"
"!pip install -e git+https://github.com/axolotl-ai-cloud/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.7.0.post2\"\n",
"!pip install deepspeed==\"0.13.1\"!pip install mlflow==\"2.13.0\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "BW2MFr7HTjub"
},
"source": [
"## Hugging Face login (optional)"
"## Create an yaml config file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from huggingface_hub import notebook_login\n",
"notebook_login()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example configuration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"id": "9pkF2dSoQEUN"
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# Your YAML string\n",
"yaml_string = \"\"\"\n",
"base_model: NousResearch/Meta-Llama-3.1-8B\n",
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
"model_type: LlamaForCausalLM\n",
"tokenizer_type: LlamaTokenizer\n",
"\n",
"load_in_8bit: false\n",
"load_in_4bit: true\n",
"strict: false\n",
"\n",
"datasets:\n",
" - path: tatsu-lab/alpaca\n",
" - path: mhenrichsen/alpaca_2k_test\n",
" type: alpaca\n",
"dataset_prepared_path: last_run_prepared\n",
"dataset_prepared_path:\n",
"val_set_size: 0.05\n",
"output_dir: ./outputs/lora-out\n",
"\n",
"sequence_len: 2048\n",
"sample_packing: true\n",
"eval_sample_packing: true\n",
"pad_to_sequence_len: true\n",
"output_dir: ./outputs/qlora-out\n",
"\n",
"adapter: qlora\n",
"lora_model_dir:\n",
"\n",
"sequence_len: 4096\n",
"sample_packing: true\n",
"eval_sample_packing: false\n",
"pad_to_sequence_len: true\n",
"\n",
"lora_r: 32\n",
"lora_alpha: 16\n",
"lora_dropout: 0.05\n",
"lora_target_modules:\n",
"lora_target_linear: true\n",
"lora_fan_in_fan_out:\n",
"lora_modules_to_save:\n",
" - embed_tokens\n",
" - lm_head\n",
"\n",
"wandb_project:\n",
"wandb_entity:\n",
@@ -95,12 +105,12 @@
"wandb_name:\n",
"wandb_log_model:\n",
"\n",
"gradient_accumulation_steps: 2\n",
"micro_batch_size: 1\n",
"num_epochs: 1\n",
"optimizer: paged_adamw_8bit\n",
"gradient_accumulation_steps: 4\n",
"micro_batch_size: 2\n",
"num_epochs: 4\n",
"optimizer: paged_adamw_32bit\n",
"lr_scheduler: cosine\n",
"learning_rate: 2e-5\n",
"learning_rate: 0.0002\n",
"\n",
"train_on_inputs: false\n",
"group_by_length: false\n",
@@ -111,15 +121,13 @@
"gradient_checkpointing: true\n",
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"local_rank:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: false\n",
"sdp_attention: true\n",
"flash_attention: true\n",
"\n",
"warmup_steps: 1\n",
"max_steps: 25\n",
"evals_per_epoch: 1\n",
"eval_table_size:\n",
"warmup_steps: 10\n",
"evals_per_epoch: 4\n",
"saves_per_epoch: 1\n",
"debug:\n",
"deepspeed:\n",
@@ -127,9 +135,8 @@
"fsdp:\n",
"fsdp_config:\n",
"special_tokens:\n",
" pad_token: <|end_of_text|>\n",
"\"\"\"\n",
"\n",
"\"\"\"\n",
"\n",
"# Convert the YAML string to a Python dictionary\n",
"yaml_dict = yaml.safe_load(yaml_string)\n",
@@ -139,124 +146,31 @@
"\n",
"# Write the YAML file\n",
"with open(file_path, 'w') as file:\n",
" yaml.dump(yaml_dict, file)"
" yaml.dump(yaml_dict, file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"metadata": {
"id": "bidoj8YLTusD"
},
"source": [
"Above we have a configuration file with base LLM model and datasets specified, among many other things. Axolotl can automatically detect whether the specified datasets are on HuggingFace repo or local machine.\n",
"\n",
"The Axolotl configuration options encompass model and dataset selection, data pre-processing, and training. Let's go through them line by line:\n",
"\n",
"* \"base model\": String value, specifies the underlying pre-trained LLM that will be used for finetuning\n",
"\n",
"Next we have options for model weights quantization. Quantization allows for reduction in occupied memory on GPUs.\n",
"\n",
"* \"load_in_8bit\": Boolean value, whether to quantize the model weights into 8-bit integer.\n",
"\n",
"* \"load_in_4bit\": Boolean value, whether to quantize the model weights into 4-bit integer.\n",
"\n",
"* \"strict\": Boolean value. If false, it allows for overriding established configuration options in the yaml file when executing in command-line interface.\n",
"\n",
"* \"datasets\": a list of dicts that contain path and type of data sets as well as other optional configurations where datasets are concerned. Supports multiple datasets.\n",
"\n",
"* \"val_set_size\": Either a float value less than one or an integer less than the total size of dataset. Sets the size of validation set from the whole dataset. If float, sets the proportion of the dataset assigned for validation. If integer, sets the direct size of validation set.\n",
"\n",
"* \"output_dir\": String value. Path of trained model.\n",
"\n",
"For data preprocessing:\n",
"\n",
"* \"sequence_len\": Integer. Specifies the maximum sequence length of the input. Typically 2048 or less.\n",
"\n",
"* \"pad_to_sequence_len\": Boolean. Padding input to maximum sequence length.\n",
"\n",
"* \"sample_packing\": Boolean. Specifies whether to use multi-packing with block diagonal attention.\n",
"\n",
"* \"special_tokens\": Python dict, optional. Allows users to specify the additional special tokens to be ignored by the tokenizer.\n",
"\n",
"For LoRA configuration and its hyperparamters:\n",
"\n",
"* \"adapter\": String. Either \"lora\" or \"qlora\", depending on user's choice.\n",
"\n",
"* \"lora_model_dir\": String, Optional. Path to directory that contains LoRA model, if there is already a trained LoRA model the user would like to use.\n",
"\n",
"* \"lora_r\": Integer. Refers to the rank of LoRA decomposition matrices. Higher value will reduce LoRA efficiency. Recommended to be set to 8.\n",
"\n",
"* \"lora_alpha\": Integer. Scale the weight matrices by $\\frac{\\text{lora_alpha}}{\\text{lora_r}}$Recommended to be fixed at 16.\n",
"\n",
"* \"lora_dropout\": Float that is 1 or less. The dropout probability of a lora layer.\n",
"\n",
"* \"lora_target_linear\": Boolean. If true, lora will target all linear modules in the transformers architecture.\n",
"\n",
"* \"lora_modules_to_save\": If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.\n",
"\n",
"See [LoRA](https://arxiv.org/abs/2106.09685) for detailed explanation of LoRA implementation.\n",
"\n",
"For the training configurations:\n",
"\n",
"* \"gradient_accumulation_steps\": Integer. The number of steps over which to accumulate gradient for batch training. E.g. if 2, backprop is performed every two steps.\n",
"\n",
"* \"micro_batch_size\": Integer. Batch size per gpu / gradient_accumulation_steps\n",
"\n",
"* \"num_epochs\": Integer. Number of epochs. One epoch is when training has looped over every batch in the whole data set once.\n",
"\n",
"* \"optimizer\": The optimizer to use for the training.\n",
"\n",
"* \"learning_rate\": The learning rate.\n",
"\n",
"* \"lr_scheduler\": The learning rate scheduler to use for adjusting learning rate during training.\n",
"\n",
"* \"train_on_inputs\": Boolean. Whether to ignore or include the user's prompt from the training labels.\n",
"\n",
"* \"group_by_length\": Boolean. Whether to group similarly sized data to minimize padding.\n",
"\n",
"* \"bf16\": Either \"auto\", \"true\", or \"false\". Whether to use CUDA bf16 floating point format. If set to \"auto\", will automatically apply bf16 should the gpu supports it.\n",
"\n",
"* \"fp16\": Optional. Specifies whether to use CUDA fp16. Automatically set to true if \"bf16\" is set to true. Otherwise false.\n",
"\n",
"* \"tf32\": Boolean. Whether to use CUDA tf32. Will override bf16.\n",
"\n",
"* \"gradient_checkpointing\": Boolean. Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing\n",
"\n",
"* \"gradient_checkpointing_kwargs\": Python Dict. Fed into the trainer.\n",
"\n",
"* \"logging_steps\": Integer. Log training information over every specified number of steps.\n",
"\n",
"* \"flash_attention\": Boolean. Whether to use the [flash attention](https://github.com/Dao-AILab/flash-attention) mechanism.\n",
"\n",
"* \"sdp_attention\": Boolean. Whether to use the Scaled Dot Product attention mechanism (the attention mechanism in the [original implementation](https://arxiv.org/abs/1706.03762) of transformers.)\n",
"\n",
"* \"warmup_steps\": Integer. The number of pre-training steps where a very low learning rate is used.\n",
"\n",
"* \"evals_per_epoch\": Integer. Number of evaluations to be performed within one training epoch.\n",
"\n",
"* \"saves_per_epoch\": Integer. Number of times the model is saved in one training epoch.\n",
"\n",
"* \"weight_decay\": Positive Float. Sets the \"strength\" of weight decay (i.e. setting the coefficient of L2 regularization)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The above is but a snippet aiming to get users familiarized with the types of streamlined configuration options axolotl provides. For a full list of configuration options, see [here](https://axolotl-ai-cloud.github.io/axolotl/docs/config.html)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Train the model"
"## Launch the training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ydTI2Jk2RStU",
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
},
"outputs": [],
"source": [
"# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
},
@@ -264,7 +178,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Predict with trained model"
"## Play with inference"
]
},
{
@@ -273,85 +187,36 @@
"metadata": {},
"outputs": [],
"source": [
"# By using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.inference /content/test_axolotl.yaml \\\n",
" --lora_model_dir=\"./outputs/lora-out\" --gradio"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Deeper Dive"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"It is also helpful to gain some familiarity over some of the core inner workings of axolotl"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Configuration Normalization"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Axolotl uses a custom Dict class, called ```DictDefault```\n",
"to store configurations specified in the yaml configuration file (into a Python variable named ```cfg```). The definition for this custom Dict can be found in the [utils/dict.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/dict.py)\n",
"\n",
"```DictDefault``` is amended such that calling a missing key from it will result in a ```None``` return type. This is important because if some configuration options aren't specified by the user, the ```None``` type allows Axolotl to perform boolean operations to determine the default settings for missing configurations. For more examples on how this is done, check out [utils/config/__init__.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/config/__init__.py)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Loading Models, Tokenizers, and Trainer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If we inspect [cli.train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/cli/train.py), we will find that most of the heavy lifting were done by the function ```train()``` which is itself imported from [src/axolotl/train.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/train.py).\n",
"\n",
"```train()``` takes care of loading the appropriate tokenizer and pre-trained model through ```load_model()``` and ```load_tokenizer()``` from [src/axolotl/utils/models.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/models.py) respectively.\n",
"\n",
"```load_tokenizer()``` loads in the appropriate tokenizer given the desired model, as well as chat templates.\n",
"\n",
"```ModelLoader``` class follows after tokenizer has been selected. It will automatically discern the base model type, load in the desired model, as well as applying model-appropriate attention mechanism modifications (e.g. flash attention). Depending on which base model the user chooses in the configuration, ```ModelLoader``` will utilize the corresponding \"attention hijacking\" script. For example, if the user specified the base model to be ```NousResearch/Meta-Llama-3.1-8B```, which is of llama type, and set ```flash_attn``` to ```True```, ```ModelLoader``` will load in [llama_attn_hijack_flash.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/monkeypatch/llama_attn_hijack_flash.py). For a list of supported attention hijacking, please refer to the directory [/src/axolotl/monkeypatch/](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch)\n",
"\n",
"Another important operation encompassed in ```train()``` is setting up the training that takes into account of user-specified traning configurations (e.g. num_epochs, optimizer) through the use of ```setup_trainer()``` from [/src/axolotl/utils/trainer.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/utils/trainer.py), which in turn relies on modules from [/src/axolotl/core/trainer_builder.py](https://github.com/axolotl-ai-cloud/axolotl/blob/main/src/axolotl/core/trainer_builder.py).\n",
"```trainer_builder.py``` provides a list of trainer object options bespoke for the task type (Causal or Reinforcement learning ('dpo', 'ipo', 'kto') )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Monkey patch\n",
"\n",
"The [Monkey patch directory](https://github.com/axolotl-ai-cloud/axolotl/tree/main/src/axolotl/monkeypatch) is where model architecture/optimization patching scripts are stored (these are modifications that are not implemented in the official releases, hence the name monkey patch). It includes attention jacking, ReLoRA, and unsloth optimization."
" --qlora_model_dir=\"./qlora-out\" --gradio"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"version": "3.9.6"
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

View File

@@ -1,12 +1,12 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft==0.13.2
transformers==4.46.3
transformers==4.46.2
tokenizers>=0.20.1
bitsandbytes==0.44.1
accelerate==1.1.0
datasets==3.1.0
deepspeed==0.15.4
deepspeed==0.15.3
pydantic==2.6.3
addict
fire
@@ -33,7 +33,7 @@ tensorboard
python-dotenv==1.0.1
autoawq==0.2.7.post2
triton>=2.3.0
liger-kernel==0.4.2
liger-kernel==0.4.1
mamba-ssm==1.2.0.post1

View File

@@ -1,33 +0,0 @@
# noqa
# pylint: skip-file
try:
import torch
except ImportError:
raise ImportError("Install torch via `pip install torch`")
from packaging.version import Version as V
v = V(torch.__version__)
cuda = str(torch.version.cuda)
is_ampere = torch.cuda.get_device_capability()[0] >= 8
if cuda != "12.1" and cuda != "11.8" and cuda != "12.4":
raise RuntimeError(f"CUDA = {cuda} not supported!")
if v <= V("2.1.0"):
raise RuntimeError(f"Torch = {v} too old!")
elif v <= V("2.1.1"):
x = "cu{}{}-torch211"
elif v <= V("2.1.2"):
x = "cu{}{}-torch212"
elif v < V("2.3.0"):
x = "cu{}{}-torch220"
elif v < V("2.4.0"):
x = "cu{}{}-torch230"
elif v < V("2.5.0"):
x = "cu{}{}-torch240"
elif v < V("2.6.0"):
x = "cu{}{}-torch250"
else:
raise RuntimeError(f"Torch = {v} too new!")
x = x.format(cuda.replace(".", ""), "-ampere" if is_ampere else "")
print(
f'pip install unsloth-zoo && pip install --no-deps "unsloth[{x}] @ git+https://github.com/unslothai/unsloth.git"'
)

View File

@@ -96,11 +96,11 @@ install_requires, dependency_links = parse_requirements()
setup(
name="axolotl",
version="0.5.2",
version="0.5.0",
description="LLM Trainer",
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
package_dir={"": "src"},
packages=find_packages("src"),
packages=find_packages(),
install_requires=install_requires,
dependency_links=dependency_links,
extras_require={
@@ -108,7 +108,7 @@ setup(
"flash-attn==2.7.0.post2",
],
"deepspeed": [
"deepspeed==0.15.4",
"deepspeed==0.14.4",
"deepspeed-kernels",
],
"mamba-ssm": [

View File

@@ -30,10 +30,7 @@ from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
from axolotl.integrations.base import PluginManager
from axolotl.logging_config import configure_logging
from axolotl.train import TrainDatasetMeta
from axolotl.utils.chat_templates import (
get_chat_template,
get_chat_template_from_config,
)
from axolotl.utils.chat_templates import get_chat_template
from axolotl.utils.comet_ import setup_comet_env_vars
from axolotl.utils.config import (
normalize_cfg_datasets,
@@ -202,10 +199,6 @@ def do_inference(
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
elif cfg.datasets[0].type == "chat_template":
chat_template_str = get_chat_template_from_config(
cfg=cfg, ds_cfg=cfg.datasets[0], tokenizer=tokenizer
)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -1212,17 +1212,11 @@ class TrainerBuilderBase(abc.ABC):
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
@@ -1269,7 +1263,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
@@ -1307,7 +1301,17 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
callbacks.extend(
[
cb
for cb in plugin_manager.add_callbacks_post_trainer(
self.cfg, trainer
)
if cb
]
)
return callbacks
def _get_trainer_cls(self):

View File

@@ -1,361 +0,0 @@
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Any, Optional
import torch
from torch.autograd import Function
from .triton.attn_qk_int8_per_block_causal_varlen import (
backward as sageattn_varlen_backward,
)
from .triton.attn_qk_int8_per_block_causal_varlen import forward as attn_true_varlen
from .triton.quant_per_block_varlen import (
per_block_int8 as per_block_int8_varlen_triton,
)
def get_cuda_arch_versions():
cuda_archs = []
for i in range(torch.cuda.device_count()):
major, minor = torch.cuda.get_device_capability(i)
cuda_archs.append(f"sm{major}{minor}")
return cuda_archs
def sageattn_varlen(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
cu_seqlens_q: torch.Tensor,
cu_seqlens_k: torch.Tensor,
max_seqlen_q: int,
max_seqlen_k: int,
sm_scale: Optional[float] = None,
smooth_k: bool = True,
**kwargs: Any,
) -> torch.Tensor:
"""
Parameters
----------
q : torch.Tensor
The query tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
k : torch.Tensor
The key tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
v : torch.Tensor
The value tensor, shape: ``[cu_seqlens_k[-1], num_kv_heads, head_dim]``.
cu_seqlens_q : torch.Tensor
The cumulative sequence lengths for the query sequences in the batch, used to index into `q`.
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
cu_seqlens_k : torch.Tensor
The cumulative sequence lengths for the key and value sequences in the batch, used to index into `k` and `v`.
Shape: ``[batch_size + 1]``, where each entry represents the cumulative length of sequences up to that batch index.
max_seqlen_q : int
The maximum sequence length for the query tensor in the batch.
max_seqlen_k : int
The maximum sequence length for the key and value tensors in the batch.
is_causal : bool
Whether to apply causal mask to the attention matrix. Only applicable when qo_len == kv_len for each sequence.
Default: False.
sm_scale : Optional[float]
The scale used in softmax, if not provided, will be set to ``1.0 / sqrt(head_dim)``.
smooth_k : bool
Whether to smooth the key tensor by subtracting the mean along the sequence dimension.
Default: True.
Returns
-------
torch.Tensor
The output tensor, shape: ``[cu_seqlens_q[-1], num_qo_heads, head_dim]``.
Note
----
- ``num_qo_heads`` must be divisible by ``num_kv_heads``.
- The tensors `q`, `k`, and `v` must have the dtype ``torch.float16``, ``torch.bfloat16`` or ``torch.float32``.
- The tensors `cu_seqlens_q` and `cu_seqlens_k` must have the dtype ``torch.int32`` or ``torch.int64``.
- All tensors must be on the same cuda device.
- `smooth_k` will introduce slight overhead but will improve the accuracy under most circumstances.
"""
dtype = q.dtype
assert q.is_cuda, "Input tensors must be on cuda."
assert dtype in [
torch.float16,
torch.bfloat16,
], "Input tensors must be in dtype of torch.float16 or torch.bfloat16"
assert q.device == k.device == v.device, "All tensors must be on the same device."
assert q.dtype == k.dtype == v.dtype, "All tensors must have the same dtype."
head_dim = q.size(-1)
assert head_dim in [64, 128], "varlen only support head_dim [64, 128]."
assert (
q.stride(-1) == 1 and k.stride(-1) == 1 and v.stride(-1) == 1
), "Last dim of qkv must be contiguous."
assert (
cu_seqlens_q.is_contiguous() and cu_seqlens_k.is_contiguous()
), "cu_seqlens_q and cu_seqlens_k must be contiguous."
if dtype == torch.bfloat16 or dtype == torch.float32:
v = v.to(torch.float16)
if smooth_k:
km = k.mean(
dim=0, keepdim=True
) # ! km is calculated on the all the batches. Calculate over each individual sequence requires dedicated kernel.
k -= km
(
q_int8,
q_scale,
k_int8,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
) = per_block_int8_varlen_triton(
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=sm_scale
)
o = attn_true_varlen(
q_int8,
k_int8,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output_dtype=dtype,
)
return o
class SageAttentionFunction(Function):
@staticmethod
def forward(
ctx,
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
"""
query: Tensor of shape [batch_size, num_heads, seq_len_q, head_dim]
key: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
value: Tensor of shape [batch_size, num_heads, seq_len_k, head_dim]
attn_mask: Optional[Tensor], mask tensor
dropout_p: float, dropout probability
is_causal: bool, whether to apply causal masking
scale: Optional[float], scaling factor for attention scores
"""
# Ensure inputs are contiguous
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
# Handle default scale
if scale is None:
scale = 1.0 / (query.size(-1) ** 0.5)
# Save parameters needed for backward
ctx.scale = scale
ctx.is_causal = is_causal
ctx.dropout_p = dropout_p
ctx.attn_mask = attn_mask
# Prepare cumulative sequence lengths and max sequence lengths
# Assuming batch sizes are consistent across query, key, and value
batch_size, num_heads, seq_len_q, head_dim = query.shape
seq_len_k = key.shape[2]
# Flatten batch and head dimensions
q = query.view(
-1, seq_len_q, head_dim
) # [batch_size * num_heads, seq_len_q, head_dim]
k = key.view(-1, seq_len_k, head_dim)
v = value.view(-1, seq_len_k, head_dim)
# Create cumulative sequence lengths
cu_seqlens_q = torch.arange(
0,
(batch_size * num_heads + 1) * seq_len_q,
seq_len_q,
dtype=torch.int32,
device=query.device,
)
cu_seqlens_k = torch.arange(
0,
(batch_size * num_heads + 1) * seq_len_k,
seq_len_k,
dtype=torch.int32,
device=key.device,
)
max_seqlen_q = seq_len_q
max_seqlen_k = seq_len_k
# Call your custom per-block int8 quantization function
(
q_int8,
q_scale,
k_int8,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
) = per_block_int8_varlen_triton(
q, k, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, sm_scale=scale
)
# Call your custom attention function
if is_causal:
output = attn_true_varlen(
q_int8,
k_int8,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output_dtype=query.dtype,
)
else:
raise NotImplementedError("Non-causal attention is not implemented yet.")
# Reshape output to match the expected shape
output = output.view(batch_size, num_heads, seq_len_q, head_dim)
# Save tensors for backward
ctx.save_for_backward(
query,
key,
value,
q_int8,
k_int8,
q_scale,
k_scale,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output,
)
return output
@staticmethod
def backward(ctx, grad_output):
(
query,
key,
value,
q_int8,
k_int8,
q_scale,
k_scale,
cu_seqlens_q,
cu_seqlens_k,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output,
) = ctx.saved_tensors
scale = ctx.scale
is_causal = ctx.is_causal
dropout_p = ctx.dropout_p
attn_mask = ctx.attn_mask
# Flatten batch and head dimensions
batch_size, num_heads, seq_len_q, head_dim = query.shape
seq_len_k = key.shape[2]
grad_output = grad_output.contiguous()
do = grad_output.view(-1, seq_len_q, head_dim)
# Compute gradients w.r.t. q, k, v
dq, dk, dv = sageattn_varlen_backward(
do,
query.view(-1, seq_len_q, head_dim),
key.view(-1, seq_len_k, head_dim),
value.view(-1, seq_len_k, head_dim),
cu_seqlens_q,
cu_seqlens_k,
seq_len_q,
seq_len_k,
q_int8,
k_int8,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
scale,
is_causal,
)
# Reshape gradients to match the input shapes
dq = dq.view(batch_size, num_heads, seq_len_q, head_dim)
dk = dk.view(batch_size, num_heads, seq_len_k, head_dim)
dv = dv.view(batch_size, num_heads, seq_len_k, head_dim)
# Handle optional arguments
d_attn_mask = None # Assuming attn_mask does not require gradients
d_dropout_p = (
None # Dropout probability is a hyperparameter, typically not optimized
)
d_is_causal = None # Not differentiable
d_scale = None # If scale is a tensor and requires grad, compute its gradient
return dq, dk, dv, d_attn_mask, d_dropout_p, d_is_causal, d_scale
def scaled_dot_product_attention(
query,
key,
value,
attn_mask=None,
dropout_p=0.0,
is_causal=False,
scale=None,
):
"""
Custom scaled dot product attention using SageAttentionFunction.
"""
return SageAttentionFunction.apply(
query, key, value, attn_mask, dropout_p, is_causal, scale
)
def monkeypatch_sdp_w_sage_attention():
"""
Replace torch.nn.functional.scaled_dot_product_attention with custom scaled dot product attention using SageAttentionFunction.
"""
torch.nn.functional.scaled_dot_product_attention = scaled_dot_product_attention

View File

@@ -1,622 +0,0 @@
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
import torch
import triton
import triton.language as tl
@triton.jit
def _attn_fwd_inner(
acc,
l_i,
m_i,
q,
q_scale,
kv_len,
K_ptrs,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
H: tl.constexpr,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
K_scale_ptr += (lo // BLOCK_N) * H
K_ptrs += stride_kn * lo
V_ptrs += stride_vn * lo
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k_mask = offs_n[None, :] < (kv_len - start_n)
k = tl.load(K_ptrs, mask=k_mask)
k_scale = tl.load(K_scale_ptr)
qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
qk = qk + tl.where(mask, 0, -1.0e6)
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk -= m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(qk, 1))
qk = qk - m_ij[:, None]
p = tl.math.exp2(qk)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
acc = acc * alpha[:, None]
v = tl.load(V_ptrs, mask=offs_n[:, None] < (kv_len - start_n))
p = p.to(tl.float16)
acc += tl.dot(p, v, out_dtype=tl.float16)
m_i = m_ij
K_ptrs += BLOCK_N * stride_kn
K_scale_ptr += H
V_ptrs += BLOCK_N * stride_vn
return acc, l_i, m_i
@triton.jit
def _attn_fwd(
Q,
K,
V,
cu_seqlens_q,
cu_seqlens_k,
Q_scale,
K_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
Out,
stride_qh,
stride_qn,
stride_kh,
stride_kn,
stride_vh,
stride_vn,
stride_oh,
stride_on,
H: tl.constexpr,
num_kv_groups: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
if (start_m * BLOCK_M) >= qo_len:
return
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
k_scale_offset = (
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
)
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = (
Q
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
Q_scale_ptr = Q_scale + q_scale_offset
K_ptrs = (
K
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
K_scale_ptr = K_scale + k_scale_offset
V_ptrs = (
V
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
O_block_ptr = (
Out
+ (cu_seqlens_q_start * stride_on + off_h * stride_oh)
+ offs_m[:, None] * stride_on
+ offs_k[None, :]
)
m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0
acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
acc, l_i, m_i = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
q_scale,
kv_len,
K_ptrs,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
H // num_kv_groups,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
4 - STAGE,
offs_m,
offs_n,
)
acc, l_i, _ = _attn_fwd_inner(
acc,
l_i,
m_i,
q,
q_scale,
kv_len,
K_ptrs,
K_scale_ptr,
V_ptrs,
stride_kn,
stride_vn,
start_m,
H // num_kv_groups,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
2,
offs_m,
offs_n,
)
acc = acc / l_i[:, None]
tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < qo_len))
@triton.jit
def _attn_bwd_inner(
dq_acc,
dk_acc,
dv_acc,
l_i,
m_i,
q,
k,
v,
do,
q_scale,
k_scale,
kv_len,
stride_kn,
stride_vn,
start_m,
H,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
offs_m: tl.constexpr,
offs_n: tl.constexpr,
):
if STAGE == 1:
lo, hi = 0, start_m * BLOCK_M
elif STAGE == 2:
lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M
lo = tl.multiple_of(lo, BLOCK_M)
k += stride_kn * lo
v += stride_vn * lo
for start_n in range(lo, hi, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
k_mask = offs_n[None, :] < (kv_len - start_n)
k_curr = tl.load(k, mask=k_mask)
v_curr = tl.load(v, mask=k_mask)
k_scale_curr = tl.load(k_scale)
s = tl.dot(q, k_curr, trans_b=True).to(tl.float32) * q_scale * k_scale_curr
if STAGE == 2:
mask = offs_m[:, None] >= (start_n + offs_n[None, :])
s = s + tl.where(mask, 0.0, -float("inf"))
m_ij = tl.maximum(m_i, tl.max(s, 1))
s = s - m_ij[:, None]
else:
m_ij = tl.maximum(m_i, tl.max(s, 1))
s = s - m_ij[:, None]
p = tl.math.exp2(s)
l_ij = tl.sum(p, 1)
alpha = tl.math.exp2(m_i - m_ij)
l_i = l_i * alpha + l_ij
m_i = m_ij
p = p / l_i[:, None] # Normalize probabilities
# Compute gradients
# Compute softmax gradient
do_scaled = do / l_i[:, None]
dv_contrib = tl.dot(p.to(tl.float16).T, do_scaled.to(tl.float16))
dv_acc += dv_contrib
dp = tl.dot(do_scaled.to(tl.float16), v_curr.to(tl.float16).T)
# Compute ds (gradient w.r.t. logits s)
p_dp = p * dp
sum_p_dp = tl.sum(p_dp, axis=1)
ds = (p_dp - p * sum_p_dp[:, None]) * tl.math.log(2.0) # Adjust for exp2
# Compute gradients w.r.t q and k
dq_contrib = tl.dot(ds.to(tl.float16), k_curr.to(tl.float16))
dk_contrib = tl.dot(ds.to(tl.float16).T, q.to(tl.float16))
dq_acc += dq_contrib * (q_scale * k_scale_curr)
dk_acc += dk_contrib * (q_scale * k_scale_curr)
k += BLOCK_N * stride_kn
k_scale += H
v += BLOCK_N * stride_vn
return dq_acc, dk_acc, dv_acc, l_i, m_i
@triton.jit
def _attn_bwd(
DO,
Q,
K,
V,
cu_seqlens_q,
cu_seqlens_k,
Q_scale,
K_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
L,
M,
DQ,
DK,
DV,
stride_qh,
stride_qn,
stride_kh,
stride_kn,
stride_vh,
stride_vn,
H: tl.constexpr,
num_kv_groups: tl.constexpr,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
STAGE: tl.constexpr,
):
start_m = tl.program_id(0)
off_z = tl.program_id(2).to(tl.int64)
off_h = tl.program_id(1).to(tl.int64)
cu_seqlens_q_start = tl.load(cu_seqlens_q + off_z)
cu_seqlens_q_end = tl.load(cu_seqlens_q + off_z + 1)
qo_len = cu_seqlens_q_end - cu_seqlens_q_start
if (start_m * BLOCK_M) >= qo_len:
return
cu_seq_lens_q_scale_start = tl.load(cu_seqlens_q_scale + off_z)
cu_seq_lens_k_scale_start = tl.load(cu_seqlens_k_scale + off_z)
q_scale_offset = cu_seq_lens_q_scale_start * H + off_h + start_m * H
k_scale_offset = (
cu_seq_lens_k_scale_start * (H // num_kv_groups) + off_h // num_kv_groups
)
cu_seqlens_k_start = tl.load(cu_seqlens_k + off_z)
cu_seqlens_k_end = tl.load(cu_seqlens_k + off_z + 1)
kv_len = cu_seqlens_k_end - cu_seqlens_k_start
offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, HEAD_DIM)
Q_ptrs = (
Q
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
DO_ptrs = (
DO
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
Q_scale_ptr = Q_scale + q_scale_offset
K_ptrs = (
K
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
K_scale_ptr = K_scale + k_scale_offset
V_ptrs = (
V
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
DQ_ptrs = (
DQ
+ (cu_seqlens_q_start * stride_qn + off_h * stride_qh)
+ offs_m[:, None] * stride_qn
+ offs_k[None, :]
)
DK_ptrs = (
DK
+ (cu_seqlens_k_start * stride_kn + (off_h // num_kv_groups) * stride_kh)
+ offs_n[None, :] * stride_kn
+ offs_k[:, None]
)
DV_ptrs = (
DV
+ (cu_seqlens_k_start * stride_vn + (off_h // num_kv_groups) * stride_vh)
+ offs_n[:, None] * stride_vn
+ offs_k[None, :]
)
L_ptrs = L + (cu_seqlens_q_start + offs_m)
M_ptrs = M + (cu_seqlens_q_start + offs_m)
m_i = tl.load(M_ptrs, mask=offs_m < qo_len, other=float("-inf"))
l_i = tl.load(L_ptrs, mask=offs_m < qo_len, other=1.0)
dq_acc = tl.zeros([BLOCK_M, HEAD_DIM], dtype=tl.float32)
dk_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
dv_acc = tl.zeros([BLOCK_N, HEAD_DIM], dtype=tl.float32)
q = tl.load(Q_ptrs, mask=offs_m[:, None] < qo_len)
do = tl.load(DO_ptrs, mask=offs_m[:, None] < qo_len)
q_scale = tl.load(Q_scale_ptr)
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
dq_acc,
dk_acc,
dv_acc,
l_i,
m_i,
q,
K_ptrs,
V_ptrs,
do,
q_scale,
K_scale_ptr,
kv_len,
stride_kn,
stride_vn,
start_m,
H // num_kv_groups,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
4 - STAGE,
offs_m,
offs_n,
)
dq_acc, dk_acc, dv_acc, l_i, m_i = _attn_bwd_inner(
dq_acc,
dk_acc,
dv_acc,
l_i,
m_i,
q,
K_ptrs,
V_ptrs,
do,
q_scale,
K_scale_ptr,
kv_len,
stride_kn,
stride_vn,
start_m,
H // num_kv_groups,
BLOCK_M,
HEAD_DIM,
BLOCK_N,
2,
offs_m,
offs_n,
)
tl.store(DQ_ptrs, dq_acc.to(DQ.dtype.element_ty), mask=offs_m[:, None] < qo_len)
tl.store(DK_ptrs, dk_acc.to(DK.dtype.element_ty), mask=offs_n[None, :] < kv_len)
tl.store(DV_ptrs, dv_acc.to(DV.dtype.element_ty), mask=offs_n[:, None] < kv_len)
def forward(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
output_dtype=torch.float16,
):
BLOCK_M = 128
BLOCK_N = 64
stage = 3
o = torch.empty(q.shape, dtype=output_dtype, device=q.device)
b = cu_seqlens_q.shape[0] - 1
_, h_qo, head_dim = q.shape
_, h_kv, _ = k.shape
HEAD_DIM_K = head_dim
num_kv_groups = h_qo // h_kv
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
_attn_fwd[grid](
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
o,
q.stride(1),
q.stride(0),
k.stride(1),
k.stride(0),
v.stride(1),
v.stride(0),
o.stride(1),
o.stride(0),
h_qo,
num_kv_groups,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
HEAD_DIM=HEAD_DIM_K,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4,
)
return o
def backward(
do,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
l,
m,
output_dtype=torch.float16,
):
BLOCK_M = 128
BLOCK_N = 64
stage = 3
device = q.device
dtype = q.dtype
b = cu_seqlens_q.shape[0] - 1
_, h_qo, head_dim = q.shape
_, h_kv, _ = k.shape
num_kv_groups = h_qo // h_kv
dq = torch.zeros_like(q, dtype=output_dtype)
dk = torch.zeros_like(k, dtype=output_dtype)
dv = torch.zeros_like(v, dtype=output_dtype)
grid = (triton.cdiv(max_seqlen_q, BLOCK_M), h_qo, b)
_attn_bwd[grid](
do,
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
q_scale,
k_scale,
cu_seqlens_q_scale,
cu_seqlens_k_scale,
l,
m,
dq,
dk,
dv,
q.stride(1),
q.stride(0),
k.stride(1),
k.stride(0),
v.stride(1),
v.stride(0),
h_qo,
num_kv_groups,
HEAD_DIM=head_dim,
BLOCK_M=BLOCK_M,
BLOCK_N=BLOCK_N,
STAGE=stage,
num_warps=4 if head_dim == 64 else 8,
num_stages=4,
)
return dq, dk, dv
# class TritonAttentionFunction(torch.autograd.Function):
# @staticmethod
# def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale):
# l = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
# m = torch.zeros(q.shape[0], device=q.device, dtype=torch.float32)
# output = forward(q, k, v, cu_seqlens_q, cu_seqlens_k, q.shape[0], q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
# ctx.save_for_backward(q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m)
# return output
#
# @staticmethod
# def backward(ctx, do):
# q, k, v, cu_seqlens_q, cu_seqlens_k, q_scale, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale, l, m = ctx.saved_tensors
# dq, dk, dv = backward(
# do, q, k, v,
# cu_seqlens_q, cu_seqlens_k,
# q.shape[0], q_scale, k_scale,
# cu_seqlens_q_scale, cu_seqlens_k_scale,
# l, m,
# )
# return dq, dk, dv, None, None, None, None, None, None

View File

@@ -1,158 +0,0 @@
"""
Copyright (c) 2024 by SageAttention team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import torch
import triton
import triton.language as tl
@triton.jit
def quant_per_block_int8_kernel(
Input,
Output,
Scale,
cu_seqlens_input,
cu_seqlens_scale,
stride_ih,
stride_in,
stride_oh,
stride_on,
sm_scale,
H: tl.constexpr,
C: tl.constexpr,
BLK: tl.constexpr,
):
off_blk = tl.program_id(0)
off_h = tl.program_id(1)
off_b = tl.program_id(2)
cu_seqlens_input_start = tl.load(cu_seqlens_input + off_b)
cu_seqlens_input_end = tl.load(cu_seqlens_input + off_b + 1)
L = cu_seqlens_input_end - cu_seqlens_input_start
if (off_blk * BLK) >= L:
return
cu_seqlens_scale_start = tl.load(cu_seqlens_scale + off_b)
offs_n = off_blk * BLK + tl.arange(0, BLK)
offs_k = tl.arange(0, C)
input_ptrs = (
Input
+ cu_seqlens_input_start * stride_in
+ off_h * stride_ih
+ offs_n[:, None] * stride_in
+ offs_k[None, :]
)
output_ptrs = (
Output
+ cu_seqlens_input_start * stride_on
+ off_h * stride_oh
+ offs_n[:, None] * stride_on
+ offs_k[None, :]
)
scale_ptrs = Scale + cu_seqlens_scale_start * H + off_h + off_blk * H
x = tl.load(input_ptrs, mask=offs_n[:, None] < L)
x = x.to(tl.float32)
x *= sm_scale
scale = tl.max(tl.abs(x)) / 127.0
x_int8 = x / scale
x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)
x_int8 = x_int8.to(tl.int8)
tl.store(output_ptrs, x_int8, mask=offs_n[:, None] < L)
tl.store(scale_ptrs, scale)
def per_block_int8(
q,
k,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q,
max_seqlen_k,
BLKQ=128,
BLKK=64,
sm_scale=None,
):
q_int8 = torch.empty(q.shape, dtype=torch.int8, device=q.device)
k_int8 = torch.empty(k.shape, dtype=torch.int8, device=k.device)
h_qo = q.shape[1]
h_kv = k.shape[1]
head_dim = q.shape[-1]
b = cu_seqlens_q.shape[0] - 1
q_batch_len = cu_seqlens_q[1:] - cu_seqlens_q[:-1]
k_batch_len = cu_seqlens_k[1:] - cu_seqlens_k[:-1]
q_scale_len = (q_batch_len + BLKQ - 1) // BLKQ
k_scale_len = (k_batch_len + BLKK - 1) // BLKK
cu_seqlens_q_scale = torch.nn.functional.pad(
torch.cumsum(q_scale_len, dim=0), (1, 0), value=0
)
cu_seqlens_k_scale = torch.nn.functional.pad(
torch.cumsum(k_scale_len, dim=0), (1, 0), value=0
)
q_scale = torch.empty(
(cu_seqlens_q_scale[-1], h_qo), device=q.device, dtype=torch.float32
)
k_scale = torch.empty(
(cu_seqlens_k_scale[-1], h_kv), device=k.device, dtype=torch.float32
)
if sm_scale is None:
sm_scale = head_dim**-0.5
grid = ((max_seqlen_q + BLKQ - 1) // BLKQ, h_qo, b)
quant_per_block_int8_kernel[grid](
q,
q_int8,
q_scale,
cu_seqlens_q,
cu_seqlens_q_scale,
q.stride(1),
q.stride(0),
q_int8.stride(1),
q_int8.stride(0),
sm_scale=(sm_scale * 1.44269504),
H=h_qo,
C=head_dim,
BLK=BLKQ,
)
grid = ((max_seqlen_k + BLKK - 1) // BLKK, h_kv, b)
quant_per_block_int8_kernel[grid](
k,
k_int8,
k_scale,
cu_seqlens_k,
cu_seqlens_k_scale,
k.stride(1),
k.stride(0),
k_int8.stride(1),
k_int8.stride(0),
sm_scale=1.0,
H=h_kv,
C=head_dim,
BLK=BLKK,
)
return q_int8, q_scale, k_int8, k_scale, cu_seqlens_q_scale, cu_seqlens_k_scale

View File

@@ -0,0 +1,83 @@
"""
fix for zero3 8-bit lora
see https://github.com/huggingface/transformers/pull/32943/files
"""
import inspect
import logging
from transformers import modeling_utils
LOG = logging.getLogger("axolotl.monkeypatch.modeling_zero3_int8_lora")
ORIGINAL_LOAD_CODE = """
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
value = type(value)(value.data.to(param_to), **value.__dict__)
setattr(module, tensor_name, value)
"""
PATCHED_LOAD_CODE = """
if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
module, tensor_name = get_module_from_name(model, param_name)
value = getattr(module, tensor_name)
param_to = "cpu"
if is_fsdp_enabled() and not is_local_dist_rank_0():
param_to = "meta"
val_kwargs = {}
if hasattr(module, "weight") and module.weight.__class__.__name__ == "Int8Params":
val_kwargs["requires_grad"] = False
value = type(value)(value.data.to(param_to), **val_kwargs, **value.__dict__)
setattr(module, tensor_name, value)
"""
def get_modeling_state_dict_code() -> str:
load_code = inspect.getsource(
modeling_utils._load_state_dict_into_meta_model # pylint: disable=protected-access
)
return load_code
def check_modeling_state_dict_code_is_patchable() -> bool:
load_code = get_modeling_state_dict_code()
return ORIGINAL_LOAD_CODE in load_code
def patch_modeling_state_dict_code():
"""
monkeypatch for fixing the meta model loader for zero3 8-bit lora
"""
load_code = get_modeling_state_dict_code()
modeling_utils._original_load_state_dict_into_meta_model = ( # pylint: disable=protected-access
load_code
)
assert (
ORIGINAL_LOAD_CODE in load_code
), "Original _load_state_dict_into_meta_model code not found"
load_code = load_code.replace(ORIGINAL_LOAD_CODE, PATCHED_LOAD_CODE)
load_code = load_code.replace(
"def _load_state_dict_into_meta_model(",
"def _fixed_load_state_dict_into_meta_model(",
1,
)
items_to_import = []
for item in dir(modeling_utils):
if item in load_code:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.modeling_utils import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(load_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _load_state_dict_into_meta_model")
modeling_utils._load_state_dict_into_meta_model = _fixed_load_state_dict_into_meta_model # pylint: disable=protected-access,undefined-variable # noqa: F821

View File

@@ -0,0 +1,83 @@
"""
fix for FSDP gradient accumulation
see https://github.com/huggingface/transformers/pull/34645
"""
import inspect
from accelerate.logging import get_logger
from transformers.trainer import Trainer
from axolotl.monkeypatch.unsloth_ import detab_code
LOG = get_logger("axolotl.monkeypatch.trainer_fsdp_grad_accumulation")
ORIGINAL_CONTEXT_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i == len(batch_samples) - 1
else contextlib.nullcontext
)
"""
PATCHED_CONTEXT_CODE = """
context = (
functools.partial(self.accelerator.no_sync, model=model)
if i != len(batch_samples) - 1
else contextlib.nullcontext
)
"""
def get_training_loop_code() -> str:
training_loop = inspect.getsource(
Trainer._inner_training_loop # pylint: disable=protected-access
)
return training_loop
def check_training_loop_is_patchable() -> bool:
train_loop = get_training_loop_code()
train_loop, _ = detab_code(train_loop)
return ORIGINAL_CONTEXT_CODE in train_loop
def patch_training_loop_for_fsdp_grad_accum():
"""
monkeypatch for fixing the training loop for FSDP gradient accumulation
"""
train_loop = get_training_loop_code()
Trainer._original_inner_training_loop = ( # pylint: disable=protected-access
train_loop
)
train_loop, _ = detab_code(train_loop)
assert (
ORIGINAL_CONTEXT_CODE in train_loop
), "Original _inner_training_loop code not found"
train_loop = train_loop.replace(ORIGINAL_CONTEXT_CODE, PATCHED_CONTEXT_CODE)
train_loop = train_loop.replace(
"def _inner_training_loop(",
"def _fixed_inner_training_loop(",
1,
)
# load imports necessary
import transformers.trainer
items_to_import = []
for item in dir(transformers.trainer):
if item in train_loop:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from transformers.trainer import ("
+ ", ".join(x for x in items_to_import)
+ ")",
globals(),
)
exec(train_loop, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching _inner_training_loop", main_process_only=True)
Trainer._inner_training_loop = ( # pylint: disable=protected-access
_fixed_inner_training_loop # pylint: disable=undefined-variable # noqa: F821
)

View File

@@ -188,7 +188,7 @@ def integrate_lora_mlp_patch(peft_model: PeftModelForCausalLM):
for module in layer_modules
)
mlp_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
@@ -213,7 +213,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
qkv_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)
@@ -232,7 +232,7 @@ def integrate_lora_patch(peft_model: PeftModelForCausalLM, cfg):
for module in layer_modules
)
o_not_dora = all(
len(getattr(module, "lora_magnitude_vector", []) or []) == 0
getattr(module, "lora_magnitude_vector", None) is None
for module in layer_modules
)

View File

@@ -4,9 +4,6 @@ import functools
import pynvml
import torch
from pynvml.nvml import NVMLError
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.distributed import get_device_type
def check_cuda_device(default_value):
@@ -56,12 +53,6 @@ def mps_memory_usage_all():
return usage, reserved - usage, 0
def npu_memory_usage_all(device=0):
usage = torch.npu.memory_allocated(device) / 1024.0**3
reserved = torch.npu.memory_reserved(device) / 1024.0**3
return usage, reserved - usage, 0
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
@@ -78,11 +69,8 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
cur_device = get_device_type()
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
elif "npu" in str(cur_device) and is_torch_npu_available():
usage, cache, misc = npu_memory_usage_all(device)
else:
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
@@ -91,7 +79,6 @@ def log_gpu_memory_usage(log, msg, device):
if misc > 0:
extras.append(f"+{misc:.03f}GB misc")
log.info(
f"{str(cur_device)} memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})",
stacklevel=2,
f"GPU memory usage {msg}: {usage:.03f}GB ({', '.join(extras)})", stacklevel=2
)
return usage, cache, misc

View File

@@ -5,7 +5,6 @@ from typing import Optional
import torch
from transformers.utils import is_torch_bf16_gpu_available
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.integrations.config import merge_input_args
from axolotl.utils.bench import log_gpu_memory_usage
@@ -30,10 +29,7 @@ def choose_device(cfg):
if torch.backends.mps.is_available():
return "mps"
if is_torch_npu_available():
return f"npu:{cfg.local_rank}"
raise SystemError("No CUDA/mps/npu device found")
raise SystemError("No CUDA/mps device found")
except Exception: # pylint: disable=broad-exception-caught
return "cpu"
@@ -43,8 +39,6 @@ def choose_device(cfg):
else:
if cfg.device.startswith("cuda"):
cfg.device_map = {"": torch.cuda.current_device()}
elif cfg.device.startswith("npu"):
cfg.device_map = {"npu": torch.npu.current_device()}
else:
cfg.device_map = {"": cfg.device}

View File

@@ -7,6 +7,7 @@ Module for pydantic models for configuration
import logging
import os
from enum import Enum
from importlib.metadata import version
from typing import Annotated, Any, Dict, List, Literal, Optional, Tuple, Union
from pydantic import (
@@ -19,7 +20,6 @@ from pydantic import (
)
from transformers import SchedulerType
from transformers.training_args import OptimizerNames
from transformers.utils.import_utils import is_torch_npu_available
from axolotl.utils.config.models.internals import GPUCapabilities
@@ -250,10 +250,8 @@ class KTODataset(BaseModel):
class LoftQConfig(BaseModel):
"""LoftQ configuration subset"""
loftq_bits: int = Field(
default=4, json_schema_extra={"description": "Quantization bits for LoftQ"}
)
# loftq_iter: int = Field(default=1, json_schema_extra={"description": "Alternating iterations for LoftQ"})
loftq_bits: int = Field(default=4, metadata={"help": "Quantization bits for LoftQ"})
# loftq_iter: int = Field(default=1, metadata={"help": "Alternating iterations for LoftQ"})
class PeftConfig(BaseModel):
@@ -296,8 +294,8 @@ class LoraConfig(BaseModel):
qlora_sharded_model_loading: Optional[bool] = Field(
default=False,
json_schema_extra={
"description": "load qlora model in sharded format for FSDP using answer.ai technique."
metadata={
"help": "load qlora model in sharded format for FSDP using answer.ai technique."
},
)
lora_on_cpu: Optional[bool] = None
@@ -306,15 +304,13 @@ class LoraConfig(BaseModel):
loraplus_lr_ratio: Optional[float] = Field(
default=None,
json_schema_extra={
"description": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
metadata={
"help": "loraplus learning rate ratio lr_B / lr_A. Recommended value is 2^4."
},
)
loraplus_lr_embedding: Optional[float] = Field(
default=1e-6,
json_schema_extra={
"description": "loraplus learning rate for lora embedding layers."
},
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
merge_lora: Optional[bool] = None
@@ -384,10 +380,10 @@ class ModelInputConfig(BaseModel):
tokenizer_use_fast: Optional[bool] = None
tokenizer_legacy: Optional[bool] = None
tokenizer_type: Optional[str] = Field(
default=None, json_schema_extra={"description": "transformers tokenizer class"}
default=None, metadata={"help": "transformers tokenizer class"}
)
processor_type: Optional[str] = Field(
default=None, json_schema_extra={"description": "transformers processor class"}
default=None, metadata={"help": "transformers processor class"}
)
trust_remote_code: Optional[bool] = None
@@ -409,18 +405,18 @@ class HyperparametersConfig(BaseModel):
gradient_accumulation_steps: Optional[int] = Field(default=1)
micro_batch_size: Optional[int] = Field(
default=1,
json_schema_extra={"description": "per gpu micro batch size for training"},
metadata={"help": "per gpu micro batch size for training"},
)
batch_size: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "Total batch size, we do not recommended setting this manually"
metadata={
"help": "Total batch size, we do not recommended setting this manually"
},
)
eval_batch_size: Optional[int] = Field(
default=None,
json_schema_extra={
"description": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
metadata={
"help": "per gpu micro batch size for evals, defaults to value of micro_batch_size"
},
)
@@ -445,13 +441,12 @@ class HyperparametersConfig(BaseModel):
]
] = OptimizerNames.ADAMW_HF.value
optim_args: Optional[Union[str, Dict[str, Any]]] = Field(
default=None,
json_schema_extra={"description": "Optional arguments to supply to optimizer."},
default=None, metadata={"help": "Optional arguments to supply to optimizer."}
)
optim_target_modules: Optional[Union[List[str], Literal["all_linear"]]] = Field(
default=None,
json_schema_extra={
"description": "The target modules to optimize, i.e. the module names that you would like to train."
metadata={
"help": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
torchdistx_path: Optional[str] = None
@@ -511,15 +506,15 @@ class LISAConfig(BaseModel):
lisa_n_layers: Optional[int] = Field(
default=None,
json_schema_extra={"description": "the number of activate layers in LISA"},
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = Field(
default=None,
json_schema_extra={"description": "how often to switch layers in LISA"},
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = Field(
default="model.layers",
json_schema_extra={"description": "path under the model to access the layers"},
metadata={"help": "path under the model to access the layers"},
)
@@ -618,8 +613,7 @@ class AxolotlInputConfig(
pretraining_dataset: Optional[ # type: ignore
conlist(Union[PretrainingDataset, SFTDataset], min_length=1)
] = Field(
default=None,
json_schema_extra={"description": "streaming dataset to use for pretraining"},
default=None, metadata={"help": {"streaming dataset to use for pretraining"}}
)
dataset_processes: Optional[int] = Field(default=os.cpu_count())
dataset_keep_in_memory: Optional[bool] = None
@@ -679,8 +673,7 @@ class AxolotlInputConfig(
sequence_len: int = Field(default=512)
min_sample_len: Optional[int] = None
max_prompt_len: int = Field(
default=512,
json_schema_extra={"description": "maximum prompt length for RL training"},
default=512, metadata={"help": "maximum prompt length for RL training"}
)
sample_packing: Optional[bool] = None
sample_packing_group_size: Optional[int] = 100_000
@@ -699,8 +692,8 @@ class AxolotlInputConfig(
pretrain_multipack_buffer_size: Optional[int] = 10_000
pretrain_multipack_attn: Optional[bool] = Field(
default=True,
json_schema_extra={
"description": "whether to prevent cross attention for packed sequences during pretraining",
metadata={
"help": "whether to prevent cross attention for packed sequences during pretraining",
},
)
@@ -1314,7 +1307,6 @@ class AxolotlInputConfig(
and data.get("gradient_checkpointing_kwargs", {})
and data.get("gradient_checkpointing_kwargs", {}).get("use_reentrant")
is False
and data.get("deepspeed", "") is not None
and "zero3" in data.get("deepspeed", "")
):
# may result in:
@@ -1426,6 +1418,21 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_unsloth_xformers_version(cls, data):
if (
data.get("unsloth_lora_mlp")
or data.get("unsloth_lora_qkv")
or data.get("unsloth_lora_o")
):
xformers_version = version("xformers")
if xformers_version == "0.0.27":
raise ValueError(
"xformers version 0.0.27 is not supported with unsloth. Please downgrade to 0.0.26.post1"
)
return data
@model_validator(mode="before")
@classmethod
def check_torch_compile_deepspeed(cls, data):
@@ -1435,40 +1442,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_npu_config(cls, data):
if is_torch_npu_available():
# check attention config
attn_list = ["flash_attention", "sdp_attention", "s2_attention"]
for attn in attn_list:
if data.get(attn):
raise NotImplementedError(
f"{attn} is currently not supported in Ascend npu, please disable this configuration."
)
# check quant config
if data.get("optimizer") is not None and "bit" in data.get("optimizer"):
optimizer = data.get("optimizer")
raise NotImplementedError(
f"{optimizer} is currently not supported in Ascend npu, choose another one please."
)
quant_list = ["load_in_8bit", "load_in_4bit"]
for quant in quant_list:
if data.get(quant):
raise NotImplementedError(
f"Quantification is currently not supported in Ascend npu, please disable {quant}."
)
# check dtype config
if data.get("tf32"):
raise NotImplementedError(
"tf32 dtype is currently not supported in Ascend npu, please disable this configuration"
)
return data
class AxolotlConfigWCapabilities(AxolotlInputConfig):
"""wrapper to valdiate gpu capabilities with the configured options"""

View File

@@ -64,57 +64,15 @@ def map_dataset(cfg, data_set, ds_transform_fn, tokenizer):
tokenizer = load_tokenizer(cfg)
ds_transform_fn = partial(ds_transform_fn, tokenizer=tokenizer)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
data_set = data_set.map(
ds_transform_fn,
desc="Mapping RL Dataset",
)
if isinstance(data_set, DatasetDict):
data_set = data_set["train"]
return data_set
def drop_long_rl_seq(
sample, rl, tokenizer, sequence_len # pylint: disable=invalid-name
):
if rl in ("dpo", "ipo", "orpo", "simpo"):
if not (
sample.get("prompt") and sample.get("chosen") and sample.get("rejected")
):
raise ValueError(
"Prompt, chosen and rejected keys are required for DPO/ORPO datasets"
)
prompt = sample["prompt"]
chosen = sample["chosen"]
rejected = sample["rejected"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_chosen = len(tokenizer(chosen, add_special_tokens=False)["input_ids"])
len_rejected = len(tokenizer(rejected, add_special_tokens=False)["input_ids"])
return (len_prompt + len_chosen) <= sequence_len and (
len_prompt + len_rejected
) <= sequence_len
if rl == "kto":
if not (sample.get("prompt") and sample.get("completion")):
raise ValueError("Prompt and completion keys are required for KTO datasets")
prompt = sample["prompt"]
completion = sample["completion"]
len_prompt = len(tokenizer(prompt, add_special_tokens=False)["input_ids"])
len_completion = len(
tokenizer(completion, add_special_tokens=False)["input_ids"]
)
return (len_prompt + len_completion) <= sequence_len
raise ValueError("Unknown RL type")
def load_prepare_dpo_datasets(cfg):
def load_split(dataset_cfgs, _cfg):
split_datasets: List[Any] = []
@@ -136,7 +94,7 @@ def load_prepare_dpo_datasets(cfg):
)
split_datasets.insert(i, ds)
tokenizer = load_tokenizer(cfg)
tokenizer = None
for i, data_set in enumerate(split_datasets):
_type = dataset_cfgs[i]["type"]
@@ -163,28 +121,7 @@ def load_prepare_dpo_datasets(cfg):
# "prompt", "chosen" and "rejected" already preprocessed
split_datasets[i] = data_set
drop_long = partial(
drop_long_rl_seq,
rl=_cfg.rl,
tokenizer=tokenizer,
sequence_len=cfg.sequence_len,
)
prior_len = len(split_datasets[i])
split_datasets[i] = split_datasets[i].filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(split_datasets[i])
if dropped:
LOG.warning(f"Dropped {dropped} long samples from dataset index {i}")
combined_datasets = concatenate_datasets(split_datasets)
combined_datasets = combined_datasets.shuffle(seed=cfg.seed)
return combined_datasets
return concatenate_datasets(split_datasets)
with zero_first(is_main_process()):
train_is_preprocessed = False

View File

@@ -9,44 +9,10 @@ from datetime import timedelta
import torch
import torch.distributed as dist
from accelerate import PartialState
from transformers.utils.import_utils import (
is_torch_cuda_available,
is_torch_mps_available,
is_torch_npu_available,
)
distributed_state = None # pylint: disable=invalid-name
def get_device_type():
device = torch.device("cpu")
if is_torch_cuda_available():
device = torch.device("cuda")
elif is_torch_mps_available():
device = torch.device("mps")
elif is_torch_npu_available():
device = torch.device("npu")
return device
def get_device_count():
cur_device = get_device_type()
if "cuda" in str(cur_device):
return torch.cuda.device_count()
if "npu" in str(cur_device):
return torch.npu.device_count()
return 1
def get_current_device():
cur_device = get_device_type()
if "cuda" in str(cur_device):
return torch.cuda.current_device()
if "npu" in str(cur_device):
return torch.npu.current_device()
return 0
def is_distributed():
"""
Check if distributed training is initialized.
@@ -125,7 +91,7 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
if not is_distributed():
return [value_scalar]
value_tensor = torch.tensor(
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
value_scalar, device=torch.cuda.current_device()
).float()
if not is_main_process():
@@ -149,14 +115,13 @@ def broadcast_dict(vals: dict):
if not is_distributed():
return vals
cur_device = get_device_type()
if is_main_process():
data_byte = pickle.dumps(vals)
data_tensor = torch.ByteTensor(list(data_byte)).to(cur_device)
data_size = torch.IntTensor([len(data_byte)]).to(cur_device)
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
else:
data_tensor = torch.empty([1024], dtype=torch.uint8, device=cur_device)
data_size = torch.IntTensor([0]).to(cur_device)
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
data_size = torch.IntTensor([0]).to("cuda")
dist.broadcast(data_size, 0)
if not is_main_process():
@@ -185,15 +150,14 @@ def compute_and_broadcast(fn): # pylint: disable=invalid-name
Returns:
- The computed value (int or float).
"""
cur_device = f"{get_device_type()}:{get_current_device()}"
if is_main_process():
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=cur_device, dtype=torch.float32
value_scalar, device=torch.cuda.current_device(), dtype=torch.float32
)
else:
value_tensor = torch.tensor(
0.0, device=cur_device, dtype=torch.float32
0.0, device=torch.cuda.current_device(), dtype=torch.float32
) # Placeholder tensor
# Broadcast the tensor to all processes.
@@ -220,7 +184,7 @@ def gather_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-name
"""
value_scalar = fn()
value_tensor = torch.tensor(
value_scalar, device=f"{get_device_type()}:{get_current_device()}"
value_scalar, device=torch.cuda.current_device()
).float()
# Placeholder tensor for gathering results

View File

@@ -46,7 +46,6 @@ from transformers.integrations.deepspeed import (
)
from axolotl.common.architectures import MOE_ARCH_BLOCK
from axolotl.integrations.sageattention.lib.core import monkeypatch_sdp_w_sage_attention
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
@@ -56,7 +55,7 @@ from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import get_device_count, get_device_type, zero_only
from axolotl.utils.distributed import zero_only
from axolotl.utils.gradient_checkpointing import hf_grad_checkpoint_unsloth_wrapper
from axolotl.utils.lora_embeddings import get_linear_embedding_layers
from axolotl.utils.model_shard_quant import load_sharded_model, load_sharded_model_quant
@@ -571,8 +570,7 @@ class ModelLoader:
)
max_memory = {}
num_device = get_device_count()
for i in range(num_device):
for i in range(torch.cuda.device_count()):
max_memory[i] = gpu_memory_limit
max_memory["cpu"] = "256GiB" # something sufficiently large to fit anything
@@ -597,11 +595,8 @@ class ModelLoader:
self.model_kwargs["device_map"] = device_map
self.model_kwargs["torch_dtype"] = self.cfg.torch_dtype
cur_device = get_device_type()
if "mps" in str(cur_device):
if torch.backends.mps.is_available():
self.model_kwargs["device_map"] = "mps:0"
elif "npu" in str(cur_device):
self.model_kwargs["device_map"] = "npu:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
@@ -708,7 +703,6 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"sdpa"
)
monkeypatch_sdp_w_sage_attention()
elif self.cfg.eager_attention:
self.model_kwargs["attn_implementation"] = "eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
@@ -1056,11 +1050,7 @@ class ModelLoader:
self.ajust_model_config()
# log device memory usage
if hasattr(self.model, "device") and self.model.device.type in (
"cuda",
"mps",
"npu",
):
if hasattr(self.model, "device") and self.model.device.type in ("cuda", "mps"):
log_gpu_memory_usage(LOG, "after model load", self.model.device)
# make sure these are fp32 per Ramesh et al. (2021)
@@ -1128,9 +1118,9 @@ class ModelLoader:
and not skip_move_to_device
):
# TODO revaldate this conditional
self.model.to(f"{str(get_device_type())}:{self.cfg.local_rank}")
self.model.to(f"cuda:{self.cfg.local_rank}")
if get_device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
setattr(self.model, "is_parallelizable", True)
setattr(self.model, "model_parallel", True)

View File

@@ -66,47 +66,28 @@ def process_tokens_for_rl_debug(tokens, color, tokenizer, text_only):
def check_rl_example_labels(example, tokenizer, text_only=False):
field_prompt, field_chosen, field_rejected, field_completion = (
"prompt",
"chosen",
"rejected",
"completion",
)
field_prompt, field_chosen, field_rejected = "prompt", "chosen", "rejected"
input_tokens = example[field_prompt]
labels_chosen = example.get(field_chosen)
labels_rejected = example.get(field_rejected)
labels_completion = example.get(field_completion)
# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "
labels_chosen, labels_rejected = example[field_chosen], example[field_rejected]
# Process and color each type of token
colored_tokens = process_tokens_for_rl_debug(
input_tokens, "yellow", tokenizer, text_only
)
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)
# Process tokens
if labels_completion is None:
colored_chosens = process_tokens_for_rl_debug(
labels_chosen, "green", tokenizer, text_only
)
colored_rejecteds = process_tokens_for_rl_debug(
labels_rejected, "red", tokenizer, text_only
)
else:
colored_completion = process_tokens_for_rl_debug(
labels_completion, "green", tokenizer, text_only
)
# Create a delimiter based on text_only flag
delimiter = "" if text_only else " "
# Logging information
LOG.info(f"INPUT PROMPT: {delimiter.join(colored_tokens)}\n\n")
if labels_completion is None:
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
else:
LOG.info(f"COMPLETION RESPONSE: {delimiter.join(colored_completion)}\n\n\n")
LOG.info(f"CHOSEN RESPONSE: {delimiter.join(colored_chosens)}\n\n")
LOG.info(f"REJECTED RESPONSE: {delimiter.join(colored_rejecteds)}\n\n\n")
return delimiter.join(colored_tokens)

View File

@@ -16,6 +16,9 @@ from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_fsdp_grad_accum import (
patch_training_loop_for_fsdp_grad_accum,
)
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -203,59 +206,37 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids")
prior_len = len(train_dataset)
train_dataset = train_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_long,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Dropping Long Sequences",
)
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
# drop samples with where the number of elements with labels not equal to -100 is zero
def drop_no_trainable_tokens(sample):
return np.sum(np.array(sample["labels"]) != -100) > 0
prior_len = len(train_dataset)
train_dataset = train_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(train_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)
if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
)
if cfg.group_by_length:
train_dataset = train_dataset.map(
@@ -456,6 +437,15 @@ def setup_deepspeed_env(cfg, stage=None):
os.environ["ACCELERATE_DEEPSPEED_ZERO_STAGE"] = str(stage)
if stage == 3:
os.environ["ACCELERATE_DEEPSPEED_ZERO3_INIT"] = "true"
if cfg.adapter and cfg.load_in_8bit:
from axolotl.monkeypatch.modeling_zero3_int8_lora import (
patch_modeling_state_dict_code,
)
try:
patch_modeling_state_dict_code()
except AssertionError:
LOG.warning("Failed to patch the meta model loading code")
# If we don't assign this, it doesn't actually get set in the accelerate weakref
_ = HfTrainerDeepSpeedConfig(cfg.deepspeed)
@@ -515,7 +505,12 @@ def prepare_opinionated_env(cfg):
def setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, processor, total_num_steps
):
if cfg.rl in ("dpo", "ipo", "orpo", "kto", "simpo"):
if cfg.fsdp:
try:
patch_training_loop_for_fsdp_grad_accum()
except AssertionError:
pass
if cfg.rl in ["dpo", "ipo", "orpo", "kto", "simpo"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer, processor)
trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2]

View File

@@ -601,3 +601,61 @@ class TestMultiGPULlama:
str(Path(temp_dir) / "config.yaml"),
]
)
def test_8bit_lora_ds_zero3(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"load_in_8bit": True,
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"sample_packing": True,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "tatsu-lab/alpaca",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"flash_attention": True,
"deepspeed": "deepspeed_configs/zero3_bf16_cpuoffload_all.json",
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)

View File

@@ -0,0 +1,15 @@
"""Test module for checking whether the integration of Unsloth with Hugging Face Transformers is working as expected."""
import unittest
from axolotl.monkeypatch.trainer_fsdp_grad_accum import check_training_loop_is_patchable
class TestTrainerFSDPIntegration(unittest.TestCase):
"""Unsloth monkeypatch integration tests."""
def test_train_loop_patchable(self):
# ensures the current version of transformers has loss code that matches our patching code
self.assertTrue(
check_training_loop_is_patchable(),
"HF transformers _inner_training_loop has changed and isn't patchable",
)

View File

@@ -32,19 +32,16 @@ class TestCosineConstantLr(unittest.TestCase):
def test_schedulers(self):
self.assertEqual(self.lr_scheduler.get_last_lr()[0], 0)
for _ in range(self.warmup_steps):
self.optimizer.step()
self.lr_scheduler.step()
self.assertEqual(self.lr_scheduler.get_last_lr()[0], self._lr)
constant_step = int(self.train_steps * self.constant_lr_ratio)
remaining_step = self.train_steps - constant_step
for _ in range(constant_step):
self.optimizer.step()
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio
)
for _ in range(remaining_step):
self.optimizer.step()
self.lr_scheduler.step()
self.assertEqual(
self.lr_scheduler.get_last_lr()[0], self._lr * self.min_lr_ratio

View File

@@ -68,53 +68,6 @@ class TestValidation(BaseValidation):
assert cfg.train_on_inputs is False
assert cfg.weight_decay is None
def test_zero3_qlora_use_reentrant_false(self, minimal_cfg):
test_cfg = DictDefault(
{
"deepspeed": "deepspeed_configs/zero3_bf16.json",
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"adapter": "qlora",
}
| minimal_cfg
)
with self._caplog.at_level(logging.WARNING):
validate_config(test_cfg)
assert (
"qlora + zero3 with use_reentrant: false may result in a CheckpointError about recomputed values"
in self._caplog.records[0].message
)
def test_deepspeed_empty(self, minimal_cfg):
test_cfg = DictDefault(
{
"deepspeed": "",
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"adapter": "qlora",
}
| minimal_cfg
)
_ = validate_config(test_cfg)
def test_deepspeed_not_set(self, minimal_cfg):
test_cfg = DictDefault(
{
"deepspeed": None,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": False},
"load_in_4bit": True,
"adapter": "qlora",
}
| minimal_cfg
)
_ = validate_config(test_cfg)
def test_datasets_min_length(self):
cfg = DictDefault(
{