Compare commits

..

25 Commits

Author SHA1 Message Date
Wing Lian
d46d7dfe30 wip 2024-02-01 00:28:16 -05:00
Wing Lian
047d9e1d5b helper utils 2024-01-31 12:49:29 -05:00
Wing Lian
88a0c05d2c wip 2024-01-31 12:07:39 -05:00
Wing Lian
8f2b591baf set torch version to what is installed during axolotl install (#1234) 2024-01-31 08:47:34 -05:00
DreamGenX
5787e1a23f Fix and document test_datasets (#1228)
* Make sure test_dataset are used and treat val_set_size.

* Add test_datasets docs.

* Apply suggestions from code review

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-31 06:48:57 -05:00
xhedit
8608d8003e Fix typo (#1231) [skip ci] 2024-01-31 06:46:55 -05:00
Wing Lian
4cb7900a56 Peft lotfq (#1222)
* loftq support for lora

* fix loftq check

* update readme for loftq

* readability cleanup

* use peft main for loftq fixes, remove unnecessary special tokens

* remove unused test from older deprecation
2024-01-28 18:50:08 -05:00
Filippo Broggini
18f811978c FEAT: add tagging support to axolotl for DPOTrainer (#1209)
* Add AxolotlDPOTrainer

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-26 20:01:57 -05:00
Wing Lian
afb5dd9655 Update FUNDING.yml [skip ci] 2024-01-26 20:00:28 -05:00
Wing Lian
8da1633124 Revert "run PR e2e docker CI tests in Modal" (#1220) [skip ci] 2024-01-26 16:50:44 -05:00
Wing Lian
36d053f6f0 run PR e2e docker CI tests in Modal (#1217) [skip ci]
* wip modal for ci

* handle falcon layernorms better

* update

* rebuild the template each time with the pseudo-ARGS

* fix ref

* update tests to use modal

* cleanup ci script

* make sure to install jinja2 also

* kickoff the gh action on gh hosted runners and specify num gpus
2024-01-26 16:13:27 -05:00
JohanWork
af29d81f80 ADD: warning if hub_model_id ist set but not any save strategy (#1202)
* warning if hub model id set but no save

* add warning

* move the warning

* add test

* allow more public methods for tests for now

* fix tests

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-26 10:38:55 -05:00
Wing Lian
1b180034c7 ensure the tests use the same version of torch as the latest base docker images (#1215) [skip ci] 2024-01-26 10:38:30 -05:00
DreamGenX
62ca4a2b71 Respect sliding_window=None (#1214) 2024-01-26 07:43:37 -05:00
Igor Berlenko
5407ddd233 Update qlora.yml - remove max_packed_sequence_len (#1210) [skip ci] 2024-01-26 07:43:05 -05:00
Wing Lian
74c72ca5eb drop py39 docker images, add py311, upgrade pytorch to 2.1.2 (#1205)
* drop py39 docker images, add py311, upgrade pytorch to 2.1.2

* also allow the main build to be manually triggered

* fix workflow_dispatch in yaml
2024-01-26 00:38:49 -05:00
Wing Lian
e923e62d24 more checks and fixes for deepspeed and fsdp (#1208) [skip ci] 2024-01-25 20:01:45 -05:00
Wing Lian
ba944e6554 workaround for transformers bug requireing do_sample for saveing pretrained (#1206) 2024-01-25 11:34:41 -05:00
Wing Lian
badda3783b make sure to register the base chatml template even if no system message is provided (#1207) 2024-01-25 10:38:08 -05:00
Wing Lian
a01b998c0f Update deps 202401 (#1204) [skip ci]
* update deps

* xformers fix too
2024-01-25 10:11:49 -05:00
Wing Lian
33e117088f precompute dpo logprobs setting and fixes (#1199) [skip ci]
* add support for precompute_ref_log_probs for dpo

* add chatml.icr type for argilla orca dpo

* update inline doc

* also set use_reentrant to false for dpo when not set

* don't set use_reentrant to true for rl

* make sure to set gradient checkpointing too
2024-01-25 09:31:55 -05:00
Ricardo Dominguez-Olmedo
b4ac96adef fix learning rate scheduler's warnings (#1135) [skip ci]
* fix schedulers warnings

* chore: lint

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-25 07:09:34 -05:00
mhenrichsen
98b4762077 Feat/chatml add system message (#1117)
* add system message to template

* readme update

* added code to register new system message

* register chatml template for test

---------

Co-authored-by: Mads Henrichsen <mads@BrbartiendeMads.lan>
Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-01-25 08:24:27 +01:00
JohanWork
ee0b5f60e5 add colab example (#1196) [skip ci] 2024-01-24 20:09:09 -05:00
NanoCode012
08719b9609 fix(log): improve warning to clarify that lora_modules_to_save expect a list (#1197) 2024-01-24 20:08:34 -05:00
32 changed files with 763 additions and 250 deletions

2
.github/FUNDING.yml vendored
View File

@@ -1,6 +1,6 @@
# These are supported funding model platforms # These are supported funding model platforms
github: OpenAccess-AI-Collective # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2] github: [winglian, OpenAccess-AI-Collective] # Replace with up to 4 GitHub Sponsors-enabled usernames e.g., [user1, user2]
patreon: # Replace with a single Patreon username patreon: # Replace with a single Patreon username
open_collective: # Replace with a single Open Collective username open_collective: # Replace with a single Open Collective username
ko_fi: axolotl_ai # Replace with a single Ko-fi username ko_fi: axolotl_ai # Replace with a single Ko-fi username

View File

@@ -1,10 +1,7 @@
name: ci-cd-base name: ci-cd-base
on: on:
push: workflow_dispatch:
branches:
- "main-base"
- "dev-base"
jobs: jobs:
build-base: build-base:
@@ -15,11 +12,6 @@ jobs:
fail-fast: false fail-fast: false
matrix: matrix:
include: include:
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.9"
pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
@@ -28,12 +20,17 @@ jobs:
- cuda: "118" - cuda: "118"
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.1 pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "121" - cuda: "121"
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.1 pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "121"
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.1.2
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
@@ -56,7 +53,7 @@ jobs:
context: . context: .
file: ./docker/Dockerfile-base file: ./docker/Dockerfile-base
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-args: | build-args: |
CUDA_VERSION=${{ matrix.cuda_version }} CUDA_VERSION=${{ matrix.cuda_version }}

View File

@@ -4,6 +4,7 @@ on:
push: push:
branches: branches:
- "main" - "main"
workflow_dispatch:
jobs: jobs:
build-axolotl: build-axolotl:
@@ -15,24 +16,24 @@ jobs:
include: include:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.1
axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.1 pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras: axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
@@ -86,24 +87,24 @@ jobs:
include: include:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.9" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: 118 - cuda: 118
cuda_version: 11.8.0 cuda_version: 11.8.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.1.2
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.1
axolotl_extras:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.1 pytorch: 2.1.2
axolotl_extras:
- cuda: 121
cuda_version: 12.1.0
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras: axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:

View File

@@ -73,7 +73,7 @@ jobs:
- cuda: 121 - cuda: 121
cuda_version: 12.1.0 cuda_version: 12.1.0
python_version: "3.10" python_version: "3.10"
pytorch: 2.1.1 pytorch: 2.1.2
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@@ -106,3 +106,7 @@ jobs:
- name: GPU Unit Tests monkeypatched w docker image - name: GPU Unit Tests monkeypatched w docker image
run: | run: |
docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/ docker run --privileged --gpus "all" --env WANDB_DISABLED=true --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} pytest /workspace/axolotl/tests/e2e/patched/
- name: Prune image from docker
if: github.ref != 'refs/heads/main'
run: |
docker rmi -f ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}

View File

@@ -607,12 +607,25 @@ datasets:
# For `completion` datsets only, uses the provided field instead of `text` column # For `completion` datsets only, uses the provided field instead of `text` column
field: field:
# A list of one or more datasets to eval the model with.
# You can use either test_datasets, or val_set_size, but not both.
test_datasets:
- path: /workspace/data/eval.jsonl
ds_type: json
# You need to specify a split. For "json" datasets the default split is called "train".
split: train
type: completion
data_files:
- /workspace/data/eval.jsonl
# use RL training: dpo, ipo, kto_pair # use RL training: dpo, ipo, kto_pair
rl: rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing # Saves the desired chat template to the tokenizer_config.json for easier inferencing
# Currently supports chatml and inst (mistral/mixtral) # Currently supports chatml and inst (mistral/mixtral)
chat_template: chatml chat_template: chatml
# Changes the default system message
default_system_message: You are a helpful assistant. Please give a long and detailed answer. # Currently only supports chatml.
# Axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
@@ -694,6 +707,12 @@ lora_modules_to_save:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
peft:
# Configuration options for loftq initialization for LoRA
# https://huggingface.co/docs/peft/developer_guides/quantization#loftq-initialization
loftq_config:
loftq_bits: # typically 4 bits
# ReLoRA configuration # ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart relora_steps: # Number of steps per ReLoRA restart

View File

@@ -15,15 +15,6 @@
"hysteresis": 2, "hysteresis": 2,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -19,15 +19,6 @@
"hysteresis": 2, "hysteresis": 2,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -23,15 +23,6 @@
"hysteresis": 2, "hysteresis": 2,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -23,15 +23,6 @@
"hysteresis": 2, "hysteresis": 2,
"min_loss_scale": 1 "min_loss_scale": 1
}, },
"optimizer": {
"type": "AdamW",
"params": {
"lr": "auto",
"betas": "auto",
"eps": "auto",
"weight_decay": "auto"
}
},
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",
"train_batch_size": "auto", "train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto", "train_micro_batch_size_per_gpu": "auto",

View File

@@ -11,7 +11,6 @@ val_set_size: 0.05
adapter: qlora adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 2048 sequence_len: 2048
max_packed_sequence_len: 2048
lora_r: 16 lora_r: 16
lora_alpha: 32 lora_alpha: 32
lora_dropout: 0.05 lora_dropout: 0.05

View File

@@ -0,0 +1,197 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "AKjdG7tbTb-n"
},
"source": [
"# Example notebook for running Axolotl on google colab"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "RcbNpOgWRcii"
},
"outputs": [],
"source": [
"import torch\n",
"# Check so there is a gpu available, a T4(free tier) is enough to run this notebook\n",
"assert (torch.cuda.is_available()==True)"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "h3nLav8oTRA5"
},
"source": [
"## Install Axolotl and dependencies"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "3c3yGAwnOIdi",
"outputId": "e3777b5a-40ef-424f-e181-62dfecd1dd01"
},
"outputs": [],
"source": [
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "BW2MFr7HTjub"
},
"source": [
"## Create an yaml config file"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "9pkF2dSoQEUN"
},
"outputs": [],
"source": [
"import yaml\n",
"\n",
"# Your YAML string\n",
"yaml_string = \"\"\"\n",
"base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T\n",
"model_type: LlamaForCausalLM\n",
"tokenizer_type: LlamaTokenizer\n",
"is_llama_derived_model: true\n",
"\n",
"load_in_8bit: false\n",
"load_in_4bit: true\n",
"strict: false\n",
"\n",
"datasets:\n",
" - path: mhenrichsen/alpaca_2k_test\n",
" type: alpaca\n",
"dataset_prepared_path:\n",
"val_set_size: 0.05\n",
"output_dir: ./qlora-out\n",
"\n",
"adapter: qlora\n",
"lora_model_dir:\n",
"\n",
"sequence_len: 1096\n",
"sample_packing: true\n",
"pad_to_sequence_len: true\n",
"\n",
"lora_r: 32\n",
"lora_alpha: 16\n",
"lora_dropout: 0.05\n",
"lora_target_modules:\n",
"lora_target_linear: true\n",
"lora_fan_in_fan_out:\n",
"\n",
"wandb_project:\n",
"wandb_entity:\n",
"wandb_watch:\n",
"wandb_name:\n",
"wandb_log_model:\n",
"\n",
"mlflow_experiment_name: colab-example\n",
"\n",
"gradient_accumulation_steps: 1\n",
"micro_batch_size: 1\n",
"num_epochs: 4\n",
"max_steps: 20\n",
"optimizer: paged_adamw_32bit\n",
"lr_scheduler: cosine\n",
"learning_rate: 0.0002\n",
"\n",
"train_on_inputs: false\n",
"group_by_length: false\n",
"bf16: false\n",
"fp16: true\n",
"tf32: false\n",
"\n",
"gradient_checkpointing: true\n",
"early_stopping_patience:\n",
"resume_from_checkpoint:\n",
"local_rank:\n",
"logging_steps: 1\n",
"xformers_attention:\n",
"flash_attention: false\n",
"\n",
"warmup_steps: 10\n",
"evals_per_epoch:\n",
"saves_per_epoch:\n",
"debug:\n",
"deepspeed:\n",
"weight_decay: 0.0\n",
"fsdp:\n",
"fsdp_config:\n",
"special_tokens:\n",
"\n",
"\"\"\"\n",
"\n",
"# Convert the YAML string to a Python dictionary\n",
"yaml_dict = yaml.safe_load(yaml_string)\n",
"\n",
"# Specify your file path\n",
"file_path = 'test_axolotl.yaml'\n",
"\n",
"# Write the YAML file\n",
"with open(file_path, 'w') as file:\n",
" yaml.dump(yaml_dict, file)\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "bidoj8YLTusD"
},
"source": [
"## Launch the training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "ydTI2Jk2RStU",
"outputId": "d6d0df17-4b53-439c-c802-22c0456d301b"
},
"outputs": [],
"source": [
"# Buy using the ! the comand will be executed as a bash command\n",
"!accelerate launch -m axolotl.cli.train /content/test_axolotl.yaml"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"gpuType": "T4",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -67,6 +67,3 @@ weight_decay: 0.1
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -0,0 +1,70 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0.05
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
peft:
loftq_config:
loftq_bits: 4
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
s2_attention:
warmup_steps: 10
evals_per_epoch: 4
eval_table_size:
eval_table_max_new_tokens: 128
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -65,6 +65,3 @@ weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -65,6 +65,3 @@ weight_decay: 0.0
fsdp: fsdp:
fsdp_config: fsdp_config:
special_tokens: special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,6 +1,6 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/ --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2 packaging==23.2
peft==0.7.0 peft @ git+https://github.com/huggingface/peft.git
transformers==4.37.0 transformers==4.37.0
tokenizers==0.15.0 tokenizers==0.15.0
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
@@ -15,16 +15,14 @@ sentencepiece
wandb wandb
einops einops
xformers==0.0.22 xformers==0.0.22
optimum==1.13.2 optimum==1.16.2
hf_transfer hf_transfer
colorama colorama
numba numba
numpy>=1.24.4 numpy>=1.24.4
mlflow mlflow
# qlora things # qlora things
bert-score==0.3.13
evaluate==0.4.0 evaluate==0.4.0
rouge-score==0.1.2
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
pynvml pynvml

View File

@@ -27,9 +27,10 @@ def parse_requirements():
try: try:
torch_version = version("torch") torch_version = version("torch")
if torch_version.startswith("2.1.1"): _install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
_install_requires.pop(_install_requires.index("xformers==0.0.22")) _install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers==0.0.23") _install_requires.append("xformers>=0.0.23")
except PackageNotFoundError: except PackageNotFoundError:
pass pass
@@ -50,7 +51,7 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn==2.3.3", "flash-attn==2.5.0",
], ],
"fused-dense-lib": [ "fused-dense-lib": [
"fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib", "fused-dense-lib @ git+https://github.com/Dao-AILab/flash-attention@v2.3.3#subdirectory=csrc/fused_dense_lib",

View File

@@ -18,6 +18,7 @@ from axolotl.cli import (
) )
from axolotl.common.cli import PreprocessCliArgs from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.prompt_strategies.sharegpt import register_chatml_template
LOG = logging.getLogger("axolotl.cli.preprocess") LOG = logging.getLogger("axolotl.cli.preprocess")
@@ -34,6 +35,14 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cfg.chat_template == "chatml" and parsed_cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {parsed_cfg.default_system_message}"
)
register_chatml_template(parsed_cfg.default_system_message)
else:
register_chatml_template()
if not parsed_cfg.dataset_prepared_path: if not parsed_cfg.dataset_prepared_path:
msg = ( msg = (
Fore.RED Fore.RED

View File

@@ -18,6 +18,7 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.prompt_strategies.sharegpt import register_chatml_template
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train") LOG = logging.getLogger("axolotl.cli.train")
@@ -37,6 +38,14 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]:
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
check_user_token() check_user_token()
if cfg.chat_template == "chatml" and cfg.default_system_message:
LOG.info(
f"ChatML set. Adding default system message: {cfg.default_system_message}"
)
register_chatml_template(cfg.default_system_message)
else:
register_chatml_template()
if cfg.rl: if cfg.rl:
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
else: else:

View File

@@ -8,15 +8,17 @@ import importlib
import logging import logging
import math import math
import sys import sys
import typing
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps, partial
from pathlib import Path from pathlib import Path
from typing import List, Optional, Type, Union from typing import Dict, List, Optional, Tuple, Type, Union
import torch import torch
import transformers import transformers
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
@@ -29,6 +31,7 @@ from transformers.trainer_utils import seed_worker
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
@@ -50,15 +53,39 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr, get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup, get_cosine_schedule_with_quadratic_warmup,
) )
from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed
try: try:
import torch._dynamo # pylint: disable=ungrouped-imports import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError: except ImportError:
pass pass
if typing.TYPE_CHECKING:
# hacky, but recommended per https://github.com/python/mypy/issues/5837
_MixinTrainerBase = Trainer
else:
_MixinTrainerBase = object
LOG = logging.getLogger("axolotl.core.trainer_builder") LOG = logging.getLogger("axolotl.core.trainer_builder")
def _sanitize_kwargs_for_tagging(tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
@dataclass @dataclass
class AxolotlTrainingArguments(TrainingArguments): class AxolotlTrainingArguments(TrainingArguments):
""" """
@@ -137,66 +164,10 @@ class AxolotlTrainingArguments(TrainingArguments):
) )
class AxolotlTrainer(Trainer): class AxolotlMultiPackTrainerMixin(_MixinTrainerBase): # type: ignore
""" """Trainer Mixin class for dataloaders and samplers"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.lr_scheduler_type == "cosine" and self.args.cosine_min_lr_ratio is not None:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
if self.args.deepspeed:
LOG.warning("Using cosine scheduler with deepspeed. This may be ignored if a scheduler is set \
in the deepspeed JSON")
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
@@ -210,20 +181,6 @@ class AxolotlTrainer(Trainer):
) )
return super()._get_train_sampler() return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader: def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining: if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset train_dataset = self.train_dataset
@@ -247,7 +204,7 @@ class AxolotlTrainer(Trainer):
del dataloader_params["batch_size"] del dataloader_params["batch_size"]
else: else:
dataloader_params["sampler"] = sampler dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False self.accelerator.even_batches = False
@@ -256,6 +213,20 @@ class AxolotlTrainer(Trainer):
) )
return super().get_train_dataloader() return super().get_train_dataloader()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False: if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init self.data_collator = ( # pylint: disable=attribute-defined-outside-init
@@ -327,6 +298,81 @@ class AxolotlTrainer(Trainer):
return DataLoader(bench_dataset, **dataloader_params) return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params)) # return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
class AxolotlTrainer(AxolotlMultiPackTrainerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
tag_names = ["axolotl"]
def __init__(
self,
*_args,
num_epochs=1,
bench_data_collator=None,
eval_data_collator=None,
**kwargs
):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
self.eval_data_collator = eval_data_collator
super().__init__(*_args, **kwargs)
self.train_data_collator = self.data_collator
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
use_cosine_quadratic = (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
)
use_cosine_min_lr = (
self.args.lr_scheduler_type == "cosine"
and self.args.cosine_min_lr_ratio is not None
)
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if use_cosine_quadratic:
if use_cosine_min_lr:
LOG.warning("Both cosine quadratic warmup and min lr detected. Using quadratic warmup.")
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
elif self.args.cosine_min_lr_ratio and use_cosine_min_lr:
assert 0 <= self.args.cosine_min_lr_ratio <= 1.0, "cosine_min_lr_ratio must be between 0.0 and 1.0"
self.lr_scheduler = get_cosine_schedule_with_min_lr( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
min_lr_ratio=self.args.cosine_min_lr_ratio,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
else:
if use_cosine_quadratic:
LOG.warning("axolotl's cosine scheduler with quadratic warmup not used (e.g., because of deepspeed).")
if use_cosine_min_lr:
LOG.warning("axolotl's cosine scheduler with min lr not used (e.g., because of deepspeed).")
return self.lr_scheduler
def compute_loss(self, model, inputs, return_outputs=False): def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc # use one's weighted cross entropy loss calc
# if self.args.sample_packing: # if self.args.sample_packing:
@@ -336,30 +382,13 @@ class AxolotlTrainer(Trainer):
# return (loss, outputs) if return_outputs else loss # return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs) return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
if isinstance(tag_names, str):
tag_names = [tag_names]
if kwargs is not None:
if "tags" not in kwargs:
kwargs["tags"] = tag_names
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
kwargs["tags"].extend(tag_names)
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
tag_names.append(kwargs["tags"])
kwargs["tags"] = tag_names
return kwargs
@wraps(Trainer.push_to_hub) @wraps(Trainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str: def push_to_hub(self, *args, **kwargs) -> str:
""" """
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details. model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
""" """
kwargs = self._sanitize_kwargs_for_tagging( kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
tag_names=self.tag_names, kwargs=kwargs
)
return super().push_to_hub(*args, **kwargs) return super().push_to_hub(*args, **kwargs)
@@ -458,6 +487,77 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler return self.lr_scheduler
class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
tag_names = ["axolotl", "dpo"]
@wraps(DPOTrainer.push_to_hub)
def push_to_hub(self, *args, **kwargs) -> str:
"""
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
"""
kwargs = _sanitize_kwargs_for_tagging(tag_names=self.tag_names, kwargs=kwargs)
return super().push_to_hub(*args, **kwargs)
def tokenize_row(self, feature, *args, **kwargs) -> Dict:
# check if dataset is already tokenized
if not self.is_encoder_decoder:
keys = [
"chosen_input_ids",
"chosen_attention_mask",
"chosen_labels",
"rejected_input_ids",
"rejected_attention_mask",
"rejected_labels",
]
if all(k in feature.keys() for k in keys):
return feature
else:
keys = [
"chosen_labels",
"rejected_labels",
"prompt_input_ids",
"prompt_attention_mask",
]
if all(k in feature.keys() for k in keys):
return feature
return super().tokenize_row(feature, *args, **kwargs)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
]:
all_logits = model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
position_ids=batch["position_ids"],
).logits
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True)
unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn)
labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True)
unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn)
unpacked_logps = self.get_batch_logps(
unpacked_logits,
unpacked_labels,
average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = unpacked_logps[::2]
rejected_logps = unpacked_logps[1::2]
chosen_logits = unpacked_logits[::2]
rejected_logits = unpacked_logits[1::2]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
class TrainerBuilderBase(abc.ABC): class TrainerBuilderBase(abc.ABC):
""" """
Base class for trainer builder Base class for trainer builder
@@ -638,7 +738,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"gradient_checkpointing" "gradient_checkpointing"
] = self.cfg.gradient_checkpointing ] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs: if self.cfg.gradient_checkpointing_kwargs is not None:
training_arguments_kwargs[ training_arguments_kwargs[
"gradient_checkpointing_kwargs" "gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs ] = self.cfg.gradient_checkpointing_kwargs
@@ -705,7 +805,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False: elif self.cfg.sample_packing and self.cfg.eval_sample_packing is False:
training_arguments_kwargs["dataloader_drop_last"] = True training_arguments_kwargs["dataloader_drop_last"] = True
if self.cfg.val_set_size == 0: if not self.cfg.test_datasets and self.cfg.val_set_size == 0:
# no eval set, so don't eval # no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no" training_arguments_kwargs["evaluation_strategy"] = "no"
elif self.cfg.eval_steps: elif self.cfg.eval_steps:
@@ -792,6 +892,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
self.cfg.load_best_model_at_end is not False self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience or self.cfg.early_stopping_patience
) )
and not self.cfg.test_datasets
and self.cfg.val_set_size > 0 and self.cfg.val_set_size > 0
and self.cfg.save_steps and self.cfg.save_steps
and self.cfg.eval_steps and self.cfg.eval_steps
@@ -1015,6 +1116,18 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
training_args_kwargs[ training_args_kwargs[
"dataloader_prefetch_factor" "dataloader_prefetch_factor"
] = self.cfg.dataloader_prefetch_factor ] = self.cfg.dataloader_prefetch_factor
if self.cfg.gradient_checkpointing:
training_args_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.gradient_checkpointing_kwargs is not None:
training_args_kwargs[
"gradient_checkpointing_kwargs"
] = self.cfg.gradient_checkpointing_kwargs
else:
training_args_kwargs["gradient_checkpointing_kwargs"] = {
"use_reentrant": False
}
training_args = TrainingArguments( training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1025,9 +1138,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
save_steps=self.cfg.save_steps, save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir, output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps, warmup_steps=self.cfg.warmup_steps,
gradient_checkpointing=self.cfg.gradient_checkpointing,
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
or {"use_reentrant": False},
logging_first_step=True, logging_first_step=True,
logging_steps=1, logging_steps=1,
optim=self.cfg.optimizer, optim=self.cfg.optimizer,
@@ -1050,7 +1160,11 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer_kwargs["peft_config"] = self.peft_config
dpo_trainer = DPOTrainer( if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
dpo_trainer = AxolotlDPOTrainer(
self.model, self.model,
self.model_ref, self.model_ref,
args=training_args, args=training_args,
@@ -1064,6 +1178,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
callbacks=self.get_callbacks(), callbacks=self.get_callbacks(),
**dpo_trainer_kwargs, **dpo_trainer_kwargs,
) )
setattr(dpo_trainer, "use_dpo_data_collator", True)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer) dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer): for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback) dpo_trainer.add_callback(callback)

View File

@@ -94,7 +94,7 @@ def _prepare_decoder_attention_mask(
sliding_window, sliding_window,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# [bsz, seq_len] # [bsz, seq_len]
if attention_mask is None: if attention_mask is None or sliding_window is None:
return attention_mask return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios. # NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
@@ -151,7 +151,7 @@ def flashattn_forward(
) )
use_sliding_windows = ( use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None getattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window and kv_seq_len > self.config.sliding_window
) )

