Compare commits
1 Commits
ia3-peft
...
sharegpt-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4d84d56d5 |
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -25,11 +25,6 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
11
.github/workflows/main.yml
vendored
11
.github/workflows/main.yml
vendored
@@ -23,11 +23,6 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -51,7 +46,6 @@ jobs:
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
@@ -74,11 +68,6 @@ jobs:
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
@@ -12,4 +12,4 @@ generated-members=numpy.*, torch.*
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
too-many-boolean-expressions,
|
||||
too-many-nested-blocks,
|
||||
|
||||
316
README.md
316
README.md
@@ -23,10 +23,9 @@ Features:
|
||||
- [Supported Features](#axolotl-supports)
|
||||
- [Quickstart](#quickstart-)
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Docker Installation](#environment)
|
||||
- [Conda/Pip venv Installation](#condapip-venv)
|
||||
- [LambdaLabs Installation](#lambdalabs)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
@@ -51,7 +50,7 @@ Features:
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
</p>
|
||||
<p>
|
||||
Go ahead and Axolotl questions!!
|
||||
Go ahead and axolotl questions!!
|
||||
</p>
|
||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||
@@ -96,14 +95,14 @@ accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--peft_model_dir="./lora-out"
|
||||
--lora_model_dir="./lora-out"
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Environment
|
||||
|
||||
#### Docker
|
||||
- Docker
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
@@ -115,12 +114,12 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
#### Conda/Pip venv
|
||||
- Conda/Pip venv
|
||||
1. Install python >=**3.9**
|
||||
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install Axolotl along with python dependencies
|
||||
3. Install axolotl along with python dependencies
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
@@ -131,7 +130,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```
|
||||
Get the token at huggingface.co/settings/tokens
|
||||
|
||||
#### LambdaLabs
|
||||
- LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
@@ -175,8 +174,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```
|
||||
</details>
|
||||
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
- Windows: Please use WSL or Docker!
|
||||
|
||||
### Dataset
|
||||
|
||||
@@ -297,24 +295,25 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
For a dataset that is preprocessed for instruction purposes:
|
||||
|
||||
```json
|
||||
{"instruction": "...", "output": "..."}
|
||||
```
|
||||
|
||||
You can use this example in your YAML config:
|
||||
|
||||
Using yaml. Example:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
field_system: system
|
||||
format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: |-
|
||||
User: {instruction}<|end_of_turn|>
|
||||
Assistant:
|
||||
format: |-
|
||||
User: {instruction}
|
||||
{input}<|end_of_turn|>
|
||||
Assistant:
|
||||
```
|
||||
|
||||
Using file:
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
#### How to use your custom pretokenized dataset
|
||||
|
||||
- Do not pass a `type:`
|
||||
@@ -384,10 +383,10 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- lora
|
||||
```yaml
|
||||
adapter: lora # qlora or leave blank for full finetune
|
||||
peft_r: 8
|
||||
peft_alpha: 16
|
||||
peft_dropout: 0.05
|
||||
peft_target_modules:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
```
|
||||
@@ -397,15 +396,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
<summary>All yaml options</summary>
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# This can also be a relative path to a model on disk
|
||||
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# this can also be a relative path to a model on disk
|
||||
base_model: ./llama-7b-hf
|
||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
# you can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
@@ -420,24 +419,23 @@ trust_remote_code:
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
# resize the model embeddings when new tokens are added to multiples of 32
|
||||
# this is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
|
||||
# Used to identify which the model is based on
|
||||
# used to identify which the model is based on
|
||||
is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
# whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
gptq_groupsize: 128 # group size
|
||||
gptq_model_v1: false # v1 or v2
|
||||
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# Use bitsandbytes 4 bit
|
||||
# use bitsandbytes 4 bit
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
@@ -451,9 +449,9 @@ tf32: true # require >=ampere
|
||||
bfloat16: true # require >=ampere
|
||||
float16: true
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
@@ -463,18 +461,17 @@ datasets:
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
|
||||
# Custom user prompt
|
||||
# custom user prompt
|
||||
- path: repo
|
||||
type:
|
||||
# The below are defaults. only set what's needed.
|
||||
# the below are defaults. only set what's needed.
|
||||
system_prompt: ""
|
||||
system_format: "{system}"
|
||||
field_system: system
|
||||
field_instruction: instruction
|
||||
field_input: input
|
||||
field_output: output
|
||||
field_output: input
|
||||
|
||||
# Customizable to be single line or multi-line
|
||||
# customizable to be single line or multi-line
|
||||
system_format: "{system}"
|
||||
# 'format' can include {input}
|
||||
format: |-
|
||||
User: {instruction} {input}
|
||||
@@ -482,13 +479,13 @@ datasets:
|
||||
# 'no_input_format' cannot include {input}
|
||||
no_input_format: "{instruction} "
|
||||
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
# for completions datsets, uses the provided field if not `text`
|
||||
field:
|
||||
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
# push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
@@ -498,8 +495,8 @@ hub_model_id: # repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
val_set_size: 0.04
|
||||
@@ -508,38 +505,34 @@ dataset_shard_num:
|
||||
# Index of shard to use for whole dataset
|
||||
dataset_shard_idx:
|
||||
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# the maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
# pad inputs so each step uses constant sized buffers
|
||||
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
# Max sequence length to concatenate training samples together up to
|
||||
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# FutureWarning: This will soon be DEPRECATED
|
||||
max_packed_sequence_len: 1024
|
||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
||||
# set to 'false' if getting errors during eval with sample_packing on.
|
||||
eval_sample_packing:
|
||||
# You can set these packing optimizations AFTER starting a training at least once.
|
||||
# you can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||
peft_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
peft_r: 8
|
||||
peft_alpha: 16
|
||||
peft_dropout: 0.05
|
||||
peft_target_modules:
|
||||
# if you already have a lora model trained that you want to load, put that here
|
||||
# lora hyperparameters
|
||||
lora_model_dir:
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_modules:
|
||||
- q_proj
|
||||
- v_proj
|
||||
# - k_proj
|
||||
@@ -547,49 +540,36 @@ peft_target_modules:
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
peft_target_linear: # if true, will target all linear layers
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
peft_modules_to_save:
|
||||
lora_target_linear: # if true, will target all linear layers
|
||||
lora_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
# Once you complete training, the model will be saved to the following directory.
|
||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||
lora_out_dir:
|
||||
peft_fan_in_fan_out: false
|
||||
peft_feedforward_modules: # ffn modules for IA3, for llama down projection
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# ReLoRA configuration
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # Number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # Number of per-restart warmup steps
|
||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # number of per-restart warmup steps
|
||||
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_project: # your wandb project name
|
||||
wandb_entity: # a wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_run_id: # Set the name of your wandb run
|
||||
wandb_run_id: # set the name of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
# where to save the finished model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
# Whether to use torch.compile and which backend to use
|
||||
# whether to use torch.compile and which backend to use
|
||||
torch_compile: # bool
|
||||
torch_compile_backend: # Optional[str]
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
# training hyperparameters
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 3
|
||||
@@ -597,47 +577,44 @@ warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
eval_steps: # Leave empty to eval at each epoch
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
save_strategy: # set to `no` to skip checkpoint saves
|
||||
save_steps: # leave empty to save at each epoch
|
||||
eval_steps: # leave empty to eval at each epoch
|
||||
save_total_limit: # checkpoints saved at a time
|
||||
max_steps:
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# Group similarly sized data to minimize padding.
|
||||
# May be slower to start, as it must download and sort the entire dataset.
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
# group similarly sized data to minimize padding
|
||||
# may be slower to start, as it must download and sort the entire dataset
|
||||
# note that training loss may have an oscillating pattern with this enabled
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
# specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
|
||||
# For one_cycle optim
|
||||
lr_div_factor: # Learning rate div factor
|
||||
# for one_cycle optim
|
||||
lr_div_factor: # learning rate div factor
|
||||
|
||||
# For log_sweep optim
|
||||
# for log_sweep optim
|
||||
log_sweep_min_lr:
|
||||
log_sweep_max_lr:
|
||||
|
||||
# Specify optimizer
|
||||
# specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
#
|
||||
@@ -663,7 +640,7 @@ log_sweep_max_lr:
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
optimizer:
|
||||
# Specify weight decay
|
||||
# specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
@@ -672,56 +649,49 @@ adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
noisy_embedding_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
# whether to bettertransformers
|
||||
flash_optimum:
|
||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
# Whether to use scaled-dot-product attention
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# LLaMA only
|
||||
# llama only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# Resume from a specific checkpoint dir
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||
# be careful with this being turned on between different models
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
# don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
# Add or change special tokens.
|
||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
||||
# add or change special tokens
|
||||
special_tokens:
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
|
||||
# Add extra tokens.
|
||||
# add extra tokens
|
||||
tokens:
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
# Deepspeed config path. e.g., deepspeed/zero3.json
|
||||
# Deepspeed config path
|
||||
deepspeed:
|
||||
|
||||
# Advanced DDP Arguments
|
||||
@@ -747,66 +717,6 @@ strict:
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary> Understanding of batch size and gradient accumulation steps </summary>
|
||||
<br/>
|
||||
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
|
||||
|
||||
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
|
||||
|
||||
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
|
||||
|
||||
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
|
||||
|
||||
**Example 1:**
|
||||
Micro batch size: 3
|
||||
Gradient accumulation steps: 2
|
||||
Number of GPUs: 3
|
||||
Total batch size = 3 * 2 * 3 = 18
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|----------------|----------------|----------------|
|
||||
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
|
||||
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (accumulate) | → (accumulate) | → (accumulate) |
|
||||
|----------------|----------------|----------------|
|
||||
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
|
||||
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
|
||||
```
|
||||
|
||||
**Example 2:**
|
||||
Micro batch size: 2
|
||||
Gradient accumulation steps: 1
|
||||
Number of GPUs: 3
|
||||
Total batch size = 2 * 1 * 3 = 6
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|-----------|-----------|-----------|
|
||||
| S1, S2 | S3, S4 | S5, S6 |
|
||||
| e1, e2 | e3, e4 | e5, e6 |
|
||||
|-----------|-----------|-----------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Train
|
||||
|
||||
Run
|
||||
@@ -870,7 +780,7 @@ Pass the appropriate flag to the train command:
|
||||
|
||||
- Pretrained LORA:
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --peft_model_dir="./lora-output-dir"
|
||||
python -m axolotl.cli.inference examples/your_config.yml --lora_model_dir="./lora-output-dir"
|
||||
```
|
||||
- Full weights finetune:
|
||||
```bash
|
||||
@@ -882,16 +792,12 @@ Pass the appropriate flag to the train command:
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
|
||||
|
||||
### Merge LORA to base
|
||||
|
||||
Add below flag to train command above
|
||||
|
||||
```bash
|
||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --peft_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
||||
```
|
||||
|
||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||
|
||||
@@ -5,9 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG CUDA="118"
|
||||
ENV BNB_CUDA_VERSION=$CUDA
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
@@ -19,7 +16,6 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
|
||||
@@ -14,7 +14,7 @@ ARG CUDA="118"
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
@@ -57,6 +57,11 @@ FROM base-builder
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
# recompile apex
|
||||
RUN python3 -m pip uninstall -y apex
|
||||
RUN git clone https://github.com/NVIDIA/apex
|
||||
RUN cd apex && python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
# Multipack
|
||||
|
||||
4k context, bsz =4,
|
||||
each character represents 256 tokens
|
||||
X represents a padding token
|
||||
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B ]
|
||||
C C C C C C C ]
|
||||
D D D D ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F ]
|
||||
[ G G G ]
|
||||
[ H H H H ]]
|
||||
|
||||
[[ I I I ]
|
||||
[ J J J ]
|
||||
[ K K K K K]
|
||||
[ L L L ]]
|
||||
```
|
||||
|
||||
after padding to longest input in each step
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B X X X X X X ]
|
||||
C C C C C C C X X X X ]
|
||||
D D D D X X X X X X X ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F X X X X ]
|
||||
[ G G G X X X X X ]
|
||||
[ H H H H X X X X ]]
|
||||
|
||||
[[ I I I X X ]
|
||||
[ J J J X X ]
|
||||
[ K K K K K ]
|
||||
[ L L L X X ]]
|
||||
```
|
||||
|
||||
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A B B B B B
|
||||
B C C C C C C C D D D D E E E E
|
||||
E E E E F F F F F G G G H H H H
|
||||
I I I J J J J K K K K K L L L X ]]
|
||||
```
|
||||
@@ -18,7 +18,7 @@ dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.01
|
||||
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
sample_packing: false
|
||||
|
||||
@@ -10,7 +10,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 16
|
||||
|
||||
@@ -20,7 +20,7 @@ sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -20,7 +20,7 @@ sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -20,7 +20,7 @@ sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -15,7 +15,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 16
|
||||
|
||||
@@ -22,7 +22,7 @@ dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 64
|
||||
|
||||
@@ -10,7 +10,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -9,7 +9,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 512
|
||||
max_packed_sequence_len:
|
||||
lora_r:
|
||||
|
||||
@@ -18,7 +18,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
sample_packing:
|
||||
lora_r: 8
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: meta-llama/Llama-2-7b-hf
|
||||
base_model_config: meta-llama/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./ia3-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: ia3
|
||||
peft_model_dir:
|
||||
peft_target_modules:
|
||||
- k_proj
|
||||
- v_proj
|
||||
- down_proj
|
||||
peft_feedforward_modules:
|
||||
- down_proj
|
||||
peft_fan_in_fan_out: false
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 2
|
||||
num_epochs: 5
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -20,7 +20,7 @@ sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -16,7 +16,7 @@ val_set_size: 0.01
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
@@ -20,7 +20,7 @@ sequence_len: 4096
|
||||
sample_packing: true
|
||||
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -16,8 +16,8 @@ val_set_size: 0.01
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
sample_packing:
|
||||
pad_to_sequence_len:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
@@ -30,7 +30,7 @@ micro_batch_size: 2
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000005
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
|
||||
@@ -19,8 +19,8 @@ adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
sample_packing: True
|
||||
pad_to_sequence_len: True
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
@@ -43,7 +43,7 @@ wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -9,7 +9,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r:
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r: 8
|
||||
|
||||
@@ -12,7 +12,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
sample_packing: true
|
||||
lora_r: 8
|
||||
|
||||
@@ -22,7 +22,7 @@ sample_packing: true
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
|
||||
@@ -22,7 +22,7 @@ sample_packing: false # not CURRENTLY compatible with LoRAs
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
lora_r: 64
|
||||
lora_alpha: 32
|
||||
lora_dropout: 0.05
|
||||
|
||||
@@ -13,7 +13,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len: 2048
|
||||
lora_r: 64
|
||||
|
||||
@@ -7,7 +7,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 512
|
||||
lora_r: 16
|
||||
lora_alpha: 32
|
||||
@@ -28,8 +28,8 @@ num_epochs: 3
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
bf16: True
|
||||
tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
|
||||
@@ -10,7 +10,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -8,7 +8,7 @@ datasets:
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
max_packed_sequence_len:
|
||||
lora_r: 8
|
||||
|
||||
@@ -20,7 +20,7 @@ dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
peft_model_dir:
|
||||
lora_model_dir:
|
||||
sequence_len: 8192
|
||||
max_packed_sequence_len:
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 370 KiB |
@@ -16,7 +16,7 @@ flash-attn>=2.3.0
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers>=0.0.22
|
||||
xformers
|
||||
optimum
|
||||
hf_transfer
|
||||
colorama
|
||||
|
||||
10
setup.py
10
setup.py
@@ -21,14 +21,6 @@ def parse_requirements():
|
||||
):
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
||||
if "torch==2.1.0" in _install_requires:
|
||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
||||
_install_requires.append(
|
||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
||||
)
|
||||
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
@@ -46,7 +38,7 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn>=2.3.0",
|
||||
"flash-attn>=2.2.1",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed",
|
||||
|
||||
@@ -194,7 +194,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
cfg.axolotl_config_path = config
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
|
||||
@@ -14,7 +14,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -16,11 +14,8 @@ from axolotl.cli import (
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -32,14 +27,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "--prepare_ds_only called without dataset_prepared_path set."
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
from datasets import Dataset, IterableDataset, Sequence, Value
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
@@ -42,11 +42,15 @@ class TokenizedPromptDataset(Dataset):
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 100
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
return (
|
||||
dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
|
||||
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -116,8 +116,6 @@ def flashattn_forward(
|
||||
attention_mask: [bsz, q_len]
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
@@ -153,13 +151,6 @@ def flashattn_forward(
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
if query_states.dtype == torch.float32:
|
||||
query_states = query_states.to(dtype=original_dtype)
|
||||
if key_states.dtype == torch.float32:
|
||||
key_states = key_states.to(dtype=original_dtype)
|
||||
if value_states.dtype == torch.float32:
|
||||
value_states = value_states.to(dtype=original_dtype)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
@@ -318,10 +309,6 @@ def flashattn_forward(
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
# handle conversion back for IA3
|
||||
if attn_output.dtype == torch.float32:
|
||||
attn_output = attn_output.to(dtype=original_dtype)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
@@ -515,7 +502,6 @@ def llama_model_forward(
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
original_dtype = hidden_states.dtype
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
@@ -573,10 +559,6 @@ def llama_model_forward(
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
# handle conversion back for IA3
|
||||
if hidden_states.dtype == torch.float32:
|
||||
hidden_states = hidden_states.to(dtype=original_dtype)
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
|
||||
import torch
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||
# pylint: disable=duplicate-code
|
||||
def noised_embed(orig_embed, noise_alpha, model):
|
||||
def new_func(input_ids):
|
||||
# during training, we add noise to the embedding
|
||||
# during generation, we don't add noise to the embedding
|
||||
if model.training:
|
||||
embed_init = orig_embed(input_ids)
|
||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
return orig_embed(input_ids)
|
||||
|
||||
return new_func
|
||||
|
||||
def post_init(orig_post_init):
|
||||
def new_func(self):
|
||||
orig_post_init(self)
|
||||
self.embed_tokens.forward = noised_embed(
|
||||
self.embed_tokens.forward, noise_alpha, self
|
||||
)
|
||||
|
||||
return new_func
|
||||
|
||||
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
|
||||
transformers.models.llama.modeling_llama.LlamaModel.post_init
|
||||
)
|
||||
@@ -14,9 +14,6 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
@@ -45,44 +42,6 @@ def replace_mistral_attn_with_flash_attn(
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
tgt_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: int = 4096,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for sliding window attention
|
||||
"""
|
||||
tensor = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
mask = torch.tril(tensor, diagonal=0)
|
||||
# make the mask banded to account for sliding window
|
||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||
mask = torch.log(mask).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
@@ -94,29 +53,11 @@ def _prepare_decoder_attention_mask(
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||
bsz=input_shape[0],
|
||||
tgt_len=input_shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
attention_mask = attention_mask + sliding_window_mask
|
||||
else:
|
||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self: OriginalMistralAttention,
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -150,41 +91,10 @@ def flashattn_forward(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
hasattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
if use_sliding_windows:
|
||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if (
|
||||
hasattr(self.config, "sliding_window")
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
):
|
||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value) if use_cache else None
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
@@ -210,13 +120,7 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -242,7 +146,6 @@ def flashattn_forward(
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
@@ -254,7 +157,6 @@ def flashattn_forward(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
@@ -289,7 +191,6 @@ def flashattn_forward(
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
"""
|
||||
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
|
||||
import torch
|
||||
import transformers.models.mistral.modeling_mistral
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
|
||||
# pylint: disable=duplicate-code
|
||||
def noised_embed(orig_embed, noise_alpha, model):
|
||||
def new_func(input_ids):
|
||||
# during training, we add noise to the embedding
|
||||
# during generation, we don't add noise to the embedding
|
||||
if model.training:
|
||||
embed_init = orig_embed(input_ids)
|
||||
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
|
||||
mag_norm = noise_alpha / torch.sqrt(dims)
|
||||
return embed_init + torch.zeros_like(embed_init).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
return orig_embed(input_ids)
|
||||
|
||||
return new_func
|
||||
|
||||
def post_init(orig_post_init):
|
||||
def new_func(self):
|
||||
orig_post_init(self)
|
||||
self.embed_tokens.forward = noised_embed(
|
||||
self.embed_tokens.forward, noise_alpha, self
|
||||
)
|
||||
|
||||
return new_func
|
||||
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.post_init
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Module for Alpaca prompt strategy classes"""
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompt_style = PromptStyle.CHAT.value
|
||||
if ds_cfg and "conversation" in ds_cfg:
|
||||
prompt_style = ds_cfg["conversation"]
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(prompt_style),
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and ds_cfg["skip"]:
|
||||
strat.skip_invalid = True
|
||||
return strat
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
@@ -54,13 +57,38 @@ def load_guanaco(tokenizer, cfg):
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
return NousShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
basic sharegpt strategy used by nous/ldj for input/output keyed data
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
def get_conversation_thread(self):
|
||||
return "conversation"
|
||||
|
||||
def map_conversation_thread(self, conversation):
|
||||
turns = []
|
||||
for turn in conversation:
|
||||
turns.append({"from": "human", "value": turn["input"]})
|
||||
turns.append({"from": "gpt", "value": turn["output"]})
|
||||
return turns
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
@@ -68,10 +96,11 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
def map_conversation_thread(self, conversation):
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
turns = [
|
||||
{"from": turn["role"], "value": turn["value"]} for turn in conversation
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
@@ -80,11 +109,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
def map_conversation_thread(self, conversation):
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
{"from": role_map[turn["role"]], "value": turn["text"]}
|
||||
for turn in conversation
|
||||
]
|
||||
return turns
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
@@ -56,6 +58,26 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def supports_batched(self):
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_user_token(self):
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_assistant_token(self):
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
@@ -330,99 +352,109 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
_skip_invalid = False
|
||||
|
||||
@property
|
||||
def supports_batched(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def skip_invalid(self):
|
||||
return self._skip_invalid
|
||||
|
||||
@skip_invalid.setter
|
||||
def skip_invalid(self, value):
|
||||
self._skip_invalid = value
|
||||
|
||||
def get_conversation_thread(self):
|
||||
return "conversations"
|
||||
|
||||
def map_conversation_thread(self, conversation):
|
||||
return conversation
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
result, current_len = tokenize_prompt_default()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
)
|
||||
tokenized_res = defaultdict(lambda: [])
|
||||
conv_field = self.get_conversation_thread()
|
||||
for prmpt in prompt[conv_field]:
|
||||
result, current_len = tokenize_prompt_default()
|
||||
user_token = self._get_user_token()
|
||||
assistant_token = self._get_assistant_token()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation # pylint: disable=protected-access
|
||||
)
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
if conversation.roles[0] in part[0]:
|
||||
turn = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
if not part[1].strip():
|
||||
err_msg = f"user turn has empty text: {prmpt}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif conversation.roles[1] in part[0]:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not part[1].strip():
|
||||
err_msg = f"assistant turn has empty text: {prmpt}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
elif part[0] == "":
|
||||
turn = part[1]
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
err_msg = f"unhandled role: {part[0]}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
continue
|
||||
|
||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||
role_remap = []
|
||||
if (
|
||||
conversation.name == "vicuna_v1.1"
|
||||
and "roles" in prompt
|
||||
and len(prompt["roles"]) >= 2
|
||||
):
|
||||
role_remap = [
|
||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
||||
]
|
||||
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
if conversation.roles[0] in part[0]:
|
||||
role = (
|
||||
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else part[0]
|
||||
)
|
||||
turn = role + part[1]
|
||||
# this is still the user query, we should
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif conversation.roles[1] in part[0]:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
role = (
|
||||
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else part[0]
|
||||
)
|
||||
turn = role + part[1]
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
|
||||
len_role, len(labels)
|
||||
)
|
||||
elif part[0] == "":
|
||||
turn = part[1]
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {part[0]}")
|
||||
continue
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
for key, val in sorted(result.items(), key=lambda x: x[0]):
|
||||
tokenized_res[key].append(val)
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
except ValueError as err:
|
||||
LOG.warning("skipping prompt: %s", str(err))
|
||||
return tokenized_res
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
if not prompt.strip():
|
||||
|
||||
@@ -274,11 +274,9 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for _, sentence in enumerate(source):
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
if len(conv.messages) > 0 and (
|
||||
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
||||
):
|
||||
if role != conv.roles[j % 2]:
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
|
||||
@@ -58,7 +58,9 @@ def train(
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
if (
|
||||
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
|
||||
) or cfg.resume_from_checkpoint is True:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
@@ -71,7 +73,9 @@ def train(
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
resume_from_checkpoint = (
|
||||
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
|
||||
)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
|
||||
@@ -514,27 +514,3 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||
return control
|
||||
|
||||
return LogPredictionCallback
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to wandb"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
||||
artifact.add_file(local_path=self.axolotl_config_path)
|
||||
wandb.run.log_artifact(artifact)
|
||||
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
@@ -121,18 +121,6 @@ def normalize_config(cfg):
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
if cfg.adapter is not None:
|
||||
for key in list(cfg.keys()):
|
||||
if key.startswith("lora_"):
|
||||
new_key = key.replace("lora_", "peft_")
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
f"{key} soon to be deprecated. please use {new_key}"
|
||||
)
|
||||
)
|
||||
cfg[new_key] = cfg[key]
|
||||
del cfg[key]
|
||||
|
||||
|
||||
def validate_config(cfg):
|
||||
if is_torch_bf16_gpu_available():
|
||||
@@ -202,10 +190,7 @@ def validate_config(cfg):
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LoRA finetuning")
|
||||
|
||||
if not cfg.load_in_8bit and cfg.adapter == "ia3":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for IA3 finetuning")
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
|
||||
@@ -16,7 +16,6 @@ from datasets import (
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
@@ -45,6 +44,7 @@ from axolotl.utils.trainer import (
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
|
||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
@@ -158,21 +158,23 @@ def load_tokenized_prepared_datasets(
|
||||
token=use_auth_token,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
except (FileNotFoundError, ValueError):
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
if not d.type:
|
||||
ds = load_from_disk(d.path)
|
||||
else:
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = "json"
|
||||
if d.ds_type:
|
||||
@@ -357,7 +359,7 @@ def load_tokenized_prepared_datasets(
|
||||
if len(datasets) > 1:
|
||||
LOG.info("shuffle merged datasets")
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
if cfg.local_rank == 0:
|
||||
if cfg.local_rank == 0 and cfg.dataset_prepared_path:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
if cfg.push_dataset_to_hub:
|
||||
|
||||
@@ -180,26 +180,6 @@ def load_model(
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
|
||||
from axolotl.monkeypatch.llama_embeddings_hijack import (
|
||||
replace_llama_embeddings_with_uniform_distribution,
|
||||
)
|
||||
|
||||
LOG.info("patching with noisy embeddings")
|
||||
replace_llama_embeddings_with_uniform_distribution(
|
||||
noise_alpha=cfg.noisy_embedding_alpha
|
||||
)
|
||||
|
||||
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
|
||||
from axolotl.monkeypatch.mistral_embeddings_hijack import (
|
||||
replace_mistral_embeddings_with_uniform_distribution,
|
||||
)
|
||||
|
||||
LOG.info("patching with noisy embeddings")
|
||||
replace_mistral_embeddings_with_uniform_distribution(
|
||||
noise_alpha=cfg.noisy_embedding_alpha
|
||||
)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
replace_llama_rope_with_xpos_rope,
|
||||
@@ -402,25 +382,25 @@ def load_model(
|
||||
if model_config.model_type == "btlm":
|
||||
# don't upcast lm_head for btlm
|
||||
continue
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float32)
|
||||
|
||||
require_peft: bool = False
|
||||
if cfg.adapter in ["lora", "qlora", "ia3"]:
|
||||
require_peft = True
|
||||
|
||||
if require_peft:
|
||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
needs_fa2_dtype = True
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
# convert them back to fp16/bf16 for flash-attn compatibility.
|
||||
if require_peft or cfg.fsdp or (cfg.flash_attention and cfg.is_llama_derived_model):
|
||||
if needs_fa2_dtype or (cfg.flash_attention and cfg.is_llama_derived_model):
|
||||
LOG.info("converting modules to %s for flash attention", cfg.torch_dtype)
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name:
|
||||
@@ -429,7 +409,7 @@ def load_model(
|
||||
if hasattr(module, "weight"):
|
||||
module.to(cfg.torch_dtype)
|
||||
|
||||
model, peft_config = load_adapter(model, cfg, cfg.adapter)
|
||||
model, lora_config = load_adapter(model, cfg, cfg.adapter)
|
||||
|
||||
if cfg.ddp and not load_in_8bit:
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
@@ -460,7 +440,7 @@ def load_model(
|
||||
log_gpu_memory_usage(LOG, "after adapters", model.device)
|
||||
|
||||
# TODO resume_from_checkpoint handling
|
||||
return model, peft_config
|
||||
return model, lora_config
|
||||
|
||||
|
||||
def load_adapter(model, cfg, adapter, inference=False):
|
||||
@@ -470,8 +450,6 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
model.enable_input_require_grads()
|
||||
if adapter == "ia3":
|
||||
return load_ia3(model, cfg, inference=inference)
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg, inference=inference)
|
||||
if adapter == "llama-adapter":
|
||||
@@ -490,11 +468,11 @@ def load_llama_adapter(model, cfg):
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if cfg.peft_model_dir:
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretained PEFT - llama_adapter")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.peft_model_dir,
|
||||
cfg.lora_model_dir,
|
||||
torch_dtype=torch.float16,
|
||||
)
|
||||
else:
|
||||
@@ -507,20 +485,16 @@ def load_llama_adapter(model, cfg):
|
||||
|
||||
def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||
peft_module_names = set()
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
isinstance(module, cls)
|
||||
or "Linear" in module.__class__.__name__
|
||||
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
||||
):
|
||||
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
|
||||
names = name.split(".")
|
||||
peft_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
if "lm_head" in peft_module_names: # needed for 16-bit
|
||||
peft_module_names.remove("lm_head")
|
||||
if "lm_head" in lora_module_names: # needed for 16-bit
|
||||
lora_module_names.remove("lm_head")
|
||||
|
||||
return list(peft_module_names)
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def load_lora(model, cfg, inference=False):
|
||||
@@ -528,68 +502,34 @@ def load_lora(model, cfg, inference=False):
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
peft_target_modules = list(cfg.peft_target_modules or [])
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
|
||||
if cfg.peft_target_linear:
|
||||
if cfg.lora_target_linear:
|
||||
linear_names = find_all_linear_names(model)
|
||||
LOG.info(f"found linear modules: {repr(linear_names)}")
|
||||
peft_target_modules = list(set(peft_target_modules + linear_names))
|
||||
lora_target_modules = list(set(lora_target_modules + linear_names))
|
||||
|
||||
peft_config = LoraConfig(
|
||||
r=cfg.peft_r,
|
||||
lora_alpha=cfg.peft_alpha,
|
||||
target_modules=peft_target_modules,
|
||||
lora_dropout=cfg.peft_dropout,
|
||||
fan_in_fan_out=cfg.peft_fan_in_fan_out,
|
||||
modules_to_save=cfg.peft_modules_to_save if cfg.peft_modules_to_save else None,
|
||||
lora_config = LoraConfig(
|
||||
r=cfg.lora_r,
|
||||
lora_alpha=cfg.lora_alpha,
|
||||
target_modules=lora_target_modules,
|
||||
lora_dropout=cfg.lora_dropout,
|
||||
fan_in_fan_out=cfg.lora_fan_in_fan_out,
|
||||
modules_to_save=cfg.lora_modules_to_save if cfg.lora_modules_to_save else None,
|
||||
bias="none",
|
||||
task_type="CAUSAL_LM",
|
||||
)
|
||||
|
||||
if cfg.peft_model_dir:
|
||||
if cfg.lora_model_dir:
|
||||
LOG.debug("Loading pretained PEFT - LoRA")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.peft_model_dir,
|
||||
cfg.lora_model_dir,
|
||||
is_trainable=(not inference),
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, peft_config)
|
||||
model = get_peft_model(model, lora_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model, peft_config
|
||||
|
||||
|
||||
def load_ia3(model, cfg, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
from peft import IA3Config, PeftModel, get_peft_model
|
||||
|
||||
peft_config_kwargs = {}
|
||||
if cfg.peft_init_ia3_weights is not None:
|
||||
peft_config_kwargs["init_ia3_weights"] = cfg.peft_init_ia3_weights
|
||||
if cfg.peft_fan_in_fan_out is not None:
|
||||
peft_config_kwargs["fan_in_fan_out"] = cfg.peft_fan_in_fan_out
|
||||
|
||||
peft_config = IA3Config(
|
||||
target_modules=cfg.peft_target_modules,
|
||||
feedforward_modules=cfg.peft_feedforward_modules,
|
||||
modules_to_save=cfg.peft_modules_to_save,
|
||||
task_type="CAUSAL_LM",
|
||||
**peft_config_kwargs,
|
||||
)
|
||||
|
||||
if cfg.peft_model_dir:
|
||||
LOG.debug("Loading pretained PEFT - IA3")
|
||||
model = PeftModel.from_pretrained(
|
||||
model,
|
||||
cfg.peft_model_dir,
|
||||
is_trainable=(not inference),
|
||||
)
|
||||
else:
|
||||
model = get_peft_model(model, peft_config)
|
||||
|
||||
model.print_trainable_parameters()
|
||||
|
||||
return model, peft_config
|
||||
return model, lora_config
|
||||
|
||||
@@ -30,7 +30,6 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
@@ -423,9 +422,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
)
|
||||
|
||||
# Phi doesn't want the attention_mask feature when training
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||
):
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||
@@ -778,9 +775,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
||||
trainer.add_callback(LogPredictionCallback(cfg))
|
||||
|
||||
if cfg.use_wandb:
|
||||
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
|
||||
|
||||
if cfg.do_bench_eval:
|
||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||
|
||||
|
||||
@@ -24,10 +24,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_lora(self):
|
||||
"""
|
||||
support for legacy lora_ configs
|
||||
:return:
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
@@ -70,101 +66,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_lora_peft(self):
|
||||
"""
|
||||
support for legacy lora_ configs
|
||||
:return:
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"peft_r": 32,
|
||||
"peft_alpha": 64,
|
||||
"peft_dropout": 0.05,
|
||||
"peft_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_ia3_peft(self):
|
||||
"""
|
||||
support for IA3 peft
|
||||
:return:
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "ia3",
|
||||
"peft_r": 32,
|
||||
"peft_alpha": 64,
|
||||
"peft_dropout": 0.05,
|
||||
"peft_target_modules": ["k_proj", "v_proj", "down_proj"],
|
||||
"peft_feedforward_modules": ["down_proj"],
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_lora_packing(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
|
||||
2
tests/fixtures/conversation.tokenized.json
vendored
2
tests/fixtures/conversation.tokenized.json
vendored
File diff suppressed because one or more lines are too long
@@ -1,48 +0,0 @@
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class NormalizationTest(unittest.TestCase):
|
||||
"""
|
||||
Test the cfg normalization module
|
||||
"""
|
||||
|
||||
_caplog: Optional[pytest.LogCaptureFixture] = None
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def inject_fixtures(self, caplog):
|
||||
self._caplog = caplog
|
||||
|
||||
def test_lora_to_peft(self):
|
||||
base_cfg = DictDefault(
|
||||
{
|
||||
"gradient_accumulation_steps": 1,
|
||||
"micro_batch_size": 1,
|
||||
"base_model": "NousResearch/Llama-2-7b-hf",
|
||||
"base_model_config": "NousResearch/Llama-2-7b-hf",
|
||||
}
|
||||
)
|
||||
cfg = base_cfg | DictDefault(
|
||||
{
|
||||
"adapter": "lora",
|
||||
"lora_r": 128,
|
||||
"lora_alpha": 64,
|
||||
}
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
normalize_config(cfg)
|
||||
assert any(
|
||||
"soon to be deprecated. please use peft_" in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
assert cfg.peft_r == 128
|
||||
assert cfg.peft_alpha == 64
|
||||
@@ -90,73 +90,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert "assistant turn has empty text" in self._caplog.records[1].message
|
||||
|
||||
def test_sharegpt_warnings_turns(self):
|
||||
conversation = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert (
|
||||
"Role did not alternate between turns (gpt and human)"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_sharegpt_changes_roles(self):
|
||||
conversation = {
|
||||
"roles": ["USER", "CHARACTER"],
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
],
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
|
||||
|
||||
def test_sharegpt_assistant_label_ignore(self):
|
||||
conversation = {
|
||||
"roles": ["user", "assistant"],
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
],
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
idx = res["input_ids"].index(20255) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
|
||||
Reference in New Issue
Block a user