Compare commits

..

1 Commits

Author SHA1 Message Date
Wing Lian
b4d84d56d5 support for batched sharegpt tokenization to skip bad data 2023-10-06 15:03:07 -04:00
84 changed files with 1198 additions and 2554 deletions

View File

@@ -25,11 +25,6 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -23,12 +23,6 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
@@ -52,12 +46,9 @@ jobs:
build-args: | build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: | tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }} labels: ${{ steps.metadata.outputs.labels }}
build-axolotl-runpod: build-axolotl-runpod:
needs: build-axolotl needs: build-axolotl
@@ -77,11 +68,6 @@ jobs:
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-nested-blocks,

379
README.md
View File

@@ -23,15 +23,15 @@ Features:
- [Supported Features](#axolotl-supports) - [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-) - [Quickstart](#quickstart-)
- [Installation](#installation) - [Installation](#installation)
- [Docker](#docker) - [Docker Installation](#environment)
- [Conda/Pip venv](#condapip-venv) - [Conda/Pip venv Installation](#condapip-venv)
- [LambdaLabs](#lambdalabs) - [LambdaLabs Installation](#lambdalabs)
- [Windows](#windows)
- [Dataset](#dataset) - [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts) - [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
- [Config](#config) - [Config](#config)
- [Train](#train) - [Train](#train)
- [Training w/ Deepspeed](#training-with-deepspeed)
- [Inference](#inference) - [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base) - [Merge LORA to Base](#merge-lora-to-base)
- [Common Errors](#common-errors-) - [Common Errors](#common-errors-)
@@ -50,7 +50,7 @@ Features:
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b> <b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
</p> </p>
<p> <p>
Go ahead and Axolotl questions!! Go ahead and axolotl questions!!
</p> </p>
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit"> <img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main"> <img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
@@ -102,7 +102,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
### Environment ### Environment
#### Docker - Docker
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
@@ -114,31 +114,12 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
docker compose up -d docker compose up -d
``` ```
<details> - Conda/Pip venv
<summary>Docker advanced</summary>
A more powerful Docker command to run would be this:
```bash
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
```
It additionally:
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
</details>
#### Conda/Pip venv
1. Install python >=**3.9** 1. Install python >=**3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/ 2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install Axolotl along with python dependencies 3. Install axolotl along with python dependencies
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
@@ -149,7 +130,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
``` ```
Get the token at huggingface.co/settings/tokens Get the token at huggingface.co/settings/tokens
#### LambdaLabs - LambdaLabs
<details> <details>
<summary>Click to Expand</summary> <summary>Click to Expand</summary>
@@ -193,8 +174,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
``` ```
</details> </details>
#### Windows - Windows: Please use WSL or Docker!
Please use WSL or Docker!
### Dataset ### Dataset
@@ -315,24 +295,25 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
For a dataset that is preprocessed for instruction purposes: Using yaml. Example:
```json
{"instruction": "...", "output": "..."}
```
You can use this example in your YAML config:
```yaml ```yaml
datasets: datasets:
- path: repo - path: repo
type: type:
system_prompt: "" system_prompt: ""
field_system: system no_input_format: |-
format: "[INST] {instruction} [/INST]" User: {instruction}<|end_of_turn|>
no_input_format: "[INST] {instruction} [/INST]" Assistant:
format: |-
User: {instruction}
{input}<|end_of_turn|>
Assistant:
``` ```
Using file:
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
#### How to use your custom pretokenized dataset #### How to use your custom pretokenized dataset
- Do not pass a `type:` - Do not pass a `type:`
@@ -374,13 +355,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
- typescript - typescript
type: ... # unimplemented custom format type: ... # unimplemented custom format
# fastchat conversation
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
datasets:
- path: ...
type: sharegpt
conversation: chatml
# local # local
datasets: datasets:
- path: data.jsonl # or json - path: data.jsonl # or json
@@ -419,18 +393,18 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
<details> <details>
<summary>All yaml options (click me)</summary> <summary>All yaml options</summary>
```yaml ```yaml
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files # this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# This can also be a relative path to a model on disk # this can also be a relative path to a model on disk
base_model: ./llama-7b-hf base_model: ./llama-7b-hf
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc) # you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
base_model_ignore_patterns: base_model_ignore_patterns:
# If the base_model repo on hf hub doesn't include configuration .json files, # if the base_model repo on hf hub doesn't include configuration .json files,
# You can set that here, or leave this empty to default to base_model # you can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub # you can specify to choose a specific model revision from huggingface hub
model_revision: model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer # Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
@@ -445,24 +419,23 @@ trust_remote_code:
tokenizer_use_fast: tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True # Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy: tokenizer_legacy:
# Resize the model embeddings when new tokens are added to multiples of 32 # resize the model embeddings when new tokens are added to multiples of 32
# This is reported to improve training speed on some models # this is reported to improve training speed on some models
resize_token_embeddings_to_32x: resize_token_embeddings_to_32x:
# Used to identify which the model is based on # used to identify which the model is based on
is_falcon_derived_model: is_falcon_derived_model:
is_llama_derived_model: is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model: is_mistral_derived_model:
# Whether you are training a 4-bit GPTQ quantized model # whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
gptq_groupsize: 128 # group size gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2 gptq_model_v1: false # v1 or v2
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true load_in_8bit: true
# Use bitsandbytes 4 bit # use bitsandbytes 4 bit
load_in_4bit: load_in_4bit:
# Use CUDA bf16 # Use CUDA bf16
@@ -476,9 +449,9 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere bfloat16: true # require >=ampere
float16: true float16: true
# A list of one or more datasets to finetune the model with # a list of one or more datasets to finetune the model with
datasets: datasets:
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files # hf dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -486,22 +459,19 @@ datasets:
data_files: # Optional[str] path to source data files data_files: # Optional[str] path to source data files
shards: # Optional[int] number of shards to split data into shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
# Optional[str] fastchat conversation type, only used with type: sharegpt # custom user prompt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
# Custom user prompt
- path: repo - path: repo
type: type:
# The below are defaults. only set what's needed. # the below are defaults. only set what's needed.
system_prompt: "" system_prompt: ""
system_format: "{system}"
field_system: system field_system: system
field_instruction: instruction field_instruction: instruction
field_input: input field_output: input
field_output: output
# Customizable to be single line or multi-line # customizable to be single line or multi-line
system_format: "{system}"
# 'format' can include {input} # 'format' can include {input}
format: |- format: |-
User: {instruction} {input} User: {instruction} {input}
@@ -509,13 +479,13 @@ datasets:
# 'no_input_format' cannot include {input} # 'no_input_format' cannot include {input}
no_input_format: "{instruction} " no_input_format: "{instruction} "
# For `completion` datsets only, uses the provided field instead of `text` column # for completions datsets, uses the provided field if not `text`
field: field:
# Axolotl attempts to save the dataset as an arrow after packing the data together so # axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
# Push prepared dataset to hub # push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set. # if not set.
@@ -525,8 +495,8 @@ hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub # how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy: hub_strategy:
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# Required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval. # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
val_set_size: 0.04 val_set_size: 0.04
@@ -535,34 +505,30 @@ dataset_shard_num:
# Index of shard to use for whole dataset # Index of shard to use for whole dataset
dataset_shard_idx: dataset_shard_idx:
# The maximum length of an input to train with, this should typically be less than 2048 # the maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048 # as most models have a token/context limit of 2048
sequence_len: 2048 sequence_len: 2048
# Pad inputs so each step uses constant sized buffers # pad inputs so each step uses constant sized buffers
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently # this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len: pad_to_sequence_len:
# Max sequence length to concatenate training samples together up to # max sequence length to concatenate training samples together up to
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED # FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024 max_packed_sequence_len: 1024
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' # use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing: sample_packing:
# Set to 'false' if getting errors during eval with sample_packing on. # set to 'false' if getting errors during eval with sample_packing on.
eval_sample_packing: eval_sample_packing:
# You can set these packing optimizations AFTER starting a training at least once. # you can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values. # The trainer will provide recommended values for these values.
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
# If you already have a lora model trained that you want to load, put that here. # if you already have a lora model trained that you want to load, put that here
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`. # lora hyperparameters
lora_model_dir: lora_model_dir:
# LoRA hyperparameters
# For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
lora_r: 8 lora_r: 8
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
@@ -574,96 +540,81 @@ lora_target_modules:
# - gate_proj # - gate_proj
# - down_proj # - down_proj
# - up_proj # - up_proj
lora_target_linear: # If true, will target all linear layers lora_target_linear: # if true, will target all linear layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
lora_modules_to_save: lora_modules_to_save:
# - embed_tokens # - embed_tokens
# - lm_head # - lm_head
# Once you complete training, the model will be saved to the following directory.
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir: lora_out_dir:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# ReLoRA configuration # ReLoRA configuration
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # Number of steps per ReLoRA restart relora_steps: # number of steps per ReLoRA restart
relora_warmup_steps: # Number of per-restart warmup steps relora_warmup_steps: # number of per-restart warmup steps
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it # wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # Your wandb project name wandb_project: # your wandb project name
wandb_entity: # A wandb Team name if using a Team wandb_entity: # a wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: # Set the name of your wandb run wandb_run_id: # set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# Where to save the full-finetuned model to # where to save the finished model to
output_dir: ./completed-model output_dir: ./completed-model
# Whether to use torch.compile and which backend to use # whether to use torch.compile and which backend to use
torch_compile: # bool torch_compile: # bool
torch_compile_backend: # Optional[str] torch_compile_backend: # Optional[str]
# Training hyperparameters # training hyperparameters
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size:
num_epochs: 4 num_epochs: 3
warmup_steps: 100 warmup_steps: 100
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
save_strategy: # Set to `no` to skip checkpoint saves save_strategy: # set to `no` to skip checkpoint saves
save_steps: # Leave empty to save at each epoch save_steps: # leave empty to save at each epoch
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps eval_steps: # leave empty to eval at each epoch
save_total_limit: # Checkpoints saved at a time save_total_limit: # checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps: max_steps:
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128 eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
# Save model as safetensors (require safetensors package) # save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
# Whether to mask out or include the human's prompt from the training labels # whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# Group similarly sized data to minimize padding. # group similarly sized data to minimize padding
# May be slower to start, as it must download and sort the entire dataset. # may be slower to start, as it must download and sort the entire dataset
# Note that training loss may have an oscillating pattern with this enabled. # note that training loss may have an oscillating pattern with this enabled
group_by_length: false group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false gradient_checkpointing: false
# Stop training after this many evaluation losses have increased in a row # stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3 early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer # specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
# For one_cycle optim # for one_cycle optim
lr_div_factor: # Learning rate div factor lr_div_factor: # learning rate div factor
# For log_sweep optim # for log_sweep optim
log_sweep_min_lr: log_sweep_min_lr:
log_sweep_max_lr: log_sweep_max_lr:
# Specify optimizer # specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see: # Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
# #
@@ -689,7 +640,7 @@ log_sweep_max_lr:
# - paged_lion_32bit # - paged_lion_32bit
# - paged_lion_8bit # - paged_lion_8bit
optimizer: optimizer:
# Specify weight decay # specify weight decay
weight_decay: weight_decay:
# adamw hyperparams # adamw hyperparams
adam_beta1: adam_beta1:
@@ -698,58 +649,49 @@ adam_epsilon:
# Gradient clipping max norm # Gradient clipping max norm
max_grad_norm: max_grad_norm:
# Augmentation techniques # whether to bettertransformers
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral
noisy_embedding_alpha:
# Whether to bettertransformers
flash_optimum: flash_optimum:
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers: # whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention: xformers_attention:
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention: # whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention: flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation # whether to use scaled-dot-product attention
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
# Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention: sdp_attention:
# Landmark attention (only llama) # Landmark attention (only llama)
landmark_attention: landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# LLaMA only # llama only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling: rope_scaling:
type: # linear | dynamic type: # linear | dynamic
factor: # float factor: # float
# Resume from a specific checkpoint dir # resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
# If resume_from_checkpoint isn't set and you simply want it to start where it left off. # if resume_from_checkpoint isn't set and you simply want it to start where it left off
# Be careful with this being turned on between different models. # be careful with this being turned on between different models
auto_resume_from_checkpoints: false auto_resume_from_checkpoints: false
# Don't mess with this, it's here for accelerate and torchrun # don't mess with this, it's here for accelerate and torchrun
local_rank: local_rank:
# Add or change special tokens. # add or change special tokens
# If you add tokens here, you don't need to add them to the `tokens` list.
special_tokens: special_tokens:
# bos_token: "<s>" # bos_token: "<s>"
# eos_token: "</s>" # eos_token: "</s>"
# unk_token: "<unk>" # unk_token: "<unk>"
# add extra tokens
# Add extra tokens.
tokens: tokens:
# FSDP # FSDP
fsdp: fsdp:
fsdp_config: fsdp_config:
# Deepspeed config path. e.g., deepspeed/zero3.json # Deepspeed config path
deepspeed: deepspeed:
# Advanced DDP Arguments # Advanced DDP Arguments
@@ -775,66 +717,6 @@ strict:
</details> </details>
<details>
<summary> Understanding of batch size and gradient accumulation steps </summary>
<br/>
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
**Example 1:**
Micro batch size: 3
Gradient accumulation steps: 2
Number of GPUs: 3
Total batch size = 3 * 2 * 3 = 18
```
| GPU 1 | GPU 2 | GPU 3 |
|----------------|----------------|----------------|
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|----------------|----------------|----------------|
| → (accumulate) | → (accumulate) | → (accumulate) |
|----------------|----------------|----------------|
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|----------------|----------------|----------------|
| → (apply) | → (apply) | → (apply) |
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
Weight update for w1:
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
```
**Example 2:**
Micro batch size: 2
Gradient accumulation steps: 1
Number of GPUs: 3
Total batch size = 2 * 1 * 3 = 6
```
| GPU 1 | GPU 2 | GPU 3 |
|-----------|-----------|-----------|
| S1, S2 | S3, S4 | S5, S6 |
| e1, e2 | e3, e4 | e5, e6 |
|-----------|-----------|-----------|
| → (apply) | → (apply) | → (apply) |
Accumulated gradient for the weight w1 (considering all GPUs):
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
Weight update for w1:
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
```
</details>
### Train ### Train
Run Run
@@ -842,41 +724,14 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml accelerate launch -m axolotl.cli.train your_config.yml
``` ```
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
This is recommended for large datasets.
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
- Use `--debug` to see preprocessed examples.
```bash
python -m axolotl.cli.preprocess your_config.yml
```
#### Multi-GPU #### Multi-GPU
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed You can optionally pre-tokenize dataset with the following before finetuning:
is the recommended multi-GPU option currently because FSDP may experience ```bash
[loss instability](https://github.com/huggingface/transformers/issues/26498). CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
##### DeepSpeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```yaml
deepspeed: deepspeed/zero1.json
``` ```
```shell ##### Config
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
##### FSDP
- llama FSDP - llama FSDP
```yaml ```yaml
@@ -901,6 +756,24 @@ wandb_run_id:
wandb_log_model: wandb_log_model:
``` ```
### Training with Deepspeed
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
```
or
```yaml
deepspeed: deepspeed/zero1.json
```
### Inference ### Inference
Pass the appropriate flag to the train command: Pass the appropriate flag to the train command:
@@ -919,10 +792,6 @@ Pass the appropriate flag to the train command:
--base_model="./completed-model" --prompter=None --load_in_8bit=True --base_model="./completed-model" --prompter=None --load_in_8bit=True
``` ```
Please use `--sample_packing False` if you have it on and receive the error similar to below:
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
### Merge LORA to base ### Merge LORA to base
Add below flag to train command above Add below flag to train command above
@@ -939,8 +808,6 @@ CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
## Common Errors 🧰 ## Common Errors 🧰
See also the [FAQ's](./docs/faq.md).
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it: > If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
Please reduce any below Please reduce any below

View File

@@ -1,6 +1,14 @@
{ {
"zero_optimization": { "zero_optimization": {
"stage": 3, "stage": 3,
"offload_optimizer": {
"device": "cpu",
"pin_memory": true
},
"offload_param": {
"device": "cpu",
"pin_memory": true
},
"overlap_comm": true, "overlap_comm": true,
"contiguous_gradients": true, "contiguous_gradients": true,
"sub_group_size": 0, "sub_group_size": 0,
@@ -33,13 +41,12 @@
} }
}, },
"scheduler": { "scheduler": {
"type": "WarmupDecayLR", "type": "WarmupLR",
"params": { "params": {
"warmup_min_lr": "auto", "warmup_min_lr": "auto",
"warmup_max_lr": "auto", "warmup_max_lr": "auto",
"warmup_num_steps": "auto", "warmup_num_steps": "auto",
"warmup_type": "linear", "warmup_type": "linear"
"total_num_steps": "auto"
} }
}, },
"gradient_accumulation_steps": "auto", "gradient_accumulation_steps": "auto",

View File

@@ -5,9 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG CUDA="118" ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y vim curl apt-get install -y vim curl
@@ -19,7 +16,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \

View File

@@ -14,7 +14,7 @@ ARG CUDA="118"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \ && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \
@@ -57,6 +57,11 @@ FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
RUN cd apex && python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes

View File

@@ -1,18 +0,0 @@
# Axolotl FAQ's
> The trainer stopped and hasn't progressed in several minutes.
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
> Exitcode -9
This usually happens when you run out of system RAM.
> Exitcode -7 while using deepspeed
Try upgrading deepspeed w: `pip install -U deepspeed`
> AttributeError: 'DummyOptim' object has no attribute 'step'
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.

View File

@@ -1,51 +0,0 @@
# Multipack
4k context, bsz =4,
each character represents 256 tokens
X represents a padding token
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B ]
C C C C C C C ]
D D D D ]]
[[ E E E E E E E E ]
[ F F F F ]
[ G G G ]
[ H H H H ]]
[[ I I I ]
[ J J J ]
[ K K K K K]
[ L L L ]]
```
after padding to longest input in each step
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B X X X X X X ]
C C C C C C C X X X X ]
D D D D X X X X X X X ]]
[[ E E E E E E E E ]
[ F F F F X X X X ]
[ G G G X X X X X ]
[ H H H H X X X X ]]
[[ I I I X X ]
[ J J J X X ]
[ K K K K K ]
[ L L L X X ]]
```
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A B B B B B
B C C C C C C C D D D D E E E E
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```

View File

@@ -1,4 +1,5 @@
base_model: cerebras/btlm-3b-8k-base base_model: cerebras/btlm-3b-8k-base
base_model_config: cerebras/btlm-3b-8k-base
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: GPT2Tokenizer tokenizer_type: GPT2Tokenizer
trust_remote_code: true trust_remote_code: true

View File

@@ -1,4 +1,5 @@
base_model: cerebras/Cerebras-GPT-1.3B base_model: cerebras/Cerebras-GPT-1.3B
base_model_config: cerebras/Cerebras-GPT-1.3B
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -49,7 +50,7 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -34,7 +35,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +55,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-13b-hf base_model: codellama/CodeLlama-13b-hf
base_model_config: codellama/CodeLlama-13b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -36,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +57,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -34,7 +35,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +55,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-34b-hf base_model: codellama/CodeLlama-34b-hf
base_model_config: codellama/CodeLlama-34b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -36,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +57,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -34,7 +35,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +55,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: codellama/CodeLlama-7b-hf base_model: codellama/CodeLlama-7b-hf
base_model_config: codellama/CodeLlama-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: CodeLlamaTokenizer tokenizer_type: CodeLlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -36,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +57,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -1,6 +1,7 @@
# 1b: tiiuae/falcon-rw-1b # 1b: tiiuae/falcon-rw-1b
# 40b: tiiuae/falcon-40b # 40b: tiiuae/falcon-40b
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main # required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
@@ -53,7 +54,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization # decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1 micro_batch_size: 1
gradient_accumulation_steps: 2 gradient_accumulation_steps: 2
num_epochs: 4 num_epochs: 3
# Optimizer for QLoRA # Optimizer for QLoRA
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:

View File

@@ -1,4 +1,5 @@
base_model: tiiuae/falcon-7b base_model: tiiuae/falcon-7b
base_model_config: tiiuae/falcon-7b
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -1,4 +1,5 @@
base_model: EleutherAI/gpt-j-6b base_model: EleutherAI/gpt-j-6b
base_model_config: EleutherAI/gpt-j-6b
load_in_8bit: false load_in_8bit: false
load_in_4bit: true load_in_4bit: true
strict: false strict: false
@@ -46,7 +47,7 @@ flash_attention:
gptq_groupsize: gptq_groupsize:
gptq_model_v1: gptq_model_v1:
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: save_steps:
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: huggyllama/llama-7b base_model: huggyllama/llama-7b
base_model_config: huggyllama/llama-7b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false
@@ -24,7 +25,7 @@ wandb_log_model:
output_dir: ./jeopardy-bot-7b output_dir: ./jeopardy-bot-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -9,16 +9,12 @@ gradient_accumulation_steps: 2
micro_batch_size: 1 micro_batch_size: 1
```shell ```shell
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
``` ```
or or
```shell ```shell
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml accelerate launch scripts/finetune.py examples/llama-2/lora.yml
```
To launch a full finetuning with 16-bit precision:
```shell
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
``` ```

View File

@@ -1,72 +0,0 @@
base_model: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
adapter:
lora_model_dir:
lora_r:
lora_alpha:
lora_dropout:
lora_target_linear:
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:
gradient_accumulation_steps: 1
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: true
fp16: false
tf32: false
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true
flash_attn_cross_entropy: false
flash_attn_rms_norm: true
flash_attn_fuse_qkv: false
flash_attn_fuse_mlp: true
warmup_steps: 100
eval_steps: 0.05
eval_table_size:
save_steps:
debug:
deepspeed: #deepspeed/zero2.json # multi-gpu only
weight_decay: 0.1
fsdp:
fsdp_config:
special_tokens:
bos_token: "<s>"
eos_token: "</s>"
unk_token: "<unk>"

View File

@@ -1,4 +1,5 @@
base_model: TheBloke/Llama-2-7B-GPTQ base_model: TheBloke/Llama-2-7B-GPTQ
base_model_config: TheBloke/Llama-2-7B-GPTQ
is_llama_derived_model: false is_llama_derived_model: false
gptq: true gptq: true
gptq_disable_exllama: true gptq_disable_exllama: true
@@ -37,7 +38,7 @@ wandb_log_model:
output_dir: ./model-out output_dir: ./model-out
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_torch optimizer: adamw_torch
adam_beta2: 0.95 adam_beta2: 0.95
adam_eps: 0.00001 adam_eps: 0.00001

View File

@@ -1,4 +1,5 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -34,7 +35,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +55,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size:
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:

View File

@@ -1,4 +1,5 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -36,7 +37,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -56,7 +57,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size:
save_steps: save_steps:
debug: debug:

View File

@@ -1,4 +1,5 @@
base_model: NousResearch/Llama-2-7b-hf base_model: NousResearch/Llama-2-7b-hf
base_model_config: NousResearch/Llama-2-7b-hf
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_llama_derived_model: true is_llama_derived_model: true
@@ -40,7 +41,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -60,7 +61,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
save_steps: 50 save_steps: 50
debug: debug:
deepspeed: deepspeed:

View File

@@ -1,4 +1,5 @@
base_model: PY007/TinyLlama-1.1B-step-50K-105b base_model: PY007/TinyLlama-1.1B-step-50K-105b
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
@@ -34,7 +35,7 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.0002
@@ -54,7 +55,7 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size:
save_steps: save_steps:
debug: debug:

View File

@@ -1,4 +1,5 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true is_mistral_derived_model: true
@@ -15,8 +16,8 @@ val_set_size: 0.01
output_dir: ./out output_dir: ./out
sequence_len: 8192 sequence_len: 8192
sample_packing: true sample_packing:
pad_to_sequence_len: true pad_to_sequence_len:
wandb_project: wandb_project:
wandb_entity: wandb_entity:
@@ -26,10 +27,10 @@ wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 2
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.000005 learning_rate: 0.0002
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
@@ -46,8 +47,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size: 5
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:
debug: debug:

View File

@@ -1,4 +1,5 @@
base_model: mistralai/Mistral-7B-v0.1 base_model: mistralai/Mistral-7B-v0.1
base_model_config: mistralai/Mistral-7B-v0.1
model_type: MistralForCausalLM model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true is_mistral_derived_model: true
@@ -18,8 +19,8 @@ adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 8192 sequence_len: 8192
sample_packing: true sample_packing: True
pad_to_sequence_len: true pad_to_sequence_len: True
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
@@ -42,7 +43,7 @@ wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 2 micro_batch_size: 4
num_epochs: 1 num_epochs: 1
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
@@ -63,8 +64,8 @@ xformers_attention:
flash_attention: true flash_attention: true
warmup_steps: 10 warmup_steps: 10
eval_steps: 0.05 eval_steps: 20
eval_table_size: eval_table_size: 5
eval_table_max_new_tokens: 128 eval_table_max_new_tokens: 128
save_steps: save_steps:
debug: debug:

View File

@@ -1,4 +1,5 @@
base_model: mosaicml/mpt-7b base_model: mosaicml/mpt-7b
base_model_config: mosaicml/mpt-7b
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
load_in_8bit: false load_in_8bit: false
@@ -26,7 +27,7 @@ wandb_log_model:
output_dir: ./mpt-alpaca-7b output_dir: ./mpt-alpaca-7b
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -1,4 +1,5 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false

View File

@@ -1,4 +1,5 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: true load_in_8bit: true

View File

@@ -1,4 +1,5 @@
base_model: openlm-research/open_llama_3b_v2 base_model: openlm-research/open_llama_3b_v2
base_model_config: openlm-research/open_llama_3b_v2
model_type: LlamaForCausalLM model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer tokenizer_type: LlamaTokenizer
load_in_8bit: false load_in_8bit: false

View File

@@ -1,4 +1,5 @@
base_model: microsoft/phi-1_5 base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: MixFormerSequentialForCausalLM model_type: MixFormerSequentialForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_llama_derived_model: false is_llama_derived_model: false

View File

@@ -1,4 +1,5 @@
base_model: microsoft/phi-1_5 base_model: microsoft/phi-1_5
base_model_config: microsoft/phi-1_5
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
is_llama_derived_model: false is_llama_derived_model: false

View File

@@ -1,4 +1,5 @@
base_model: EleutherAI/pythia-12b-deduped base_model: EleutherAI/pythia-12b-deduped
base_model_config: EleutherAI/pythia-12b-deduped
base_model_ignore_patterns: pytorch* # prefer safetensors base_model_ignore_patterns: pytorch* # prefer safetensors
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer

View File

@@ -1,4 +1,5 @@
base_model: EleutherAI/pythia-1.4b-deduped base_model: EleutherAI/pythia-1.4b-deduped
base_model_config: EleutherAI/pythia-1.4b-deduped
load_in_8bit: true load_in_8bit: true
datasets: datasets:
- path: teknium/GPT4-LLM-Cleaned - path: teknium/GPT4-LLM-Cleaned
@@ -23,15 +24,15 @@ wandb_log_model:
output_dir: ./lora-alpaca-pythia output_dir: ./lora-alpaca-pythia
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
micro_batch_size: 4 micro_batch_size: 4
num_epochs: 4 num_epochs: 3
learning_rate: 0.00001 learning_rate: 0.00001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: true bf16: True
tf32: true tf32: True
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:
weight_decay: 0.1 weight_decay: 0.1
eval_steps: 0.05 eval_steps: 20
logging_steps: 1 logging_steps: 1

View File

@@ -1,4 +1,5 @@
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1 base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
model_type: GPTNeoXForCausalLM model_type: GPTNeoXForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
trust_remote_code: trust_remote_code:
@@ -27,7 +28,7 @@ wandb_log_model:
output_dir: ./redpajama-alpaca-3b output_dir: ./redpajama-alpaca-3b
batch_size: 4 batch_size: 4
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
torchdistx_path: torchdistx_path:
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -1,4 +1,5 @@
base_model: replit/replit-code-v1-3b base_model: replit/replit-code-v1-3b
base_model_config: replit/replit-code-v1-3b
trust_remote_code: true trust_remote_code: true
load_in_8bit: false load_in_8bit: false
datasets: datasets:
@@ -26,7 +27,7 @@ wandb_log_model:
output_dir: ./lora-replit output_dir: ./lora-replit
batch_size: 8 batch_size: 8
micro_batch_size: 1 micro_batch_size: 1
num_epochs: 4 num_epochs: 3
optimizer: optimizer:
torchdistx_path: torchdistx_path:
lr_scheduler: lr_scheduler:

View File

@@ -1,6 +1,7 @@
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora # An example finetuning Saleforce's XGen-7b model with 8k context using qlora
# on Tim Dettmer's Guanaco dataset. # on Tim Dettmer's Guanaco dataset.
base_model: Salesforce/xgen-7b-8k-base base_model: Salesforce/xgen-7b-8k-base
base_model_config: Salesforce/xgen-7b-8k-base
trust_remote_code: true trust_remote_code: true
model_type: AutoModelForCausalLM model_type: AutoModelForCausalLM
tokenizer_type: AutoTokenizer tokenizer_type: AutoTokenizer
@@ -51,7 +52,7 @@ output_dir: ./qlora-out
# decrease if OOM, increase for max VRAM utilization # decrease if OOM, increase for max VRAM utilization
micro_batch_size: 1 micro_batch_size: 1
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
num_epochs: 4 num_epochs: 3
# Optimizer for QLoRA # Optimizer for QLoRA
optimizer: paged_adamw_32bit optimizer: paged_adamw_32bit
torchdistx_path: torchdistx_path:

View File

@@ -1 +0,0 @@
# Page

View File

@@ -1,4 +0,0 @@
# Table of contents
* [Page](README.md)
* [Small dev details](small-dev-details.md)

View File

@@ -1,3 +0,0 @@
# Small dev details
/

Binary file not shown.

Before

Width:  |  Height:  |  Size: 370 KiB

View File

@@ -4,7 +4,7 @@ torch==2.0.1
auto-gptq auto-gptq
packaging packaging
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697 transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
bitsandbytes>=0.41.1 bitsandbytes>=0.41.1
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9 accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
deepspeed deepspeed
@@ -16,7 +16,7 @@ flash-attn>=2.3.0
sentencepiece sentencepiece
wandb wandb
einops einops
xformers>=0.0.22 xformers
optimum optimum
hf_transfer hf_transfer
colorama colorama
@@ -31,4 +31,3 @@ scikit-learn==1.2.2
pynvml pynvml
art art
fschat==0.2.29 fschat==0.2.29
tensor_parallel

View File

@@ -45,6 +45,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
shard(cfg=parsed_cfg, cli_args=parsed_cli_args) shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
else: else:
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -21,14 +21,6 @@ def parse_requirements():
): ):
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)
# TODO(wing) remove once xformers release supports torch 2.1.0
if "torch==2.1.0" in _install_requires:
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
_install_requires.append(
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
)
return _install_requires, _dependency_links return _install_requires, _dependency_links
@@ -46,7 +38,7 @@ setup(
dependency_links=dependency_links, dependency_links=dependency_links,
extras_require={ extras_require={
"flash-attn": [ "flash-attn": [
"flash-attn>=2.3.0", "flash-attn>=2.2.1",
], ],
"deepspeed": [ "deepspeed": [
"deepspeed", "deepspeed",

View File

@@ -194,7 +194,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
# load the config from the yaml file # load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml, # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value # then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()
@@ -222,9 +221,7 @@ def load_datasets(
) -> TrainDatasetMeta: ) -> TrainDatasetMeta:
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
cfg, tokenizer
)
if cli_args.debug or cfg.debug: if cli_args.debug or cfg.debug:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
@@ -240,10 +237,6 @@ def load_datasets(
text_only=cli_args.debug_text_only, text_only=cli_args.debug_text_only,
) )
LOG.info("printing prompters...")
for prompter in prompters:
LOG.info(prompter)
return TrainDatasetMeta( return TrainDatasetMeta(
train_dataset=train_dataset, train_dataset=train_dataset,
eval_dataset=eval_dataset, eval_dataset=eval_dataset,

View File

@@ -14,7 +14,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True

View File

@@ -1,53 +0,0 @@
"""
CLI to run training on a model
"""
import logging
from pathlib import Path
import fire
import transformers
from colorama import Fore
from axolotl.cli import (
check_accelerate_default_config,
check_user_token,
load_cfg,
load_datasets,
print_axolotl_text_art,
)
from axolotl.common.cli import PreprocessCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)
check_accelerate_default_config()
check_user_token()
parser = transformers.HfArgumentParser((PreprocessCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)
if not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "preprocess CLI called without dataset_prepared_path set, "
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
LOG.info(
Fore.GREEN
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
+ Fore.RESET
)
if __name__ == "__main__":
fire.Fire(do_cli)

View File

@@ -1,7 +1,6 @@
""" """
CLI to run training on a model CLI to run training on a model
""" """
import logging
from pathlib import Path from pathlib import Path
import fire import fire
@@ -17,8 +16,6 @@ from axolotl.cli import (
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -30,7 +27,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only:
return
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta) train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)

View File

@@ -25,22 +25,11 @@ class TrainerCliArgs:
debug_num_examples: int = field(default=5) debug_num_examples: int = field(default=5)
inference: bool = field(default=False) inference: bool = field(default=False)
merge_lora: bool = field(default=False) merge_lora: bool = field(default=False)
prepare_ds_only: bool = field(default=False)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
shard: bool = field(default=False) shard: bool = field(default=False)
@dataclass
class PreprocessCliArgs:
"""
dataclass representing arguments for preprocessing only
"""
debug: bool = field(default=False)
debug_text_only: bool = field(default=False)
debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None)
def load_model_and_tokenizer( def load_model_and_tokenizer(
*, *,
cfg: DictDefault, cfg: DictDefault,

View File

@@ -1,5 +0,0 @@
"""
Various shared constants
"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -1,711 +0,0 @@
"""
Builder for the training args and trainer
"""
import abc
import importlib
import logging
import math
import os
import sys
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional, Union
import tensor_parallel as tp
import torch
import transformers
from datasets import Dataset
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import is_distributed
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
LOG = logging.getLogger("axolotl.core.trainer_builder")
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
tensor_parallel: bool = field(
default=False, metadata={"help": "Use tensor parallelism to train"}
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
self.num_epochs = num_epochs
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if (
self.args.world_size > 1
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
)
)
return super().get_train_dataloader()
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=self.num_epochs,
)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
def _wrap_model(self, model, training=True, dataloader=None):
if self.args.tensor_parallel:
model = tp.tensor_parallel(model, distributed=is_distributed())
model.hf_device_map = tp.infer_sharded_device_map(model)
else:
model = super()._wrap_model(model, training=training, dataloader=dataloader)
return model
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
class TrainerBuilderBase(abc.ABC):
"""
Base class for trainer builder
"""
_train_dataset = None
_eval_dataset = None
def __init__(self, cfg, model, tokenizer):
self.cfg = cfg
self.model = model
self.tokenizer = tokenizer
@property
def train_dataset(self):
return self._train_dataset
@train_dataset.setter
def train_dataset(self, dataset):
self._train_dataset = dataset
@property
def eval_dataset(self):
return self._eval_dataset
@eval_dataset.setter
def eval_dataset(self, dataset):
self._eval_dataset = dataset
@abstractmethod
def build(self, total_num_steps):
pass
@abstractmethod
def get_callbacks(self):
pass
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
class HFCausalTrainerBuilder(TrainerBuilderBase):
"""
Build the HuggingFace training args/trainer for Causal models
"""
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
return training_arguments_kwargs
def hook_post_create_training_args(self, training_arguments):
# TODO
return training_arguments
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
# TODO
return trainer_kwargs, trainer_cls
def hook_post_create_trainer(self, trainer):
if self.cfg.tensor_parallel:
trainer.model = trainer.accelerator.prepare_model(
trainer.model, device_placement=True
)
return trainer
def get_callbacks(self):
callbacks = []
callbacks.append(GPUStatsCallback(self.cfg))
callbacks.append(EvalFirstStepCallback)
if self.cfg.relora_steps:
callbacks.append(ReLoRACallback(self.cfg))
if (
hasattr(self.model, "use_bettertransformer")
and self.model.use_bettertransformer is True
):
callbacks.append(SaveBetterTransformerModelCallback)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
)
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer
)
callbacks.append(LogPredictionCallback(self.cfg))
if self.cfg.do_bench_eval:
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
if self.cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
self.cfg.early_stopping_patience,
)
callbacks.append(early_stop_cb)
return callbacks
def _get_trainer_cls(self):
if self.cfg.lr_scheduler == "one_cycle" and (
self.cfg.fsdp or self.cfg.adapter == "qlora"
):
return OneCycleLRSchedulerTrainer
if self.cfg.relora_steps:
return ReLoRATrainer
return AxolotlTrainer
def build(self, total_num_steps):
warmup_steps = (
self.cfg.warmup_steps
if self.cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
self.cfg.logging_steps
if self.cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
training_arguments_kwargs = {}
if self.cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = self.cfg.bf16
training_arguments_kwargs["fp16"] = (
self.cfg.fp16 and not self.cfg.bf16
) or False
training_arguments_kwargs["tf32"] = self.cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if self.cfg.seed:
training_arguments_kwargs["seed"] = self.cfg.seed
if self.cfg.gradient_checkpointing:
training_arguments_kwargs[
"gradient_checkpointing"
] = self.cfg.gradient_checkpointing
if self.cfg.fsdp:
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
if self.cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
# deepspeed
if self.cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
if self.cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs[
"lr_quadratic_warmup"
] = self.cfg.lr_quadratic_warmup
if self.cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
if self.cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
if self.cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
if self.cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
if self.cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if self.cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
if self.cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
if self.cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = self.cfg.sample_packing_eff_est
if self.cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
elif self.cfg.evaluation_strategy:
training_arguments_kwargs[
"evaluation_strategy"
] = self.cfg.evaluation_strategy
elif self.cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if self.cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch"
if self.cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
if self.cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
if self.cfg.metric_for_best_model:
training_arguments_kwargs[
"metric_for_best_model"
] = self.cfg.metric_for_best_model
if self.cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
if self.cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
elif torch._dynamo: # pylint: disable=protected-access
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
if self.cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = self.cfg.torch_compile_backend
# DDP Config
if self.cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if self.cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
if self.cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs[
"ddp_broadcast_buffers"
] = self.cfg.ddp_broadcast_buffers
# these are all the "standard" kwargs that are def used
training_arguments_kwargs["max_steps"] = (
total_num_steps if self.cfg.max_steps else -1
)
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
training_arguments_kwargs[
"per_device_train_batch_size"
] = self.cfg.micro_batch_size
training_arguments_kwargs[
"per_device_eval_batch_size"
] = self.cfg.eval_batch_size
training_arguments_kwargs[
"gradient_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs[
"eval_accumulation_steps"
] = self.cfg.gradient_accumulation_steps
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
training_arguments_kwargs["save_total_limit"] = (
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
)
training_arguments_kwargs["load_best_model_at_end"] = (
(
self.cfg.load_best_model_at_end is not False
or self.cfg.early_stopping_patience
)
and self.cfg.val_set_size > 0
and self.cfg.save_steps
and self.cfg.eval_steps
and self.cfg.save_steps % self.cfg.eval_steps == 0
) or False
training_arguments_kwargs["ddp_find_unused_parameters"] = (
False if self.cfg.ddp else None
)
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
training_arguments_kwargs["run_name"] = (
self.cfg.wandb_run_id if self.cfg.use_wandb else None
)
training_arguments_kwargs["optim"] = (
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
)
training_arguments_kwargs["lr_scheduler_type"] = (
self.cfg.lr_scheduler
if self.cfg.lr_scheduler
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine"
)
training_arguments_kwargs["weight_decay"] = (
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
)
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs[
"sample_packing_seq_len_multiplier"
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if self.cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
self.cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(self.model, self.tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [self.train_dataset, self.eval_dataset]:
dataset = dataset.map(
partial(
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
),
batched=False,
num_proc=32,
)
trainer_cls = self._get_trainer_cls()
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
trainer_kwargs, trainer_cls
)
trainer = trainer_cls(
model=self.model,
train_dataset=self.train_dataset,
eval_dataset=self.eval_dataset,
args=training_args,
data_collator=DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=self.get_callbacks(),
num_epochs=self.cfg.num_epochs,
**trainer_kwargs,
)
trainer = self.hook_post_create_trainer(trainer)
for callback in self.get_post_trainer_create_callbacks(trainer):
trainer.add_callback(callback)
return trainer

View File

@@ -5,7 +5,7 @@ import os
from typing import List from typing import List
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset, Sequence, Value
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
@@ -42,11 +42,15 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100 map_kwargs["batch_size"] = 100
return dataset.map( return (
self.prompt_tokenizer.tokenize_prompt, dataset.map(
num_proc=num_proc, self.prompt_tokenizer.tokenize_prompt,
remove_columns=features, num_proc=num_proc,
**map_kwargs, remove_columns=features,
**map_kwargs,
)
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
) )

View File

@@ -13,18 +13,12 @@ import transformers
from einops import rearrange from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input from flash_attn.bert_padding import pad_input, unpad_input
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.llama.modeling_llama import LlamaAttention
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OriginalLlamaDecoderLayer, LlamaDecoderLayer as OriginalLlamaDecoderLayer,
) )
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
LlamaMLP,
apply_rotary_pos_emb,
repeat_kv,
)
from xformers.ops import SwiGLU
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
try: try:
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
@@ -44,28 +38,6 @@ except ImportError:
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def replace_llama_mlp_with_swiglu(model):
for name, module in model.named_modules():
if isinstance(module, LlamaMLP):
mlp = FusedMLP(
module.config, module.gate_proj, module.up_proj, module.down_proj
)
set_module_name(model, name, mlp)
def replace_llama_qkv_with_fused(model):
for name, module in model.named_modules():
if isinstance(module, LlamaAttention):
qkv = FusedAttention(
module.config,
module.q_proj,
module.k_proj,
module.v_proj,
module.o_proj,
)
set_module_name(model, name, qkv)
def replace_llama_attn_with_flash_attn( def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False, packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False, cross_entropy: Optional[bool] = False,
@@ -114,92 +86,6 @@ def replace_llama_attn_with_flash_attn(
) )
class FusedAttention(LlamaAttention):
"""
Fused QKV Attention layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
q: torch.nn.Linear, # pylint: disable=invalid-name
k: torch.nn.Linear, # pylint: disable=invalid-name
v: torch.nn.Linear, # pylint: disable=invalid-name
o: torch.nn.Linear, # pylint: disable=invalid-name
):
super().__init__(config)
self.config = config
self.init_device = next(iter(q.state_dict().values())).device
# define equivalent fused qkv projection
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
self.qkv_proj = torch.nn.Linear(
q.in_features, sum(self.out_features), device=self.init_device, bias=False
)
self.o_proj = o
# overwrite initialized weights with pretrained weights
self.qkv_proj.weight.data = torch.cat(
(q.weight.data, k.weight.data, v.weight.data), dim=0
)
def _post_training(self, model, name):
q_proj, k_proj, v_proj = torch.split(
self.qkv_proj.weight.data, self.out_features, dim=0
)
new_attn = LlamaAttention(self.config)
new_attn.q_proj.weight.data = q_proj
new_attn.k_proj.weight.data = k_proj
new_attn.v_proj.weight.data = v_proj
new_attn.o_proj.weight.data = self.o_proj.weight.data
set_module_name(model, name, new_attn)
class FusedMLP(torch.nn.Module):
"""
Fused MLP layer for incrementally improved training efficiency
"""
def __init__(
self,
config,
gate_proj: torch.nn.Linear,
up_proj: torch.nn.Linear,
down_proj: torch.nn.Linear,
):
super().__init__()
self.config = config
self.swiglu = SwiGLU(
in_features=config.hidden_size,
hidden_features=config.intermediate_size,
bias=False,
_pack_weights=True,
)
# overwrite initialized weights with pretrained weights
self.swiglu.w12.weight.data = torch.cat(
(gate_proj.weight.data, up_proj.weight.data), dim=0
)
self.swiglu.w3.weight.data = down_proj.weight.data
def _post_training(self, model, name):
w1, w2 = torch.split( # pylint: disable=invalid-name
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
)
# Assign the split weights back to the original layers
new_mlp = LlamaMLP(self.config)
new_mlp.gate_proj.weight.data = w1
new_mlp.up_proj.weight.data = w2
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
set_module_name(model, name, new_mlp)
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
return self.swiglu(x)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
@@ -261,14 +147,9 @@ def flashattn_forward(
value_states = torch.cat(value_states, dim=-1) value_states = torch.cat(value_states, dim=-1)
else: else:
if isinstance(self, FusedAttention): query_states = self.q_proj(hidden_states)
query_states, key_states, value_states = self.qkv_proj(hidden_states).split( key_states = self.k_proj(hidden_states)
self.out_features, dim=-1 value_states = self.v_proj(hidden_states)
)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view( query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim bsz, q_len, self.num_heads, self.head_dim

View File

@@ -14,9 +14,6 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
flash_attn_varlen_qkvpacked_func, flash_attn_varlen_qkvpacked_func,
) )
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.mistral.modeling_mistral import (
MistralAttention as OriginalMistralAttention,
)
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OriginalMistralDecoderLayer, MistralDecoderLayer as OriginalMistralDecoderLayer,
) )
@@ -45,44 +42,6 @@ def replace_mistral_attn_with_flash_attn(
) )
@torch.jit.script
def _make_sliding_window_causal_mask(
bsz: int,
tgt_len: int,
dtype: torch.dtype,
device: torch.device,
past_key_values_length: int = 0,
sliding_window: int = 4096,
):
"""
Make causal mask used for sliding window attention
"""
tensor = torch.full(
(tgt_len, tgt_len),
fill_value=1,
device=device,
)
mask = torch.tril(tensor, diagonal=0)
# make the mask banded to account for sliding window
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
mask = torch.triu(mask, diagonal=-sliding_window + 1)
mask = torch.log(mask).to(dtype)
if past_key_values_length > 0:
mask = torch.cat(
[
torch.zeros(
tgt_len, past_key_values_length, dtype=dtype, device=device
),
mask,
],
dim=-1,
)
return mask[None, None, :, :].expand(
bsz, 1, tgt_len, tgt_len + past_key_values_length
)
# Disable the transformation of the attention mask in LlamaModel as the flash attention # Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask # requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask( def _prepare_decoder_attention_mask(
@@ -94,29 +53,11 @@ def _prepare_decoder_attention_mask(
sliding_window, sliding_window,
): # pylint: disable=unused-argument ): # pylint: disable=unused-argument
# [bsz, seq_len] # [bsz, seq_len]
if attention_mask is None:
return attention_mask
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
sliding_window_mask = _make_sliding_window_causal_mask(
bsz=input_shape[0],
tgt_len=input_shape[1],
dtype=inputs_embeds.dtype,
device=inputs_embeds.device,
past_key_values_length=past_key_values_length,
sliding_window=sliding_window,
)
attention_mask = attention_mask + sliding_window_mask
else:
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
return attention_mask return attention_mask
def flashattn_forward( def flashattn_forward(
self: OriginalMistralAttention, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None,
@@ -150,41 +91,10 @@ def flashattn_forward(
query_states, key_states, cos, sin, position_ids query_states, key_states, cos, sin, position_ids
) )
use_sliding_windows = (
hasattr(self.config, "sliding_window") is not None
and kv_seq_len > self.config.sliding_window
)
if use_sliding_windows:
window_size = (self.config.sliding_window, self.config.sliding_window)
else:
window_size = (-1, -1)
if past_key_value is not None: if past_key_value is not None:
# Activate slicing cache only if the config has a value `sliding_windows` attribute # reuse k, v, self_attention
if ( key_states = torch.cat([past_key_value[0], key_states], dim=2)
hasattr(self.config, "sliding_window") value_states = torch.cat([past_key_value[1], value_states], dim=2)
and kv_seq_len > self.config.sliding_window
):
slicing_tokens = kv_seq_len - self.config.sliding_window
past_key = past_key_value[0]
past_value = past_key_value[1]
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
if past_key.shape[-2] != self.config.sliding_window - 1:
raise ValueError(
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
f" {past_key.shape}"
)
past_key_value = (past_key, past_value) if use_cache else None
if past_key_value is not None:
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None past_key_value = (key_states, value_states) if use_cache else None
@@ -210,13 +120,7 @@ def flashattn_forward(
qkv = rearrange(qkv, "b s ... -> (b s) ...") qkv = rearrange(qkv, "b s ... -> (b s) ...")
output = flash_attn_varlen_qkvpacked_func( output = flash_attn_varlen_qkvpacked_func(
qkv, qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
cu_seqlens,
max_seqlen,
0.0,
softmax_scale=None,
causal=True,
window_size=window_size,
) )
output = rearrange(output, "(b s) ... -> b s ...", b=bsz) output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
elif query_states.shape == key_states.shape: elif query_states.shape == key_states.shape:
@@ -242,7 +146,6 @@ def flashattn_forward(
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)
else: else:
@@ -254,7 +157,6 @@ def flashattn_forward(
query_states, query_states,
torch.stack([key_states, value_states], 2), torch.stack([key_states, value_states], 2),
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
else: else:
( # pylint: disable=unbalanced-tuple-unpacking ( # pylint: disable=unbalanced-tuple-unpacking
@@ -289,7 +191,6 @@ def flashattn_forward(
0.0, 0.0,
softmax_scale=None, softmax_scale=None,
causal=is_causal, causal=is_causal,
window_size=window_size,
) )
output = output_pad_fn(output_unpad) output = output_pad_fn(output_unpad)

View File

@@ -1,65 +0,0 @@
"""
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
from peft import PeftModel
from transformers import PreTrainedModel
def patch_neft(alpha, model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
embeddings.noisy_embedding_alpha = alpha
old_forward = embeddings.forward
# This hack seems to be needed to properly use a custom forward pass
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
embeddings, embeddings.__class__
)
setattr(embeddings, "forward", bound_method)
embeddings._old_forward = old_forward # pylint: disable=protected-access
return model
def unpatch_neft(model):
embeddings = None
if isinstance(model, PreTrainedModel):
embeddings = model.get_input_embeddings()
if isinstance(model, PeftModel):
embeddings = model.base_model.get_input_embeddings()
if not embeddings:
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
if hasattr(embeddings, "_old_forward"):
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
del embeddings._old_forward # pylint: disable=protected-access
del embeddings.noisy_embedding_alpha
def neft_forward(self, inputs: torch.Tensor):
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
if self.training:
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
-mag_norm, mag_norm
)
return embeddings
def pretrain_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
def post_train_hook(cfg, trainer):
if cfg.noisy_embedding_alpha:
unpatch_neft(trainer.model)

View File

@@ -101,16 +101,3 @@ def get_cu_seqlens_from_pos_ids(position_ids):
max_seq_lens.append(max_seq_len) max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens) return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value):
if "." in name:
parent_name = name.rsplit(".", 1)[0]
child_name = name[len(parent_name) + 1 :]
parent = model.get_submodule(parent_name)
else:
parent_name = ""
parent = model
child_name = name
setattr(parent, child_name, value)

View File

@@ -1,6 +1,6 @@
"""Module for Alpaca prompt strategy classes""" """Module containing the AlpacaQAPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional, Tuple from typing import Tuple
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
AlpacaPromptTokenizingStrategy, AlpacaPromptTokenizingStrategy,
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): def load(tokenizer, cfg):
prompt_style = PromptStyle.CHAT.value
if ds_cfg and "conversation" in ds_cfg:
prompt_style = ds_cfg["conversation"]
return AlpacaPromptTokenizingStrategy( return AlpacaPromptTokenizingStrategy(
AlpacaPrompter(prompt_style), AlpacaPrompter(PromptStyle.CHAT.value),
tokenizer, tokenizer,
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,

View File

@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) )
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strategy = SimpleShareGPTPromptTokenizingStrategy( strat = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2( ShareGPTPrompterV2(
conversation=conversation, conversation=conversation,
role_key_model=field_model, role_key_model=field_model,
@@ -34,9 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and "strict" in ds_cfg: if ds_cfg and ds_cfg["skip"]:
strategy.strict = ds_cfg["strict"] strat.skip_invalid = True
return strategy return strat
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
@@ -57,30 +57,37 @@ def load_guanaco(tokenizer, cfg):
) )
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return NousShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
""" """
basic sharegpt strategy to grab conversations from the sample row basic sharegpt strategy used by nous/ldj for input/output keyed data
""" """
_strict = True def get_conversation_thread(self):
return "conversation"
@property def map_conversation_thread(self, conversation):
def strict(self): turns = []
return self._strict for turn in conversation:
turns.append({"from": "human", "value": turn["input"]})
@strict.setter turns.append({"from": "gpt", "value": turn["output"]})
def strict(self, strict):
self._strict = strict
def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
if self.strict:
return conversations
# remap roles - allow for assistant turn
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
turns = [
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
]
return turns return turns
@@ -89,10 +96,11 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
def get_conversation_thread(self, prompt): def map_conversation_thread(self, conversation):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ... # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations] turns = [
{"from": turn["role"], "value": turn["value"]} for turn in conversation
]
return turns return turns
@@ -101,11 +109,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
sharegpt strategy that remaps oasst data to sharegpt format sharegpt strategy that remaps oasst data to sharegpt format
""" """
def get_conversation_thread(self, prompt): def map_conversation_thread(self, conversation):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ... # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"} role_map = {"prompter": "human", "assistant": "gpt"}
turns = [ turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations {"from": role_map[turn["role"]], "value": turn["text"]}
for turn in conversation
] ]
return turns return turns

View File

@@ -2,7 +2,9 @@
import abc import abc
import copy import copy
import functools
import logging import logging
from collections import defaultdict
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
@@ -45,8 +47,6 @@ class PromptTokenizingStrategy(abc.ABC):
self.prompter = prompter self.prompter = prompter
self.tokenizer: PreTrainedTokenizer = tokenizer self.tokenizer: PreTrainedTokenizer = tokenizer
self.train_on_inputs = train_on_inputs self.train_on_inputs = train_on_inputs
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
# TODO: Document how they are different.
self.sequence_len = sequence_len self.sequence_len = sequence_len
self.max_length = sequence_len self.max_length = sequence_len
@@ -58,34 +58,57 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self): def supports_batched(self):
return False return False
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
@functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
def _tokenize( def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding: ) -> BatchEncoding:
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []}) result: BatchEncoding
if not prompt: if not prompt:
LOG.warning("Empty text requested for tokenization.") LOG.warning("Empty text requested for tokenization.")
return empty result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer( result = self.tokenizer(
prompt, prompt,
truncation=True, truncation=True,
max_length=self.max_length, max_length=self.max_length,
padding=False, padding=False,
return_tensors=None, return_tensors=None,
) )
if len(result["input_ids"]) == 0: if len(result["input_ids"]) == 0:
LOG.warning("Tokenizer result is empty. You may want to audit your dataset") LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
return empty
if ( if (
result["input_ids"][-1] != self.tokenizer.eos_token_id len(result["input_ids"]) > 0
and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(result["input_ids"]) < self.max_length and len(result["input_ids"]) < self.max_length
and add_eos_token and add_eos_token
): ):
result["input_ids"].append(self.tokenizer.eos_token_id) result["input_ids"].append(self.tokenizer.eos_token_id)
result["attention_mask"].append(1) result["attention_mask"].append(1)
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token: if (
len(result["input_ids"]) > 0
and result["input_ids"][0] == self.tokenizer.bos_token_id
and strip_bos_token
):
result["input_ids"] = result["input_ids"][1:] result["input_ids"] = result["input_ids"][1:]
result["attention_mask"] = result["attention_mask"][1:] result["attention_mask"] = result["attention_mask"][1:]
@@ -121,7 +144,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
if not self.train_on_inputs: if not self.train_on_inputs:
user_prompt_len = len(tokenized_prompt["input_ids"]) user_prompt_len = len(tokenized_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len tokenized_prompt["labels"] = [-100] * user_prompt_len
tokenized_res_prompt = self._tokenize( tokenized_res_prompt = self._tokenize(
response, strip_bos_token=True, add_eos_token=True response, strip_bos_token=True, add_eos_token=True
) )
@@ -245,7 +268,6 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
raise NotImplementedError raise NotImplementedError
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# pylint: disable=duplicate-code
( (
instruction, instruction,
input, # pylint: disable=redefined-builtin input, # pylint: disable=redefined-builtin
@@ -270,7 +292,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
user_prompt_len = len(tokenized_user_prompt["input_ids"]) user_prompt_len = len(tokenized_user_prompt["input_ids"])
# TODO this could be sped up using numpy array slicing # TODO this could be sped up using numpy array slicing
tokenized_full_prompt["labels"] = [ tokenized_full_prompt["labels"] = [
IGNORE_INDEX -100
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:] ] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
return tokenized_full_prompt return tokenized_full_prompt
@@ -330,105 +352,141 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for ShareGPT prompts. Tokenizing strategy for ShareGPT prompts.
""" """
def get_conversation_thread(self, prompt): _skip_invalid = False
return prompt["conversations"]
@property
def supports_batched(self):
return True
@property
def skip_invalid(self):
return self._skip_invalid
@skip_invalid.setter
def skip_invalid(self, value):
self._skip_invalid = value
def get_conversation_thread(self):
return "conversations"
def map_conversation_thread(self, conversation):
return conversation
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
# Initial values. We will append to these as we go through the conversation. tokenized_res = defaultdict(lambda: [])
result, current_len = tokenize_prompt_default() conv_field = self.get_conversation_thread()
conversation: Conversation = ( for prmpt in prompt[conv_field]:
self.prompter._conversation.copy() # pylint: disable=protected-access result, current_len = tokenize_prompt_default()
) user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
)
try:
for _, part in enumerate(
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
if not part[1].strip():
err_msg = f"user turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
err_msg = f"assistant turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
err_msg = f"unhandled role: {part[0]}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
continue
# support for custom roles from the dataset, only useful for vicuna style prompts/roles # pylint: disable=duplicate-code
role_remap = [] result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
for key, val in sorted(result.items(), key=lambda x: x[0]):
tokenized_res[key].append(val)
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
except ValueError as err:
LOG.warning("skipping prompt: %s", str(err))
return tokenized_res
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
if not prompt.strip():
LOG.warning("Empty text requested for tokenization.")
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
else:
result = self.tokenizer(
prompt,
truncation=True,
max_length=self.sequence_len,
padding=False,
return_tensors=None,
)
if ( if (
conversation.name == "vicuna_v1.1" len(result["input_ids"]) > 0
and "roles" in prompt and result["input_ids"][-1] != self.tokenizer.eos_token_id
and len(prompt["roles"]) >= 2 and len(result["input_ids"]) < self.sequence_len
and add_eos_token
): ):
role_remap = [ result["input_ids"].append(self.tokenizer.eos_token_id)
{"from": conversation.roles[0], "to": prompt["roles"][0]}, result["attention_mask"].append(1)
{"from": conversation.roles[1], "to": prompt["roles"][1]},
]
try: if (
for _, part in enumerate( len(result["input_ids"]) > 0
self.prompter.build_prompt(self.get_conversation_thread(prompt)) and result["input_ids"][0] == self.tokenizer.bos_token_id
): and strip_bos_token
if not isinstance(part, tuple): ):
LOG.warning(f"expected tuple, got {part}") result["input_ids"] = result["input_ids"][1:]
continue result["attention_mask"] = result["attention_mask"][1:]
user, assistant = conversation.roles result["labels"] = result["input_ids"].copy()
role, content = part return result
# Uses "in" because role contains extra characters
if user in role:
role = (
role.replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else role
)
turn = role + content
# this is still the user query, we should
if not content.strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif assistant in role:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
role = (
role.replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else role
)
turn = role + content
# this should be the assistant response, should end with an eos token
if not content.strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
elif role == "":
turn = content
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {role}")
continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]: def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:

View File

@@ -4,12 +4,10 @@ import logging
from enum import Enum from enum import Enum
from typing import Generator, Optional, Union from typing import Generator, Optional, Union
from colorama import Fore
from fastchat.conversation import Conversation, get_conv_template from fastchat.conversation import Conversation, get_conv_template
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
IGNORE_TOKEN_ID = -100 IGNORE_TOKEN_ID = -100
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
class PromptStyle(Enum): class PromptStyle(Enum):
@@ -57,15 +55,20 @@ class AlpacaPrompter:
) )
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n" self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def _build_result(self, instruction, input_text, output): def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input_text: if input:
res = ( res = (
self.system_format.format(system=self.system_prompt) self.system_format.format(system=self.system_prompt)
if self.system_prompt if self.system_prompt
else "" else ""
) + self.turn_format.format(instruction=instruction, input=input_text) ) + self.turn_format.format(instruction=instruction, input=input)
else: else:
res = ( res = (
self.system_format.format(system=self.system_no_input_prompt) self.system_format.format(system=self.system_no_input_prompt)
@@ -74,21 +77,7 @@ class AlpacaPrompter:
) + self.turn_no_input_format.format(instruction=instruction) ) + self.turn_no_input_format.format(instruction=instruction)
if output: if output:
res = f"{res}{output}" res = f"{res}{output}"
yield res
return res
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
) -> Generator[str, None, None]:
yield self._build_result(instruction, input, output)
def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)
class UnpromptedPrompter(AlpacaPrompter): class UnpromptedPrompter(AlpacaPrompter):
@@ -202,14 +191,14 @@ class ReflectAlpacaPrompter:
) )
self.response_split = "ASSISTANT:" self.response_split = "ASSISTANT:"
def _build_result( def build_prompt(
self, self,
instruction: str, instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None, output: Union[None, str] = None,
reflection: Union[None, str] = None, reflection: Union[None, str] = None,
corrected: Union[None, str] = None, corrected: Union[None, str] = None,
): ) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input # returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended. # if a label (=response, =output) is provided, it's also appended.
if input: if input:
@@ -223,30 +212,7 @@ class ReflectAlpacaPrompter:
corrected=corrected, corrected=corrected,
) )
res = f"{res}{label}" res = f"{res}{label}"
yield res
return res
def build_prompt(
self,
instruction: str,
input: Union[None, str] = None, # pylint: disable=redefined-builtin
output: Union[None, str] = None,
reflection: Union[None, str] = None,
corrected: Union[None, str] = None,
) -> Generator[str, None, None]:
# pylint: disable=duplicate-code
yield self._build_result(
instruction,
input,
output,
reflection,
corrected,
)
def __repr__(self) -> str:
return REPR_TEMPLATE.format(
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
)
SHAREGPT_ASSERTION_FAILED_ROLE = ( SHAREGPT_ASSERTION_FAILED_ROLE = (
@@ -281,7 +247,7 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
if role_key_model: if role_key_model:
self.role_key_model = role_key_model self.role_key_model = role_key_model
def _build_result(self, source): def build_prompt(self, source) -> Generator[str, None, None]:
if len(source) < 2: if len(source) < 2:
# If there isn't a back and forth conversation, ignore it # If there isn't a back and forth conversation, ignore it
# also happens on the data splitting leaving empty conversations # also happens on the data splitting leaving empty conversations
@@ -308,28 +274,17 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
raise err raise err
conv.messages = [] conv.messages = []
for _, sentence in enumerate(source): for j, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
if len(conv.messages) > 0 and ( if role != conv.roles[j % 2]:
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])
return conv.get_turns() for part in conv.get_turns():
def build_prompt(self, source) -> Generator[str, None, None]:
turns = self._build_result(source)
for part in turns:
if part[0] and not part[1]: if part[0] and not part[1]:
LOG.warning(f"role with empty message: {part[0]}") LOG.warning(f"role with empty message: {part[0]}")
yield part yield part
def __repr__(self) -> str:
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
class ShareGPTPrompterV2(ShareGPTPrompter): class ShareGPTPrompterV2(ShareGPTPrompter):
""" """
@@ -347,15 +302,3 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
role_key_human=role_key_human, role_key_human=role_key_human,
role_key_model=role_key_model, role_key_model=role_key_model,
) )
class UnsupportedPrompter:
"""
A dummy class for custom prompters
"""
def __init__(self) -> None:
pass
def __repr__(self):
return "Pre-tokenized or custom dataset types are unsupported for logging"

View File

@@ -12,11 +12,9 @@ import torch
import transformers.modelcard import transformers.modelcard
from datasets import Dataset from datasets import Dataset
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers.deepspeed import is_deepspeed_zero3_enabled
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.logging_config import configure_logging from axolotl.logging_config import configure_logging
from axolotl.monkeypatch import neft_embeddings
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import setup_trainer
@@ -41,7 +39,10 @@ class TrainDatasetMeta:
def train( def train(
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta *,
cfg: DictDefault,
cli_args: TrainerCliArgs,
dataset_meta: TrainDatasetMeta,
): ):
# load the tokenizer first # load the tokenizer first
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}") LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
@@ -57,7 +58,9 @@ def train(
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: if (
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
) or cfg.resume_from_checkpoint is True:
possible_checkpoints = [ possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
] ]
@@ -70,7 +73,9 @@ def train(
LOG.info( LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
) )
resume_from_checkpoint = cfg.resume_from_checkpoint resume_from_checkpoint = (
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
)
trainer = setup_trainer( trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
@@ -108,7 +113,6 @@ def train(
if cfg.group_by_length: if cfg.group_by_length:
LOG.info("hang tight... sorting dataset for group_by_length") LOG.info("hang tight... sorting dataset for group_by_length")
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum: if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel( with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=True, enable_mem_efficient=True enable_flash=True, enable_math=True, enable_mem_efficient=True
@@ -116,15 +120,9 @@ def train(
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else: else:
trainer.train(resume_from_checkpoint=resume_from_checkpoint) trainer.train(resume_from_checkpoint=resume_from_checkpoint)
post_train_hooks(cfg, trainer)
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
# post training
for name, module in model.named_modules():
if hasattr(module, "_post_training"):
module._post_training(model, name) # pylint: disable=protected-access
if trainer.is_fsdp_enabled: if trainer.is_fsdp_enabled:
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT") trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.") LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
@@ -140,22 +138,6 @@ def train(
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.fsdp:
trainer.save_model(cfg.output_dir) trainer.save_model(cfg.output_dir)
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
trainer.accelerator.wait_for_everyone()
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
# The model name saved is `pytorch_model.bin`
unwrapped_model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)
@@ -166,23 +148,3 @@ def train(
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./")) trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
return model, tokenizer return model, tokenizer
def pretrain_hooks(cfg, trainer):
"""
Run hooks right before kicking off the training
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.pretrain_hook(cfg, trainer)
def post_train_hooks(cfg, trainer):
"""
Run hooks right after training completes
:param cfg:
:param trainer:
:return:
"""
neft_embeddings.post_train_hook(cfg, trainer)

View File

@@ -1,13 +1,10 @@
"""Benchmarking and measurement utilities""" """Benchmarking and measurement utilities"""
import functools import functools
import logging
import pynvml import pynvml
import torch import torch
from pynvml.nvml import NVMLError from pynvml.nvml import NVMLError
LOG = logging.getLogger("axolotl.utils.bench")
def check_cuda_device(default_value): def check_cuda_device(default_value):
""" """
@@ -65,14 +62,7 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device): def log_gpu_memory_usage(log, msg, device):
if not torch.cuda.is_available(): usage, cache, misc = gpu_memory_usage_all(device)
return (0, 0, 0)
try:
usage, cache, misc = gpu_memory_usage_all(device)
except ValueError as exc:
LOG.exception(exc)
return (0, 0, 0)
extras = [] extras = []
if cache > 0: if cache > 0:
extras.append(f"+{cache:.03f}GB cache") extras.append(f"+{cache:.03f}GB cache")

View File

@@ -37,7 +37,7 @@ from axolotl.utils.distributed import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from axolotl.core.trainer_builder import AxolotlTrainingArguments from axolotl.utils.trainer import AxolotlTrainingArguments
LOG = logging.getLogger("axolotl.callbacks") LOG = logging.getLogger("axolotl.callbacks")
IGNORE_INDEX = -100 IGNORE_INDEX = -100
@@ -514,27 +514,3 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
return control return control
return LogPredictionCallback return LogPredictionCallback
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
"""Callback to save axolotl config to wandb"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
artifact = wandb.Artifact(name="axolotl-config", type="config")
artifact.add_file(local_path=self.axolotl_config_path)
wandb.run.log_artifact(artifact)
LOG.info("Axolotl config has been saved to WandB as an artifact.")
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control

View File

@@ -79,9 +79,6 @@ def normalize_config(cfg):
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count() cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
if not cfg.base_model_config:
cfg.base_model_config = cfg.base_model
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
cfg.model_config_type = model_config.model_type cfg.model_config_type = model_config.model_type
@@ -122,9 +119,6 @@ def normalize_config(cfg):
or (cfg.model_type and "mistral" in cfg.model_type.lower()) or (cfg.model_type and "mistral" in cfg.model_type.lower())
) )
if isinstance(cfg.learning_rate, str):
cfg.learning_rate = float(cfg.learning_rate)
log_gpu_memory_usage(LOG, "baseline", cfg.device) log_gpu_memory_usage(LOG, "baseline", cfg.device)
@@ -195,15 +189,9 @@ def validate_config(cfg):
if not cfg.load_in_4bit: if not cfg.load_in_4bit:
raise ValueError("Require cfg.load_in_4bit to be True for qlora") raise ValueError("Require cfg.load_in_4bit to be True for qlora")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with QLoRA")
if not cfg.load_in_8bit and cfg.adapter == "lora": if not cfg.load_in_8bit and cfg.adapter == "lora":
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning") LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
raise ValueError("Fused modules are not supported with LoRA")
if cfg.relora_steps: if cfg.relora_steps:
if cfg.adapter not in ("lora", "qlora"): if cfg.adapter not in ("lora", "qlora"):
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA") raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
@@ -217,9 +205,6 @@ def validate_config(cfg):
if cfg.lr_scheduler == "one_cycle": if cfg.lr_scheduler == "one_cycle":
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler") raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
raise ValueError("Fused modules are not supported with ReLoRA")
if cfg.trust_remote_code: if cfg.trust_remote_code:
LOG.warning( LOG.warning(
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model." "`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
@@ -354,25 +339,6 @@ def validate_config(cfg):
"eval_steps and evaluation_strategy are not supported with val_set_size == 0" "eval_steps and evaluation_strategy are not supported with val_set_size == 0"
) )
if (
cfg.sample_packing
and cfg.eval_table_size
and cfg.eval_sample_packing is not False
):
raise ValueError(
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
)
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
raise ValueError(
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
)
if cfg.tensor_parallel and cfg.gradient_checkpointing:
raise ValueError(
"TensorParallelPreTrainedModel does not support gradient checkpointing"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -3,7 +3,7 @@ import functools
import hashlib import hashlib
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
import torch import torch
from datasets import ( from datasets import (
@@ -16,7 +16,6 @@ from datasets import (
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
@@ -36,7 +35,6 @@ from axolotl.prompters import (
MultipleChoiceExplainPrompter, MultipleChoiceExplainPrompter,
ReflectAlpacaPrompter, ReflectAlpacaPrompter,
SummarizeTLDRPrompter, SummarizeTLDRPrompter,
UnsupportedPrompter,
) )
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import is_main_process, zero_first from axolotl.utils.distributed import is_main_process, zero_first
@@ -46,6 +44,7 @@ from axolotl.utils.trainer import (
) )
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def md5(to_hash: str, encoding: str = "utf-8") -> str: def md5(to_hash: str, encoding: str = "utf-8") -> str:
@@ -56,10 +55,9 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
def prepare_dataset(cfg, tokenizer): def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset, prompters = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )
else: else:
@@ -72,7 +70,7 @@ def prepare_dataset(cfg, tokenizer):
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, cfg.max_steps, prompters return train_dataset, eval_dataset, cfg.max_steps
with zero_first(is_main_process()): with zero_first(is_main_process()):
train_dataset, eval_dataset = process_datasets_for_packing( train_dataset, eval_dataset = process_datasets_for_packing(
@@ -85,7 +83,7 @@ def prepare_dataset(cfg, tokenizer):
LOG.info(f"Maximum number of steps set at {total_num_steps}") LOG.info(f"Maximum number of steps set at {total_num_steps}")
else: else:
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer) total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
return train_dataset, eval_dataset, total_num_steps, prompters return train_dataset, eval_dataset, total_num_steps
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
@@ -111,7 +109,6 @@ def load_tokenized_prepared_datasets(
else Path(default_dataset_prepared_path) / ds_hash else Path(default_dataset_prepared_path) / ds_hash
) )
dataset = None dataset = None
prompters = []
use_auth_token = cfg.hf_use_auth_token use_auth_token = cfg.hf_use_auth_token
try: try:
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
@@ -150,48 +147,50 @@ def load_tokenized_prepared_datasets(
yield dataset yield dataset
# pylint: disable=invalid-name # pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg.datasets): for d in for_d_in_datasets(cfg.datasets):
ds: Union[Dataset, DatasetDict] = None ds: Union[Dataset, DatasetDict] = None
ds_from_hub = False ds_from_hub = False
try: try:
load_dataset( load_dataset(
config_dataset.path, d.path,
name=config_dataset.name, name=d.name,
streaming=True, streaming=True,
token=use_auth_token, token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ConnectionError): except (FileNotFoundError, ValueError):
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(config_dataset.path) local_path = Path(d.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk` if not d.type:
ds = load_dataset( ds = load_from_disk(d.path)
config_dataset.path, else:
name=config_dataset.name, ds = load_dataset(
data_files=config_dataset.data_files, d.path,
streaming=False, name=d.name,
split=None, data_files=d.data_files,
) streaming=False,
split=None,
)
elif local_path.is_file(): elif local_path.is_file():
ds_type = "json" ds_type = "json"
if config_dataset.ds_type: if d.ds_type:
ds_type = config_dataset.ds_type ds_type = d.ds_type
elif ".parquet" in config_dataset.path: elif ".parquet" in d.path:
ds_type = "parquet" ds_type = "parquet"
elif ".arrow" in config_dataset.path: elif ".arrow" in d.path:
ds_type = "arrow" ds_type = "arrow"
elif ".csv" in config_dataset.path: elif ".csv" in d.path:
ds_type = "csv" ds_type = "csv"
elif ".txt" in config_dataset.path: elif ".txt" in d.path:
ds_type = "text" ds_type = "text"
ds = load_dataset( ds = load_dataset(
ds_type, ds_type,
name=config_dataset.name, name=d.name,
data_files=config_dataset.path, data_files=d.path,
streaming=False, streaming=False,
split=None, split=None,
) )
@@ -201,25 +200,25 @@ def load_tokenized_prepared_datasets(
) )
elif ds_from_hub: elif ds_from_hub:
ds = load_dataset( ds = load_dataset(
config_dataset.path, d.path,
name=config_dataset.name, name=d.name,
streaming=False, streaming=False,
data_files=config_dataset.data_files, data_files=d.data_files,
token=use_auth_token, token=use_auth_token,
) )
else: else:
if isinstance(config_dataset.data_files, str): if isinstance(d.data_files, str):
fp = hf_hub_download( fp = hf_hub_download(
repo_id=config_dataset.path, repo_id=d.path,
repo_type="dataset", repo_type="dataset",
filename=config_dataset.data_files, filename=d.data_files,
) )
elif isinstance(config_dataset.data_files, list): elif isinstance(d.data_files, list):
fp = [] fp = []
for file in config_dataset.data_files: for file in d.data_files:
fp.append( fp.append(
hf_hub_download( hf_hub_download(
repo_id=config_dataset.path, repo_id=d.path,
repo_type="dataset", repo_type="dataset",
filename=file, filename=file,
) )
@@ -229,27 +228,21 @@ def load_tokenized_prepared_datasets(
"data_files must be either a string or list of strings" "data_files must be either a string or list of strings"
) )
ds = load_dataset( ds = load_dataset(
"json", "json", name=d.name, data_files=fp, streaming=False, split=None
name=config_dataset.name,
data_files=fp,
streaming=False,
split=None,
) )
if not ds: if not ds:
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if config_dataset.shards: if d.shards:
if "train" in ds: if "train" in ds:
ds = ds.shuffle(seed=seed)["train"].shard( ds = ds.shuffle(seed=seed)["train"].shard(
num_shards=config_dataset.shards, index=0 num_shards=d.shards, index=0
) )
else: else:
ds = ds.shuffle(seed=seed).shard( ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
num_shards=config_dataset.shards, index=0
)
d_base_type = d_prompt_style = None d_base_type = d_prompt_style = None
d_type = config_dataset.type d_type = d.type
if isinstance(d_type, str): if isinstance(d_type, str):
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
@@ -258,33 +251,115 @@ def load_tokenized_prepared_datasets(
ds = ds["train"] ds = ds["train"]
elif ( elif (
isinstance(ds, DatasetDict) isinstance(ds, DatasetDict)
and config_dataset.train_on_split and d.train_on_split
and config_dataset.train_on_split in ds and d.train_on_split in ds
): ):
ds = ds[config_dataset.train_on_split] ds = ds[d.train_on_split]
elif isinstance(ds, DatasetDict): elif isinstance(ds, DatasetDict):
raise ValueError( raise ValueError(
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `" f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
)
if (
"input_ids" in ds.features
and "attention_mask" in ds.features
and "labels" in ds.features
):
# dataset is already tokenized, just drop it straight in
datasets.append(ds)
elif isinstance(d.type, DictDefault):
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif ds_strategy := load(d.type, tokenizer, cfg, d):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "explainchoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceExplainPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "concisechoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
MultipleChoiceConcisePrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "summarizetldr":
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
SummarizeTLDRPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy(
JeopardyPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy(
AlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(
GPTeacherPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "reflection":
ds_strategy = AlpacaReflectionPTStrategy(
ReflectAlpacaPrompter(d_prompt_style),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
suffix = ""
if ":load_" in d.type:
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
raise ValueError(
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
) )
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
datasets.append(dataset_wrapper)
prompters.append(dataset_prompter)
LOG.info("merging datasets") LOG.info("merging datasets")
dataset = concatenate_datasets(datasets) dataset = concatenate_datasets(datasets)
if len(datasets) > 1: if len(datasets) > 1:
LOG.info("shuffle merged datasets") LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed) dataset = dataset.shuffle(seed=seed)
if cfg.local_rank == 0: if cfg.local_rank == 0 and cfg.dataset_prepared_path:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
@@ -295,14 +370,14 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
return dataset, prompters return dataset
def load_prepare_datasets( def load_prepare_datasets(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
cfg, cfg,
default_dataset_prepared_path, default_dataset_prepared_path,
) -> Tuple[Dataset, Dataset, List[Any]]: ) -> Tuple[Dataset, Dataset]:
max_packed_sequence_len = ( max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
) )
@@ -311,7 +386,6 @@ def load_prepare_datasets(
) # make sure we don't accidentally set it larger than sequence_len ) # make sure we don't accidentally set it larger than sequence_len
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
prompters = []
if cfg.max_packed_sequence_len is not None: if cfg.max_packed_sequence_len is not None:
# see if we can go ahead and load the stacked dataset # see if we can go ahead and load the stacked dataset
seed = f"@{str(cfg.seed)}" if cfg.seed else "" seed = f"@{str(cfg.seed)}" if cfg.seed else ""
@@ -367,7 +441,7 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
else: else:
dataset, prompters = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -409,7 +483,7 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
dataset, prompters = load_tokenized_prepared_datasets( dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path tokenizer, cfg, default_dataset_prepared_path
) )
@@ -460,124 +534,7 @@ def load_prepare_datasets(
train_dataset = dataset train_dataset = dataset
eval_dataset = None eval_dataset = None
return train_dataset, eval_dataset, prompters return train_dataset, eval_dataset
def get_dataset_wrapper(
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
):
dataset_wrapper = None
dataset_prompter = None
if (
"input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
# dataset is already tokenized, just drop it straight in
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = dataset
elif isinstance(config_dataset.type, DictDefault):
ds_strategy = load(
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
)
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
dataset_prompter = UnsupportedPrompter()
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
elif d_base_type == "alpaca":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "explainchoice":
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "concisechoice":
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "summarizetldr":
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "jeopardy":
dataset_prompter = JeopardyPrompter(d_prompt_style)
ds_strategy = JeopardyPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "oasst":
dataset_prompter = AlpacaPrompter(d_prompt_style)
ds_strategy = OpenAssistantPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "gpteacher":
dataset_prompter = GPTeacherPrompter(d_prompt_style)
ds_strategy = GPTeacherPromptTokenizingStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
elif d_base_type == "reflection":
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
ds_strategy = AlpacaReflectionPTStrategy(
dataset_prompter,
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
dataset_wrapper = ds_wrapper
else:
suffix = ""
if ":load_" in config_dataset.type:
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
LOG.error(
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
)
raise ValueError(
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
)
return dataset_wrapper, dataset_prompter
def encode_pretraining( def encode_pretraining(

View File

@@ -3,9 +3,6 @@ import hashlib
import itertools import itertools
import logging import logging
import math import math
import time
from queue import Queue
from threading import Thread
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
import numba import numba
@@ -152,8 +149,6 @@ class MultipackDistributedDataloader:
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1, sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1, device_count: int = 1,
prefetch_max: int = 1000,
num_epochs: int = 1,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
@@ -172,7 +167,6 @@ class MultipackDistributedDataloader:
self.seq_max_length = seq_max_length self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn self.collate_fn = collate_fn
self.num_epochs = num_epochs
self.num_replicas = 1 self.num_replicas = 1
self.rank = 0 self.rank = 0
@@ -183,44 +177,6 @@ class MultipackDistributedDataloader:
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count self.device_count = device_count
# maxsize is maximum number of samples in queue
self.prefetch_max = prefetch_max
self.queue: Queue = Queue(maxsize=prefetch_max)
self.thread = None
def _worker(self):
LOG.info(
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
)
for epoch in range(self.num_epochs):
for sample in self._internal_batch_generator():
while True:
if self.queue.full():
time.sleep(1)
else:
break
self.queue.put(sample)
# stop the queue when epoch is done
self.queue.put(None)
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
if self.thread is None:
self.thread = Thread(target=self._worker, daemon=True)
self.thread.start()
while True:
item = self.queue.get()
if item is None:
break
yield item
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
LOG.info("generating packed batches") LOG.info("generating packed batches")
if self.sampler: if self.sampler:
@@ -250,7 +206,11 @@ class MultipackDistributedDataloader:
return batches, totseqs return batches, totseqs
def _internal_batch_generator(self): def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
all_batches, _ = self.generate_batches(set_stats=True) all_batches, _ = self.generate_batches(set_stats=True)
features = self.dataset.features.keys() features = self.dataset.features.keys()
len_remaining = self._len_est() len_remaining = self._len_est()

View File

@@ -7,7 +7,6 @@ from typing import Optional, Tuple # noqa: F401
import bitsandbytes as bnb import bitsandbytes as bnb
import torch import torch
import transformers import transformers
import transformers.utils.bitsandbytes
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from peft import PeftConfig, prepare_model_for_kbit_training from peft import PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear from peft.tuners.lora import QuantLinear
@@ -32,7 +31,7 @@ LOG = logging.getLogger("axolotl")
def load_model_config(cfg): def load_model_config(cfg):
model_config_name = cfg.base_model_config or cfg.base_model model_config_name = cfg.base_model_config or cfg.base_model
trust_remote_code = cfg.trust_remote_code is True trust_remote_code: bool = False or cfg.trust_remote_code
return AutoConfig.from_pretrained( return AutoConfig.from_pretrained(
model_config_name, trust_remote_code=trust_remote_code model_config_name, trust_remote_code=trust_remote_code
) )
@@ -73,6 +72,11 @@ def load_tokenizer(cfg):
# set a pad_token, but use eos_token so we don't add a new token # set a pad_token, but use eos_token so we don't add a new token
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
tokenizer.add_special_tokens({"pad_token": "[PAD]"}) tokenizer.add_special_tokens({"pad_token": "[PAD]"})
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -94,11 +98,6 @@ def load_tokenizer(cfg):
] ]
) )
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
return tokenizer return tokenizer
@@ -222,7 +221,7 @@ def load_model(
load_in_4bit=True, load_in_4bit=True,
llm_int8_threshold=6.0, llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False, llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16, bnb_4bit_compute_dtype=cfg.torch_dtype,
bnb_4bit_use_double_quant=True, bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4", bnb_4bit_quant_type="nf4",
) )
@@ -236,12 +235,7 @@ def load_model(
model_kwargs["use_flash_attention_2"] = True model_kwargs["use_flash_attention_2"] = True
try: try:
if ( if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
cfg.is_llama_derived_model
and not cfg.trust_remote_code
and not cfg.gptq
and not cfg.tensor_parallel
):
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config_kwargs = {} config_kwargs = {}
@@ -258,20 +252,6 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
) )
if cfg.flash_attention and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_mlp_with_swiglu,
replace_llama_qkv_with_fused,
)
if cfg.flash_attn_fuse_mlp:
LOG.info("patching with SwiGLU")
replace_llama_mlp_with_swiglu(model)
if cfg.flash_attn_fuse_qkv:
LOG.info("patching with fused QKV")
replace_llama_qkv_with_fused(model)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
# This is a WIP, still an issue with the backward pass # This is a WIP, still an issue with the backward pass
# RuntimeError: grad can be implicitly created only for scalar outputs # RuntimeError: grad can be implicitly created only for scalar outputs
@@ -307,7 +287,7 @@ def load_model(
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs, **model_kwargs,
) )
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel: elif model_type and not cfg.trust_remote_code:
if cfg.gptq: if cfg.gptq:
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
base_model, base_model,
@@ -322,17 +302,6 @@ def load_model(
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
elif cfg.tensor_parallel:
model_kwargs.pop("device_map")
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
low_cpu_mem_usage=True,
offload_state_dict=True,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
base_model, base_model,
@@ -383,18 +352,15 @@ def load_model(
**model_kwargs, **model_kwargs,
) )
try: embeddings_len = (
embeddings_len = ( math.ceil(len(tokenizer) / 32) * 32
math.ceil(len(tokenizer) / 32) * 32 if cfg.resize_token_embeddings_to_32x
if cfg.resize_token_embeddings_to_32x else len(tokenizer)
else len(tokenizer) )
) if model.get_input_embeddings().num_embeddings < embeddings_len:
if model.get_input_embeddings().num_embeddings < embeddings_len: model.resize_token_embeddings(embeddings_len)
model.resize_token_embeddings(embeddings_len) else:
else: model.tie_weights()
model.tie_weights()
except NotImplementedError:
LOG.warning("`resize_token_embeddings` not implemented on model")
if ( if (
hasattr(model.config, "max_position_embeddings") hasattr(model.config, "max_position_embeddings")
@@ -406,20 +372,6 @@ def load_model(
) )
model.config.max_position_embeddings = cfg.sequence_len model.config.max_position_embeddings = cfg.sequence_len
if (
hasattr(model.config, "bos_token_id")
and model.config.bos_token_id
and model.config.bos_token_id != tokenizer.bos_token_id
):
model.config.bos_token_id = tokenizer.bos_token_id
if (
hasattr(model.config, "eos_token_id")
and model.config.eos_token_id
and model.config.eos_token_id != tokenizer.eos_token_id
):
model.config.eos_token_id = tokenizer.eos_token_id
if model.device.type == "cuda": if model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device) log_gpu_memory_usage(LOG, "after model load", model.device)
@@ -430,7 +382,7 @@ def load_model(
if model_config.model_type == "btlm": if model_config.model_type == "btlm":
# don't upcast lm_head for btlm # don't upcast lm_head for btlm
continue continue
if "lm_head" in name or "embed_tokens" in name: if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)
@@ -497,12 +449,7 @@ def load_adapter(model, cfg, adapter, inference=False):
if adapter is None: if adapter is None:
return model, None return model, None
if hasattr(model, "enable_input_require_grads"): if hasattr(model, "enable_input_require_grads"):
try: model.enable_input_require_grads()
model.enable_input_require_grads()
except NotImplementedError:
LOG.warning("enable_input_require_grads not implemented on model")
if adapter == "qlora" and cfg.tensor_parallel:
model, _ = load_tp_qlora(model)
if adapter in ["lora", "qlora"]: if adapter in ["lora", "qlora"]:
return load_lora(model, cfg, inference=inference) return load_lora(model, cfg, inference=inference)
if adapter == "llama-adapter": if adapter == "llama-adapter":
@@ -540,11 +487,7 @@ def find_all_linear_names(model):
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear) cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
lora_module_names = set() lora_module_names = set()
for name, module in model.named_modules(): for name, module in model.named_modules():
if ( if isinstance(module, cls) or "Linear" in module.__class__.__name__:
isinstance(module, cls)
or "Linear" in module.__class__.__name__
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
):
names = name.split(".") names = name.split(".")
lora_module_names.add(names[0] if len(names) == 1 else names[-1]) lora_module_names.add(names[0] if len(names) == 1 else names[-1])
@@ -554,25 +497,6 @@ def find_all_linear_names(model):
return list(lora_module_names) return list(lora_module_names)
def load_tp_qlora(model):
from transformers.utils.bitsandbytes import replace_with_bnb_linear
model = replace_with_bnb_linear(
model,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
),
)
model.is_loaded_in_4bit = True
return model, None
def load_lora(model, cfg, inference=False): def load_lora(model, cfg, inference=False):
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]] # type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]

View File

@@ -34,5 +34,6 @@ def check_example_labels(example, tokenizer, text_only=False):
delimiter = "" if text_only else " " delimiter = "" if text_only else " "
LOG.info(delimiter.join(colored_tokens)) LOG.info(delimiter.join(colored_tokens))
LOG.info("\n\n\n") LOG.info("\n\n\n")
print(" ".join(colored_tokens))
return " ".join(colored_tokens) return " ".join(colored_tokens)

View File

@@ -1,19 +1,39 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import importlib
import logging import logging
import math import math
import os import os
import sys
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial from functools import partial
from typing import List from pathlib import Path
from typing import List, Optional, Union
import numpy as np import numpy as np
import torch import torch
import torch.cuda import torch.cuda
import torch.distributed as dist import torch.distributed as dist
from datasets import set_caching_enabled import transformers
from torch.utils.data import DistributedSampler, RandomSampler from datasets import Dataset, set_caching_enabled
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import (
DataLoader,
DistributedSampler,
RandomSampler,
SequentialSampler,
)
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import SequentialDistributedSampler
from axolotl.core.trainer_builder import HFCausalTrainerBuilder from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
log_prediction_callback_factory,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.distributed import ( from axolotl.utils.distributed import (
@@ -22,6 +42,7 @@ from axolotl.utils.distributed import (
reduce_and_broadcast, reduce_and_broadcast,
zero_first, zero_first,
) )
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -88,6 +109,269 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
return weighted_cross_entropy(logits, labels, weights) return weighted_cross_entropy(logits, labels, weights)
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def __init__(self, *args, bench_data_collator=None, **kwargs):
self.bench_data_collator = bench_data_collator
super().__init__(*args, **kwargs)
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if (
self.args.world_size > 1
and self.args.sample_packing
and self.args.eval_sample_packing is not False
):
return SequentialDistributedSampler(
eval_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
batch_size=self.args.per_device_eval_batch_size,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_train_dataloader()
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> Union[DataLoader, MultipackDistributedDataloader]:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
pct_start = num_warmup_steps / num_training_steps
self.lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return self.lr_scheduler
class ReLoRATrainer(AxolotlTrainer):
"""
Trainer subclass that uses the OneCycleLR scheduler
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.lr_scheduler = None
def create_scheduler(
self,
num_training_steps: int,
optimizer: Optional[torch.optim.Optimizer] = None,
):
optimizer = self.optimizer if optimizer is None else optimizer
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
if self.args.relora_steps:
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
warmup_steps,
)
else:
self.lr_scheduler = lr_scheduler
return self.lr_scheduler
def add_position_ids(sample): def add_position_ids(sample):
sample_len = len(sample["input_ids"]) sample_len = len(sample["input_ids"])
sample["position_ids"] = torch.arange(len(sample["input_ids"])) sample["position_ids"] = torch.arange(len(sample["input_ids"]))
@@ -138,9 +422,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
) )
# Phi doesn't want the attention_mask feature when training # Phi doesn't want the attention_mask feature when training
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or ( if "CodeGenTokenizer" in tokenizer.__class__.__name__:
cfg.is_mistral_derived_model and cfg.flash_attention
):
train_dataset = train_dataset.remove_columns("attention_mask") train_dataset = train_dataset.remove_columns("attention_mask")
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.remove_columns("attention_mask") eval_dataset = eval_dataset.remove_columns("attention_mask")
@@ -216,7 +498,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
packing_efficiency_estimate=cfg.sample_packing_eff_est, packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.micro_batch_size, sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)), device_count=int(os.environ.get("WORLD_SIZE", 1)),
num_epochs=cfg.num_epochs,
) )
data_loader_len = data_loader.len_w_stats() data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency() actual_eff = data_loader.efficiency()
@@ -266,8 +547,242 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
elif cfg.deepspeed: elif cfg.deepspeed:
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer) warmup_steps = (
trainer_builder.train_dataset = train_dataset cfg.warmup_steps
trainer_builder.eval_dataset = eval_dataset if cfg.warmup_steps is not None
else min(int(0.03 * total_num_steps), 100)
)
logging_steps = (
cfg.logging_steps
if cfg.logging_steps is not None
else max(min(int(0.005 * total_num_steps), 10), 1)
)
return trainer_builder.build(total_num_steps) training_arguments_kwargs = {}
if cfg.bf16 == "full":
training_arguments_kwargs["bf16_full_eval"] = True
else:
training_arguments_kwargs["bf16"] = cfg.bf16
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
training_arguments_kwargs["tf32"] = cfg.tf32
training_arguments_kwargs["warmup_steps"] = warmup_steps
training_arguments_kwargs["logging_steps"] = logging_steps
if cfg.seed:
training_arguments_kwargs["seed"] = cfg.seed
if cfg.gradient_checkpointing:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
if cfg.fsdp:
training_arguments_kwargs["fsdp"] = cfg.fsdp
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
# deepspeed
if cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
if cfg.adam_beta1:
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
if cfg.adam_beta2:
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
if cfg.adam_epsilon:
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
if cfg.hub_model_id:
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.hub_strategy:
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
if cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
if cfg.eval_steps:
training_arguments_kwargs["evaluation_strategy"] = "steps"
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
elif cfg.evaluation_strategy:
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
elif cfg.val_set_size == 0:
# no eval set, so don't eval
training_arguments_kwargs["evaluation_strategy"] = "no"
else:
# we have an eval set, but no steps defined, default to use epoch
training_arguments_kwargs["evaluation_strategy"] = "epoch"
if cfg.save_steps:
training_arguments_kwargs["save_strategy"] = "steps"
training_arguments_kwargs["save_steps"] = cfg.save_steps
elif cfg.save_strategy:
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
else:
# default to saving each epoch if not defined
training_arguments_kwargs["save_strategy"] = "epoch"
if cfg.do_bench_eval:
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
if cfg.bench_dataset:
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
if cfg.metric_for_best_model:
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
if cfg.greater_is_better:
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
if cfg.torch_compile:
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
else:
import torch._dynamo # pylint: disable=redefined-outer-name
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
True
)
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
if cfg.torch_compile_backend:
training_arguments_kwargs[
"torch_compile_backend"
] = cfg.torch_compile_backend
# DDP Config
if cfg.ddp_timeout:
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
if cfg.ddp_bucket_cap_mb:
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
if cfg.ddp_broadcast_buffers is not None:
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps if cfg.max_steps else -1,
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate,
output_dir=cfg.output_dir,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
load_best_model_at_end=(
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
and cfg.val_set_size > 0
and cfg.save_steps
and cfg.eval_steps
and cfg.save_steps % cfg.eval_steps == 0
)
or False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
lr_scheduler_type=cfg.lr_scheduler
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
eval_sample_packing=cfg.eval_sample_packing,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
relora_steps=cfg.relora_steps,
relora_warmup_steps=cfg.relora_warmup_steps,
**training_arguments_kwargs,
)
trainer_kwargs = {}
if cfg.optimizer == "adamw_anyprecision":
if Path(cfg.torchdistx_path).exists():
sys.path.append(cfg.torchdistx_path)
importlib.import_module("torchdistx")
callbacks = []
callbacks.append(GPUStatsCallback(cfg))
callbacks.append(EvalFirstStepCallback)
if cfg.relora_steps:
callbacks.append(ReLoRACallback(cfg))
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = {
"padding": True, # True/"longest" is the default
}
if cfg.pad_to_sequence_len:
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
cfg.sequence_len / 64
)
else:
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if cfg.is_llama_derived_model and cfg.landmark_attention:
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
set_model_mem_id,
)
set_model_mem_id(model, tokenizer)
LOG.info("Adding landmark attention tokens to dataset")
for dataset in [train_dataset, eval_dataset]:
dataset = dataset.map(
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
batched=False,
num_proc=32,
)
trainer_cls = AxolotlTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
trainer_cls = OneCycleLRSchedulerTrainer
elif cfg.relora_steps:
trainer_cls = ReLoRATrainer
trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
bench_data_collator=transformers.DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,
),
callbacks=callbacks,
**trainer_kwargs,
)
if cfg.use_wandb and cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg))
if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
# TODO on_save callback to sync checkpoints to GCP/AWS in background
if cfg.early_stopping_patience:
early_stop_cb = EarlyStoppingCallback(
cfg.early_stopping_patience,
)
trainer.add_callback(early_stop_cb)
return trainer

View File

@@ -1,72 +0,0 @@
"""
E2E tests for lora llama
"""
import logging
import os
import tempfile
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestFusedLlama(unittest.TestCase):
"""
Test case for Llama models using Fused layers
"""
def test_fft_packing(self):
# pylint: disable=duplicate-code
output_dir = tempfile.mkdtemp()
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": output_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(output_dir) / "pytorch_model.bin").exists()

View File

@@ -29,6 +29,7 @@ class TestLoraLlama(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
@@ -71,6 +72,7 @@ class TestLoraLlama(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "JackFram/llama-68m", "base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,
"sample_packing": True, "sample_packing": True,
@@ -115,6 +117,7 @@ class TestLoraLlama(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ", "base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
"model_type": "AutoModelForCausalLM", "model_type": "AutoModelForCausalLM",
"tokenizer_type": "LlamaTokenizer", "tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024, "sequence_len": 1024,

View File

@@ -31,6 +31,7 @@ class TestMistral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"load_in_8bit": True, "load_in_8bit": True,
@@ -76,6 +77,7 @@ class TestMistral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True, "flash_attention": True,
"sequence_len": 1024, "sequence_len": 1024,
"val_set_size": 0.1, "val_set_size": 0.1,

View File

@@ -31,6 +31,7 @@ class TestMistral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 1024, "sequence_len": 1024,
@@ -77,6 +78,7 @@ class TestMistral(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "openaccess-ai-collective/tiny-mistral", "base_model": "openaccess-ai-collective/tiny-mistral",
"base_model_config": "openaccess-ai-collective/tiny-mistral",
"flash_attention": True, "flash_attention": True,
"sample_packing": True, "sample_packing": True,
"sequence_len": 1024, "sequence_len": 1024,

View File

@@ -27,6 +27,7 @@ class TestPhi(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "microsoft/phi-1_5", "base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM", "model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",
@@ -70,6 +71,7 @@ class TestPhi(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"base_model": "microsoft/phi-1_5", "base_model": "microsoft/phi-1_5",
"base_model_config": "microsoft/phi-1_5",
"trust_remote_code": True, "trust_remote_code": True,
"model_type": "MixFormerSequentialForCausalLM", "model_type": "MixFormerSequentialForCausalLM",
"tokenizer_type": "AutoTokenizer", "tokenizer_type": "AutoTokenizer",

File diff suppressed because one or more lines are too long

View File

@@ -1,46 +0,0 @@
"""
Test classes for checking functionality of the cfg normalization
"""
import unittest
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
class NormalizeConfigTestCase(unittest.TestCase):
"""
test class for normalize_config checks
"""
def _get_base_cfg(self):
return DictDefault(
{
"base_model": "JackFram/llama-68m",
"base_model_config": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"num_epochs": 1,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
}
)
def test_lr_as_float(self):
cfg = (
self._get_base_cfg()
| DictDefault( # pylint: disable=unsupported-binary-operation
{
"learning_rate": "5e-5",
}
)
)
normalize_config(cfg)
assert cfg.learning_rate == 0.00005
def test_base_model_config_set_when_empty(self):
cfg = self._get_base_cfg()
del cfg.base_model_config
normalize_config(cfg)
assert cfg.base_model_config == cfg.base_model

View File

@@ -90,73 +90,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
strat.tokenize_prompt(conversation) strat.tokenize_prompt(conversation)
assert "assistant turn has empty text" in self._caplog.records[1].message assert "assistant turn has empty text" in self._caplog.records[1].message
def test_sharegpt_warnings_turns(self):
conversation = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
]
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
strat.tokenize_prompt(conversation)
assert (
"Role did not alternate between turns (gpt and human)"
in self._caplog.records[0].message
)
def test_sharegpt_changes_roles(self):
conversation = {
"roles": ["USER", "CHARACTER"],
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
],
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
def test_sharegpt_assistant_label_ignore(self):
conversation = {
"roles": ["user", "assistant"],
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
],
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100
def test_no_sys_prompt(self): def test_no_sys_prompt(self):
""" """
tests the interface between the user and assistant parts tests the interface between the user and assistant parts

View File

@@ -565,87 +565,3 @@ class ValidationTest(unittest.TestCase):
) )
validate_config(cfg) validate_config(cfg)
def test_eval_table_size_conflict_eval_packing(self):
cfg = DictDefault(
{
"sample_packing": True,
"eval_table_size": 100,
}
)
with pytest.raises(
ValueError, match=r".*Please set 'eval_sample_packing' to false.*"
):
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": True,
"eval_sample_packing": False,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": False,
"eval_table_size": 100,
}
)
validate_config(cfg)
cfg = DictDefault(
{
"sample_packing": True,
"eval_table_size": 100,
"eval_sample_packing": False,
}
)
validate_config(cfg)
def test_load_in_x_bit_without_adapter(self):
cfg = DictDefault(
{
"load_in_4bit": True,
}
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"load_in_8bit": True,
}
)
with pytest.raises(
ValueError,
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
):
validate_config(cfg)
cfg = DictDefault(
{
"load_in_4bit": True,
"adapter": "qlora",
}
)
validate_config(cfg)
cfg = DictDefault(
{
"load_in_8bit": True,
"adapter": "lora",
}
)
validate_config(cfg)