View File

@@ -23,6 +23,31 @@ def argilla(
return transform_fn return transform_fn
def icr(
cfg,
): # pylint: disable=possibly-unused-variable,unused-argument
"""
chatml transforms for datasets with system, input, chosen, rejected
ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs
"""
def transform_fn(sample):
if "system" in sample and sample["system"]:
sample["prompt"] = (
f"<|im_start|>system\n{sample['system']}<|im_end|>\n"
f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
)
else:
sample[
"prompt"
] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n"
sample["chosen"] = f"{sample['chosen']}<|im_end|>"
sample["rejected"] = f"{sample['rejected']}<|im_end|>"
return sample
return transform_fn
def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument
""" """
For Intel Orca DPO Pairs For Intel Orca DPO Pairs

View File

@@ -6,16 +6,19 @@ from fastchat.conversation import Conversation, SeparatorStyle, register_conv_te
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2 from axolotl.prompters import ShareGPTPrompterV2
register_conv_template(
Conversation( def register_chatml_template(system_message=None):
name="chatml", system_message = system_message or "You are a helpful assistant."
system_template="<|im_start|>system\n{system_message}", register_conv_template(
system_message="You are a helpful assistant.", Conversation(
roles=["<|im_start|>user", "<|im_start|>assistant"], name="chatml",
sep_style=SeparatorStyle.CHATML, system_template="<|im_start|>system\n{system_message}",
sep="<|im_end|>", system_message=system_message,
roles=["<|im_start|>user", "<|im_start|>assistant"],
sep_style=SeparatorStyle.CHATML,
sep="<|im_end|>",
)
) )
)
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):

