Compare commits
1 Commits
ia3-peft
...
sharegpt-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4d84d56d5 |
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -25,11 +25,6 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.0.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
- cuda: "118"
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.0
|
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v3
|
||||||
|
|||||||
11
.github/workflows/main.yml
vendored
11
.github/workflows/main.yml
vendored
@@ -23,11 +23,6 @@ jobs:
|
|||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.0.1
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
- cuda: 118
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -51,7 +46,6 @@ jobs:
|
|||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
|
||||||
file: ./docker/Dockerfile
|
file: ./docker/Dockerfile
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
push: ${{ github.event_name != 'pull_request' }}
|
||||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
@@ -74,11 +68,6 @@ jobs:
|
|||||||
pytorch: 2.0.1
|
pytorch: 2.0.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
is_latest: true
|
||||||
- cuda: 118
|
|
||||||
cuda_version: 11.8.0
|
|
||||||
python_version: "3.10"
|
|
||||||
pytorch: 2.1.0
|
|
||||||
axolotl_extras:
|
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
@@ -12,4 +12,4 @@ generated-members=numpy.*, torch.*
|
|||||||
disable=missing-function-docstring, line-too-long, import-error,
|
disable=missing-function-docstring, line-too-long, import-error,
|
||||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||||
too-many-boolean-expressions,
|
too-many-nested-blocks,
|
||||||
|
|||||||
316
README.md
316
README.md
@@ -23,10 +23,9 @@ Features:
|
|||||||
- [Supported Features](#axolotl-supports)
|
- [Supported Features](#axolotl-supports)
|
||||||
- [Quickstart](#quickstart-)
|
- [Quickstart](#quickstart-)
|
||||||
- [Installation](#installation)
|
- [Installation](#installation)
|
||||||
- [Docker](#docker)
|
- [Docker Installation](#environment)
|
||||||
- [Conda/Pip venv](#condapip-venv)
|
- [Conda/Pip venv Installation](#condapip-venv)
|
||||||
- [LambdaLabs](#lambdalabs)
|
- [LambdaLabs Installation](#lambdalabs)
|
||||||
- [Windows](#windows)
|
|
||||||
- [Dataset](#dataset)
|
- [Dataset](#dataset)
|
||||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||||
@@ -51,7 +50,7 @@ Features:
|
|||||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||||
</p>
|
</p>
|
||||||
<p>
|
<p>
|
||||||
Go ahead and Axolotl questions!!
|
Go ahead and axolotl questions!!
|
||||||
</p>
|
</p>
|
||||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||||
@@ -96,14 +95,14 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
|||||||
|
|
||||||
# inference
|
# inference
|
||||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||||
--peft_model_dir="./lora-out"
|
--lora_model_dir="./lora-out"
|
||||||
```
|
```
|
||||||
|
|
||||||
## Installation
|
## Installation
|
||||||
|
|
||||||
### Environment
|
### Environment
|
||||||
|
|
||||||
#### Docker
|
- Docker
|
||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
@@ -115,12 +114,12 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
docker compose up -d
|
docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Conda/Pip venv
|
- Conda/Pip venv
|
||||||
1. Install python >=**3.9**
|
1. Install python >=**3.9**
|
||||||
|
|
||||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||||
|
|
||||||
3. Install Axolotl along with python dependencies
|
3. Install axolotl along with python dependencies
|
||||||
```bash
|
```bash
|
||||||
pip3 install packaging
|
pip3 install packaging
|
||||||
pip3 install -e '.[flash-attn,deepspeed]'
|
pip3 install -e '.[flash-attn,deepspeed]'
|
||||||
@@ -131,7 +130,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
```
|
```
|
||||||
Get the token at huggingface.co/settings/tokens
|
Get the token at huggingface.co/settings/tokens
|
||||||
|
|
||||||
#### LambdaLabs
|
- LambdaLabs
|
||||||
<details>
|
<details>
|
||||||
|
|
||||||
<summary>Click to Expand</summary>
|
<summary>Click to Expand</summary>
|
||||||
@@ -175,8 +174,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
|||||||
```
|
```
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
#### Windows
|
- Windows: Please use WSL or Docker!
|
||||||
Please use WSL or Docker!
|
|
||||||
|
|
||||||
### Dataset
|
### Dataset
|
||||||
|
|
||||||
@@ -297,24 +295,25 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
|
|
||||||
#### How to add custom prompts
|
#### How to add custom prompts
|
||||||
|
|
||||||
For a dataset that is preprocessed for instruction purposes:
|
Using yaml. Example:
|
||||||
|
|
||||||
```json
|
|
||||||
{"instruction": "...", "output": "..."}
|
|
||||||
```
|
|
||||||
|
|
||||||
You can use this example in your YAML config:
|
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
datasets:
|
datasets:
|
||||||
- path: repo
|
- path: repo
|
||||||
type:
|
type:
|
||||||
system_prompt: ""
|
system_prompt: ""
|
||||||
field_system: system
|
no_input_format: |-
|
||||||
format: "[INST] {instruction} [/INST]"
|
User: {instruction}<|end_of_turn|>
|
||||||
no_input_format: "[INST] {instruction} [/INST]"
|
Assistant:
|
||||||
|
format: |-
|
||||||
|
User: {instruction}
|
||||||
|
{input}<|end_of_turn|>
|
||||||
|
Assistant:
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Using file:
|
||||||
|
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||||
|
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||||
|
|
||||||
#### How to use your custom pretokenized dataset
|
#### How to use your custom pretokenized dataset
|
||||||
|
|
||||||
- Do not pass a `type:`
|
- Do not pass a `type:`
|
||||||
@@ -384,10 +383,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
- lora
|
- lora
|
||||||
```yaml
|
```yaml
|
||||||
adapter: lora # qlora or leave blank for full finetune
|
adapter: lora # qlora or leave blank for full finetune
|
||||||
peft_r: 8
|
lora_r: 8
|
||||||
peft_alpha: 16
|
lora_alpha: 16
|
||||||
peft_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
peft_target_modules:
|
lora_target_modules:
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
```
|
```
|
||||||
@@ -397,15 +396,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
|||||||
<summary>All yaml options</summary>
|
<summary>All yaml options</summary>
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||||
# This can also be a relative path to a model on disk
|
# this can also be a relative path to a model on disk
|
||||||
base_model: ./llama-7b-hf
|
base_model: ./llama-7b-hf
|
||||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||||
base_model_ignore_patterns:
|
base_model_ignore_patterns:
|
||||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||||
# You can set that here, or leave this empty to default to base_model
|
# you can set that here, or leave this empty to default to base_model
|
||||||
base_model_config: ./llama-7b-hf
|
base_model_config: ./llama-7b-hf
|
||||||
# You can specify to choose a specific model revision from huggingface hub
|
# you can specify to choose a specific model revision from huggingface hub
|
||||||
model_revision:
|
model_revision:
|
||||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||||
# than the one defined in the base model
|
# than the one defined in the base model
|
||||||
@@ -420,24 +419,23 @@ trust_remote_code:
|
|||||||
tokenizer_use_fast:
|
tokenizer_use_fast:
|
||||||
# Whether to use the legacy tokenizer setting, defaults to True
|
# Whether to use the legacy tokenizer setting, defaults to True
|
||||||
tokenizer_legacy:
|
tokenizer_legacy:
|
||||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
# resize the model embeddings when new tokens are added to multiples of 32
|
||||||
# This is reported to improve training speed on some models
|
# this is reported to improve training speed on some models
|
||||||
resize_token_embeddings_to_32x:
|
resize_token_embeddings_to_32x:
|
||||||
|
|
||||||
# Used to identify which the model is based on
|
# used to identify which the model is based on
|
||||||
is_falcon_derived_model:
|
is_falcon_derived_model:
|
||||||
is_llama_derived_model:
|
is_llama_derived_model:
|
||||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
|
||||||
is_mistral_derived_model:
|
is_mistral_derived_model:
|
||||||
|
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
# whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
gptq_groupsize: 128 # group size
|
gptq_groupsize: 128 # group size
|
||||||
gptq_model_v1: false # v1 or v2
|
gptq_model_v1: false # v1 or v2
|
||||||
|
|
||||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||||
load_in_8bit: true
|
load_in_8bit: true
|
||||||
# Use bitsandbytes 4 bit
|
# use bitsandbytes 4 bit
|
||||||
load_in_4bit:
|
load_in_4bit:
|
||||||
|
|
||||||
# Use CUDA bf16
|
# Use CUDA bf16
|
||||||
@@ -451,9 +449,9 @@ tf32: true # require >=ampere
|
|||||||
bfloat16: true # require >=ampere
|
bfloat16: true # require >=ampere
|
||||||
float16: true
|
float16: true
|
||||||
|
|
||||||
# A list of one or more datasets to finetune the model with
|
# a list of one or more datasets to finetune the model with
|
||||||
datasets:
|
datasets:
|
||||||
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||||
- path: vicgalle/alpaca-gpt4
|
- path: vicgalle/alpaca-gpt4
|
||||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||||
@@ -463,18 +461,17 @@ datasets:
|
|||||||
name: # Optional[str] name of dataset configuration to load
|
name: # Optional[str] name of dataset configuration to load
|
||||||
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||||
|
|
||||||
# Custom user prompt
|
# custom user prompt
|
||||||
- path: repo
|
- path: repo
|
||||||
type:
|
type:
|
||||||
# The below are defaults. only set what's needed.
|
# the below are defaults. only set what's needed.
|
||||||
system_prompt: ""
|
system_prompt: ""
|
||||||
system_format: "{system}"
|
|
||||||
field_system: system
|
field_system: system
|
||||||
field_instruction: instruction
|
field_instruction: instruction
|
||||||
field_input: input
|
field_output: input
|
||||||
field_output: output
|
|
||||||
|
|
||||||
# Customizable to be single line or multi-line
|
# customizable to be single line or multi-line
|
||||||
|
system_format: "{system}"
|
||||||
# 'format' can include {input}
|
# 'format' can include {input}
|
||||||
format: |-
|
format: |-
|
||||||
User: {instruction} {input}
|
User: {instruction} {input}
|
||||||
@@ -482,13 +479,13 @@ datasets:
|
|||||||
# 'no_input_format' cannot include {input}
|
# 'no_input_format' cannot include {input}
|
||||||
no_input_format: "{instruction} "
|
no_input_format: "{instruction} "
|
||||||
|
|
||||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
# for completions datsets, uses the provided field if not `text`
|
||||||
field:
|
field:
|
||||||
|
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
# Push prepared dataset to hub
|
# push prepared dataset to hub
|
||||||
push_dataset_to_hub: # repo path
|
push_dataset_to_hub: # repo path
|
||||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||||
# if not set.
|
# if not set.
|
||||||
@@ -498,8 +495,8 @@ hub_model_id: # repo path to push finetuned model
|
|||||||
# how to push checkpoints to hub
|
# how to push checkpoints to hub
|
||||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||||
hub_strategy:
|
hub_strategy:
|
||||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
# required to be true when used in combination with `push_dataset_to_hub`
|
||||||
hf_use_auth_token: # boolean
|
hf_use_auth_token: # boolean
|
||||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||||
val_set_size: 0.04
|
val_set_size: 0.04
|
||||||
@@ -508,38 +505,34 @@ dataset_shard_num:
|
|||||||
# Index of shard to use for whole dataset
|
# Index of shard to use for whole dataset
|
||||||
dataset_shard_idx:
|
dataset_shard_idx:
|
||||||
|
|
||||||
# The maximum length of an input to train with, this should typically be less than 2048
|
# the maximum length of an input to train with, this should typically be less than 2048
|
||||||
# as most models have a token/context limit of 2048
|
# as most models have a token/context limit of 2048
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
# Pad inputs so each step uses constant sized buffers
|
# pad inputs so each step uses constant sized buffers
|
||||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||||
pad_to_sequence_len:
|
pad_to_sequence_len:
|
||||||
# Max sequence length to concatenate training samples together up to
|
# max sequence length to concatenate training samples together up to
|
||||||
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||||
# FutureWarning: This will soon be DEPRECATED
|
# FutureWarning: This will soon be DEPRECATED
|
||||||
max_packed_sequence_len: 1024
|
max_packed_sequence_len: 1024
|
||||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||||
sample_packing:
|
sample_packing:
|
||||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
# set to 'false' if getting errors during eval with sample_packing on.
|
||||||
eval_sample_packing:
|
eval_sample_packing:
|
||||||
# You can set these packing optimizations AFTER starting a training at least once.
|
# you can set these packing optimizations AFTER starting a training at least once.
|
||||||
# The trainer will provide recommended values for these values.
|
# The trainer will provide recommended values for these values.
|
||||||
sample_packing_eff_est:
|
sample_packing_eff_est:
|
||||||
total_num_tokens:
|
total_num_tokens:
|
||||||
|
|
||||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
# If you already have a lora model trained that you want to load, put that here.
|
# if you already have a lora model trained that you want to load, put that here
|
||||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
# lora hyperparameters
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
lora_r: 8
|
||||||
# LoRA hyperparameters
|
lora_alpha: 16
|
||||||
# For more details about the following options, see:
|
lora_dropout: 0.05
|
||||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
lora_target_modules:
|
||||||
peft_r: 8
|
|
||||||
peft_alpha: 16
|
|
||||||
peft_dropout: 0.05
|
|
||||||
peft_target_modules:
|
|
||||||
- q_proj
|
- q_proj
|
||||||
- v_proj
|
- v_proj
|
||||||
# - k_proj
|
# - k_proj
|
||||||
@@ -547,49 +540,36 @@ peft_target_modules:
|
|||||||
# - gate_proj
|
# - gate_proj
|
||||||
# - down_proj
|
# - down_proj
|
||||||
# - up_proj
|
# - up_proj
|
||||||
peft_target_linear: # if true, will target all linear layers
|
lora_target_linear: # if true, will target all linear layers
|
||||||
|
lora_modules_to_save:
|
||||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
|
||||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
|
||||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
|
||||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
|
||||||
peft_modules_to_save:
|
|
||||||
# - embed_tokens
|
# - embed_tokens
|
||||||
# - lm_head
|
# - lm_head
|
||||||
|
|
||||||
# Once you complete training, the model will be saved to the following directory.
|
|
||||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
|
||||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
|
||||||
lora_out_dir:
|
lora_out_dir:
|
||||||
peft_fan_in_fan_out: false
|
lora_fan_in_fan_out: false
|
||||||
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
|
|
||||||
|
|
||||||
# ReLoRA configuration
|
# ReLoRA configuration
|
||||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||||
relora_steps: # Number of steps per ReLoRA restart
|
relora_steps: # number of steps per ReLoRA restart
|
||||||
relora_warmup_steps: # Number of per-restart warmup steps
|
relora_warmup_steps: # number of per-restart warmup steps
|
||||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||||
|
|
||||||
# wandb configuration if you're using it
|
# wandb configuration if you're using it
|
||||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||||
wandb_project: # Your wandb project name
|
wandb_project: # your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # a wandb Team name if using a Team
|
||||||
wandb_watch:
|
wandb_watch:
|
||||||
wandb_run_id: # Set the name of your wandb run
|
wandb_run_id: # set the name of your wandb run
|
||||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||||
|
|
||||||
# Where to save the full-finetuned model to
|
# where to save the finished model to
|
||||||
output_dir: ./completed-model
|
output_dir: ./completed-model
|
||||||
|
|
||||||
# Whether to use torch.compile and which backend to use
|
# whether to use torch.compile and which backend to use
|
||||||
torch_compile: # bool
|
torch_compile: # bool
|
||||||
torch_compile_backend: # Optional[str]
|
torch_compile_backend: # Optional[str]
|
||||||
|
|
||||||
# Training hyperparameters
|
# training hyperparameters
|
||||||
|
|
||||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
|
||||||
gradient_accumulation_steps: 1
|
gradient_accumulation_steps: 1
|
||||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
|
||||||
micro_batch_size: 2
|
micro_batch_size: 2
|
||||||
eval_batch_size:
|
eval_batch_size:
|
||||||
num_epochs: 3
|
num_epochs: 3
|
||||||
@@ -597,47 +577,44 @@ warmup_steps: 100
|
|||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
lr_quadratic_warmup:
|
lr_quadratic_warmup:
|
||||||
logging_steps:
|
logging_steps:
|
||||||
save_strategy: # Set to `no` to skip checkpoint saves
|
save_strategy: # set to `no` to skip checkpoint saves
|
||||||
save_steps: # Leave empty to save at each epoch
|
save_steps: # leave empty to save at each epoch
|
||||||
eval_steps: # Leave empty to eval at each epoch
|
eval_steps: # leave empty to eval at each epoch
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # checkpoints saved at a time
|
||||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
|
||||||
# if both are set, num_epochs will not be guaranteed.
|
|
||||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
|
||||||
max_steps:
|
max_steps:
|
||||||
|
|
||||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
|
||||||
|
|
||||||
# Save model as safetensors (require safetensors package)
|
# save model as safetensors (require safetensors package)
|
||||||
save_safetensors:
|
save_safetensors:
|
||||||
|
|
||||||
# Whether to mask out or include the human's prompt from the training labels
|
# whether to mask out or include the human's prompt from the training labels
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
# Group similarly sized data to minimize padding.
|
# group similarly sized data to minimize padding
|
||||||
# May be slower to start, as it must download and sort the entire dataset.
|
# may be slower to start, as it must download and sort the entire dataset
|
||||||
# Note that training loss may have an oscillating pattern with this enabled.
|
# note that training loss may have an oscillating pattern with this enabled
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|
||||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||||
gradient_checkpointing: false
|
gradient_checkpointing: false
|
||||||
|
|
||||||
# Stop training after this many evaluation losses have increased in a row
|
# stop training after this many evaluation losses have increased in a row
|
||||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||||
early_stopping_patience: 3
|
early_stopping_patience: 3
|
||||||
|
|
||||||
# Specify a scheduler and kwargs to use with the optimizer
|
# specify a scheduler and kwargs to use with the optimizer
|
||||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||||
lr_scheduler_kwargs:
|
lr_scheduler_kwargs:
|
||||||
|
|
||||||
# For one_cycle optim
|
# for one_cycle optim
|
||||||
lr_div_factor: # Learning rate div factor
|
lr_div_factor: # learning rate div factor
|
||||||
|
|
||||||
# For log_sweep optim
|
# for log_sweep optim
|
||||||
log_sweep_min_lr:
|
log_sweep_min_lr:
|
||||||
log_sweep_max_lr:
|
log_sweep_max_lr:
|
||||||
|
|
||||||
# Specify optimizer
|
# specify optimizer
|
||||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||||
#
|
#
|
||||||
@@ -663,7 +640,7 @@ log_sweep_max_lr:
|
|||||||
# - paged_lion_32bit
|
# - paged_lion_32bit
|
||||||
# - paged_lion_8bit
|
# - paged_lion_8bit
|
||||||
optimizer:
|
optimizer:
|
||||||
# Specify weight decay
|
# specify weight decay
|
||||||
weight_decay:
|
weight_decay:
|
||||||
# adamw hyperparams
|
# adamw hyperparams
|
||||||
adam_beta1:
|
adam_beta1:
|
||||||
@@ -672,56 +649,49 @@ adam_epsilon:
|
|||||||
# Gradient clipping max norm
|
# Gradient clipping max norm
|
||||||
max_grad_norm:
|
max_grad_norm:
|
||||||
|
|
||||||
# Augmentation techniques
|
# whether to bettertransformers
|
||||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
|
||||||
# currently only supported on Llama and Mistral
|
|
||||||
noisy_embedding_alpha:
|
|
||||||
|
|
||||||
# Whether to bettertransformers
|
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||||
xformers_attention:
|
xformers_attention:
|
||||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||||
flash_attention:
|
flash_attention:
|
||||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||||
# Whether to use scaled-dot-product attention
|
# whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
# Landmark attention (only llama)
|
# Landmark attention (only llama)
|
||||||
landmark_attention:
|
landmark_attention:
|
||||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||||
# LLaMA only
|
# llama only
|
||||||
xpos_rope:
|
xpos_rope:
|
||||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||||
rope_scaling:
|
rope_scaling:
|
||||||
type: # linear | dynamic
|
type: # linear | dynamic
|
||||||
factor: # float
|
factor: # float
|
||||||
|
|
||||||
# Resume from a specific checkpoint dir
|
# resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||||
# Be careful with this being turned on between different models.
|
# be careful with this being turned on between different models
|
||||||
auto_resume_from_checkpoints: false
|
auto_resume_from_checkpoints: false
|
||||||
|
|
||||||
# Don't mess with this, it's here for accelerate and torchrun
|
# don't mess with this, it's here for accelerate and torchrun
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|
||||||
# Add or change special tokens.
|
# add or change special tokens
|
||||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
|
||||||
special_tokens:
|
special_tokens:
|
||||||
# bos_token: "<s>"
|
# bos_token: "<s>"
|
||||||
# eos_token: "</s>"
|
# eos_token: "</s>"
|
||||||
# unk_token: "<unk>"
|
# unk_token: "<unk>"
|
||||||
|
# add extra tokens
|
||||||
# Add extra tokens.
|
|
||||||
tokens:
|
tokens:
|
||||||
|
|
||||||
# FSDP
|
# FSDP
|
||||||
fsdp:
|
fsdp:
|
||||||
fsdp_config:
|
fsdp_config:
|
||||||
|
|
||||||
# Deepspeed config path. e.g., deepspeed/zero3.json
|
# Deepspeed config path
|
||||||
deepspeed:
|
deepspeed:
|
||||||
|
|
||||||
# Advanced DDP Arguments
|
# Advanced DDP Arguments
|
||||||
@@ -747,66 +717,6 @@ strict:
|
|||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
<details>
|
|
||||||
<summary> Understanding of batch size and gradient accumulation steps </summary>
|
|
||||||
<br/>
|
|
||||||
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
|
|
||||||
|
|
||||||
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
|
|
||||||
|
|
||||||
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
|
|
||||||
|
|
||||||
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
|
|
||||||
|
|
||||||
**Example 1:**
|
|
||||||
Micro batch size: 3
|
|
||||||
Gradient accumulation steps: 2
|
|
||||||
Number of GPUs: 3
|
|
||||||
Total batch size = 3 * 2 * 3 = 18
|
|
||||||
|
|
||||||
```
|
|
||||||
| GPU 1 | GPU 2 | GPU 3 |
|
|
||||||
|----------------|----------------|----------------|
|
|
||||||
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
|
|
||||||
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|
|
||||||
|----------------|----------------|----------------|
|
|
||||||
| → (accumulate) | → (accumulate) | → (accumulate) |
|
|
||||||
|----------------|----------------|----------------|
|
|
||||||
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
|
|
||||||
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|
|
||||||
|----------------|----------------|----------------|
|
|
||||||
| → (apply) | → (apply) | → (apply) |
|
|
||||||
|
|
||||||
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
|
|
||||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
|
|
||||||
|
|
||||||
Weight update for w1:
|
|
||||||
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Example 2:**
|
|
||||||
Micro batch size: 2
|
|
||||||
Gradient accumulation steps: 1
|
|
||||||
Number of GPUs: 3
|
|
||||||
Total batch size = 2 * 1 * 3 = 6
|
|
||||||
|
|
||||||
```
|
|
||||||
| GPU 1 | GPU 2 | GPU 3 |
|
|
||||||
|-----------|-----------|-----------|
|
|
||||||
| S1, S2 | S3, S4 | S5, S6 |
|
|
||||||
| e1, e2 | e3, e4 | e5, e6 |
|
|
||||||
|-----------|-----------|-----------|
|
|
||||||
| → (apply) | → (apply) | → (apply) |
|
|
||||||
|
|
||||||
Accumulated gradient for the weight w1 (considering all GPUs):
|
|
||||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
|
|
||||||
|
|
||||||
Weight update for w1:
|
|
||||||
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
|
|
||||||
```
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
### Train
|
### Train
|
||||||
|
|
||||||
Run
|
Run
|
||||||
@@ -870,7 +780,7 @@ Pass the appropriate flag to the train command:
|
|||||||
|
|
||||||
- Pretrained LORA:
|
- Pretrained LORA:
|
||||||
```bash
|
```bash
|
||||||
python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
|
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
|
||||||
```
|
```
|
||||||
- Full weights finetune:
|
- Full weights finetune:
|
||||||
```bash
|
```bash
|
||||||
@@ -882,16 +792,12 @@ Pass the appropriate flag to the train command:
|
|||||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||||
```
|
```
|
||||||
|
|
||||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
|
||||||
|
|
||||||
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
|
|
||||||
|
|
||||||
### Merge LORA to base
|
### Merge LORA to base
|
||||||
|
|
||||||
Add below flag to train command above
|
Add below flag to train command above
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||||
```
|
```
|
||||||
|
|
||||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||||
|
|||||||
@@ -5,9 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
|||||||
ARG AXOLOTL_EXTRAS=""
|
ARG AXOLOTL_EXTRAS=""
|
||||||
ARG CUDA="118"
|
ARG CUDA="118"
|
||||||
ENV BNB_CUDA_VERSION=$CUDA
|
ENV BNB_CUDA_VERSION=$CUDA
|
||||||
ARG PYTORCH_VERSION="2.0.1"
|
|
||||||
|
|
||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y vim curl
|
apt-get install -y vim curl
|
||||||
@@ -19,7 +16,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ ARG CUDA="118"
|
|||||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||||
|
|
||||||
RUN apt-get update \
|
RUN apt-get update \
|
||||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
|
||||||
&& wget \
|
&& wget \
|
||||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||||
&& mkdir /root/.conda \
|
&& mkdir /root/.conda \
|
||||||
@@ -57,6 +57,11 @@ FROM base-builder
|
|||||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||||
|
|
||||||
|
# recompile apex
|
||||||
|
RUN python3 -m pip uninstall -y apex
|
||||||
|
RUN git clone https://github.com/NVIDIA/apex
|
||||||
|
RUN cd apex && python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||||
|
|
||||||
RUN mkdir -p /workspace/builds
|
RUN mkdir -p /workspace/builds
|
||||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||||
|
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
# Multipack
|
|
||||||
|
|
||||||
4k context, bsz =4,
|
|
||||||
each character represents 256 tokens
|
|
||||||
X represents a padding token
|
|
||||||
|
|
||||||
```
|
|
||||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
|
||||||
[[ A A A A A A A A A A A ]
|
|
||||||
B B B B B B ]
|
|
||||||
C C C C C C C ]
|
|
||||||
D D D D ]]
|
|
||||||
|
|
||||||
[[ E E E E E E E E ]
|
|
||||||
[ F F F F ]
|
|
||||||
[ G G G ]
|
|
||||||
[ H H H H ]]
|
|
||||||
|
|
||||||
[[ I I I ]
|
|
||||||
[ J J J ]
|
|
||||||
[ K K K K K]
|
|
||||||
[ L L L ]]
|
|
||||||
```
|
|
||||||
|
|
||||||
after padding to longest input in each step
|
|
||||||
```
|
|
||||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
|
||||||
[[ A A A A A A A A A A A ]
|
|
||||||
B B B B B B X X X X X X ]
|
|
||||||
C C C C C C C X X X X ]
|
|
||||||
D D D D X X X X X X X ]]
|
|
||||||
|
|
||||||
[[ E E E E E E E E ]
|
|
||||||
[ F F F F X X X X ]
|
|
||||||
[ G G G X X X X X ]
|
|
||||||
[ H H H H X X X X ]]
|
|
||||||
|
|
||||||
[[ I I I X X ]
|
|
||||||
[ J J J X X ]
|
|
||||||
[ K K K K K ]
|
|
||||||
[ L L L X X ]]
|
|
||||||
```
|
|
||||||
|
|
||||||
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
|
|
||||||
```
|
|
||||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
|
||||||
[[ A A A A A A A A A A A B B B B B
|
|
||||||
B C C C C C C C D D D D E E E E
|
|
||||||
E E E E F F F F F G G G H H H H
|
|
||||||
I I I J J J J K K K K K L L L X ]]
|
|
||||||
```
|
|
||||||
@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
|
|||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
sample_packing: false
|
sample_packing: false
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len: 2048
|
max_packed_sequence_len: 2048
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ dataset_prepared_path:
|
|||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 64
|
lora_r: 64
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r:
|
lora_r:
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing:
|
sample_packing:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -1,72 +0,0 @@
|
|||||||
base_model: meta-llama/Llama-2-7b-hf
|
|
||||||
base_model_config: meta-llama/Llama-2-7b-hf
|
|
||||||
model_type: LlamaForCausalLM
|
|
||||||
tokenizer_type: LlamaTokenizer
|
|
||||||
is_llama_derived_model: true
|
|
||||||
|
|
||||||
load_in_8bit: true
|
|
||||||
load_in_4bit: false
|
|
||||||
strict: false
|
|
||||||
|
|
||||||
datasets:
|
|
||||||
- path: mhenrichsen/alpaca_2k_test
|
|
||||||
type: alpaca
|
|
||||||
dataset_prepared_path: last_run_prepared
|
|
||||||
val_set_size: 0.01
|
|
||||||
output_dir: ./ia3-out
|
|
||||||
|
|
||||||
sequence_len: 4096
|
|
||||||
sample_packing: true
|
|
||||||
pad_to_sequence_len: true
|
|
||||||
|
|
||||||
adapter: ia3
|
|
||||||
peft_model_dir:
|
|
||||||
peft_target_modules:
|
|
||||||
- k_proj
|
|
||||||
- v_proj
|
|
||||||
- down_proj
|
|
||||||
peft_feedforward_modules:
|
|
||||||
- down_proj
|
|
||||||
peft_fan_in_fan_out: false
|
|
||||||
|
|
||||||
wandb_project:
|
|
||||||
wandb_entity:
|
|
||||||
wandb_watch:
|
|
||||||
wandb_run_id:
|
|
||||||
wandb_log_model:
|
|
||||||
|
|
||||||
gradient_accumulation_steps: 1
|
|
||||||
micro_batch_size: 2
|
|
||||||
num_epochs: 5
|
|
||||||
optimizer: adamw_bnb_8bit
|
|
||||||
lr_scheduler: cosine
|
|
||||||
learning_rate: 0.0002
|
|
||||||
|
|
||||||
train_on_inputs: false
|
|
||||||
group_by_length: false
|
|
||||||
bf16: true
|
|
||||||
fp16: false
|
|
||||||
tf32: false
|
|
||||||
|
|
||||||
gradient_checkpointing: true
|
|
||||||
early_stopping_patience:
|
|
||||||
resume_from_checkpoint:
|
|
||||||
local_rank:
|
|
||||||
logging_steps: 1
|
|
||||||
xformers_attention:
|
|
||||||
flash_attention: true
|
|
||||||
|
|
||||||
warmup_steps: 10
|
|
||||||
eval_steps: 0.05
|
|
||||||
eval_table_size:
|
|
||||||
eval_table_max_new_tokens:
|
|
||||||
save_steps:
|
|
||||||
debug:
|
|
||||||
deepspeed:
|
|
||||||
weight_decay: 0.1
|
|
||||||
fsdp:
|
|
||||||
fsdp_config:
|
|
||||||
special_tokens:
|
|
||||||
bos_token: "<s>"
|
|
||||||
eos_token: "</s>"
|
|
||||||
unk_token: "<unk>"
|
|
||||||
@@ -20,7 +20,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
|||||||
output_dir: ./relora-out
|
output_dir: ./relora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ sequence_len: 4096
|
|||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -16,8 +16,8 @@ val_set_size: 0.01
|
|||||||
output_dir: ./out
|
output_dir: ./out
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing:
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len:
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
@@ -30,7 +30,7 @@ micro_batch_size: 2
|
|||||||
num_epochs: 3
|
num_epochs: 3
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
learning_rate: 0.000005
|
learning_rate: 0.0002
|
||||||
|
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
|
|||||||
@@ -19,8 +19,8 @@ adapter: qlora
|
|||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: True
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: True
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
@@ -43,7 +43,7 @@ wandb_run_id:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
|
|
||||||
gradient_accumulation_steps: 4
|
gradient_accumulation_steps: 4
|
||||||
micro_batch_size: 2
|
micro_batch_size: 4
|
||||||
num_epochs: 1
|
num_epochs: 1
|
||||||
optimizer: adamw_bnb_8bit
|
optimizer: adamw_bnb_8bit
|
||||||
lr_scheduler: cosine
|
lr_scheduler: cosine
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
lora_r:
|
lora_r:
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 1024
|
sequence_len: 1024
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ sample_packing: true
|
|||||||
pad_to_sequence_len:
|
pad_to_sequence_len:
|
||||||
|
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r:
|
lora_r:
|
||||||
lora_alpha:
|
lora_alpha:
|
||||||
lora_dropout:
|
lora_dropout:
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
|
|||||||
pad_to_sequence_len:
|
pad_to_sequence_len:
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
lora_r: 64
|
lora_r: 64
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
lora_dropout: 0.05
|
lora_dropout: 0.05
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len: 2048
|
max_packed_sequence_len: 2048
|
||||||
lora_r: 64
|
lora_r: 64
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 512
|
sequence_len: 512
|
||||||
lora_r: 16
|
lora_r: 16
|
||||||
lora_alpha: 32
|
lora_alpha: 32
|
||||||
@@ -28,8 +28,8 @@ num_epochs: 3
|
|||||||
learning_rate: 0.00001
|
learning_rate: 0.00001
|
||||||
train_on_inputs: false
|
train_on_inputs: false
|
||||||
group_by_length: false
|
group_by_length: false
|
||||||
bf16: true
|
bf16: True
|
||||||
tf32: true
|
tf32: True
|
||||||
early_stopping_patience:
|
early_stopping_patience:
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.02
|
val_set_size: 0.02
|
||||||
adapter:
|
adapter:
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ datasets:
|
|||||||
dataset_prepared_path:
|
dataset_prepared_path:
|
||||||
val_set_size: 0.05
|
val_set_size: 0.05
|
||||||
adapter: lora
|
adapter: lora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ dataset_prepared_path:
|
|||||||
val_set_size: 0.01
|
val_set_size: 0.01
|
||||||
# enable QLoRA
|
# enable QLoRA
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
peft_model_dir:
|
lora_model_dir:
|
||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
max_packed_sequence_len:
|
max_packed_sequence_len:
|
||||||
|
|
||||||
|
|||||||
Binary file not shown.
|
Before Width: | Height: | Size: 370 KiB |
@@ -16,7 +16,7 @@ flash-attn>=2.3.0
|
|||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
xformers>=0.0.22
|
xformers
|
||||||
optimum
|
optimum
|
||||||
hf_transfer
|
hf_transfer
|
||||||
colorama
|
colorama
|
||||||
|
|||||||
10
setup.py
10
setup.py
@@ -21,14 +21,6 @@ def parse_requirements():
|
|||||||
):
|
):
|
||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
|
||||||
if "torch==2.1.0" in _install_requires:
|
|
||||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
|
||||||
_install_requires.append(
|
|
||||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
|
||||||
)
|
|
||||||
|
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
|
|
||||||
@@ -46,7 +38,7 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn>=2.3.0",
|
"flash-attn>=2.2.1",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
|
|||||||
@@ -194,7 +194,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
|||||||
# load the config from the yaml file
|
# load the config from the yaml file
|
||||||
with open(config, encoding="utf-8") as file:
|
with open(config, encoding="utf-8") as file:
|
||||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
cfg.axolotl_config_path = config
|
|
||||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||||
# then overwrite the value
|
# then overwrite the value
|
||||||
cfg_keys = cfg.keys()
|
cfg_keys = cfg.keys()
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
print_axolotl_text_art()
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
parsed_cfg.sample_packing = False
|
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
|
|||||||
@@ -1,12 +1,10 @@
|
|||||||
"""
|
"""
|
||||||
CLI to run training on a model
|
CLI to run training on a model
|
||||||
"""
|
"""
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
from colorama import Fore
|
|
||||||
|
|
||||||
from axolotl.cli import (
|
from axolotl.cli import (
|
||||||
check_accelerate_default_config,
|
check_accelerate_default_config,
|
||||||
@@ -16,11 +14,8 @@ from axolotl.cli import (
|
|||||||
print_axolotl_text_art,
|
print_axolotl_text_art,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
||||||
from axolotl.train import train
|
from axolotl.train import train
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.cli.train")
|
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
@@ -32,14 +27,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
|
||||||
msg = (
|
|
||||||
Fore.RED
|
|
||||||
+ "--prepare_ds_only called without dataset_prepared_path set."
|
|
||||||
+ Fore.RESET
|
|
||||||
)
|
|
||||||
LOG.warning(msg)
|
|
||||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
|
||||||
|
|
||||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
if parsed_cli_args.prepare_ds_only:
|
if parsed_cli_args.prepare_ds_only:
|
||||||
|
|||||||
@@ -1,5 +0,0 @@
|
|||||||
"""
|
|
||||||
Various shared constants
|
|
||||||
"""
|
|
||||||
|
|
||||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
|
||||||
@@ -5,7 +5,7 @@ import os
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset, Sequence, Value
|
||||||
|
|
||||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|
||||||
@@ -42,11 +42,15 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
map_kwargs["batch_size"] = 100
|
map_kwargs["batch_size"] = 100
|
||||||
return dataset.map(
|
return (
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
dataset.map(
|
||||||
num_proc=num_proc,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
remove_columns=features,
|
num_proc=num_proc,
|
||||||
**map_kwargs,
|
remove_columns=features,
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
|
||||||
|
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -116,8 +116,6 @@ def flashattn_forward(
|
|||||||
attention_mask: [bsz, q_len]
|
attention_mask: [bsz, q_len]
|
||||||
"""
|
"""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
original_dtype = hidden_states.dtype
|
|
||||||
|
|
||||||
bsz, q_len, _ = hidden_states.size()
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
if not hasattr(self, "pretraining_tp"):
|
if not hasattr(self, "pretraining_tp"):
|
||||||
@@ -153,13 +151,6 @@ def flashattn_forward(
|
|||||||
key_states = self.k_proj(hidden_states)
|
key_states = self.k_proj(hidden_states)
|
||||||
value_states = self.v_proj(hidden_states)
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
if query_states.dtype == torch.float32:
|
|
||||||
query_states = query_states.to(dtype=original_dtype)
|
|
||||||
if key_states.dtype == torch.float32:
|
|
||||||
key_states = key_states.to(dtype=original_dtype)
|
|
||||||
if value_states.dtype == torch.float32:
|
|
||||||
value_states = value_states.to(dtype=original_dtype)
|
|
||||||
|
|
||||||
query_states = query_states.view(
|
query_states = query_states.view(
|
||||||
bsz, q_len, self.num_heads, self.head_dim
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
).transpose(1, 2)
|
).transpose(1, 2)
|
||||||
@@ -318,10 +309,6 @@ def flashattn_forward(
|
|||||||
else:
|
else:
|
||||||
attn_output = self.o_proj(attn_output)
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
# handle conversion back for IA3
|
|
||||||
if attn_output.dtype == torch.float32:
|
|
||||||
attn_output = attn_output.to(dtype=original_dtype)
|
|
||||||
|
|
||||||
return attn_output, None, past_key_value
|
return attn_output, None, past_key_value
|
||||||
|
|
||||||
|
|
||||||
@@ -515,7 +502,6 @@ def llama_model_forward(
|
|||||||
)
|
)
|
||||||
|
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
original_dtype = hidden_states.dtype
|
|
||||||
|
|
||||||
if self.gradient_checkpointing and self.training:
|
if self.gradient_checkpointing and self.training:
|
||||||
if use_cache:
|
if use_cache:
|
||||||
@@ -573,10 +559,6 @@ def llama_model_forward(
|
|||||||
|
|
||||||
hidden_states = layer_outputs[0]
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
# handle conversion back for IA3
|
|
||||||
if hidden_states.dtype == torch.float32:
|
|
||||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
|
||||||
|
|
||||||
if use_cache:
|
if use_cache:
|
||||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
"""
|
|
||||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def noised_embed(orig_embed, noise_alpha, model):
|
|
||||||
def new_func(input_ids):
|
|
||||||
# during training, we add noise to the embedding
|
|
||||||
# during generation, we don't add noise to the embedding
|
|
||||||
if model.training:
|
|
||||||
embed_init = orig_embed(input_ids)
|
|
||||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
|
||||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
|
||||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
|
||||||
-mag_norm, mag_norm
|
|
||||||
)
|
|
||||||
return orig_embed(input_ids)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
def post_init(orig_post_init):
|
|
||||||
def new_func(self):
|
|
||||||
orig_post_init(self)
|
|
||||||
self.embed_tokens.forward = noised_embed(
|
|
||||||
self.embed_tokens.forward, noise_alpha, self
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
|
||||||
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
|
||||||
)
|
|
||||||
@@ -14,9 +14,6 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
|||||||
flash_attn_varlen_qkvpacked_func,
|
flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
MistralAttention as OriginalMistralAttention,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
)
|
)
|
||||||
@@ -45,44 +42,6 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@torch.jit.script
|
|
||||||
def _make_sliding_window_causal_mask(
|
|
||||||
bsz: int,
|
|
||||||
tgt_len: int,
|
|
||||||
dtype: torch.dtype,
|
|
||||||
device: torch.device,
|
|
||||||
past_key_values_length: int = 0,
|
|
||||||
sliding_window: int = 4096,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Make causal mask used for sliding window attention
|
|
||||||
"""
|
|
||||||
tensor = torch.full(
|
|
||||||
(tgt_len, tgt_len),
|
|
||||||
fill_value=1,
|
|
||||||
device=device,
|
|
||||||
)
|
|
||||||
mask = torch.tril(tensor, diagonal=0)
|
|
||||||
# make the mask banded to account for sliding window
|
|
||||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
|
||||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
|
||||||
mask = torch.log(mask).to(dtype)
|
|
||||||
|
|
||||||
if past_key_values_length > 0:
|
|
||||||
mask = torch.cat(
|
|
||||||
[
|
|
||||||
torch.zeros(
|
|
||||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
|
||||||
),
|
|
||||||
mask,
|
|
||||||
],
|
|
||||||
dim=-1,
|
|
||||||
)
|
|
||||||
return mask[None, None, :, :].expand(
|
|
||||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||||
# requires the attention mask to be the same as the key_padding_mask
|
# requires the attention mask to be the same as the key_padding_mask
|
||||||
def _prepare_decoder_attention_mask(
|
def _prepare_decoder_attention_mask(
|
||||||
@@ -94,29 +53,11 @@ def _prepare_decoder_attention_mask(
|
|||||||
sliding_window,
|
sliding_window,
|
||||||
): # pylint: disable=unused-argument
|
): # pylint: disable=unused-argument
|
||||||
# [bsz, seq_len]
|
# [bsz, seq_len]
|
||||||
if attention_mask is None:
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
|
||||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
|
||||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
|
||||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
|
||||||
bsz=input_shape[0],
|
|
||||||
tgt_len=input_shape[1],
|
|
||||||
dtype=inputs_embeds.dtype,
|
|
||||||
device=inputs_embeds.device,
|
|
||||||
past_key_values_length=past_key_values_length,
|
|
||||||
sliding_window=sliding_window,
|
|
||||||
)
|
|
||||||
attention_mask = attention_mask + sliding_window_mask
|
|
||||||
else:
|
|
||||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
|
||||||
|
|
||||||
return attention_mask
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
def flashattn_forward(
|
def flashattn_forward(
|
||||||
self: OriginalMistralAttention,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
@@ -150,41 +91,10 @@ def flashattn_forward(
|
|||||||
query_states, key_states, cos, sin, position_ids
|
query_states, key_states, cos, sin, position_ids
|
||||||
)
|
)
|
||||||
|
|
||||||
use_sliding_windows = (
|
|
||||||
hasattr(self.config, "sliding_window") is not None
|
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
)
|
|
||||||
|
|
||||||
if use_sliding_windows:
|
|
||||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
|
||||||
else:
|
|
||||||
window_size = (-1, -1)
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
if past_key_value is not None:
|
||||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
# reuse k, v, self_attention
|
||||||
if (
|
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||||
hasattr(self.config, "sliding_window")
|
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||||
and kv_seq_len > self.config.sliding_window
|
|
||||||
):
|
|
||||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
|
||||||
|
|
||||||
past_key = past_key_value[0]
|
|
||||||
past_value = past_key_value[1]
|
|
||||||
|
|
||||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
|
||||||
|
|
||||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
|
||||||
raise ValueError(
|
|
||||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
|
||||||
f" {past_key.shape}"
|
|
||||||
)
|
|
||||||
|
|
||||||
past_key_value = (past_key, past_value) if use_cache else None
|
|
||||||
|
|
||||||
if past_key_value is not None:
|
|
||||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
|
||||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
|
||||||
|
|
||||||
past_key_value = (key_states, value_states) if use_cache else None
|
past_key_value = (key_states, value_states) if use_cache else None
|
||||||
|
|
||||||
@@ -210,13 +120,7 @@ def flashattn_forward(
|
|||||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
output = flash_attn_varlen_qkvpacked_func(
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
qkv,
|
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||||
cu_seqlens,
|
|
||||||
max_seqlen,
|
|
||||||
0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
window_size=window_size,
|
|
||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
elif query_states.shape == key_states.shape:
|
elif query_states.shape == key_states.shape:
|
||||||
@@ -242,7 +146,6 @@ def flashattn_forward(
|
|||||||
0.0,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
|
||||||
)
|
)
|
||||||
output = output_pad_fn(output_unpad)
|
output = output_pad_fn(output_unpad)
|
||||||
else:
|
else:
|
||||||
@@ -254,7 +157,6 @@ def flashattn_forward(
|
|||||||
query_states,
|
query_states,
|
||||||
torch.stack([key_states, value_states], 2),
|
torch.stack([key_states, value_states], 2),
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
( # pylint: disable=unbalanced-tuple-unpacking
|
( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
@@ -289,7 +191,6 @@ def flashattn_forward(
|
|||||||
0.0,
|
0.0,
|
||||||
softmax_scale=None,
|
softmax_scale=None,
|
||||||
causal=is_causal,
|
causal=is_causal,
|
||||||
window_size=window_size,
|
|
||||||
)
|
)
|
||||||
output = output_pad_fn(output_unpad)
|
output = output_pad_fn(output_unpad)
|
||||||
|
|
||||||
|
|||||||
@@ -1,40 +0,0 @@
|
|||||||
"""
|
|
||||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import transformers.models.mistral.modeling_mistral
|
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
logger = logging.get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
def noised_embed(orig_embed, noise_alpha, model):
|
|
||||||
def new_func(input_ids):
|
|
||||||
# during training, we add noise to the embedding
|
|
||||||
# during generation, we don't add noise to the embedding
|
|
||||||
if model.training:
|
|
||||||
embed_init = orig_embed(input_ids)
|
|
||||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
|
||||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
|
||||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
|
||||||
-mag_norm, mag_norm
|
|
||||||
)
|
|
||||||
return orig_embed(input_ids)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
def post_init(orig_post_init):
|
|
||||||
def new_func(self):
|
|
||||||
orig_post_init(self)
|
|
||||||
self.embed_tokens.forward = noised_embed(
|
|
||||||
self.embed_tokens.forward, noise_alpha, self
|
|
||||||
)
|
|
||||||
|
|
||||||
return new_func
|
|
||||||
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
|
||||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
|
||||||
)
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""Module for Alpaca prompt strategy classes"""
|
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||||
|
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
|
|||||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||||
|
|
||||||
|
|
||||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
def load(tokenizer, cfg):
|
||||||
prompt_style = PromptStyle.CHAT.value
|
|
||||||
if ds_cfg and "conversation" in ds_cfg:
|
|
||||||
prompt_style = ds_cfg["conversation"]
|
|
||||||
|
|
||||||
return AlpacaPromptTokenizingStrategy(
|
return AlpacaPromptTokenizingStrategy(
|
||||||
AlpacaPrompter(prompt_style),
|
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||||
tokenizer,
|
tokenizer,
|
||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
)
|
)
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
return SimpleShareGPTPromptTokenizingStrategy(
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(
|
ShareGPTPrompterV2(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
role_key_model=field_model,
|
role_key_model=field_model,
|
||||||
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
if ds_cfg and ds_cfg["skip"]:
|
||||||
|
strat.skip_invalid = True
|
||||||
|
return strat
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
def load_role(tokenizer, cfg):
|
||||||
@@ -54,13 +57,38 @@ def load_guanaco(tokenizer, cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
conversation = (
|
||||||
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||||
|
)
|
||||||
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
|
return NousShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_model=field_model,
|
||||||
|
role_key_human=field_human,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
basic sharegpt strategy to grab conversations from the sample row
|
basic sharegpt strategy used by nous/ldj for input/output keyed data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self):
|
||||||
return prompt["conversations"]
|
return "conversation"
|
||||||
|
|
||||||
|
def map_conversation_thread(self, conversation):
|
||||||
|
turns = []
|
||||||
|
for turn in conversation:
|
||||||
|
turns.append({"from": "human", "value": turn["input"]})
|
||||||
|
turns.append({"from": "gpt", "value": turn["output"]})
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
@@ -68,10 +96,11 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
|
|||||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def map_conversation_thread(self, conversation):
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
turns = [
|
||||||
|
{"from": turn["role"], "value": turn["value"]} for turn in conversation
|
||||||
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
@@ -80,11 +109,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
sharegpt strategy that remaps oasst data to sharegpt format
|
sharegpt strategy that remaps oasst data to sharegpt format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def map_conversation_thread(self, conversation):
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||||
turns = [
|
turns = [
|
||||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
{"from": role_map[turn["role"]], "value": turn["text"]}
|
||||||
|
for turn in conversation
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|||||||
@@ -2,7 +2,9 @@
|
|||||||
|
|
||||||
import abc
|
import abc
|
||||||
import copy
|
import copy
|
||||||
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
from fastchat.conversation import Conversation
|
from fastchat.conversation import Conversation
|
||||||
@@ -56,6 +58,26 @@ class PromptTokenizingStrategy(abc.ABC):
|
|||||||
def supports_batched(self):
|
def supports_batched(self):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=128)
|
||||||
|
def _get_user_token(self):
|
||||||
|
try:
|
||||||
|
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||||
|
if isinstance(id_or_ids, (int,)):
|
||||||
|
return id_or_ids
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
@functools.lru_cache(maxsize=128)
|
||||||
|
def _get_assistant_token(self):
|
||||||
|
try:
|
||||||
|
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||||
|
if isinstance(id_or_ids, (int,)):
|
||||||
|
return id_or_ids
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
def _tokenize(
|
def _tokenize(
|
||||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||||
) -> BatchEncoding:
|
) -> BatchEncoding:
|
||||||
@@ -330,99 +352,109 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
Tokenizing strategy for ShareGPT prompts.
|
Tokenizing strategy for ShareGPT prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
_skip_invalid = False
|
||||||
return prompt["conversations"]
|
|
||||||
|
@property
|
||||||
|
def supports_batched(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skip_invalid(self):
|
||||||
|
return self._skip_invalid
|
||||||
|
|
||||||
|
@skip_invalid.setter
|
||||||
|
def skip_invalid(self, value):
|
||||||
|
self._skip_invalid = value
|
||||||
|
|
||||||
|
def get_conversation_thread(self):
|
||||||
|
return "conversations"
|
||||||
|
|
||||||
|
def map_conversation_thread(self, conversation):
|
||||||
|
return conversation
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
result, current_len = tokenize_prompt_default()
|
tokenized_res = defaultdict(lambda: [])
|
||||||
conversation: Conversation = (
|
conv_field = self.get_conversation_thread()
|
||||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
for prmpt in prompt[conv_field]:
|
||||||
)
|
result, current_len = tokenize_prompt_default()
|
||||||
|
user_token = self._get_user_token()
|
||||||
|
assistant_token = self._get_assistant_token()
|
||||||
|
conversation: Conversation = (
|
||||||
|
self.prompter._conversation # pylint: disable=protected-access
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
for _, part in enumerate(
|
||||||
|
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
|
||||||
|
):
|
||||||
|
if isinstance(part, tuple):
|
||||||
|
if conversation.roles[0] in part[0]:
|
||||||
|
turn = part[0] + part[1] if not user_token else part[1]
|
||||||
|
# this is still the user query, we should
|
||||||
|
if not part[1].strip():
|
||||||
|
err_msg = f"user turn has empty text: {prmpt}"
|
||||||
|
if self.skip_invalid:
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
LOG.warning(err_msg)
|
||||||
|
res = self._tokenize(
|
||||||
|
turn,
|
||||||
|
add_eos_token=False,
|
||||||
|
strip_bos_token=True,
|
||||||
|
)
|
||||||
|
if user_token:
|
||||||
|
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||||
|
# everything from this is masked out from the labels
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
|
elif conversation.roles[1] in part[0]:
|
||||||
|
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||||
|
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||||
|
# this should be the assistant response, should end with an eos token
|
||||||
|
if not part[1].strip():
|
||||||
|
err_msg = f"assistant turn has empty text: {prmpt}"
|
||||||
|
if self.skip_invalid:
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
LOG.warning(err_msg)
|
||||||
|
res = self._tokenize(
|
||||||
|
turn,
|
||||||
|
add_eos_token=True,
|
||||||
|
strip_bos_token=True,
|
||||||
|
)
|
||||||
|
if assistant_token:
|
||||||
|
res["input_ids"] = [
|
||||||
|
assistant_token,
|
||||||
|
*res["input_ids"],
|
||||||
|
]
|
||||||
|
# not masked out from labels
|
||||||
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
|
elif part[0] == "":
|
||||||
|
turn = part[1]
|
||||||
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
|
res = self._tokenize(
|
||||||
|
turn, add_eos_token=False, strip_bos_token=False
|
||||||
|
)
|
||||||
|
# everything from this is masked out from the labels
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
|
else:
|
||||||
|
err_msg = f"unhandled role: {part[0]}"
|
||||||
|
if self.skip_invalid:
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
LOG.warning(err_msg)
|
||||||
|
continue
|
||||||
|
|
||||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
# pylint: disable=duplicate-code
|
||||||
role_remap = []
|
result, current_len = parse_tokenized_to_result(
|
||||||
if (
|
result,
|
||||||
conversation.name == "vicuna_v1.1"
|
current_len,
|
||||||
and "roles" in prompt
|
res,
|
||||||
and len(prompt["roles"]) >= 2
|
labels,
|
||||||
):
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
role_remap = [
|
)
|
||||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
for key, val in sorted(result.items(), key=lambda x: x[0]):
|
||||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
tokenized_res[key].append(val)
|
||||||
]
|
except (KeyError, AssertionError, IndexError) as err:
|
||||||
|
raise InvalidDataException(str(err)) from err
|
||||||
try:
|
except ValueError as err:
|
||||||
for _, part in enumerate(
|
LOG.warning("skipping prompt: %s", str(err))
|
||||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
return tokenized_res
|
||||||
):
|
|
||||||
if isinstance(part, tuple):
|
|
||||||
if conversation.roles[0] in part[0]:
|
|
||||||
role = (
|
|
||||||
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
|
|
||||||
if role_remap
|
|
||||||
else part[0]
|
|
||||||
)
|
|
||||||
turn = role + part[1]
|
|
||||||
# this is still the user query, we should
|
|
||||||
if not part[1].strip():
|
|
||||||
LOG.warning(f"user turn has empty text: {prompt}")
|
|
||||||
res = self._tokenize(
|
|
||||||
turn,
|
|
||||||
add_eos_token=False,
|
|
||||||
strip_bos_token=True,
|
|
||||||
)
|
|
||||||
# everything from this is masked out from the labels
|
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
||||||
elif conversation.roles[1] in part[0]:
|
|
||||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
|
||||||
role = (
|
|
||||||
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
|
|
||||||
if role_remap
|
|
||||||
else part[0]
|
|
||||||
)
|
|
||||||
turn = role + part[1]
|
|
||||||
# this should be the assistant response, should end with an eos token
|
|
||||||
if not part[1].strip():
|
|
||||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
|
||||||
res = self._tokenize(
|
|
||||||
turn,
|
|
||||||
add_eos_token=True,
|
|
||||||
strip_bos_token=True,
|
|
||||||
)
|
|
||||||
role_res = self._tokenize(
|
|
||||||
role.rstrip(),
|
|
||||||
add_eos_token=False,
|
|
||||||
strip_bos_token=True,
|
|
||||||
)
|
|
||||||
# not masked out from labels
|
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
|
||||||
len_role = len(role_res["input_ids"])
|
|
||||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
|
||||||
len_role, len(labels)
|
|
||||||
)
|
|
||||||
elif part[0] == "":
|
|
||||||
turn = part[1]
|
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
|
||||||
res = self._tokenize(
|
|
||||||
turn, add_eos_token=False, strip_bos_token=False
|
|
||||||
)
|
|
||||||
# everything from this is masked out from the labels
|
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
|
||||||
else:
|
|
||||||
LOG.warning(f"unhandled role: {part[0]}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
result, current_len = parse_tokenized_to_result(
|
|
||||||
result,
|
|
||||||
current_len,
|
|
||||||
res,
|
|
||||||
labels,
|
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
except (KeyError, AssertionError, IndexError) as err:
|
|
||||||
raise InvalidDataException(str(err)) from err
|
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||||
if not prompt.strip():
|
if not prompt.strip():
|
||||||
|
|||||||
@@ -274,11 +274,9 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
|||||||
raise err
|
raise err
|
||||||
|
|
||||||
conv.messages = []
|
conv.messages = []
|
||||||
for _, sentence in enumerate(source):
|
for j, sentence in enumerate(source):
|
||||||
role = roles[sentence["from"]]
|
role = roles[sentence["from"]]
|
||||||
if len(conv.messages) > 0 and (
|
if role != conv.roles[j % 2]:
|
||||||
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
|
||||||
):
|
|
||||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||||
conv.append_message(role, sentence["value"])
|
conv.append_message(role, sentence["value"])
|
||||||
|
|
||||||
|
|||||||
@@ -58,7 +58,9 @@ def train(
|
|||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
if (
|
||||||
|
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
|
||||||
|
) or cfg.resume_from_checkpoint is True:
|
||||||
possible_checkpoints = [
|
possible_checkpoints = [
|
||||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||||
]
|
]
|
||||||
@@ -71,7 +73,9 @@ def train(
|
|||||||
LOG.info(
|
LOG.info(
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = (
|
||||||
|
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
|
||||||
|
)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
|
|||||||
@@ -514,27 +514,3 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
|||||||
return control
|
return control
|
||||||
|
|
||||||
return LogPredictionCallback
|
return LogPredictionCallback
|
||||||
|
|
||||||
|
|
||||||
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|
||||||
"""Callback to save axolotl config to wandb"""
|
|
||||||
|
|
||||||
def __init__(self, axolotl_config_path):
|
|
||||||
self.axolotl_config_path = axolotl_config_path
|
|
||||||
|
|
||||||
def on_train_begin(
|
|
||||||
self,
|
|
||||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
|
||||||
state: TrainerState, # pylint: disable=unused-argument
|
|
||||||
control: TrainerControl,
|
|
||||||
**kwargs, # pylint: disable=unused-argument
|
|
||||||
):
|
|
||||||
if is_main_process():
|
|
||||||
try:
|
|
||||||
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
|
||||||
artifact.add_file(local_path=self.axolotl_config_path)
|
|
||||||
wandb.run.log_artifact(artifact)
|
|
||||||
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
|
||||||
return control
|
|
||||||
|
|||||||
@@ -121,18 +121,6 @@ def normalize_config(cfg):
|
|||||||
|
|
||||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||||
|
|
||||||
if cfg.adapter is not None:
|
|
||||||
for key in list(cfg.keys()):
|
|
||||||
if key.startswith("lora_"):
|
|
||||||
new_key = key.replace("lora_", "peft_")
|
|
||||||
LOG.warning(
|
|
||||||
PendingDeprecationWarning(
|
|
||||||
f"{key} soon to be deprecated. please use {new_key}"
|
|
||||||
)
|
|
||||||
)
|
|
||||||
cfg[new_key] = cfg[key]
|
|
||||||
del cfg[key]
|
|
||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
if is_torch_bf16_gpu_available():
|
if is_torch_bf16_gpu_available():
|
||||||
@@ -202,10 +190,7 @@ def validate_config(cfg):
|
|||||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||||
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
|
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||||
|
|
||||||
if not cfg.load_in_8bit and cfg.adapter == "ia3":
|
|
||||||
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
|
|
||||||
|
|
||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
if cfg.adapter not in ("lora", "qlora"):
|
if cfg.adapter not in ("lora", "qlora"):
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ from datasets import (
|
|||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
|
||||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies import load
|
from axolotl.prompt_strategies import load
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
@@ -45,6 +44,7 @@ from axolotl.utils.trainer import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||||
|
|
||||||
|
|
||||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||||
@@ -158,21 +158,23 @@ def load_tokenized_prepared_datasets(
|
|||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except (FileNotFoundError, ConnectionError):
|
except (FileNotFoundError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(d.path)
|
local_path = Path(d.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
if not d.type:
|
||||||
ds = load_dataset(
|
ds = load_from_disk(d.path)
|
||||||
d.path,
|
else:
|
||||||
name=d.name,
|
ds = load_dataset(
|
||||||
data_files=d.data_files,
|
d.path,
|
||||||
streaming=False,
|
name=d.name,
|
||||||
split=None,
|
data_files=d.data_files,
|
||||||
)
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = "json"
|
ds_type = "json"
|
||||||
if d.ds_type:
|
if d.ds_type:
|
||||||
@@ -357,7 +359,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
if len(datasets) > 1:
|
if len(datasets) > 1:
|
||||||
LOG.info("shuffle merged datasets")
|
LOG.info("shuffle merged datasets")
|
||||||
dataset = dataset.shuffle(seed=seed)
|
dataset = dataset.shuffle(seed=seed)
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0 and cfg.dataset_prepared_path:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(prepared_ds_path)
|
dataset.save_to_disk(prepared_ds_path)
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
|
|||||||
@@ -180,26 +180,6 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
|
||||||
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
|
||||||
replace_llama_embeddings_with_uniform_distribution,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching with noisy embeddings")
|
|
||||||
replace_llama_embeddings_with_uniform_distribution(
|
|
||||||
noise_alpha=cfg.noisy_embedding_alpha
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
|
||||||
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
|
||||||
replace_mistral_embeddings_with_uniform_distribution,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching with noisy embeddings")
|
|
||||||
replace_mistral_embeddings_with_uniform_distribution(
|
|
||||||
noise_alpha=cfg.noisy_embedding_alpha
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||||
replace_llama_rope_with_xpos_rope,
|
replace_llama_rope_with_xpos_rope,
|
||||||
@@ -402,25 +382,25 @@ def load_model(
|
|||||||
if model_config.model_type == "btlm":
|
if model_config.model_type == "btlm":
|
||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if "lm_head" in name or "embed_tokens" in name:
|
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
|
||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
require_peft: bool = False
|
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||||
if cfg.adapter in ["lora", "qlora", "ia3"]:
|
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||||
require_peft = True
|
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||||
|
):
|
||||||
if require_peft:
|
|
||||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||||
if cfg.gradient_checkpointing:
|
if cfg.gradient_checkpointing:
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable()
|
||||||
model = prepare_model_for_kbit_training(
|
model = prepare_model_for_kbit_training(
|
||||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||||
)
|
)
|
||||||
|
needs_fa2_dtype = True
|
||||||
|
|
||||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||||
if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
|
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
|
||||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
@@ -429,7 +409,7 @@ def load_model(
|
|||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(cfg.torch_dtype)
|
module.to(cfg.torch_dtype)
|
||||||
|
|
||||||
model, peft_config = load_adapter(model, cfg, cfg.adapter)
|
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||||
|
|
||||||
if cfg.ddp and not load_in_8bit:
|
if cfg.ddp and not load_in_8bit:
|
||||||
model.to(f"cuda:{cfg.local_rank}")
|
model.to(f"cuda:{cfg.local_rank}")
|
||||||
@@ -460,7 +440,7 @@ def load_model(
|
|||||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||||
|
|
||||||
# TODO resume_from_checkpoint handling
|
# TODO resume_from_checkpoint handling
|
||||||
return model, peft_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|
||||||
def load_adapter(model, cfg, adapter, inference=False):
|
def load_adapter(model, cfg, adapter, inference=False):
|
||||||
@@ -470,8 +450,6 @@ def load_adapter(model, cfg, adapter, inference=False):
|
|||||||
return model, None
|
return model, None
|
||||||
if hasattr(model, "enable_input_require_grads"):
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
model.enable_input_require_grads()
|
model.enable_input_require_grads()
|
||||||
if adapter == "ia3":
|
|
||||||
return load_ia3(model, cfg, inference=inference)
|
|
||||||
if adapter in ["lora", "qlora"]:
|
if adapter in ["lora", "qlora"]:
|
||||||
return load_lora(model, cfg, inference=inference)
|
return load_lora(model, cfg, inference=inference)
|
||||||
if adapter == "llama-adapter":
|
if adapter == "llama-adapter":
|
||||||
@@ -490,11 +468,11 @@ def load_llama_adapter(model, cfg):
|
|||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.peft_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretained PEFT - llama_adapter")
|
LOG.debug("Loading pretained PEFT - llama_adapter")
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.peft_model_dir,
|
cfg.lora_model_dir,
|
||||||
torch_dtype=torch.float16,
|
torch_dtype=torch.float16,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -507,20 +485,16 @@ def load_llama_adapter(model, cfg):
|
|||||||
|
|
||||||
def find_all_linear_names(model):
|
def find_all_linear_names(model):
|
||||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||||
peft_module_names = set()
|
lora_module_names = set()
|
||||||
for name, module in model.named_modules():
|
for name, module in model.named_modules():
|
||||||
if (
|
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
|
||||||
isinstance(module, cls)
|
|
||||||
or "Linear" in module.__class__.__name__
|
|
||||||
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
|
||||||
):
|
|
||||||
names = name.split(".")
|
names = name.split(".")
|
||||||
peft_module_names.add(names[0] if len(names) == 1 else names[-1])
|
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||||
|
|
||||||
if "lm_head" in peft_module_names: # needed for 16-bit
|
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||||
peft_module_names.remove("lm_head")
|
lora_module_names.remove("lm_head")
|
||||||
|
|
||||||
return list(peft_module_names)
|
return list(lora_module_names)
|
||||||
|
|
||||||
|
|
||||||
def load_lora(model, cfg, inference=False):
|
def load_lora(model, cfg, inference=False):
|
||||||
@@ -528,68 +502,34 @@ def load_lora(model, cfg, inference=False):
|
|||||||
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
|
||||||
peft_target_modules = list(cfg.peft_target_modules or [])
|
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||||
|
|
||||||
if cfg.peft_target_linear:
|
if cfg.lora_target_linear:
|
||||||
linear_names = find_all_linear_names(model)
|
linear_names = find_all_linear_names(model)
|
||||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||||
peft_target_modules = list(set(peft_target_modules + linear_names))
|
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||||
|
|
||||||
peft_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
r=cfg.peft_r,
|
r=cfg.lora_r,
|
||||||
lora_alpha=cfg.peft_alpha,
|
lora_alpha=cfg.lora_alpha,
|
||||||
target_modules=peft_target_modules,
|
target_modules=lora_target_modules,
|
||||||
lora_dropout=cfg.peft_dropout,
|
lora_dropout=cfg.lora_dropout,
|
||||||
fan_in_fan_out=cfg.peft_fan_in_fan_out,
|
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||||
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
|
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||||
bias="none",
|
bias="none",
|
||||||
task_type="CAUSAL_LM",
|
task_type="CAUSAL_LM",
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.peft_model_dir:
|
if cfg.lora_model_dir:
|
||||||
LOG.debug("Loading pretained PEFT - LoRA")
|
LOG.debug("Loading pretained PEFT - LoRA")
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model,
|
model,
|
||||||
cfg.peft_model_dir,
|
cfg.lora_model_dir,
|
||||||
is_trainable=(not inference),
|
is_trainable=(not inference),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model = get_peft_model(model, peft_config)
|
model = get_peft_model(model, lora_config)
|
||||||
|
|
||||||
model.print_trainable_parameters()
|
model.print_trainable_parameters()
|
||||||
|
|
||||||
return model, peft_config
|
return model, lora_config
|
||||||
|
|
||||||
|
|
||||||
def load_ia3(model, cfg, inference=False):
|
|
||||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
|
||||||
|
|
||||||
from peft import IA3Config, PeftModel, get_peft_model
|
|
||||||
|
|
||||||
peft_config_kwargs = {}
|
|
||||||
if cfg.peft_init_ia3_weights is not None:
|
|
||||||
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
|
|
||||||
if cfg.peft_fan_in_fan_out is not None:
|
|
||||||
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
|
|
||||||
|
|
||||||
peft_config = IA3Config(
|
|
||||||
target_modules=cfg.peft_target_modules,
|
|
||||||
feedforward_modules=cfg.peft_feedforward_modules,
|
|
||||||
modules_to_save=cfg.peft_modules_to_save,
|
|
||||||
task_type="CAUSAL_LM",
|
|
||||||
**peft_config_kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
if cfg.peft_model_dir:
|
|
||||||
LOG.debug("Loading pretained PEFT - IA3")
|
|
||||||
model = PeftModel.from_pretrained(
|
|
||||||
model,
|
|
||||||
cfg.peft_model_dir,
|
|
||||||
is_trainable=(not inference),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model = get_peft_model(model, peft_config)
|
|
||||||
|
|
||||||
model.print_trainable_parameters()
|
|
||||||
|
|
||||||
return model, peft_config
|
|
||||||
|
|||||||
@@ -30,7 +30,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
|||||||
from axolotl.utils.callbacks import (
|
from axolotl.utils.callbacks import (
|
||||||
EvalFirstStepCallback,
|
EvalFirstStepCallback,
|
||||||
GPUStatsCallback,
|
GPUStatsCallback,
|
||||||
SaveAxolotlConfigtoWandBCallback,
|
|
||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
@@ -423,9 +422,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Phi doesn't want the attention_mask feature when training
|
# Phi doesn't want the attention_mask feature when training
|
||||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
|
||||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
|
||||||
):
|
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||||
@@ -778,9 +775,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
||||||
trainer.add_callback(LogPredictionCallback(cfg))
|
trainer.add_callback(LogPredictionCallback(cfg))
|
||||||
|
|
||||||
if cfg.use_wandb:
|
|
||||||
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
|
|
||||||
|
|
||||||
if cfg.do_bench_eval:
|
if cfg.do_bench_eval:
|
||||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||||
|
|
||||||
|
|||||||
@@ -24,10 +24,6 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def test_lora(self):
|
def test_lora(self):
|
||||||
"""
|
|
||||||
support for legacy lora_ configs
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
output_dir = tempfile.mkdtemp()
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
@@ -70,101 +66,6 @@ class TestLoraLlama(unittest.TestCase):
|
|||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||||
|
|
||||||
def test_lora_peft(self):
|
|
||||||
"""
|
|
||||||
support for legacy lora_ configs
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"base_model_config": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "lora",
|
|
||||||
"peft_r": 32,
|
|
||||||
"peft_alpha": 64,
|
|
||||||
"peft_dropout": 0.05,
|
|
||||||
"peft_target_linear": True,
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 2,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": output_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
|
||||||
|
|
||||||
def test_ia3_peft(self):
|
|
||||||
"""
|
|
||||||
support for IA3 peft
|
|
||||||
:return:
|
|
||||||
"""
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
output_dir = tempfile.mkdtemp()
|
|
||||||
cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"base_model": "JackFram/llama-68m",
|
|
||||||
"base_model_config": "JackFram/llama-68m",
|
|
||||||
"tokenizer_type": "LlamaTokenizer",
|
|
||||||
"sequence_len": 1024,
|
|
||||||
"load_in_8bit": True,
|
|
||||||
"adapter": "ia3",
|
|
||||||
"peft_r": 32,
|
|
||||||
"peft_alpha": 64,
|
|
||||||
"peft_dropout": 0.05,
|
|
||||||
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
|
|
||||||
"peft_feedforward_modules": ["down_proj"],
|
|
||||||
"val_set_size": 0.1,
|
|
||||||
"special_tokens": {
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
},
|
|
||||||
"datasets": [
|
|
||||||
{
|
|
||||||
"path": "mhenrichsen/alpaca_2k_test",
|
|
||||||
"type": "alpaca",
|
|
||||||
},
|
|
||||||
],
|
|
||||||
"num_epochs": 2,
|
|
||||||
"micro_batch_size": 8,
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"output_dir": output_dir,
|
|
||||||
"learning_rate": 0.00001,
|
|
||||||
"optimizer": "adamw_torch",
|
|
||||||
"lr_scheduler": "cosine",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
normalize_config(cfg)
|
|
||||||
cli_args = TrainerCliArgs()
|
|
||||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
|
||||||
|
|
||||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
|
||||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
|
||||||
|
|
||||||
def test_lora_packing(self):
|
def test_lora_packing(self):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
output_dir = tempfile.mkdtemp()
|
output_dir = tempfile.mkdtemp()
|
||||||
|
|||||||
2
tests/fixtures/conversation.tokenized.json
vendored
2
tests/fixtures/conversation.tokenized.json
vendored
File diff suppressed because one or more lines are too long
@@ -1,48 +0,0 @@
|
|||||||
"""Module for testing the validation module"""
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import unittest
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from axolotl.utils.config import normalize_config
|
|
||||||
from axolotl.utils.dict import DictDefault
|
|
||||||
|
|
||||||
|
|
||||||
class NormalizationTest(unittest.TestCase):
|
|
||||||
"""
|
|
||||||
Test the cfg normalization module
|
|
||||||
"""
|
|
||||||
|
|
||||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def inject_fixtures(self, caplog):
|
|
||||||
self._caplog = caplog
|
|
||||||
|
|
||||||
def test_lora_to_peft(self):
|
|
||||||
base_cfg = DictDefault(
|
|
||||||
{
|
|
||||||
"gradient_accumulation_steps": 1,
|
|
||||||
"micro_batch_size": 1,
|
|
||||||
"base_model": "NousResearch/Llama-2-7b-hf",
|
|
||||||
"base_model_config": "NousResearch/Llama-2-7b-hf",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
cfg = base_cfg | DictDefault(
|
|
||||||
{
|
|
||||||
"adapter": "lora",
|
|
||||||
"lora_r": 128,
|
|
||||||
"lora_alpha": 64,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
normalize_config(cfg)
|
|
||||||
assert any(
|
|
||||||
"soon to be deprecated. please use peft_" in record.message
|
|
||||||
for record in self._caplog.records
|
|
||||||
)
|
|
||||||
|
|
||||||
assert cfg.peft_r == 128
|
|
||||||
assert cfg.peft_alpha == 64
|
|
||||||
@@ -90,73 +90,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
strat.tokenize_prompt(conversation)
|
strat.tokenize_prompt(conversation)
|
||||||
assert "assistant turn has empty text" in self._caplog.records[1].message
|
assert "assistant turn has empty text" in self._caplog.records[1].message
|
||||||
|
|
||||||
def test_sharegpt_warnings_turns(self):
|
|
||||||
conversation = {
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "dolor"},
|
|
||||||
{"from": "human", "value": "dolor"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
]
|
|
||||||
}
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
strat.tokenize_prompt(conversation)
|
|
||||||
assert (
|
|
||||||
"Role did not alternate between turns (gpt and human)"
|
|
||||||
in self._caplog.records[0].message
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_sharegpt_changes_roles(self):
|
|
||||||
conversation = {
|
|
||||||
"roles": ["USER", "CHARACTER"],
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "dolor"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
res = strat.tokenize_prompt(conversation)
|
|
||||||
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
|
|
||||||
|
|
||||||
def test_sharegpt_assistant_label_ignore(self):
|
|
||||||
conversation = {
|
|
||||||
"roles": ["user", "assistant"],
|
|
||||||
"conversations": [
|
|
||||||
{"from": "system", "value": "lorem"},
|
|
||||||
{"from": "gpt", "value": "ipsum"},
|
|
||||||
{"from": "human", "value": "dolor"},
|
|
||||||
{"from": "gpt", "value": "sit"},
|
|
||||||
],
|
|
||||||
}
|
|
||||||
prompter = ShareGPTPrompterV2()
|
|
||||||
strat = ShareGPTPromptTokenizingStrategy(
|
|
||||||
prompter,
|
|
||||||
self.tokenizer,
|
|
||||||
False,
|
|
||||||
2048,
|
|
||||||
)
|
|
||||||
with self._caplog.at_level(logging.WARNING):
|
|
||||||
res = strat.tokenize_prompt(conversation)
|
|
||||||
idx = res["input_ids"].index(20255) # assistant token
|
|
||||||
assert res["labels"][idx] == -100
|
|
||||||
|
|
||||||
def test_no_sys_prompt(self):
|
def test_no_sys_prompt(self):
|
||||||
"""
|
"""
|
||||||
tests the interface between the user and assistant parts
|
tests the interface between the user and assistant parts
|
||||||
|
|||||||
Reference in New Issue
Block a user