Compare commits

..

22 Commits

Author SHA1 Message Date
Wing Lian
080612219b use even if not using sample packing 2023-10-13 17:54:35 -04:00
Wing Lian
f95858d369 alternate impl of NEFT 2023-10-13 17:45:24 -04:00
Wing Lian
7f2027d93f tweak for xformers install w pytorch 2.1.0 (#727) 2023-10-13 15:21:17 -04:00
Wing Lian
8d288a2ad4 workaround for installing xformers w torch 2.1.0 (#725) 2023-10-13 11:19:30 -04:00
Wing Lian
f30afe4544 misc sharegpt fixes (#723)
* support for sharegpt with assistant talking first, better masking of assistant token, allow remap of roles from dataset

* invalid role is actually not possible

* update tokenized fixture for corrected labels
2023-10-13 11:04:39 -04:00
Wing Lian
bfbdba8614 pin xformers >= 0.0.22 (#724) 2023-10-13 10:27:56 -04:00
Maxime
3bd9528390 add noisy embedding (#721)
* add noisy embedding

* fix format

* Update README.md

* Update README.md

* linter issues

* caseus fixes

---------

Co-authored-by: Maxime <maxime@nope.no>
2023-10-13 10:00:42 -04:00
Wing Lian
2aa1f71464 fix pytorch 2.1.0 build, add multipack docs (#722) 2023-10-13 08:57:28 -04:00
Wing Lian
1c412c7e9d improve handling of the prepared ds path and other cfg defaults (#701) 2023-10-13 07:46:07 -04:00
Jan Philipp Harries
490923fb78 Save Axolotl config as WandB artifact (#716) 2023-10-11 07:28:12 -04:00
NanoCode012
5855dded3d fix(doc): update default doc according to arg (#714) 2023-10-10 21:51:56 +09:00
atgctg
ace70b33c6 Fix: lowercase True values in config (#713)
* Fix: lowercase `True` values in config

* Fix: lowercase `True` values in config
2023-10-10 21:32:20 +09:00
NanoCode012
11c48c5e03 fix(doc): Add note on inference w sample packing (#712) 2023-10-10 21:08:17 +09:00
lukemarsden
295b2662e1 Get qlora mistral-7b fine tuning working on a single 4090 (#708) 2023-10-10 15:14:23 +09:00
seungduk.kim.2304
77c84e02fd Update README with some explanations (#700)
* Update README with some explanations

* revert commit-hook change

* add more explanation about batch size and gradient accum

* not use latex foromat

* decorate

* git hook again

* Attach a link that explains about LoRA hyperparameters

* update table of content

* Explanation about lora_modules_to_save
2023-10-08 13:37:54 -04:00
mhenrichsen
f91db198f3 fix unneeded space (#699) 2023-10-07 14:19:25 -04:00
Wing Lian
7f2618b5f4 add docker images for pytorch 2.10 (#697) 2023-10-07 12:23:31 -04:00
Wing Lian
aca0398315 apex not needed as amp is part of pytorch (#696) 2023-10-07 12:20:45 -04:00
mhenrichsen
29b8f46aed Merge pull request #693 from OpenAccess-AI-Collective/update-mistral-example
update mistral lr, sample pack
2023-10-07 11:04:58 +02:00
mhenrichsen
83a950bb87 lint 2023-10-07 11:04:35 +02:00
Wing Lian
de87ea68f6 fix multiline for docker (#694) 2023-10-06 22:38:15 -04:00
mhenrichsen
4c8ddf2c6f new lr, sample pack 2023-10-06 22:58:13 +02:00
30 changed files with 649 additions and 298 deletions

View File

@@ -25,6 +25,11 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
- cuda: "118"
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout
uses: actions/checkout@v3 uses: actions/checkout@v3

View File

@@ -23,6 +23,11 @@ jobs:
python_version: "3.10" python_version: "3.10"
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout
@@ -46,6 +51,7 @@ jobs:
build-args: | build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }} BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }} CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }} push: ${{ github.event_name != 'pull_request' }}
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
@@ -68,6 +74,11 @@ jobs:
pytorch: 2.0.1 pytorch: 2.0.1
axolotl_extras: axolotl_extras:
is_latest: true is_latest: true
- cuda: 118
cuda_version: 11.8.0
python_version: "3.10"
pytorch: 2.1.0
axolotl_extras:
runs-on: [self-hosted, gpu, docker] runs-on: [self-hosted, gpu, docker]
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -12,4 +12,3 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-nested-blocks,

264
README.md
View File

@@ -23,9 +23,10 @@ Features:
- [Supported Features](#axolotl-supports) - [Supported Features](#axolotl-supports)
- [Quickstart](#quickstart-) - [Quickstart](#quickstart-)
- [Installation](#installation) - [Installation](#installation)
- [Docker Installation](#environment) - [Docker](#docker)
- [Conda/Pip venv Installation](#condapip-venv) - [Conda/Pip venv](#condapip-venv)
- [LambdaLabs Installation](#lambdalabs) - [LambdaLabs](#lambdalabs)
- [Windows](#windows)
- [Dataset](#dataset) - [Dataset](#dataset)
- [How to Add Custom Prompts](#how-to-add-custom-prompts) - [How to Add Custom Prompts](#how-to-add-custom-prompts)
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset) - [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
@@ -50,7 +51,7 @@ Features:
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b> <b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
</p> </p>
<p> <p>
Go ahead and axolotl questions!! Go ahead and Axolotl questions!!
</p> </p>
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit"> <img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main"> <img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
@@ -102,7 +103,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
### Environment ### Environment
- Docker #### Docker
```bash ```bash
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1 docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
``` ```
@@ -114,12 +115,12 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
docker compose up -d docker compose up -d
``` ```
- Conda/Pip venv #### Conda/Pip venv
1. Install python >=**3.9** 1. Install python >=**3.9**
2. Install pytorch stable https://pytorch.org/get-started/locally/ 2. Install pytorch stable https://pytorch.org/get-started/locally/
3. Install axolotl along with python dependencies 3. Install Axolotl along with python dependencies
```bash ```bash
pip3 install packaging pip3 install packaging
pip3 install -e '.[flash-attn,deepspeed]' pip3 install -e '.[flash-attn,deepspeed]'
@@ -130,7 +131,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
``` ```
Get the token at huggingface.co/settings/tokens Get the token at huggingface.co/settings/tokens
- LambdaLabs #### LambdaLabs
<details> <details>
<summary>Click to Expand</summary> <summary>Click to Expand</summary>
@@ -174,7 +175,8 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
``` ```
</details> </details>
- Windows: Please use WSL or Docker! #### Windows
Please use WSL or Docker!
### Dataset ### Dataset
@@ -396,15 +398,15 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
<summary>All yaml options</summary> <summary>All yaml options</summary>
```yaml ```yaml
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files # This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
# this can also be a relative path to a model on disk # This can also be a relative path to a model on disk
base_model: ./llama-7b-hf base_model: ./llama-7b-hf
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc) # You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
base_model_ignore_patterns: base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files, # If the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model # You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# you can specify to choose a specific model revision from huggingface hub # You can specify to choose a specific model revision from huggingface hub
model_revision: model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer # Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
@@ -419,23 +421,24 @@ trust_remote_code:
tokenizer_use_fast: tokenizer_use_fast:
# Whether to use the legacy tokenizer setting, defaults to True # Whether to use the legacy tokenizer setting, defaults to True
tokenizer_legacy: tokenizer_legacy:
# resize the model embeddings when new tokens are added to multiples of 32 # Resize the model embeddings when new tokens are added to multiples of 32
# this is reported to improve training speed on some models # This is reported to improve training speed on some models
resize_token_embeddings_to_32x: resize_token_embeddings_to_32x:
# used to identify which the model is based on # Used to identify which the model is based on
is_falcon_derived_model: is_falcon_derived_model:
is_llama_derived_model: is_llama_derived_model:
# Please note that if you set this to true, `padding_side` will be set to "left" by default
is_mistral_derived_model: is_mistral_derived_model:
# whether you are training a 4-bit GPTQ quantized model # Whether you are training a 4-bit GPTQ quantized model
gptq: true gptq: true
gptq_groupsize: 128 # group size gptq_groupsize: 128 # group size
gptq_model_v1: false # v1 or v2 gptq_model_v1: false # v1 or v2
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer # This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
load_in_8bit: true load_in_8bit: true
# use bitsandbytes 4 bit # Use bitsandbytes 4 bit
load_in_4bit: load_in_4bit:
# Use CUDA bf16 # Use CUDA bf16
@@ -449,9 +452,9 @@ tf32: true # require >=ampere
bfloat16: true # require >=ampere bfloat16: true # require >=ampere
float16: true float16: true
# a list of one or more datasets to finetune the model with # A list of one or more datasets to finetune the model with
datasets: datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files # HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
@@ -461,17 +464,18 @@ datasets:
name: # Optional[str] name of dataset configuration to load name: # Optional[str] name of dataset configuration to load
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
# custom user prompt # Custom user prompt
- path: repo - path: repo
type: type:
# the below are defaults. only set what's needed. # The below are defaults. only set what's needed.
system_prompt: "" system_prompt: ""
system_format: "{system}"
field_system: system field_system: system
field_instruction: instruction field_instruction: instruction
field_output: input field_input: input
field_output: output
# customizable to be single line or multi-line # Customizable to be single line or multi-line
system_format: "{system}"
# 'format' can include {input} # 'format' can include {input}
format: |- format: |-
User: {instruction} {input} User: {instruction} {input}
@@ -479,13 +483,13 @@ datasets:
# 'no_input_format' cannot include {input} # 'no_input_format' cannot include {input}
no_input_format: "{instruction} " no_input_format: "{instruction} "
# for completions datsets, uses the provided field if not `text` # For `completion` datsets only, uses the provided field instead of `text` column
field: field:
# axolotl attempts to save the dataset as an arrow after packing the data together so # Axolotl attempts to save the dataset as an arrow after packing the data together so
# subsequent training attempts load faster, relative path # subsequent training attempts load faster, relative path
dataset_prepared_path: data/last_run_prepared dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub # Push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()` # The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
# if not set. # if not set.
@@ -495,8 +499,8 @@ hub_model_id: # repo path to push finetuned model
# how to push checkpoints to hub # how to push checkpoints to hub
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy # https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
hub_strategy: hub_strategy:
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # Required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval. # How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
val_set_size: 0.04 val_set_size: 0.04
@@ -505,30 +509,34 @@ dataset_shard_num:
# Index of shard to use for whole dataset # Index of shard to use for whole dataset
dataset_shard_idx: dataset_shard_idx:
# the maximum length of an input to train with, this should typically be less than 2048 # The maximum length of an input to train with, this should typically be less than 2048
# as most models have a token/context limit of 2048 # as most models have a token/context limit of 2048
sequence_len: 2048 sequence_len: 2048
# pad inputs so each step uses constant sized buffers # Pad inputs so each step uses constant sized buffers
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently # This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
pad_to_sequence_len: pad_to_sequence_len:
# max sequence length to concatenate training samples together up to # Max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning # Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# FutureWarning: This will soon be DEPRECATED # FutureWarning: This will soon be DEPRECATED
max_packed_sequence_len: 1024 max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true' # Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
sample_packing: sample_packing:
# set to 'false' if getting errors during eval with sample_packing on. # Set to 'false' if getting errors during eval with sample_packing on.
eval_sample_packing: eval_sample_packing:
# you can set these packing optimizations AFTER starting a training at least once. # You can set these packing optimizations AFTER starting a training at least once.
# The trainer will provide recommended values for these values. # The trainer will provide recommended values for these values.
sample_packing_eff_est: sample_packing_eff_est:
total_num_tokens: total_num_tokens:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model # If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora adapter: lora
# if you already have a lora model trained that you want to load, put that here # If you already have a lora model trained that you want to load, put that here.
# lora hyperparameters # This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
lora_model_dir: lora_model_dir:
# LoRA hyperparameters
# For more details about the following options, see:
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
lora_r: 8 lora_r: 8
lora_alpha: 16 lora_alpha: 16
lora_dropout: 0.05 lora_dropout: 0.05
@@ -540,36 +548,48 @@ lora_target_modules:
# - gate_proj # - gate_proj
# - down_proj # - down_proj
# - up_proj # - up_proj
lora_target_linear: # if true, will target all linear layers lora_target_linear: # If true, will target all linear layers
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
lora_modules_to_save: lora_modules_to_save:
# - embed_tokens # - embed_tokens
# - lm_head # - lm_head
# Once you complete training, the model will be saved to the following directory.
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
lora_out_dir: lora_out_dir:
lora_fan_in_fan_out: false lora_fan_in_fan_out: false
# ReLoRA configuration # ReLoRA configuration
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed # Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
relora_steps: # number of steps per ReLoRA restart relora_steps: # Number of steps per ReLoRA restart
relora_warmup_steps: # number of per-restart warmup steps relora_warmup_steps: # Number of per-restart warmup steps
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
# wandb configuration if you're using it # wandb configuration if you're using it
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
wandb_project: # your wandb project name wandb_project: # Your wandb project name
wandb_entity: # a wandb Team name if using a Team wandb_entity: # A wandb Team name if using a Team
wandb_watch: wandb_watch:
wandb_run_id: # set the name of your wandb run wandb_run_id: # Set the name of your wandb run
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
# where to save the finished model to # Where to save the full-finetuned model to
output_dir: ./completed-model output_dir: ./completed-model
# whether to use torch.compile and which backend to use # Whether to use torch.compile and which backend to use
torch_compile: # bool torch_compile: # bool
torch_compile_backend: # Optional[str] torch_compile_backend: # Optional[str]
# training hyperparameters # Training hyperparameters
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
gradient_accumulation_steps: 1 gradient_accumulation_steps: 1
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
micro_batch_size: 2 micro_batch_size: 2
eval_batch_size: eval_batch_size:
num_epochs: 3 num_epochs: 3
@@ -577,44 +597,47 @@ warmup_steps: 100
learning_rate: 0.00003 learning_rate: 0.00003
lr_quadratic_warmup: lr_quadratic_warmup:
logging_steps: logging_steps:
save_strategy: # set to `no` to skip checkpoint saves save_strategy: # Set to `no` to skip checkpoint saves
save_steps: # leave empty to save at each epoch save_steps: # Leave empty to save at each epoch
eval_steps: # leave empty to eval at each epoch eval_steps: # Leave empty to eval at each epoch
save_total_limit: # checkpoints saved at a time save_total_limit: # Checkpoints saved at a time
# Maximum number of iterations to train for. It precedes num_epochs which means that
# if both are set, num_epochs will not be guaranteed.
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
max_steps: max_steps:
eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0 eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128 eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
# save model as safetensors (require safetensors package) # Save model as safetensors (require safetensors package)
save_safetensors: save_safetensors:
# whether to mask out or include the human's prompt from the training labels # Whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# group similarly sized data to minimize padding # Group similarly sized data to minimize padding.
# may be slower to start, as it must download and sort the entire dataset # May be slower to start, as it must download and sort the entire dataset.
# note that training loss may have an oscillating pattern with this enabled # Note that training loss may have an oscillating pattern with this enabled.
group_by_length: false group_by_length: false
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing # Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
gradient_checkpointing: false gradient_checkpointing: false
# stop training after this many evaluation losses have increased in a row # Stop training after this many evaluation losses have increased in a row
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback # https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
early_stopping_patience: 3 early_stopping_patience: 3
# specify a scheduler and kwargs to use with the optimizer # Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
lr_scheduler_kwargs: lr_scheduler_kwargs:
# for one_cycle optim # For one_cycle optim
lr_div_factor: # learning rate div factor lr_div_factor: # Learning rate div factor
# for log_sweep optim # For log_sweep optim
log_sweep_min_lr: log_sweep_min_lr:
log_sweep_max_lr: log_sweep_max_lr:
# specify optimizer # Specify optimizer
# Valid values are driven by the Transformers OptimizerNames class, see: # Valid values are driven by the Transformers OptimizerNames class, see:
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134 # https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
# #
@@ -640,7 +663,7 @@ log_sweep_max_lr:
# - paged_lion_32bit # - paged_lion_32bit
# - paged_lion_8bit # - paged_lion_8bit
optimizer: optimizer:
# specify weight decay # Specify weight decay
weight_decay: weight_decay:
# adamw hyperparams # adamw hyperparams
adam_beta1: adam_beta1:
@@ -649,49 +672,56 @@ adam_epsilon:
# Gradient clipping max norm # Gradient clipping max norm
max_grad_norm: max_grad_norm:
# whether to bettertransformers # Augmentation techniques
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral
noisy_embedding_alpha:
# Whether to bettertransformers
flash_optimum: flash_optimum:
# whether to use xformers attention patch https://github.com/facebookresearch/xformers: # Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
xformers_attention: xformers_attention:
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention: # Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
flash_attention: flash_attention:
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
# whether to use scaled-dot-product attention # Whether to use scaled-dot-product attention
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html # https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
sdp_attention: sdp_attention:
# Landmark attention (only llama) # Landmark attention (only llama)
landmark_attention: landmark_attention:
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py # xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
# llama only # LLaMA only
xpos_rope: xpos_rope:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling: rope_scaling:
type: # linear | dynamic type: # linear | dynamic
factor: # float factor: # float
# resume from a specific checkpoint dir # Resume from a specific checkpoint dir
resume_from_checkpoint: resume_from_checkpoint:
# if resume_from_checkpoint isn't set and you simply want it to start where it left off # If resume_from_checkpoint isn't set and you simply want it to start where it left off.
# be careful with this being turned on between different models # Be careful with this being turned on between different models.
auto_resume_from_checkpoints: false auto_resume_from_checkpoints: false
# don't mess with this, it's here for accelerate and torchrun # Don't mess with this, it's here for accelerate and torchrun
local_rank: local_rank:
# add or change special tokens # Add or change special tokens.
# If you add tokens here, you don't need to add them to the `tokens` list.
special_tokens: special_tokens:
# bos_token: "<s>" # bos_token: "<s>"
# eos_token: "</s>" # eos_token: "</s>"
# unk_token: "<unk>" # unk_token: "<unk>"
# add extra tokens
# Add extra tokens.
tokens: tokens:
# FSDP # FSDP
fsdp: fsdp:
fsdp_config: fsdp_config:
# Deepspeed config path # Deepspeed config path. e.g., deepspeed/zero3.json
deepspeed: deepspeed:
# Advanced DDP Arguments # Advanced DDP Arguments
@@ -717,6 +747,66 @@ strict:
</details> </details>
<details>
<summary> Understanding of batch size and gradient accumulation steps </summary>
<br/>
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
**Example 1:**
Micro batch size: 3
Gradient accumulation steps: 2
Number of GPUs: 3
Total batch size = 3 * 2 * 3 = 18
```
| GPU 1 | GPU 2 | GPU 3 |
|----------------|----------------|----------------|
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|----------------|----------------|----------------|
| → (accumulate) | → (accumulate) | → (accumulate) |
|----------------|----------------|----------------|
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|----------------|----------------|----------------|
| → (apply) | → (apply) | → (apply) |
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
Weight update for w1:
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
```
**Example 2:**
Micro batch size: 2
Gradient accumulation steps: 1
Number of GPUs: 3
Total batch size = 2 * 1 * 3 = 6
```
| GPU 1 | GPU 2 | GPU 3 |
|-----------|-----------|-----------|
| S1, S2 | S3, S4 | S5, S6 |
| e1, e2 | e3, e4 | e5, e6 |
|-----------|-----------|-----------|
| → (apply) | → (apply) | → (apply) |
Accumulated gradient for the weight w1 (considering all GPUs):
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
Weight update for w1:
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
```
</details>
### Train ### Train
Run Run
@@ -792,6 +882,10 @@ Pass the appropriate flag to the train command:
--base_model="./completed-model" --prompter=None --load_in_8bit=True --base_model="./completed-model" --prompter=None --load_in_8bit=True
``` ```
Please use `--sample_packing False` if you have it on and receive the error similar to below:
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
### Merge LORA to base ### Merge LORA to base
Add below flag to train command above Add below flag to train command above

View File

@@ -5,6 +5,9 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
ARG AXOLOTL_EXTRAS="" ARG AXOLOTL_EXTRAS=""
ARG CUDA="118" ARG CUDA="118"
ENV BNB_CUDA_VERSION=$CUDA ENV BNB_CUDA_VERSION=$CUDA
ARG PYTORCH_VERSION="2.0.1"
ENV PYTORCH_VERSION=$PYTORCH_VERSION
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y vim curl apt-get install -y vim curl
@@ -16,6 +19,7 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
WORKDIR /workspace/axolotl WORKDIR /workspace/axolotl
# If AXOLOTL_EXTRAS is set, append it in brackets # If AXOLOTL_EXTRAS is set, append it in brackets
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \ pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
else \ else \

View File

@@ -14,7 +14,7 @@ ARG CUDA="118"
ENV PYTHON_VERSION=$PYTHON_VERSION ENV PYTHON_VERSION=$PYTHON_VERSION
RUN apt-get update \ RUN apt-get update \
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* && apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
&& wget \ && wget \
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \ https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
&& mkdir /root/.conda \ && mkdir /root/.conda \
@@ -57,11 +57,6 @@ FROM base-builder
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX" ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
# recompile apex
RUN python3 -m pip uninstall -y apex
RUN git clone https://github.com/NVIDIA/apex
RUN cd apex && python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
RUN mkdir -p /workspace/builds RUN mkdir -p /workspace/builds
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes

51
docs/multipack.md Normal file
View File

@@ -0,0 +1,51 @@
# Multipack
4k context, bsz =4,
each character represents 256 tokens
X represents a padding token
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B ]
C C C C C C C ]
D D D D ]]
[[ E E E E E E E E ]
[ F F F F ]
[ G G G ]
[ H H H H ]]
[[ I I I ]
[ J J J ]
[ K K K K K]
[ L L L ]]
```
after padding to longest input in each step
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A ]
B B B B B B X X X X X X ]
C C C C C C C X X X X ]
D D D D X X X X X X X ]]
[[ E E E E E E E E ]
[ F F F F X X X X ]
[ G G G X X X X X ]
[ H H H H X X X X ]]
[[ I I I X X ]
[ J J J X X ]
[ K K K K K ]
[ L L L X X ]]
```
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
```
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
[[ A A A A A A A A A A A B B B B B
B C C C C C C C D D D D E E E E
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```

View File

@@ -16,8 +16,8 @@ val_set_size: 0.01
output_dir: ./out output_dir: ./out
sequence_len: 8192 sequence_len: 8192
sample_packing: sample_packing: true
pad_to_sequence_len: pad_to_sequence_len: true
wandb_project: wandb_project:
wandb_entity: wandb_entity:
@@ -30,7 +30,7 @@ micro_batch_size: 2
num_epochs: 3 num_epochs: 3
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine
learning_rate: 0.0002 learning_rate: 0.000005
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false

View File

@@ -19,8 +19,8 @@ adapter: qlora
lora_model_dir: lora_model_dir:
sequence_len: 8192 sequence_len: 8192
sample_packing: True sample_packing: true
pad_to_sequence_len: True pad_to_sequence_len: true
lora_r: 32 lora_r: 32
lora_alpha: 16 lora_alpha: 16
@@ -43,7 +43,7 @@ wandb_run_id:
wandb_log_model: wandb_log_model:
gradient_accumulation_steps: 4 gradient_accumulation_steps: 4
micro_batch_size: 4 micro_batch_size: 2
num_epochs: 1 num_epochs: 1
optimizer: adamw_bnb_8bit optimizer: adamw_bnb_8bit
lr_scheduler: cosine lr_scheduler: cosine

View File

@@ -28,8 +28,8 @@ num_epochs: 3
learning_rate: 0.00001 learning_rate: 0.00001
train_on_inputs: false train_on_inputs: false
group_by_length: false group_by_length: false
bf16: True bf16: true
tf32: True tf32: true
early_stopping_patience: early_stopping_patience:
resume_from_checkpoint: resume_from_checkpoint:
local_rank: local_rank:

View File

@@ -16,7 +16,7 @@ flash-attn>=2.3.0
sentencepiece sentencepiece
wandb wandb
einops einops
xformers xformers>=0.0.22
optimum optimum
hf_transfer hf_transfer
colorama colorama

View File

@@ -21,6 +21,14 @@ def parse_requirements():
): ):
# Handle standard packages # Handle standard packages
_install_requires.append(line) _install_requires.append(line)
# TODO(wing) remove once xformers release supports torch 2.1.0
if "torch==2.1.0" in _install_requires:
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
_install_requires.append(
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
)
return _install_requires, _dependency_links return _install_requires, _dependency_links

View File

@@ -194,6 +194,7 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
# load the config from the yaml file # load the config from the yaml file
with open(config, encoding="utf-8") as file: with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file)) cfg: DictDefault = DictDefault(yaml.safe_load(file))
cfg.axolotl_config_path = config
# if there are any options passed in the cli, if it is something that seems valid from the yaml, # if there are any options passed in the cli, if it is something that seems valid from the yaml,
# then overwrite the value # then overwrite the value
cfg_keys = cfg.keys() cfg_keys = cfg.keys()

View File

@@ -14,6 +14,7 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs) parsed_cfg = load_cfg(config, **kwargs)
parsed_cfg.sample_packing = False
parser = transformers.HfArgumentParser((TrainerCliArgs)) parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True

View File

@@ -1,10 +1,12 @@
""" """
CLI to run training on a model CLI to run training on a model
""" """
import logging
from pathlib import Path from pathlib import Path
import fire import fire
import transformers import transformers
from colorama import Fore
from axolotl.cli import ( from axolotl.cli import (
check_accelerate_default_config, check_accelerate_default_config,
@@ -14,8 +16,11 @@ from axolotl.cli import (
print_axolotl_text_art, print_axolotl_text_art,
) )
from axolotl.common.cli import TrainerCliArgs from axolotl.common.cli import TrainerCliArgs
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.train import train from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs): def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -27,6 +32,14 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
parsed_cli_args, _ = parser.parse_args_into_dataclasses( parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True return_remaining_strings=True
) )
if parsed_cli_args.prepare_ds_only and not parsed_cfg.dataset_prepared_path:
msg = (
Fore.RED
+ "--prepare_ds_only called without dataset_prepared_path set."
+ Fore.RESET
)
LOG.warning(msg)
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args) dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
if parsed_cli_args.prepare_ds_only: if parsed_cli_args.prepare_ds_only:

View File

@@ -0,0 +1,5 @@
"""
Various shared constants
"""
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"

View File

@@ -5,7 +5,7 @@ import os
from typing import List from typing import List
import torch import torch
from datasets import Dataset, IterableDataset, Sequence, Value from datasets import Dataset, IterableDataset
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
@@ -42,15 +42,11 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100 map_kwargs["batch_size"] = 100
return ( return dataset.map(
dataset.map( self.prompt_tokenizer.tokenize_prompt,
self.prompt_tokenizer.tokenize_prompt, num_proc=num_proc,
num_proc=num_proc, remove_columns=features,
remove_columns=features, **map_kwargs,
**map_kwargs,
)
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
) )

View File

@@ -42,11 +42,21 @@ def replace_llama_attn_with_flash_attn(
packed: Optional[bool] = False, packed: Optional[bool] = False,
cross_entropy: Optional[bool] = False, cross_entropy: Optional[bool] = False,
rms_norm: Optional[bool] = False, rms_norm: Optional[bool] = False,
noisy_embeddings_alpha: Optional[int] = False,
): ):
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
_prepare_decoder_attention_mask _prepare_decoder_attention_mask
) )
transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward transformers.models.llama.modeling_llama.LlamaAttention.forward = flashattn_forward
if noisy_embeddings_alpha:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = partial(
llama_model_get_inputs_embeds, noisy_embeddings_alpha=noisy_embeddings_alpha
)
else:
transformers.models.llama.modeling_llama.LlamaModel.get_inputs_embeds = (
llama_model_get_inputs_embeds
)
if packed: if packed:
transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer transformers.models.llama.modeling_llama.LlamaDecoderLayer = LlamaDecoderLayer
transformers.models.llama.modeling_llama.LlamaModel.forward = ( transformers.models.llama.modeling_llama.LlamaModel.forward = (
@@ -411,6 +421,28 @@ def generate_qkv(
) )
def llama_model_get_inputs_embeds(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
noisy_embeddings_alpha: Optional[int] = None,
):
inputs_embeds = self.embed_tokens(input_ids)
if noisy_embeddings_alpha:
input_mask = attention_mask.to(inputs_embeds) # B x L
input_lengths = torch.sum(input_mask, 1) # B
noise_ = torch.zeros_like(inputs_embeds).uniform_(-1, 1)
delta = noise_ * input_mask.unsqueeze(2)
dims = input_lengths * inputs_embeds.size(-1)
mag = noisy_embeddings_alpha / torch.sqrt(dims)
delta = (delta * mag.view(-1, 1, 1)).detach()
inputs_embeds += delta
return inputs_embeds
def llama_model_forward( def llama_model_forward(
self, self,
input_ids: torch.LongTensor = None, input_ids: torch.LongTensor = None,
@@ -477,7 +509,8 @@ def llama_model_forward(
cu_seqlens = cu_seqlens.squeeze() cu_seqlens = cu_seqlens.squeeze()
if inputs_embeds is None: if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.get_inputs_embeds(input_ids, attention_mask)
# embed positions # embed positions
if attention_mask is None: if attention_mask is None:
attention_mask = torch.ones( attention_mask = torch.ones(

View File

@@ -0,0 +1,40 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
import transformers.models.llama.modeling_llama
from transformers.utils import logging
logger = logging.get_logger(__name__)
def replace_llama_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)
return new_func
def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)
return new_func
transformers.models.llama.modeling_llama.LlamaModel.post_init = post_init(
transformers.models.llama.modeling_llama.LlamaModel.post_init
)

View File

@@ -0,0 +1,40 @@
"""
patch to add noisy embeddings per https://arxiv.org/abs/2310.05914
"""
import torch
import transformers.models.mistral.modeling_mistral
from transformers.utils import logging
logger = logging.get_logger(__name__)
def replace_mistral_embeddings_with_uniform_distribution(noise_alpha=5):
# pylint: disable=duplicate-code
def noised_embed(orig_embed, noise_alpha, model):
def new_func(input_ids):
# during training, we add noise to the embedding
# during generation, we don't add noise to the embedding
if model.training:
embed_init = orig_embed(input_ids)
dims = torch.tensor(embed_init.size(1) * embed_init.size(2))
mag_norm = noise_alpha / torch.sqrt(dims)
return embed_init + torch.zeros_like(embed_init).uniform_(
-mag_norm, mag_norm
)
return orig_embed(input_ids)
return new_func
def post_init(orig_post_init):
def new_func(self):
orig_post_init(self)
self.embed_tokens.forward = noised_embed(
self.embed_tokens.forward, noise_alpha, self
)
return new_func
transformers.models.mistral.modeling_mistral.MistralModel.post_init = post_init(
transformers.models.mistral.modeling_mistral.MistralModel.post_init
)

View File

@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) )
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
strat = ShareGPTPromptTokenizingStrategy( return SimpleShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2( ShareGPTPrompterV2(
conversation=conversation, conversation=conversation,
role_key_model=field_model, role_key_model=field_model,
@@ -34,9 +34,6 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and ds_cfg["skip"]:
strat.skip_invalid = True
return strat
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
@@ -57,38 +54,13 @@ def load_guanaco(tokenizer, cfg):
) )
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return NousShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
""" """
basic sharegpt strategy used by nous/ldj for input/output keyed data basic sharegpt strategy to grab conversations from the sample row
""" """
def get_conversation_thread(self): def get_conversation_thread(self, prompt):
return "conversation" return prompt["conversations"]
def map_conversation_thread(self, conversation):
turns = []
for turn in conversation:
turns.append({"from": "human", "value": turn["input"]})
turns.append({"from": "gpt", "value": turn["output"]})
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -96,11 +68,10 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
def map_conversation_thread(self, conversation): def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ... # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [ turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
{"from": turn["role"], "value": turn["value"]} for turn in conversation
]
return turns return turns
@@ -109,11 +80,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
sharegpt strategy that remaps oasst data to sharegpt format sharegpt strategy that remaps oasst data to sharegpt format
""" """
def map_conversation_thread(self, conversation): def get_conversation_thread(self, prompt):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ... # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"} role_map = {"prompter": "human", "assistant": "gpt"}
turns = [ turns = [
{"from": role_map[turn["role"]], "value": turn["text"]} {"from": role_map[t["role"]], "value": t["text"]} for t in conversations
for turn in conversation
] ]
return turns return turns

View File

@@ -2,9 +2,7 @@
import abc import abc
import copy import copy
import functools
import logging import logging
from collections import defaultdict
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
@@ -58,26 +56,6 @@ class PromptTokenizingStrategy(abc.ABC):
def supports_batched(self): def supports_batched(self):
return False return False
@functools.lru_cache(maxsize=128)
def _get_user_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
@functools.lru_cache(maxsize=128)
def _get_assistant_token(self):
try:
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
if isinstance(id_or_ids, (int,)):
return id_or_ids
except KeyError:
pass
return False
def _tokenize( def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding: ) -> BatchEncoding:
@@ -352,109 +330,99 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for ShareGPT prompts. Tokenizing strategy for ShareGPT prompts.
""" """
_skip_invalid = False def get_conversation_thread(self, prompt):
return prompt["conversations"]
@property
def supports_batched(self):
return True
@property
def skip_invalid(self):
return self._skip_invalid
@skip_invalid.setter
def skip_invalid(self, value):
self._skip_invalid = value
def get_conversation_thread(self):
return "conversations"
def map_conversation_thread(self, conversation):
return conversation
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
tokenized_res = defaultdict(lambda: []) result, current_len = tokenize_prompt_default()
conv_field = self.get_conversation_thread() conversation: Conversation = (
for prmpt in prompt[conv_field]: self.prompter._conversation.copy() # pylint: disable=protected-access
result, current_len = tokenize_prompt_default() )
user_token = self._get_user_token()
assistant_token = self._get_assistant_token()
conversation: Conversation = (
self.prompter._conversation # pylint: disable=protected-access
)
try:
for _, part in enumerate(
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should
if not part[1].strip():
err_msg = f"user turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
if user_token:
res["input_ids"] = [user_token, *res["input_ids"]]
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
err_msg = f"assistant turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
if assistant_token:
res["input_ids"] = [
assistant_token,
*res["input_ids"],
]
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
err_msg = f"unhandled role: {part[0]}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
continue
# pylint: disable=duplicate-code # support for custom roles from the dataset, only useful for vicuna style prompts/roles
result, current_len = parse_tokenized_to_result( role_remap = []
result, if (
current_len, conversation.name == "vicuna_v1.1"
res, and "roles" in prompt
labels, and len(prompt["roles"]) >= 2
pad_token_id=self.tokenizer.pad_token_id, ):
) role_remap = [
for key, val in sorted(result.items(), key=lambda x: x[0]): {"from": conversation.roles[0], "to": prompt["roles"][0]},
tokenized_res[key].append(val) {"from": conversation.roles[1], "to": prompt["roles"][1]},
except (KeyError, AssertionError, IndexError) as err: ]
raise InvalidDataException(str(err)) from err
except ValueError as err: try:
LOG.warning("skipping prompt: %s", str(err)) for _, part in enumerate(
return tokenized_res self.prompter.build_prompt(self.get_conversation_thread(prompt))
):
if isinstance(part, tuple):
if conversation.roles[0] in part[0]:
role = (
part[0].replace(role_remap[0]["from"], role_remap[0]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this is still the user query, we should
if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=False,
strip_bos_token=True,
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
elif conversation.roles[1] in part[0]:
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
role = (
part[0].replace(role_remap[1]["from"], role_remap[1]["to"])
if role_remap
else part[0]
)
turn = role + part[1]
# this should be the assistant response, should end with an eos token
if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}")
res = self._tokenize(
turn,
add_eos_token=True,
strip_bos_token=True,
)
role_res = self._tokenize(
role.rstrip(),
add_eos_token=False,
strip_bos_token=True,
)
# not masked out from labels
labels = copy.deepcopy(res["input_ids"])
len_role = len(role_res["input_ids"])
labels[:len_role] = [IGNORE_TOKEN_ID] * min(
len_role, len(labels)
)
elif part[0] == "":
turn = part[1]
# this is only ever the first part, should include the bos token and the user query
res = self._tokenize(
turn, add_eos_token=False, strip_bos_token=False
)
# everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else:
LOG.warning(f"unhandled role: {part[0]}")
continue
# pylint: disable=duplicate-code
result, current_len = parse_tokenized_to_result(
result,
current_len,
res,
labels,
pad_token_id=self.tokenizer.pad_token_id,
)
return result
except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
if not prompt.strip(): if not prompt.strip():

View File

@@ -274,9 +274,11 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
raise err raise err
conv.messages = [] conv.messages = []
for j, sentence in enumerate(source): for _, sentence in enumerate(source):
role = roles[sentence["from"]] role = roles[sentence["from"]]
if role != conv.roles[j % 2]: if len(conv.messages) > 0 and (
(role == conv.messages[-1][0]) or (role not in conv.roles)
):
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}") LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
conv.append_message(role, sentence["value"]) conv.append_message(role, sentence["value"])

View File

@@ -58,9 +58,7 @@ def train(
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
if ( if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
) or cfg.resume_from_checkpoint is True:
possible_checkpoints = [ possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
] ]
@@ -73,9 +71,7 @@ def train(
LOG.info( LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
) )
resume_from_checkpoint = ( resume_from_checkpoint = cfg.resume_from_checkpoint
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
)
trainer = setup_trainer( trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps

View File

@@ -514,3 +514,27 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
return control return control
return LogPredictionCallback return LogPredictionCallback
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
"""Callback to save axolotl config to wandb"""
def __init__(self, axolotl_config_path):
self.axolotl_config_path = axolotl_config_path
def on_train_begin(
self,
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
state: TrainerState, # pylint: disable=unused-argument
control: TrainerControl,
**kwargs, # pylint: disable=unused-argument
):
if is_main_process():
try:
artifact = wandb.Artifact(name="axolotl-config", type="config")
artifact.add_file(local_path=self.axolotl_config_path)
wandb.run.log_artifact(artifact)
LOG.info("Axolotl config has been saved to WandB as an artifact.")
except (FileNotFoundError, ConnectionError) as err:
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
return control

View File

@@ -16,6 +16,7 @@ from datasets import (
from huggingface_hub import hf_hub_download from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
from axolotl.prompt_strategies import load from axolotl.prompt_strategies import load
from axolotl.prompt_tokenizers import ( from axolotl.prompt_tokenizers import (
@@ -44,7 +45,6 @@ from axolotl.utils.trainer import (
) )
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
def md5(to_hash: str, encoding: str = "utf-8") -> str: def md5(to_hash: str, encoding: str = "utf-8") -> str:
@@ -158,23 +158,21 @@ def load_tokenized_prepared_datasets(
token=use_auth_token, token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except (FileNotFoundError, ValueError): except FileNotFoundError:
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(d.path) local_path = Path(d.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
if not d.type: # TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_from_disk(d.path) ds = load_dataset(
else: d.path,
ds = load_dataset( name=d.name,
d.path, data_files=d.data_files,
name=d.name, streaming=False,
data_files=d.data_files, split=None,
streaming=False, )
split=None,
)
elif local_path.is_file(): elif local_path.is_file():
ds_type = "json" ds_type = "json"
if d.ds_type: if d.ds_type:
@@ -359,7 +357,7 @@ def load_tokenized_prepared_datasets(
if len(datasets) > 1: if len(datasets) > 1:
LOG.info("shuffle merged datasets") LOG.info("shuffle merged datasets")
dataset = dataset.shuffle(seed=seed) dataset = dataset.shuffle(seed=seed)
if cfg.local_rank == 0 and cfg.dataset_prepared_path: if cfg.local_rank == 0:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:

View File

@@ -136,7 +136,11 @@ def load_model(
replace_stablelm_attn_with_flash_attn(cfg.base_model) replace_stablelm_attn_with_flash_attn(cfg.base_model)
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing: if (
cfg.is_llama_derived_model
and cfg.flash_attention
and (cfg.noisy_embeddings_alpha or cfg.sample_packing)
):
if cfg.device not in ["mps", "cpu"] and not inference: if cfg.device not in ["mps", "cpu"] and not inference:
from axolotl.monkeypatch.llama_attn_hijack_flash import ( from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn, replace_llama_attn_with_flash_attn,
@@ -147,6 +151,7 @@ def load_model(
packed=cfg.sample_packing, packed=cfg.sample_packing,
cross_entropy=cfg.flash_attn_cross_entropy, cross_entropy=cfg.flash_attn_cross_entropy,
rms_norm=cfg.flash_attn_rms_norm, rms_norm=cfg.flash_attn_rms_norm,
noisy_embeddings_alpha=cfg.noisy_embeddings_alpha,
) )
elif cfg.is_llama_derived_model and cfg.xformers_attention: elif cfg.is_llama_derived_model and cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import ( from axolotl.monkeypatch.llama_attn_hijack_xformers import (
@@ -180,6 +185,26 @@ def load_model(
LOG.info("patching with flash attention") LOG.info("patching with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing) replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
# if cfg.is_llama_derived_model and cfg.noisy_embedding_alpha:
# from axolotl.monkeypatch.llama_embeddings_hijack import (
# replace_llama_embeddings_with_uniform_distribution,
# )
#
# LOG.info("patching with noisy embeddings")
# replace_llama_embeddings_with_uniform_distribution(
# noise_alpha=cfg.noisy_embedding_alpha
# )
#
if cfg.is_mistral_derived_model and cfg.noisy_embedding_alpha:
from axolotl.monkeypatch.mistral_embeddings_hijack import (
replace_mistral_embeddings_with_uniform_distribution,
)
LOG.info("patching with noisy embeddings")
replace_mistral_embeddings_with_uniform_distribution(
noise_alpha=cfg.noisy_embedding_alpha
)
if cfg.is_llama_derived_model and cfg.xpos_rope: if cfg.is_llama_derived_model and cfg.xpos_rope:
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import ( from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
replace_llama_rope_with_xpos_rope, replace_llama_rope_with_xpos_rope,
@@ -382,7 +407,7 @@ def load_model(
if model_config.model_type == "btlm": if model_config.model_type == "btlm":
# don't upcast lm_head for btlm # don't upcast lm_head for btlm
continue continue
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): if "lm_head" in name or "embed_tokens" in name:
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)

View File

@@ -30,6 +30,7 @@ from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
EvalFirstStepCallback, EvalFirstStepCallback,
GPUStatsCallback, GPUStatsCallback,
SaveAxolotlConfigtoWandBCallback,
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
bench_eval_callback_factory, bench_eval_callback_factory,
log_prediction_callback_factory, log_prediction_callback_factory,
@@ -775,6 +776,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer) LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
trainer.add_callback(LogPredictionCallback(cfg)) trainer.add_callback(LogPredictionCallback(cfg))
if cfg.use_wandb:
trainer.add_callback(SaveAxolotlConfigtoWandBCallback(cfg.axolotl_config_path))
if cfg.do_bench_eval: if cfg.do_bench_eval:
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer)) trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))

File diff suppressed because one or more lines are too long

View File

@@ -90,6 +90,73 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
strat.tokenize_prompt(conversation) strat.tokenize_prompt(conversation)
assert "assistant turn has empty text" in self._caplog.records[1].message assert "assistant turn has empty text" in self._caplog.records[1].message
def test_sharegpt_warnings_turns(self):
conversation = {
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
]
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
strat.tokenize_prompt(conversation)
assert (
"Role did not alternate between turns (gpt and human)"
in self._caplog.records[0].message
)
def test_sharegpt_changes_roles(self):
conversation = {
"roles": ["USER", "CHARACTER"],
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
],
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
def test_sharegpt_assistant_label_ignore(self):
conversation = {
"roles": ["user", "assistant"],
"conversations": [
{"from": "system", "value": "lorem"},
{"from": "gpt", "value": "ipsum"},
{"from": "human", "value": "dolor"},
{"from": "gpt", "value": "sit"},
],
}
prompter = ShareGPTPrompterV2()
strat = ShareGPTPromptTokenizingStrategy(
prompter,
self.tokenizer,
False,
2048,
)
with self._caplog.at_level(logging.WARNING):
res = strat.tokenize_prompt(conversation)
idx = res["input_ids"].index(20255) # assistant token
assert res["labels"][idx] == -100
def test_no_sys_prompt(self): def test_no_sys_prompt(self):
""" """
tests the interface between the user and assistant parts tests the interface between the user and assistant parts