View File

@@ -63,6 +63,8 @@ def train(
msg += " and peft_config..." msg += " and peft_config..."
LOG.debug(msg) LOG.debug(msg)
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference) model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
model.generation_config.do_sample = True
model_ref = None model_ref = None
if cfg.rl: if cfg.rl:
if cfg.adapter and not cfg.rl_adapter_ref_model: if cfg.adapter and not cfg.rl_adapter_ref_model:

View File

@@ -20,7 +20,7 @@ def chat_templates(user_choice: str):
templates = { templates = {
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
} }
if user_choice in templates: if user_choice in templates:

View File

@@ -178,6 +178,9 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
features = [chunked_data] features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors) return super().__call__(features, return_tensors=return_tensors)
@dataclass
class BatchSamplerDPODataCollatorWithPadding:
@dataclass @dataclass
class MambaDataCollator: class MambaDataCollator:

View File

@@ -95,7 +95,7 @@ def normalize_config(cfg):
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs) save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
if save_steps < 1.0: # prevent saves on every step if save_steps < 1.0: # prevent saves on every step
cfg.save_steps = save_steps cfg.save_steps = save_steps
if cfg.evals_per_epoch: if (cfg.val_set_size or cfg.test_datasets) and cfg.evals_per_epoch:
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs) eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
if eval_steps < 1.0: # prevent evals on every step if eval_steps < 1.0: # prevent evals on every step
cfg.eval_steps = eval_steps cfg.eval_steps = eval_steps
@@ -163,6 +163,7 @@ def normalize_config(cfg):
cfg.gradient_checkpointing cfg.gradient_checkpointing
and cfg.unfrozen_parameters is None and cfg.unfrozen_parameters is None
and cfg.gradient_checkpointing_kwargs is None and cfg.gradient_checkpointing_kwargs is None
and cfg.rl is None
): ):
cfg.gradient_checkpointing_kwargs = {"use_reentrant": True} cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}
@@ -231,9 +232,6 @@ def validate_config(cfg):
"eval_batch_size != micro_batch_size. This can lead to VRAM instability." "eval_batch_size != micro_batch_size. This can lead to VRAM instability."
) )
if cfg.load_4bit:
raise ValueError("cfg.load_4bit parameter has been deprecated")
if cfg.adapter == "qlora": if cfg.adapter == "qlora":
if cfg.merge_lora: if cfg.merge_lora:
# can't merge qlora if loaded in 8bit or 4bit # can't merge qlora if loaded in 8bit or 4bit
@@ -259,7 +257,8 @@ def validate_config(cfg):
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp: if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA") raise ValueError("Fused modules are not supported with QLoRA")
if not cfg.load_in_8bit and cfg.adapter == "lora": loftq = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if not cfg.load_in_8bit and cfg.adapter == "lora" and not loftq:
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp): if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
@@ -339,6 +338,11 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead." "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
) )
if cfg.hub_model_id and not (cfg.save_steps or cfg.saves_per_epoch):
LOG.warning(
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
if cfg.gptq and cfg.model_revision: if cfg.gptq and cfg.model_revision:
raise ValueError( raise ValueError(
"model_revision is not supported for GPTQ models. " "model_revision is not supported for GPTQ models. "
@@ -484,35 +488,43 @@ def validate_config(cfg):
"`use_reentrant` must be false when used with partially frozen model." "`use_reentrant` must be false when used with partially frozen model."
) )
if cfg.flash_attention and cfg.deepspeed and Path(cfg.deepspeed).is_file(): if cfg.deepspeed and Path(cfg.deepspeed).is_file():
with open(cfg.deepspeed, encoding="utf-8") as file: with open(cfg.deepspeed, encoding="utf-8") as file:
contents = file.read() contents = file.read()
deepspeed_cfg: DictDefault = DictDefault(json.loads(contents)) deepspeed_cfg: DictDefault = DictDefault(json.loads(contents))
if ( if cfg.flash_attention:
deepspeed_cfg.zero_optimization if (
and deepspeed_cfg.zero_optimization.stage == 3 deepspeed_cfg.zero_optimization
): and deepspeed_cfg.zero_optimization.stage == 3
if not (
(
deepspeed_cfg.bf16
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
is True
)
or (
deepspeed_cfg.fp16
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
is True
)
): ):
raise ValueError( if not (
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention" (
) deepspeed_cfg.bf16
and deepspeed_cfg.bf16.enabled # pylint: disable=no-member
is True
)
or (
deepspeed_cfg.fp16
and deepspeed_cfg.fp16.enabled # pylint: disable=no-member
is True
)
):
raise ValueError(
"bf16.enabled or fp16.enabled must be set to true when using ZeRO-3 with flash-attention"
)
if "8bit" in cfg.optimizer and deepspeed_cfg.optimizer:
LOG.warning(
f"conflicting optimizer: {cfg.optimizer} used alongside deepspeed optimizer."
)
if cfg.test_datasets and cfg.val_set_size: if cfg.test_datasets and cfg.val_set_size:
raise ValueError( raise ValueError(
"non-zero val_set_size should not be used with test_datasets configuration" "non-zero val_set_size should not be used with test_datasets configuration"
) )
if cfg.fsdp and "bnb" in cfg.optimizer:
raise ValueError(f"FSDP not compatible with {cfg.optimizer}")
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -16,6 +16,7 @@ from datasets import (
load_from_disk, load_from_disk,
) )
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError
from torch.utils.data import RandomSampler from torch.utils.data import RandomSampler
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
@@ -213,7 +214,7 @@ def load_tokenized_prepared_datasets(
token=use_auth_token, token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError): except (FileNotFoundError, ConnectionError, HFValidationError):
pass pass
ds_from_cloud = False ds_from_cloud = False
@@ -439,7 +440,7 @@ def load_prepare_datasets(
split="train", split="train",
) -> Tuple[Dataset, Dataset, List[Prompter]]: ) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets( dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path, split=split
) )
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:

View File

@@ -9,7 +9,7 @@ import bitsandbytes as bnb
import torch import torch
import transformers import transformers
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401 from transformers import ( # noqa: F401
AddedToken, AddedToken,
@@ -67,7 +67,7 @@ def check_model_config(cfg: DictDefault, model_config: Union[AutoConfig, DictDef
): ):
lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save)) lora_modules_to_save = ", ".join(map(lambda x: f"`{x}`", lora_modules_to_save))
raise ValueError( raise ValueError(
f"`lora_modules_to_save` not properly set when adding new tokens. Please include {lora_modules_to_save} in `lora_modules_to_save`." f"`lora_modules_to_save` not properly set when adding new tokens. Please include [{lora_modules_to_save}] in `lora_modules_to_save`."
) )
@@ -182,7 +182,7 @@ def load_tokenizer(cfg):
[f"`{x}`" for x in lora_modules_to_save] [f"`{x}`" for x in lora_modules_to_save]
) )
raise ValueError( raise ValueError(
f"Please set lora_modules_to_save to {lora_modules_to_save} when using an adapter and changing the special tokens." f"Please set lora_modules_to_save to [{lora_modules_to_save}] when using an adapter and changing the special tokens."
) )
tokenizer.add_special_tokens( tokenizer.add_special_tokens(
@@ -219,7 +219,13 @@ def load_tokenizer(cfg):
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if cfg.chat_template: if cfg.chat_template:
tokenizer.chat_template = chat_templates(cfg.chat_template) chat_template_string = chat_templates(cfg.chat_template)
if cfg.default_system_message and cfg.chat_template == "chatml":
chat_template_string = chat_template_string.replace(
"You are a helpful assistant.", cfg.default_system_message
)
tokenizer.chat_template = chat_template_string
else: else:
LOG.info( LOG.info(
"No Chat template selected. Consider adding a chat template for easier inference." "No Chat template selected. Consider adding a chat template for easier inference."
@@ -636,15 +642,17 @@ def load_model(
# make sure these are fp32 per Ramesh et al. (2021) # make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(cfg.model_config_type) embedding_modules = get_linear_embedding_layers(cfg.model_config_type)
for name, module in model.named_modules(): if not cfg.fsdp:
if any(m in name for m in ["norm", "gate"]): # FSDP doesn't like mixed Float and BFloat16
module.to(torch.float32) for name, module in model.named_modules():
if model_config.model_type == "btlm": if any(m in name for m in ["norm", "gate"]):
# don't upcast lm_head for btlm
continue
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
continue
if any(m in name for m in embedding_modules):
if hasattr(module, "weight"):
module.to(torch.float32)
needs_fa2_dtype = cfg.adapter or cfg.fsdp needs_fa2_dtype = cfg.adapter or cfg.fsdp
skip_prepare_model_for_kbit_training = False skip_prepare_model_for_kbit_training = False
@@ -659,13 +667,17 @@ def load_model(
# Qwen doesn't play nicely with LoRA if this is enabled # Qwen doesn't play nicely with LoRA if this is enabled
skip_prepare_model_for_kbit_training = True skip_prepare_model_for_kbit_training = True
if (cfg.adapter == "lora" and load_in_8bit) or ( loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
cfg.adapter == "qlora" and cfg.load_in_4bit if cfg.adapter == "lora" and loftq_bits:
): skip_prepare_model_for_kbit_training = True
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
if cfg.adapter in ["lora", "qlora"]:
if cfg.gradient_checkpointing: if cfg.gradient_checkpointing:
model.gradient_checkpointing_enable() model.gradient_checkpointing_enable()
if not skip_prepare_model_for_kbit_training: if (
cfg.load_in_8bit or cfg.load_in_4bit
) and not skip_prepare_model_for_kbit_training:
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
model = prepare_model_for_kbit_training( model = prepare_model_for_kbit_training(
model, use_gradient_checkpointing=cfg.gradient_checkpointing model, use_gradient_checkpointing=cfg.gradient_checkpointing
) )
@@ -692,6 +704,7 @@ def load_model(
model, lora_config = load_adapter(model, cfg, cfg.adapter) model, lora_config = load_adapter(model, cfg, cfg.adapter)
if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit): if cfg.ddp and not load_in_8bit and not (cfg.rl and cfg.load_in_4bit):
# TODO revaldate this conditional
model.to(f"cuda:{cfg.local_rank}") model.to(f"cuda:{cfg.local_rank}")
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1: if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
@@ -743,7 +756,7 @@ def load_llama_adapter(model, cfg):
) )
if cfg.lora_model_dir: if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - llama_adapter") LOG.debug("Loading pretrained PEFT - llama_adapter")
model = PeftModel.from_pretrained( model = PeftModel.from_pretrained(
model, model,
cfg.lora_model_dir, cfg.lora_model_dir,
@@ -789,6 +802,12 @@ def load_lora(model, cfg, inference=False, config_only=False):
LOG.info(f"found linear modules: {repr(linear_names)}") LOG.info(f"found linear modules: {repr(linear_names)}")
lora_target_modules = list(set(lora_target_modules + linear_names)) lora_target_modules = list(set(lora_target_modules + linear_names))
lora_config_kwargs = {}
loftq_bits = cfg.peft and cfg.peft.loftq_config and cfg.peft.loftq_config.loftq_bits
if loftq_bits:
lora_config_kwargs["loftq_config"] = LoftQConfig(loftq_bits=loftq_bits)
lora_config_kwargs["init_lora_weights"] = "loftq"
lora_config = LoraConfig( lora_config = LoraConfig(
r=cfg.lora_r, r=cfg.lora_r,
lora_alpha=cfg.lora_alpha, lora_alpha=cfg.lora_alpha,
@@ -799,13 +818,14 @@ def load_lora(model, cfg, inference=False, config_only=False):
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None, modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
bias="none", bias="none",
task_type="CAUSAL_LM", task_type="CAUSAL_LM",
**lora_config_kwargs,
) )
if config_only: if config_only:
return None, lora_config return None, lora_config
if cfg.lora_model_dir: if cfg.lora_model_dir:
LOG.debug("Loading pretained PEFT - LoRA") LOG.debug("Loading pretrained PEFT - LoRA")
model_kwargs: Any = {} model_kwargs: Any = {}
if cfg.lora_on_cpu: if cfg.lora_on_cpu:
model_kwargs["max_memory"] = {"cpu": "256GiB"} model_kwargs["max_memory"] = {"cpu": "256GiB"}

View File

@@ -0,0 +1,61 @@
import torch
import torch.nn.functional as F
def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False):
# pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask)
if index >= nonzero_total:
return False
if pairs and (index // 2) >= (nonzero_total // 2):
return False
if pad_val and (data == pad_val).all(dim=0).all():
return False
return True
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
split_tensors = []
counts = count_nonzero_sequences(cu_seqlens)
# Iterate over each batch
for i in range(tensor.size(0)):
seq_lens = cu_seqlens[i]
start_idx = 0
# Iterate over the cumulative sequence lengths
for j, end_idx in enumerate(seq_lens[1:]):
if end_idx == start_idx:
break
# Extract and pad the current sequence
current_seq = tensor[i, start_idx:end_idx]
keep = True
if keep_fn:
keep = keep_fn(current_seq, index=j, nonzero_total=counts[i])
if not keep:
continue
padding_size = max_seqlen - current_seq.size(0)
padded_seq = F.pad(current_seq, (0, 0) * (current_seq.dim() - 2) + (0, padding_size))
# Append the padded sequence to the list
split_tensors.append(padded_seq)
# Update start index for the next sequence
start_idx = end_idx
# Stack the padded tensors
return torch.stack(split_tensors, dim=0)
def count_nonzero_sequences(cu_seqlens: torch.Tensor) -> torch.LongTensor:
diffs = torch.diff(cu_seqlens, dim=1, prepend=torch.zeros(cu_seqlens.shape[0], 1, dtype=cu_seqlens.dtype))
valid_lengths = diffs != 0
counts = valid_lengths.sum(dim=1).long()
return counts
# Example usage
# Example tensor with dimensions [batch_size, seq_len, other_dimensions...]
# example_tensor = torch.randn(batch_size, seq_len, other_dimensions...)
# cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
# split_padded_tensor = split_and_pad_packed(example_tensor, cu_seqlens, max_seqlen)

View File

@@ -7,9 +7,14 @@ from tokenizers import AddedToken
from transformers import AutoTokenizer from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.sharegpt import SimpleShareGPTPromptTokenizingStrategy from axolotl.prompt_strategies.sharegpt import (
SimpleShareGPTPromptTokenizingStrategy,
register_chatml_template,
)
from axolotl.prompters import ShareGPTPrompterV2 from axolotl.prompters import ShareGPTPrompterV2
register_chatml_template()
@pytest.fixture(name="sharegpt_dataset") @pytest.fixture(name="sharegpt_dataset")
def fixture_sharegpt_dataset(): def fixture_sharegpt_dataset():

View File

@@ -26,21 +26,12 @@ class BaseValidation(unittest.TestCase):
self._caplog = caplog self._caplog = caplog
# pylint: disable=too-many-public-methods
class ValidationTest(BaseValidation): class ValidationTest(BaseValidation):
""" """
Test the validation module Test the validation module
""" """
def test_load_4bit_deprecate(self):
cfg = DictDefault(
{
"load_4bit": True,
}
)
with pytest.raises(ValueError):
validate_config(cfg)
def test_batch_size_unused_warning(self): def test_batch_size_unused_warning(self):
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -698,6 +689,22 @@ class ValidationTest(BaseValidation):
): ):
validate_config(cfg) validate_config(cfg)
def test_hub_model_id_save_value_warns(self):
cfg = DictDefault({"hub_model_id": "test"})
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert (
"set without any models being saved" in self._caplog.records[0].message
)
def test_hub_model_id_save_value(self):
cfg = DictDefault({"hub_model_id": "test", "saves_per_epoch": 4})
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert len(self._caplog.records) == 0
class ValidationCheckModelConfig(BaseValidation): class ValidationCheckModelConfig(BaseValidation):
""" """