Compare commits
1 Commits
tensor-par
...
llama-drop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
7771498eae |
7
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
7
.github/ISSUE_TEMPLATE/bug-report.yaml
vendored
@@ -53,13 +53,6 @@ body:
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: config
|
||||
attributes:
|
||||
label: Config yaml
|
||||
description: |
|
||||
Please attach the config yaml!
|
||||
|
||||
- type: textarea
|
||||
id: possible-solution
|
||||
attributes:
|
||||
|
||||
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -25,11 +25,6 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
16
.github/workflows/main.yml
vendored
16
.github/workflows/main.yml
vendored
@@ -23,12 +23,6 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -52,12 +46,9 @@ jobs:
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
@@ -77,11 +68,6 @@ jobs:
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
8
.github/workflows/tests.yml
vendored
8
.github/workflows/tests.yml
vendored
@@ -6,11 +6,9 @@ on:
|
||||
- "main"
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
pull_request:
|
||||
paths:
|
||||
- '**.py'
|
||||
- 'requirements.txt'
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
@@ -46,7 +44,7 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install -U -e .
|
||||
pip3 install -e .
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run tests
|
||||
@@ -71,8 +69,8 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn]
|
||||
pip3 install -e .
|
||||
pip3 install flash-attn
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
- name: Run e2e tests
|
||||
|
||||
411
README.md
411
README.md
@@ -23,15 +23,15 @@ Features:
|
||||
- [Supported Features](#axolotl-supports)
|
||||
- [Quickstart](#quickstart-)
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Docker Installation](#environment)
|
||||
- [Conda/Pip venv Installation](#condapip-venv)
|
||||
- [LambdaLabs Installation](#lambdalabs)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
- [Config](#config)
|
||||
- [Train](#train)
|
||||
- [Training w/ Deepspeed](#training-with-deepspeed)
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Common Errors](#common-errors-)
|
||||
@@ -50,7 +50,7 @@ Features:
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
</p>
|
||||
<p>
|
||||
Go ahead and Axolotl questions!!
|
||||
Go ahead and axolotl questions!!
|
||||
</p>
|
||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||
@@ -102,7 +102,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
|
||||
### Environment
|
||||
|
||||
#### Docker
|
||||
- Docker
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
@@ -114,42 +114,18 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Docker advanced</summary>
|
||||
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
It additionally:
|
||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||
|
||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||
|
||||
</details>
|
||||
|
||||
#### Conda/Pip venv
|
||||
- Conda/Pip venv
|
||||
1. Install python >=**3.9**
|
||||
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install Axolotl along with python dependencies
|
||||
3. Install axolotl along with python dependencies
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
4. (Optional) Login to Huggingface to use gated models/datasets.
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
Get the token at huggingface.co/settings/tokens
|
||||
|
||||
#### LambdaLabs
|
||||
- LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
@@ -193,8 +169,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```
|
||||
</details>
|
||||
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
- Windows: Please use WSL or Docker!
|
||||
|
||||
### Dataset
|
||||
|
||||
@@ -205,7 +180,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"instruction": "...", "input": "...", "output": "..."}
|
||||
```
|
||||
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
||||
- `sharegpt:chat`: conversations where `from` is `human`/`gpt`
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -270,10 +245,6 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"article": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_v2`: in context question answering (alternate)
|
||||
```json
|
||||
{"context": "...", "question": "...", "answer": "..."}
|
||||
```
|
||||
- `context_qa.load_404`: in context question answering from an article, with default response for no answer from context
|
||||
```json
|
||||
{"article": "...", "unanswerable_question": "..."}
|
||||
@@ -298,11 +269,11 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
```json
|
||||
{"prompt": "...", "generation": "..."}
|
||||
```
|
||||
- `sharegpt.load_role`: conversations where `role` is used instead of `from`
|
||||
- `sharegpt_simple.load_role`: conversations where `role` is used instead of `from`
|
||||
```json
|
||||
{"conversations": [{"role": "...", "value": "..."}]}
|
||||
```
|
||||
- `sharegpt.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||
- `sharegpt_simple.load_guanaco`: conversations where `from` is `prompter`/`assistant` instead of default sharegpt
|
||||
```json
|
||||
{"conversations": [{"from": "...", "value": "..."}]}
|
||||
```
|
||||
@@ -315,28 +286,29 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
For a dataset that is preprocessed for instruction purposes:
|
||||
|
||||
```json
|
||||
{"instruction": "...", "output": "..."}
|
||||
```
|
||||
|
||||
You can use this example in your YAML config:
|
||||
|
||||
Using yaml. Example:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
field_system: system
|
||||
format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: |-
|
||||
User: {instruction}<|end_of_turn|>
|
||||
Assistant:
|
||||
format: |-
|
||||
User: {instruction}
|
||||
{input}<|end_of_turn|>
|
||||
Assistant:
|
||||
```
|
||||
|
||||
Using file:
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
#### How to use your custom pretokenized dataset
|
||||
|
||||
- Do not pass a `type:`
|
||||
- Columns in Dataset must be exactly `input_ids`, `attention_mask`, `labels`
|
||||
- Dataset must contain `input_ids`, `attention_mask`, `labels` in columns
|
||||
|
||||
|
||||
### Config
|
||||
@@ -374,24 +346,11 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# fastchat conversation
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
datasets:
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: data.jsonl # or json
|
||||
ds_type: json # see other options below
|
||||
type: alpaca
|
||||
|
||||
# dataset with splits, but no train split
|
||||
dataset:
|
||||
- path: knowrohit07/know_sql
|
||||
type: context_qa.load_v2
|
||||
train_on_split: validation
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -419,18 +378,18 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
|
||||
<details>
|
||||
|
||||
<summary>All yaml options (click me)</summary>
|
||||
<summary>All yaml options</summary>
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# This can also be a relative path to a model on disk
|
||||
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# this can also be a relative path to a model on disk
|
||||
base_model: ./llama-7b-hf
|
||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
# you can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
@@ -445,24 +404,18 @@ trust_remote_code:
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
# resize the model embeddings when new tokens are added to multiples of 32
|
||||
# this is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
|
||||
# Used to identify which the model is based on
|
||||
is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
# whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
gptq_groupsize: 128 # group size
|
||||
gptq_model_v1: false # v1 or v2
|
||||
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# Use bitsandbytes 4 bit
|
||||
# use bitsandbytes 4 bit
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
@@ -476,9 +429,9 @@ tf32: true # require >=ampere
|
||||
bfloat16: true # require >=ampere
|
||||
float16: true
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
@@ -487,21 +440,17 @@ datasets:
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
|
||||
# Custom user prompt
|
||||
# custom user prompt
|
||||
- path: repo
|
||||
type:
|
||||
# The below are defaults. only set what's needed.
|
||||
# the below are defaults. only set what's needed.
|
||||
system_prompt: ""
|
||||
system_format: "{system}"
|
||||
field_system: system
|
||||
field_instruction: instruction
|
||||
field_input: input
|
||||
field_output: output
|
||||
field_output: input
|
||||
|
||||
# Customizable to be single line or multi-line
|
||||
# customizable to be single line or multi-line
|
||||
system_format: "{system}"
|
||||
# 'format' can include {input}
|
||||
format: |-
|
||||
User: {instruction} {input}
|
||||
@@ -509,24 +458,21 @@ datasets:
|
||||
# 'no_input_format' cannot include {input}
|
||||
no_input_format: "{instruction} "
|
||||
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
# for completions datsets, uses the provided field if not `text`
|
||||
field:
|
||||
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
# push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
dataset_processes: # defaults to os.cpu_count() if not set
|
||||
# push checkpoints to hub
|
||||
hub_model_id: # repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
val_set_size: 0.04
|
||||
@@ -535,34 +481,30 @@ dataset_shard_num:
|
||||
# Index of shard to use for whole dataset
|
||||
dataset_shard_idx:
|
||||
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# the maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
# pad inputs so each step uses constant sized buffers
|
||||
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
# Max sequence length to concatenate training samples together up to
|
||||
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# FutureWarning: This will soon be DEPRECATED
|
||||
max_packed_sequence_len: 1024
|
||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
||||
# set to 'false' if getting errors during eval with sample_packing on.
|
||||
eval_sample_packing:
|
||||
# You can set these packing optimizations AFTER starting a training at least once.
|
||||
# you can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||
# if you already have a lora model trained that you want to load, put that here
|
||||
# lora hyperparameters
|
||||
lora_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
@@ -574,96 +516,81 @@ lora_target_modules:
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear layers
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
lora_target_linear: # if true, will target all linear layers
|
||||
lora_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
# Once you complete training, the model will be saved to the following directory.
|
||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||
lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# ReLoRA configuration
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # Number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # Number of per-restart warmup steps
|
||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # number of per-restart warmup steps
|
||||
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_project: # your wandb project name
|
||||
wandb_entity: # a wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_run_id: # Set the name of your wandb run
|
||||
wandb_run_id: # set the name of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
# where to save the finished model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
# Whether to use torch.compile and which backend to use
|
||||
# whether to use torch.compile and which backend to use
|
||||
torch_compile: # bool
|
||||
torch_compile_backend: # Optional[str]
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
# training hyperparameters
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 4
|
||||
eval_batch_size: 2
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
save_strategy: # set to `no` to skip checkpoint saves
|
||||
save_steps: # leave empty to save at each epoch
|
||||
eval_steps: # leave empty to eval at each epoch
|
||||
save_total_limit: # checkpoints saved at a time
|
||||
max_steps:
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# Group similarly sized data to minimize padding.
|
||||
# May be slower to start, as it must download and sort the entire dataset.
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
# group similarly sized data to minimize padding
|
||||
# may be slower to start, as it must download and sort the entire dataset
|
||||
# note that training loss may have an oscillating pattern with this enabled
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
# specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
|
||||
# For one_cycle optim
|
||||
lr_div_factor: # Learning rate div factor
|
||||
# for one_cycle optim
|
||||
lr_div_factor: # learning rate div factor
|
||||
|
||||
# For log_sweep optim
|
||||
# for log_sweep optim
|
||||
log_sweep_min_lr:
|
||||
log_sweep_max_lr:
|
||||
|
||||
# Specify optimizer
|
||||
# specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
#
|
||||
@@ -689,7 +616,7 @@ log_sweep_max_lr:
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
optimizer:
|
||||
# Specify weight decay
|
||||
# specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
@@ -698,58 +625,47 @@ adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
noisy_embedding_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
# whether to bettertransformers
|
||||
flash_optimum:
|
||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# Whether to use scaled-dot-product attention
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# LLaMA only
|
||||
# llama only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# Resume from a specific checkpoint dir
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||
# be careful with this being turned on between different models
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
# don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
# Add or change special tokens.
|
||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
||||
# add or change special tokens
|
||||
special_tokens:
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
|
||||
# Add extra tokens.
|
||||
# add extra tokens
|
||||
tokens:
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
# Deepspeed config path. e.g., deepspeed/zero3.json
|
||||
# Deepspeed config path
|
||||
deepspeed:
|
||||
|
||||
# Advanced DDP Arguments
|
||||
@@ -775,66 +691,6 @@ strict:
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary> Understanding of batch size and gradient accumulation steps </summary>
|
||||
<br/>
|
||||
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
|
||||
|
||||
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
|
||||
|
||||
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
|
||||
|
||||
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
|
||||
|
||||
**Example 1:**
|
||||
Micro batch size: 3
|
||||
Gradient accumulation steps: 2
|
||||
Number of GPUs: 3
|
||||
Total batch size = 3 * 2 * 3 = 18
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|----------------|----------------|----------------|
|
||||
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
|
||||
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (accumulate) | → (accumulate) | → (accumulate) |
|
||||
|----------------|----------------|----------------|
|
||||
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
|
||||
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
|
||||
```
|
||||
|
||||
**Example 2:**
|
||||
Micro batch size: 2
|
||||
Gradient accumulation steps: 1
|
||||
Number of GPUs: 3
|
||||
Total batch size = 2 * 1 * 3 = 6
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|-----------|-----------|-----------|
|
||||
| S1, S2 | S3, S4 | S5, S6 |
|
||||
| e1, e2 | e3, e4 | e5, e6 |
|
||||
|-----------|-----------|-----------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Train
|
||||
|
||||
Run
|
||||
@@ -842,41 +698,14 @@ Run
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
This is recommended for large datasets.
|
||||
|
||||
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||
- Use `--debug` to see preprocessed examples.
|
||||
|
||||
```bash
|
||||
python -m axolotl.cli.preprocess your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
|
||||
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
||||
is the recommended multi-GPU option currently because FSDP may experience
|
||||
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
||||
|
||||
##### DeepSpeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
||||
```
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
##### FSDP
|
||||
##### Config
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
@@ -901,6 +730,24 @@ wandb_run_id:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
### Training with Deepspeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
@@ -919,10 +766,6 @@ Pass the appropriate flag to the train command:
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
|
||||
|
||||
### Merge LORA to base
|
||||
|
||||
Add below flag to train command above
|
||||
@@ -939,8 +782,6 @@ CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
See also the [FAQ's](./docs/faq.md).
|
||||
|
||||
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
|
||||
|
||||
Please reduce any below
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 0,
|
||||
@@ -33,13 +41,12 @@
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
"warmup_type": "linear"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
|
||||
@@ -5,9 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG CUDA="118"
|
||||
ENV BNB_CUDA_VERSION=$CUDA
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
@@ -15,19 +12,17 @@ RUN apt-get update && \
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
RUN cd axolotl && \
|
||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .[flash-attn]; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
RUN cd axolotl && \
|
||||
git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||
git config --get remote.origin.fetch
|
||||
|
||||
# helper for huggingface-login cli
|
||||
|
||||
@@ -14,7 +14,7 @@ ARG CUDA="118"
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
@@ -57,6 +57,11 @@ FROM base-builder
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
# recompile apex
|
||||
RUN python3 -m pip uninstall -y apex
|
||||
RUN git clone https://github.com/NVIDIA/apex
|
||||
RUN cd apex && python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
|
||||
|
||||
18
docs/faq.md
18
docs/faq.md
@@ -1,18 +0,0 @@
|
||||
# Axolotl FAQ's
|
||||
|
||||
|
||||
> The trainer stopped and hasn't progressed in several minutes.
|
||||
|
||||
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
|
||||
|
||||
> Exitcode -9
|
||||
|
||||
This usually happens when you run out of system RAM.
|
||||
|
||||
> Exitcode -7 while using deepspeed
|
||||
|
||||
Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
> AttributeError: 'DummyOptim' object has no attribute 'step'
|
||||
|
||||
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
@@ -1,51 +0,0 @@
|
||||
# Multipack
|
||||
|
||||
4k context, bsz =4,
|
||||
each character represents 256 tokens
|
||||
X represents a padding token
|
||||
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B ]
|
||||
C C C C C C C ]
|
||||
D D D D ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F ]
|
||||
[ G G G ]
|
||||
[ H H H H ]]
|
||||
|
||||
[[ I I I ]
|
||||
[ J J J ]
|
||||
[ K K K K K]
|
||||
[ L L L ]]
|
||||
```
|
||||
|
||||
after padding to longest input in each step
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B X X X X X X ]
|
||||
C C C C C C C X X X X ]
|
||||
D D D D X X X X X X X ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F X X X X ]
|
||||
[ G G G X X X X X ]
|
||||
[ H H H H X X X X ]]
|
||||
|
||||
[[ I I I X X ]
|
||||
[ J J J X X ]
|
||||
[ K K K K K ]
|
||||
[ L L L X X ]]
|
||||
```
|
||||
|
||||
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A B B B B B
|
||||
B C C C C C C C D D D D E E E E
|
||||
E E E E F F F F F G G G H H H H
|
||||
I I I J J J J K K K K K L L L X ]]
|
||||
```
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: cerebras/btlm-3b-8k-base
|
||||
base_model_config: cerebras/btlm-3b-8k-base
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: GPT2Tokenizer
|
||||
trust_remote_code: true
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
base_model_config: cerebras/Cerebras-GPT-1.3B
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -6,7 +7,7 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -49,7 +50,7 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
@@ -34,7 +35,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +55,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -36,7 +37,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +57,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
@@ -34,7 +35,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +55,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -36,7 +37,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +57,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
@@ -34,7 +35,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +55,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -36,7 +37,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +57,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
@@ -11,7 +11,7 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: false
|
||||
# enable 4bit for QLoRA
|
||||
load_in_4bit: true
|
||||
@@ -17,7 +17,7 @@ datasets:
|
||||
data_files:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
@@ -53,7 +53,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_falcon_derived_model: true
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
gptq: false
|
||||
@@ -11,7 +11,7 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: EleutherAI/gpt-j-6b
|
||||
base_model_config: EleutherAI/gpt-j-6b
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -6,7 +7,7 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -46,7 +47,7 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
base_model: huggyllama/llama-7b
|
||||
base_model_config: huggyllama/llama-7b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: openaccess-ai-collective/jeopardy
|
||||
type: jeopardy
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -24,7 +25,7 @@ wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -9,16 +9,12 @@ gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
|
||||
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
|
||||
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
|
||||
```
|
||||
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
|
||||
|
||||
To launch a full finetuning with 16-bit precision:
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
|
||||
```
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: TheBloke/Llama-2-7B-GPTQ
|
||||
base_model_config: TheBloke/Llama-2-7B-GPTQ
|
||||
is_llama_derived_model: false
|
||||
gptq: true
|
||||
gptq_disable_exllama: true
|
||||
@@ -14,7 +15,7 @@ hf_use_auth_token: true
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -37,7 +38,7 @@ wandb_log_model:
|
||||
output_dir: ./model-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.00001
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
@@ -34,7 +35,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,8 +55,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -36,7 +37,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,8 +57,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -10,7 +11,7 @@ strict: false
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./relora-out
|
||||
|
||||
@@ -40,7 +41,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -60,7 +61,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps: 50
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
@@ -11,7 +12,7 @@ strict: false
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
@@ -34,7 +35,7 @@ wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,8 +55,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,12 +0,0 @@
|
||||
**Mistral 7B** is a language model with a total of 7.3 billion parameters, showcasing a notable performance across a variety of benchmarks.
|
||||
|
||||
Fine Tune:
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/mistral/config.yml
|
||||
|
||||
```
|
||||
|
||||
If you run into CUDA OOM, use deepspeed with config zero2.json:
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/mistral/config.yml --deepspeed deepspeed/zero2.json
|
||||
```
|
||||
@@ -1,61 +0,0 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.01
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000005
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,78 +0,0 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
lora_target_modules:
|
||||
- gate_proj
|
||||
- down_proj
|
||||
- up_proj
|
||||
- q_proj
|
||||
- v_proj
|
||||
- k_proj
|
||||
- o_proj
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,11 +1,12 @@
|
||||
base_model: mosaicml/mpt-7b
|
||||
base_model_config: mosaicml/mpt-7b
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -26,7 +27,7 @@ wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -8,7 +9,7 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
@@ -8,7 +9,7 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -8,7 +9,7 @@ push_dataset_to_hub:
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
base_model_config: microsoft/phi-1_5
|
||||
model_type: MixFormerSequentialForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
@@ -12,7 +13,7 @@ datasets:
|
||||
- path: garage-bAInd/Open-Platypus
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./phi-sft-out
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
base_model_config: microsoft/phi-1_5
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
@@ -12,7 +13,7 @@ datasets:
|
||||
- path: garage-bAInd/Open-Platypus
|
||||
type: alpaca
|
||||
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./phi-sft-out
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: EleutherAI/pythia-12b-deduped
|
||||
base_model_config: EleutherAI/pythia-12b-deduped
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -9,7 +10,7 @@ device_map: auto
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
base_model: EleutherAI/pythia-1.4b-deduped
|
||||
base_model_config: EleutherAI/pythia-1.4b-deduped
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -23,15 +24,15 @@ wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
bf16: True
|
||||
tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
logging_steps: 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code:
|
||||
@@ -6,7 +7,7 @@ load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.02
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -27,7 +28,7 @@ wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
base_model: replit/replit-code-v1-3b
|
||||
base_model_config: replit/replit-code-v1-3b
|
||||
trust_remote_code: true
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
@@ -26,7 +27,7 @@ wandb_log_model:
|
||||
output_dir: ./lora-replit
|
||||
batch_size: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer:
|
||||
torchdistx_path:
|
||||
lr_scheduler:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
|
||||
# on Tim Dettmer's Guanaco dataset.
|
||||
base_model: Salesforce/xgen-7b-8k-base
|
||||
base_model_config: Salesforce/xgen-7b-8k-base
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -15,7 +16,7 @@ datasets:
|
||||
data_files:
|
||||
- openassistant_best_replies_train.jsonl
|
||||
type: "completion"
|
||||
dataset_prepared_path:
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
@@ -51,7 +52,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
# Page
|
||||
@@ -1,4 +0,0 @@
|
||||
# Table of contents
|
||||
|
||||
* [Page](README.md)
|
||||
* [Small dev details](small-dev-details.md)
|
||||
@@ -1,3 +0,0 @@
|
||||
# Small dev details
|
||||
|
||||
/
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 370 KiB |
@@ -4,19 +4,20 @@ torch==2.0.1
|
||||
auto-gptq
|
||||
packaging
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@acc394c4f5e1283c19783581790b3dc3105a3697
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||
accelerate @ git+https://github.com/huggingface/accelerate
|
||||
deepspeed
|
||||
addict
|
||||
evaluate
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
datasets
|
||||
flash-attn>=2.3.0
|
||||
flash-attn>=2.2.1
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers>=0.0.22
|
||||
xformers
|
||||
optimum
|
||||
hf_transfer
|
||||
colorama
|
||||
@@ -30,5 +31,3 @@ scipy
|
||||
scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.29
|
||||
tensor_parallel
|
||||
|
||||
@@ -7,7 +7,6 @@ import transformers
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
do_inference,
|
||||
do_merge_lora,
|
||||
load_cfg,
|
||||
@@ -32,7 +31,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
)
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
@@ -45,6 +43,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
10
setup.py
10
setup.py
@@ -21,14 +21,6 @@ def parse_requirements():
|
||||
):
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
||||
if "torch==2.1.0" in _install_requires:
|
||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
||||
_install_requires.append(
|
||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
||||
)
|
||||
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
@@ -46,7 +38,7 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn>=2.3.0",
|
||||
"flash-attn>=2.2.1",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed",
|
||||
|
||||
@@ -14,8 +14,6 @@ import yaml
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from accelerate.commands.config import config_args
|
||||
from art import text2art
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
@@ -51,7 +49,7 @@ def print_axolotl_text_art(suffix=None):
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
print("Give me an instruction (Ctrl + D to finish): ")
|
||||
instruction = ""
|
||||
for line in sys.stdin:
|
||||
instruction += line # pylint: disable=consider-using-join
|
||||
@@ -194,7 +192,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
cfg.axolotl_config_path = config
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
@@ -222,9 +219,7 @@ def load_datasets(
|
||||
) -> TrainDatasetMeta:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg, tokenizer
|
||||
)
|
||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
@@ -240,10 +235,6 @@ def load_datasets(
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
for prompter in prompters:
|
||||
LOG.info(prompter)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
@@ -256,16 +247,3 @@ def check_accelerate_default_config():
|
||||
LOG.warning(
|
||||
f"accelerate config file found at {config_args.default_yaml_config_file}. This can lead to unexpected errors"
|
||||
)
|
||||
|
||||
|
||||
def check_user_token():
|
||||
# Verify if token is valid
|
||||
api = HfApi()
|
||||
try:
|
||||
user_info = api.whoami()
|
||||
return bool(user_info)
|
||||
except LocalTokenNotFoundError:
|
||||
LOG.warning(
|
||||
"Error verifying HuggingFace token. Remember to log in using `huggingface-cli login` and get your access token from https://huggingface.co/settings/tokens if you want to use gated models or datasets."
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -14,7 +14,6 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "preprocess CLI called without dataset_prepared_path set, "
|
||||
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
@@ -9,7 +8,6 @@ import transformers
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
@@ -17,20 +15,20 @@ from axolotl.cli import (
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
@@ -25,22 +25,11 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=5)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prepare_ds_only: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
@@ -1,711 +0,0 @@
|
||||
"""
|
||||
Builder for the training args and trainer
|
||||
"""
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional, Union
|
||||
|
||||
import tensor_parallel as tp
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler, SequentialSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.distributed import is_distributed
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
tensor_parallel: bool = field(
|
||||
default=False, metadata={"help": "Use tensor parallelism to train"}
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
):
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if (
|
||||
self.args.world_size > 1
|
||||
and self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
):
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
num_epochs=self.num_epochs,
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
num_epochs=self.num_epochs,
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
def _wrap_model(self, model, training=True, dataloader=None):
|
||||
if self.args.tensor_parallel:
|
||||
model = tp.tensor_parallel(model, distributed=is_distributed())
|
||||
model.hf_device_map = tp.infer_sharded_device_map(model)
|
||||
else:
|
||||
model = super()._wrap_model(model, training=training, dataloader=dataloader)
|
||||
return model
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
pct_start=pct_start,
|
||||
div_factor=6,
|
||||
)
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
"""
|
||||
|
||||
_train_dataset = None
|
||||
_eval_dataset = None
|
||||
|
||||
def __init__(self, cfg, model, tokenizer):
|
||||
self.cfg = cfg
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@train_dataset.setter
|
||||
def train_dataset(self, dataset):
|
||||
self._train_dataset = dataset
|
||||
|
||||
@property
|
||||
def eval_dataset(self):
|
||||
return self._eval_dataset
|
||||
|
||||
@eval_dataset.setter
|
||||
def eval_dataset(self, dataset):
|
||||
self._eval_dataset = dataset
|
||||
|
||||
@abstractmethod
|
||||
def build(self, total_num_steps):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_callbacks(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
|
||||
|
||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Build the HuggingFace training args/trainer for Causal models
|
||||
"""
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
return training_arguments_kwargs
|
||||
|
||||
def hook_post_create_training_args(self, training_arguments):
|
||||
# TODO
|
||||
return training_arguments
|
||||
|
||||
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
|
||||
# TODO
|
||||
return trainer_kwargs, trainer_cls
|
||||
|
||||
def hook_post_create_trainer(self, trainer):
|
||||
if self.cfg.tensor_parallel:
|
||||
trainer.model = trainer.accelerator.prepare_model(
|
||||
trainer.model, device_placement=True
|
||||
)
|
||||
return trainer
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
callbacks.append(GPUStatsCallback(self.cfg))
|
||||
callbacks.append(EvalFirstStepCallback)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
|
||||
if self.cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
self.cfg.early_stopping_patience,
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
if self.cfg.lr_scheduler == "one_cycle" and (
|
||||
self.cfg.fsdp or self.cfg.adapter == "qlora"
|
||||
):
|
||||
return OneCycleLRSchedulerTrainer
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = (
|
||||
self.cfg.warmup_steps
|
||||
if self.cfg.warmup_steps is not None
|
||||
else min(int(0.03 * total_num_steps), 100)
|
||||
)
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = self.cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = (
|
||||
self.cfg.fp16 and not self.cfg.bf16
|
||||
) or False
|
||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if self.cfg.seed:
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||
|
||||
# deepspeed
|
||||
if self.cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
|
||||
if self.cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs[
|
||||
"lr_quadratic_warmup"
|
||||
] = self.cfg.lr_quadratic_warmup
|
||||
|
||||
if self.cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
if self.cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||
if self.cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||
if self.cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||
|
||||
if self.cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
|
||||
if self.cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = self.cfg.sample_packing_eff_est
|
||||
|
||||
if self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
elif self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||
if self.cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
||||
if self.cfg.metric_for_best_model:
|
||||
training_arguments_kwargs[
|
||||
"metric_for_best_model"
|
||||
] = self.cfg.metric_for_best_model
|
||||
if self.cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||
|
||||
if self.cfg.torch_compile:
|
||||
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||
elif torch._dynamo: # pylint: disable=protected-access
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = self.cfg.torch_compile_backend
|
||||
|
||||
# DDP Config
|
||||
if self.cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
|
||||
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||
if self.cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||
if self.cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs[
|
||||
"ddp_broadcast_buffers"
|
||||
] = self.cfg.ddp_broadcast_buffers
|
||||
|
||||
# these are all the "standard" kwargs that are def used
|
||||
training_arguments_kwargs["max_steps"] = (
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs[
|
||||
"per_device_train_batch_size"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
training_arguments_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs[
|
||||
"eval_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||
training_arguments_kwargs["save_total_limit"] = (
|
||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||
)
|
||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||
(
|
||||
self.cfg.load_best_model_at_end is not False
|
||||
or self.cfg.early_stopping_patience
|
||||
)
|
||||
and self.cfg.val_set_size > 0
|
||||
and self.cfg.save_steps
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
False if self.cfg.ddp else None
|
||||
)
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
self.cfg.wandb_run_id if self.cfg.use_wandb else None
|
||||
)
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
if self.cfg.lr_scheduler
|
||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine"
|
||||
)
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
training_arguments_kwargs["sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_seq_len_multiplier"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs["tensor_parallel"] = self.cfg.tensor_parallel is True
|
||||
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(self.cfg.torchdistx_path).exists():
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
if self.cfg.pad_to_sequence_len:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||
self.cfg.sequence_len / 64
|
||||
)
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
set_model_mem_id(self.model, self.tokenizer)
|
||||
|
||||
LOG.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [self.train_dataset, self.eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(
|
||||
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
|
||||
),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=self.get_callbacks(),
|
||||
num_epochs=self.cfg.num_epochs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
return trainer
|
||||
@@ -22,7 +22,7 @@ class TokenizedPromptDataset(Dataset):
|
||||
"""
|
||||
Dataset that returns tokenized prompts from a stream of text files.
|
||||
Args:
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for processing the data.
|
||||
prompt_tokenizer (PromptTokenizingStrategy): The prompt tokenizing method for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
"""
|
||||
|
||||
@@ -55,7 +55,7 @@ class ConstantLengthDataset(IterableDataset):
|
||||
"""
|
||||
Iterable dataset that returns constant length chunks of tokens from stream of text files.
|
||||
Args:
|
||||
tokenizer (Tokenizer): The processor used for processing the data.
|
||||
tokenizer (Tokenizer): The processor used for proccessing the data.
|
||||
dataset (dataset.Dataset): Dataset with text files.
|
||||
seq_length (int): Length of token sequences to return.
|
||||
"""
|
||||
|
||||
@@ -711,8 +711,12 @@ class ParallelBlock(nn.Module):
|
||||
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
||||
self.block_idx = block_idx
|
||||
|
||||
self.mixer = MHA(config, layer_idx=block_idx)
|
||||
self.mlp = MLP(config)
|
||||
self.mixer = MHA(config=config, **mixer, layer_idx=block_idx)
|
||||
mlp_cls = mlp.pop("mlp_cls")
|
||||
if mlp_cls == "fused_mlp":
|
||||
self.mlp = FusedMLP(config=config, **mlp)
|
||||
else:
|
||||
self.mlp = MLP(config=config, **mlp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
||||
@@ -7,7 +7,6 @@ import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
from flash_attn.flash_attn_interface import flash_attn_func
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
|
||||
@@ -18,8 +17,7 @@ def replace_btlm_attn_with_flash_attn(model_name="cerebras/btlm-3b-8k-base"):
|
||||
# this is a wonky hack to get the remotely loaded module
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_btlm to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_btlm", ".modeling_btlm"
|
||||
)
|
||||
|
||||
101
src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
Normal file
101
src/axolotl/monkeypatch/falcon_attn_hijack_flash.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""
|
||||
Flash Attention monkey patch for Falcon
|
||||
|
||||
copied from https://github.com/pacman100/DHS-LLM-Workshop/blob/main/chat_assistant/training/falcon_flash_attn_monkey_patch.py
|
||||
"""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from flash_attn import flash_attn_func
|
||||
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
alibi: Optional[torch.Tensor],
|
||||
attention_mask: torch.Tensor, # pylint: disable=unused-argument
|
||||
layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
||||
head_mask: Optional[torch.Tensor] = None, # pylint: disable=unused-argument
|
||||
use_cache: bool = False,
|
||||
output_attentions: bool = False, # pylint: disable=unused-argument
|
||||
):
|
||||
fused_qkv = self.query_key_value(
|
||||
hidden_states
|
||||
) # [batch_size, seq_length, 3 x hidden_size]
|
||||
num_kv_heads = (
|
||||
self.num_heads if self.new_decoder_architecture else self.num_kv_heads
|
||||
)
|
||||
# 3 x [batch_size, seq_length, num_heads, head_dim]
|
||||
(
|
||||
query_layer,
|
||||
key_layer,
|
||||
value_layer,
|
||||
) = self._split_heads( # pylint: disable=protected-access
|
||||
fused_qkv
|
||||
)
|
||||
|
||||
batch_size, query_length, _, _ = query_layer.shape
|
||||
|
||||
query_layer = query_layer.transpose(1, 2).reshape(
|
||||
batch_size * self.num_heads, query_length, self.head_dim
|
||||
)
|
||||
key_layer = key_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads,
|
||||
query_length,
|
||||
self.head_dim,
|
||||
)
|
||||
value_layer = value_layer.transpose(1, 2).reshape(
|
||||
batch_size * num_kv_heads, query_length, self.head_dim
|
||||
)
|
||||
|
||||
past_kv_length = 0 if layer_past is None else layer_past[0].shape[1]
|
||||
query_layer, key_layer = self.maybe_rotary(query_layer, key_layer, past_kv_length)
|
||||
|
||||
if layer_past is not None:
|
||||
past_key, past_value = layer_past
|
||||
# concatenate along seq_length dimension:
|
||||
# - key: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
# - value: [batch_size * self.num_heads, kv_length, head_dim]
|
||||
key_layer = torch.cat((past_key, key_layer), dim=1)
|
||||
value_layer = torch.cat((past_value, value_layer), dim=1)
|
||||
|
||||
# unused
|
||||
# _, kv_length, _ = key_layer.shape
|
||||
if use_cache:
|
||||
present = (key_layer, value_layer)
|
||||
else:
|
||||
present = None
|
||||
# unused
|
||||
# attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float("-1e9")).to(query_layer.dtype)
|
||||
query_layer_ = (
|
||||
query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
key_layer_ = (
|
||||
key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
value_layer_ = (
|
||||
value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim)
|
||||
.transpose(1, 2)
|
||||
.to(torch.bfloat16)
|
||||
)
|
||||
|
||||
if alibi is not None:
|
||||
raise ValueError("`alibi` is not supported when `use_flash_attn` is True")
|
||||
|
||||
# below output will have shape (batch_size, seqlen, nheads, headdim)
|
||||
attn_output = flash_attn_func(query_layer_, key_layer_, value_layer_, causal=True)
|
||||
attn_output = attn_output.reshape(
|
||||
batch_size, query_length, self.num_heads * self.head_dim
|
||||
)
|
||||
output_tensor = self.dense(attn_output)
|
||||
return output_tensor, present
|
||||
|
||||
|
||||
def replace_falcon_attn_with_flash_attn():
|
||||
transformers.models.falcon.modeling_falcon.FalconAttention.forward = forward
|
||||
@@ -1,174 +0,0 @@
|
||||
"""
|
||||
monkeypatch to add a get_turns method
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Generator, Tuple
|
||||
|
||||
from fastchat.conversation import SeparatorStyle
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.fastchat_conversation_turns")
|
||||
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
ret = ""
|
||||
for role, msg in self.get_turns():
|
||||
ret += role + msg
|
||||
return ret
|
||||
|
||||
|
||||
def get_turns( # pylint: disable=too-many-return-statements
|
||||
self,
|
||||
) -> Generator[Tuple[str, str], None, None]:
|
||||
"""Get the prompt for generation."""
|
||||
system_prompt = self.system_template.format(system_message=self.system_message)
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt + seps[0]
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message + seps[i % 2]
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ": ", "" # must be end with a space
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_SINGLE:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role, message + self.sep
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.NO_COLON_TWO:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role, message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.RWKV:
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield role + ": ", message.replace("\r\n", "\n").replace(
|
||||
"\n\n", "\n"
|
||||
) + "\n\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
||||
seps = [self.sep, self.sep2]
|
||||
if self.system_message:
|
||||
yield "", system_prompt
|
||||
else:
|
||||
yield "", "[INST] "
|
||||
for i, (role, message) in enumerate(self.messages[1:]):
|
||||
if message:
|
||||
yield role + " ", message + seps[i % 2]
|
||||
else:
|
||||
yield role, ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATGLM:
|
||||
# source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308
|
||||
# source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926
|
||||
round_add_n = 1 if self.name == "chatglm2" else 0
|
||||
if system_prompt:
|
||||
yield "", system_prompt + self.sep
|
||||
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if i % 2 == 0:
|
||||
yield "", f"[Round {i//2 + round_add_n}]{self.sep}"
|
||||
|
||||
if message:
|
||||
yield f"{role}:", f"{message}{self.sep}"
|
||||
else:
|
||||
yield f"{role}:", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATML:
|
||||
yield "", "" if system_prompt == "" else system_prompt + self.sep + "\n"
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + "\n", message + self.sep + "\n"
|
||||
else:
|
||||
yield role + "\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.CHATINTERN:
|
||||
# source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
prefix = "<s>" if i % 2 == 0 else ""
|
||||
if message:
|
||||
yield prefix + role + ":", message + seps[i % 2] + "\n"
|
||||
else:
|
||||
yield role + ":", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.DOLLY:
|
||||
seps = [self.sep, self.sep2]
|
||||
yield "", system_prompt
|
||||
for i, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
suffix = "\n\n" if i % 2 == 1 else ""
|
||||
yield role + ":\n", message + seps[i % 2] + suffix
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.PHOENIX:
|
||||
yield "", system_prompt
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", "<s>" + message + "</s>"
|
||||
else:
|
||||
yield role + ": " + "<s>", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.ROBIN:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ":\n", message + self.sep
|
||||
else:
|
||||
yield role + ":\n", ""
|
||||
return
|
||||
if self.sep_style == SeparatorStyle.FALCON_CHAT:
|
||||
if self.system_message:
|
||||
yield "", system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
yield role + ": ", message + self.sep
|
||||
else:
|
||||
yield role + ":", ""
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
|
||||
def add_get_turns_to_conversation():
|
||||
import fastchat.conversation
|
||||
|
||||
fastchat.conversation.Conversation.get_turns = get_turns
|
||||
fastchat.conversation.Conversation.get_prompt = get_prompt
|
||||
@@ -12,19 +12,15 @@ import torch.nn.functional as F
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from torch import nn
|
||||
from transformers import LlamaConfig
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaMLP,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from xformers.ops import SwiGLU
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
@@ -44,33 +40,7 @@ except ImportError:
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def replace_llama_mlp_with_swiglu(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaMLP):
|
||||
mlp = FusedMLP(
|
||||
module.config, module.gate_proj, module.up_proj, module.down_proj
|
||||
)
|
||||
set_module_name(model, name, mlp)
|
||||
|
||||
|
||||
def replace_llama_qkv_with_fused(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
qkv = FusedAttention(
|
||||
module.config,
|
||||
module.q_proj,
|
||||
module.k_proj,
|
||||
module.v_proj,
|
||||
module.o_proj,
|
||||
)
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
rms_norm: Optional[bool] = False,
|
||||
):
|
||||
def replace_llama_attn_with_flash_attn(packed: Optional[bool] = False):
|
||||
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
@@ -81,123 +51,46 @@ def replace_llama_attn_with_flash_attn(
|
||||
llama_model_forward
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if cross_entropy:
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
try:
|
||||
from flash_attn.losses.cross_entropy import CrossEntropyLoss
|
||||
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
# skip only if explicitly disabled
|
||||
if rms_norm:
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
|
||||
|
||||
class FusedAttention(LlamaAttention):
|
||||
"""
|
||||
Fused QKV Attention layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.init_device = next(iter(q.state_dict().values())).device
|
||||
|
||||
# define equivalent fused qkv projection
|
||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||
LOG.info("patching with flash_attn.losses.cross_entropy")
|
||||
transformers.models.llama.modeling_llama.CrossEntropyLoss = partial(
|
||||
CrossEntropyLoss, inplace_backward=True
|
||||
)
|
||||
self.o_proj = o
|
||||
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.qkv_proj.weight.data = torch.cat(
|
||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention CrossEntropyLoss not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=xentropy_cuda_lib&subdirectory=csrc/xentropy'`)"
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
q_proj, k_proj, v_proj = torch.split(
|
||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||
try:
|
||||
from flash_attn.ops.rms_norm import RMSNorm
|
||||
|
||||
class LlamaRMSNorm(RMSNorm):
|
||||
"""Patched LLamaRMSNorm"""
|
||||
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__(hidden_size, eps=eps)
|
||||
|
||||
LOG.info("patching with flash_attn.ops.rms_norm")
|
||||
transformers.models.llama.modeling_llama.LlamaRMSNorm = LlamaRMSNorm
|
||||
except ImportError:
|
||||
LOG.info(
|
||||
"optimized flash-attention RMSNorm not found (run `pip install 'git+https://github.com/Dao-AILab/flash-attention.git#egg=dropout_layer_norm&subdirectory=csrc/layer_norm'`)"
|
||||
)
|
||||
|
||||
new_attn = LlamaAttention(self.config)
|
||||
new_attn.q_proj.weight.data = q_proj
|
||||
new_attn.k_proj.weight.data = k_proj
|
||||
new_attn.v_proj.weight.data = v_proj
|
||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||
|
||||
set_module_name(model, name, new_attn)
|
||||
class GaussianDropout(nn.Module):
|
||||
def __init__(self, p=0.5):
|
||||
super(GaussianDropout, self).__init__()
|
||||
if p <= 0 or p >= 1:
|
||||
raise Exception("p value should accomplish 0 < p < 1")
|
||||
self.p = p
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = LlamaMLP(self.config)
|
||||
new_mlp.gate_proj.weight.data = w1
|
||||
new_mlp.up_proj.weight.data = w2
|
||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
def forward(self, x):
|
||||
stddev = (self.p / (1.0 - self.p)) ** 0.5
|
||||
epsilon = torch.randn_like(x) * stddev
|
||||
return x * epsilon
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
@@ -221,7 +114,6 @@ def flashattn_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
@@ -261,14 +153,9 @@ def flashattn_forward(
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
if isinstance(self, FusedAttention):
|
||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||
self.out_features, dim=-1
|
||||
)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
@@ -330,7 +217,7 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
qkv, cu_seqlens, max_seqlen, dropout_p=0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -604,13 +491,6 @@ def llama_model_forward(
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
padding_mask = None
|
||||
else:
|
||||
if 0 in attention_mask:
|
||||
padding_mask = attention_mask
|
||||
else:
|
||||
padding_mask = None
|
||||
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
@@ -645,9 +525,7 @@ def llama_model_forward(
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(
|
||||
*inputs,
|
||||
)
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
@@ -656,10 +534,9 @@ def llama_model_forward(
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
None,
|
||||
output_attentions,
|
||||
None,
|
||||
padding_mask,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
@@ -671,7 +548,6 @@ def llama_model_forward(
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
@@ -710,6 +586,15 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
patched version of LlamaDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def __init__(self, config: LlamaConfig):
|
||||
super(LlamaDecoderLayer, self).__init__(config)
|
||||
self.attn_dropout = None
|
||||
self.mlp_dropout = None
|
||||
if config.dropout_attn:
|
||||
self.attn_dropout = GaussianDropout(p=config.dropout_attn)
|
||||
if config.dropout_mlp:
|
||||
self.mlp_dropout = GaussianDropout(p=config.dropout_mlp)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -718,7 +603,6 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
@@ -751,16 +635,19 @@ class LlamaDecoderLayer(OriginalLlamaDecoderLayer):
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
padding_mask=padding_mask,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
if self.training and self.attn_dropout:
|
||||
hidden_states = self.attn_dropout(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
if self.training and self.mlp_dropout:
|
||||
hidden_states = self.mlp_dropout(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
@@ -1,640 +0,0 @@
|
||||
"""Flash attention monkey patch for mistral model"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import logging
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_kvpacked_func,
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
|
||||
def replace_mistral_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
):
|
||||
transformers.models.mistral.modeling_mistral.MistralModel._prepare_decoder_attention_mask = ( # pylint: disable=protected-access
|
||||
_prepare_decoder_attention_mask
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralModel.forward = (
|
||||
mistral_model_forward
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
tgt_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: int = 4096,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for sliding window attention
|
||||
"""
|
||||
tensor = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
mask = torch.tril(tensor, diagonal=0)
|
||||
# make the mask banded to account for sliding window
|
||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||
mask = torch.log(mask).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
self,
|
||||
attention_mask,
|
||||
input_shape,
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||
bsz=input_shape[0],
|
||||
tgt_len=input_shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
attention_mask = attention_mask + sliding_window_mask
|
||||
else:
|
||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self: OriginalMistralAttention,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
hasattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
if use_sliding_windows:
|
||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if (
|
||||
hasattr(self.config, "sliding_window")
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
):
|
||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value) if use_cache else None
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if self.training:
|
||||
# during training q,k,v always have same seqlen
|
||||
assert key_states.shape == query_states.shape
|
||||
is_causal = True
|
||||
else:
|
||||
# turn off FA causal mask after first inference autoregressive iteration
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
qkv_unpad, cu_seqlens_q, max_seqlen_q, _, output_pad_fn = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
qkvpacked=True,
|
||||
# We have disabled _prepare_decoder_attention_mask in LlamaModel
|
||||
# the attention_mask should be the same as the key_padding_mask
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
query_states = query_states.transpose(1, 2)
|
||||
key_states = key_states.transpose(1, 2)
|
||||
value_states = value_states.transpose(1, 2)
|
||||
if attention_mask is None or attention_mask.all().item():
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
_,
|
||||
_,
|
||||
output_pad_fn,
|
||||
) = generate_qkv(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
kvpacked=True,
|
||||
key_padding_mask=attention_mask,
|
||||
query_padding_mask=attention_mask[:, -query_states.size(1) :]
|
||||
if attention_mask is not None
|
||||
else None,
|
||||
)
|
||||
if q_unpad.dtype != kv_unpad.dtype:
|
||||
kv_unpad = kv_unpad.to(q_unpad.dtype)
|
||||
output_unpad = flash_attn_varlen_kvpacked_func(
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
attn_output = output
|
||||
if attn_output.size() != (bsz, q_len, self.num_heads, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, q_len, self.num_heads, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
if not output_attentions:
|
||||
attn_weights = None
|
||||
|
||||
return attn_output, attn_weights, past_key_value
|
||||
|
||||
|
||||
# based on https://github.com/Dao-AILab/flash-attention/blob/364a5b/tests/test_flash_attn.py#L38
|
||||
def generate_qkv(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
query_padding_mask=None,
|
||||
key_padding_mask=None,
|
||||
kvpacked=False,
|
||||
qkvpacked=False,
|
||||
): # pylint: disable=invalid-name,unnecessary-lambda-assignment
|
||||
"""
|
||||
Arguments:
|
||||
q: (batch_size, seqlen_q, nheads, d)
|
||||
k: (batch_size, seqlen_k, nheads_k, d)
|
||||
v: (batch_size, seqlen_k, nheads_k, d)
|
||||
query_padding_mask: (batch_size, seqlen), bool
|
||||
key_padding_mask: (batch_size, seqlen), bool
|
||||
"""
|
||||
assert not (kvpacked and qkvpacked)
|
||||
batch_size, seqlen_q, nheads, d = q.shape
|
||||
_, seqlen_k, nheads_k, _ = k.shape
|
||||
assert k.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
assert v.shape == (batch_size, seqlen_k, nheads_k, d)
|
||||
|
||||
if query_padding_mask is not None:
|
||||
q_unpad, indices_q, cu_seqlens_q, max_seqlen_q = unpad_input(
|
||||
q, query_padding_mask
|
||||
)
|
||||
|
||||
output_pad_fn = lambda output_unpad: pad_input( # noqa: E731
|
||||
output_unpad, indices_q, batch_size, seqlen_q
|
||||
)
|
||||
|
||||
else:
|
||||
q_unpad = rearrange(q, "b s h d -> (b s) h d")
|
||||
cu_seqlens_q = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_q,
|
||||
step=seqlen_q,
|
||||
dtype=torch.int32,
|
||||
device=q_unpad.device,
|
||||
)
|
||||
max_seqlen_q = seqlen_q
|
||||
|
||||
output_pad_fn = lambda output_unpad: rearrange( # noqa: E731
|
||||
output_unpad, "(b s) h d -> b s h d", b=batch_size
|
||||
)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
k_unpad, _, cu_seqlens_k, max_seqlen_k = unpad_input(k, key_padding_mask)
|
||||
v_unpad, _, _, _ = unpad_input(v, key_padding_mask)
|
||||
else:
|
||||
k_unpad = rearrange(k, "b s h d -> (b s) h d")
|
||||
v_unpad = rearrange(v, "b s h d -> (b s) h d")
|
||||
cu_seqlens_k = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen_k,
|
||||
step=seqlen_k,
|
||||
dtype=torch.int32,
|
||||
device=k_unpad.device,
|
||||
)
|
||||
max_seqlen_k = seqlen_k
|
||||
|
||||
if qkvpacked:
|
||||
assert nheads == nheads_k
|
||||
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
return (qkv_unpad, cu_seqlens_q, max_seqlen_q, qkv, output_pad_fn)
|
||||
|
||||
if kvpacked:
|
||||
kv_unpad = torch.stack([k_unpad, v_unpad], dim=1)
|
||||
kv = torch.stack([k, v], dim=2)
|
||||
return (
|
||||
q_unpad,
|
||||
kv_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
kv,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
return (
|
||||
q_unpad,
|
||||
k_unpad,
|
||||
v_unpad,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
output_pad_fn,
|
||||
)
|
||||
|
||||
|
||||
def mistral_model_forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
sliding_window=self.config.sliding_window,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
transformers.logger.warning_once(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
|
||||
|
||||
class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
"""
|
||||
patched version of MistralDecoderLayer to pass through the precalculated cu_seqlens
|
||||
"""
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
|
||||
output_attentions (`bool`, *optional*):
|
||||
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||
returned tensors for more detail.
|
||||
use_cache (`bool`, *optional*):
|
||||
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||
(see `past_key_values`).
|
||||
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||
cu_seqlens (`torch.Tensor`, *optional*) cumulative sequence len when packing
|
||||
"""
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
def patch_neft(alpha, model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
embeddings.noisy_embedding_alpha = alpha
|
||||
old_forward = embeddings.forward
|
||||
|
||||
# This hack seems to be needed to properly use a custom forward pass
|
||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
||||
embeddings, embeddings.__class__
|
||||
)
|
||||
setattr(embeddings, "forward", bound_method)
|
||||
|
||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
||||
return model
|
||||
|
||||
|
||||
def unpatch_neft(model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
if hasattr(embeddings, "_old_forward"):
|
||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings.noisy_embedding_alpha
|
||||
|
||||
|
||||
def neft_forward(self, inputs: torch.Tensor):
|
||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
||||
|
||||
if self.training:
|
||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def pretrain_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
||||
|
||||
|
||||
def post_train_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
unpatch_neft(trainer.model)
|
||||
@@ -1,415 +0,0 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2023 Stability AI, EleutherAI, and The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# This code is based off the following work:
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt_neox/modeling_gpt_neox.py
|
||||
""" PyTorch StableLM Epoch model. """
|
||||
import importlib
|
||||
import math
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
from accelerate import init_empty_weights
|
||||
from einops import rearrange
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from torch import nn
|
||||
from transformers import AutoConfig, AutoModelForCausalLM
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.utils import logging
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def replace_stablelm_attn_with_flash_attn(model_name="stabilityai/stablelm-3b-4e1t"):
|
||||
# this is a wonky hack to get the remotely loaded module
|
||||
model_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)
|
||||
# we need to load the model here in order for modeling_stablelm_epoch to be available
|
||||
with init_empty_weights():
|
||||
AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True)
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_stablelm_epoch", ".modeling_stablelm_epoch"
|
||||
)
|
||||
modeling_stablelm = importlib.import_module(module_name)
|
||||
modeling_stablelm.Attention.forward = ( # pylint: disable=protected-access
|
||||
flashattn_attn
|
||||
)
|
||||
modeling_stablelm.StableLMEpochModel.forward = ( # pylint: disable=protected-access
|
||||
stablelm_model_forward
|
||||
)
|
||||
modeling_stablelm.DecoderLayer.forward = ( # pylint: disable=protected-access
|
||||
decoder_layer_forward
|
||||
)
|
||||
|
||||
|
||||
def rotate_half(x: torch.Tensor):
|
||||
"""Rotates half the hidden dims of the input."""
|
||||
# pylint: disable=invalid-name
|
||||
x1, x2 = torch.chunk(x, 2, dim=-1)
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
|
||||
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
|
||||
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
|
||||
# pylint: disable=invalid-name
|
||||
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
|
||||
cos = cos[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
||||
sin = sin[position_ids].unsqueeze(1) # [batch_size, 1, seq_len, dim]
|
||||
q_embed = (q * cos) + (rotate_half(q) * sin)
|
||||
k_embed = (k * cos) + (rotate_half(k) * sin)
|
||||
return q_embed, k_embed
|
||||
|
||||
|
||||
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
|
||||
"""
|
||||
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
|
||||
num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
|
||||
"""
|
||||
batch, num_key_value_heads, slen, head_dim = hidden_states.shape
|
||||
if n_rep == 1:
|
||||
return hidden_states
|
||||
hidden_states = hidden_states[:, :, None, :, :].expand(
|
||||
batch, num_key_value_heads, n_rep, slen, head_dim
|
||||
)
|
||||
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
|
||||
|
||||
|
||||
def flashattn_attn(
|
||||
self,
|
||||
hidden_states: torch.FloatTensor,
|
||||
attention_mask: torch.FloatTensor,
|
||||
position_ids: torch.LongTensor,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False, # pylint: disable=unused-argument
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
|
||||
query_rot = query_states[..., : self.rotary_ndims]
|
||||
query_pass = query_states[..., self.rotary_ndims :]
|
||||
key_rot = key_states[..., : self.rotary_ndims]
|
||||
key_pass = key_states[..., self.rotary_ndims :]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_rot, key_rot, cos, sin, position_ids
|
||||
)
|
||||
|
||||
# [batch_size, num_heads, seq_len, head_dim]
|
||||
query_states = torch.cat((query_states, query_pass), dim=-1)
|
||||
key_states = torch.cat((key_states, key_pass), dim=-1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Reuse k, v, self_attention
|
||||
key_states = torch.cat((past_key_value[0], key_states), dim=2)
|
||||
value_states = torch.cat((past_key_value[1], value_states), dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# Repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
[query_states, key_states, value_states], dim=2
|
||||
) # [bsz, nh, 3, q_len, hd]
|
||||
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
softmax_scale = None
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=softmax_scale, causal=True
|
||||
)
|
||||
|
||||
attn_output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
attn_output = rearrange(attn_output, "b s h d -> b s (h d)")
|
||||
else:
|
||||
attn_weights = torch.matmul(
|
||||
query_states, key_states.transpose(2, 3)
|
||||
) / math.sqrt(self.head_dim)
|
||||
|
||||
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
|
||||
f" {attn_weights.size()}"
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
|
||||
raise ValueError(
|
||||
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
|
||||
)
|
||||
attn_weights = attn_weights + attention_mask
|
||||
|
||||
# Upcast attention to fp32
|
||||
attn_weights = nn.functional.softmax(
|
||||
attn_weights, dim=-1, dtype=torch.float32
|
||||
).to(query_states.dtype)
|
||||
attn_output = torch.matmul(attn_weights, value_states)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
|
||||
# Merge heads
|
||||
attn_output = attn_output.transpose(1, 2).contiguous()
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
# Final linear projection
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
|
||||
|
||||
def decoder_layer_forward(
|
||||
self,
|
||||
hidden_states: Optional[torch.FloatTensor],
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: Optional[bool] = False,
|
||||
use_cache: Optional[bool] = False,
|
||||
cu_seqlens: Optional[torch.Tensor] = None,
|
||||
max_seqlen: Optional[torch.Tensor] = None,
|
||||
) -> Union[
|
||||
Tuple[torch.Tensor], Optional[Tuple[torch.Tensor, Tuple[torch.FloatTensor, ...]]]
|
||||
]:
|
||||
# pylint: disable=duplicate-code
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.input_layernorm(hidden_states)
|
||||
|
||||
# Self Attention
|
||||
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||
hidden_states=hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
# Fully Connected
|
||||
residual = hidden_states
|
||||
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||
hidden_states = self.mlp(hidden_states)
|
||||
hidden_states = residual + hidden_states
|
||||
|
||||
outputs = (hidden_states,)
|
||||
|
||||
if output_attentions:
|
||||
outputs += (self_attn_weights,)
|
||||
|
||||
if use_cache:
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
def stablelm_model_forward(
|
||||
self,
|
||||
input_ids: Optional[torch.LongTensor] = None,
|
||||
attention_mask: Optional[torch.FloatTensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPast]:
|
||||
# pylint: disable=duplicate-code
|
||||
output_attentions = (
|
||||
output_attentions
|
||||
if output_attentions is not None
|
||||
else self.config.output_attentions
|
||||
)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
# Retrieve input_ids and inputs_embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError(
|
||||
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||
)
|
||||
if input_ids is not None:
|
||||
batch_size, seq_length = input_ids.shape
|
||||
elif inputs_embeds is not None:
|
||||
batch_size, seq_length, _ = inputs_embeds.shape
|
||||
else:
|
||||
raise ValueError(
|
||||
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||
)
|
||||
|
||||
seq_length_with_past = seq_length
|
||||
past_key_values_length = 0
|
||||
|
||||
if past_key_values is not None:
|
||||
past_key_values_length = past_key_values[0][0].shape[2]
|
||||
seq_length_with_past = seq_length_with_past + past_key_values_length
|
||||
|
||||
cu_seqlens = None
|
||||
max_seqlen = None
|
||||
if position_ids is None:
|
||||
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||
position_ids = torch.arange(
|
||||
past_key_values_length,
|
||||
seq_length + past_key_values_length,
|
||||
dtype=torch.long,
|
||||
device=device,
|
||||
)
|
||||
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||
else:
|
||||
position_ids = position_ids.view(-1, seq_length).long()
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||
cu_seqlens = cu_seqlens.squeeze()
|
||||
|
||||
if inputs_embeds is None:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
# Embed positions
|
||||
if attention_mask is None:
|
||||
attention_mask = torch.ones(
|
||||
(batch_size, seq_length_with_past),
|
||||
dtype=torch.bool,
|
||||
device=inputs_embeds.device,
|
||||
)
|
||||
attention_mask = (
|
||||
self._prepare_decoder_attention_mask( # pylint: disable=protected-access
|
||||
attention_mask,
|
||||
(batch_size, seq_length),
|
||||
inputs_embeds,
|
||||
past_key_values_length,
|
||||
)
|
||||
)
|
||||
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
if use_cache:
|
||||
logger.warning(
|
||||
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||
)
|
||||
use_cache = False
|
||||
|
||||
# Decoder layers
|
||||
all_hidden_states = () if output_hidden_states else None
|
||||
all_self_attns = () if output_attentions else None
|
||||
next_decoder_cache = () if use_cache else None
|
||||
|
||||
for idx, decoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
past_key_value = past_key_values[idx] if past_key_values is not None else None
|
||||
|
||||
if self.gradient_checkpointing and self.training:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
# None for past_key_value
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
layer_outputs = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(decoder_layer),
|
||||
hidden_states,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
past_key_value,
|
||||
output_attentions,
|
||||
None,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
)
|
||||
else:
|
||||
layer_outputs = decoder_layer(
|
||||
hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_value=past_key_value,
|
||||
output_attentions=output_attentions,
|
||||
use_cache=use_cache,
|
||||
cu_seqlens=cu_seqlens,
|
||||
max_seqlen=max_seqlen,
|
||||
)
|
||||
|
||||
hidden_states = layer_outputs[0]
|
||||
|
||||
if use_cache:
|
||||
next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
|
||||
|
||||
if output_attentions:
|
||||
all_self_attns += (layer_outputs[1],)
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
# Add hidden states from the last decoder layer
|
||||
if output_hidden_states:
|
||||
all_hidden_states += (hidden_states,)
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if not return_dict:
|
||||
return tuple(
|
||||
v
|
||||
for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
|
||||
if v is not None
|
||||
)
|
||||
return BaseModelOutputWithPast(
|
||||
last_hidden_state=hidden_states,
|
||||
past_key_values=next_cache,
|
||||
hidden_states=all_hidden_states,
|
||||
attentions=all_self_attns,
|
||||
)
|
||||
@@ -101,16 +101,3 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
child_name = name[len(parent_name) + 1 :]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ""
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Module for Alpaca prompt strategy classes"""
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompt_style = PromptStyle.CHAT.value
|
||||
if ds_cfg and "conversation" in ds_cfg:
|
||||
prompt_style = ds_cfg["conversation"]
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(prompt_style),
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -24,15 +24,6 @@ def load(tokenizer, cfg):
|
||||
)
|
||||
|
||||
|
||||
def load_v2(tokenizer, cfg):
|
||||
return ContextQaV2PromptTokenizingStrategy(
|
||||
ContextV2Prompter(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class AlpacaContextPrompter(AlpacaPrompter):
|
||||
"""
|
||||
Customized system prompted for concise QA
|
||||
@@ -59,38 +50,6 @@ class AlpacaContextPromptTokenizingStrategy(InstructionPromptTokenizingStrategy)
|
||||
)
|
||||
|
||||
|
||||
class ContextQaV2PromptTokenizingStrategy(InstructionPromptTokenizingStrategy):
|
||||
"""
|
||||
Tokenization Strategy to combine in-context article with a question and answer
|
||||
"""
|
||||
|
||||
def parse_instruction_fields(self, prompt) -> Tuple[str, str, str]:
|
||||
return (
|
||||
"Context: "
|
||||
+ prompt["context"]
|
||||
+ "\nQuestion: "
|
||||
+ prompt["question"]
|
||||
+ "\n",
|
||||
"",
|
||||
"Answer: " + prompt["answer"],
|
||||
)
|
||||
|
||||
|
||||
class ContextV2Prompter(AlpacaPrompter):
|
||||
"""
|
||||
Customized system prompted for concise QA
|
||||
"""
|
||||
|
||||
system_prompt = ""
|
||||
system_no_input_prompt = ""
|
||||
|
||||
def match_prompt_style(self):
|
||||
# pylint: disable=duplicate-code
|
||||
self.turn_format = "{instruction}\n{input}"
|
||||
self.turn_no_input_format = "{instruction}"
|
||||
self.system_format = "{system}"
|
||||
|
||||
|
||||
class AlpacaMissingInfoContextPromptTokenizingStrategy(
|
||||
InstructionPromptTokenizingStrategy
|
||||
):
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
"""Module for Jokes prompts using sharegpt style """
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleJokesShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -1,47 +1,21 @@
|
||||
"""Module containing the SimpleShareGPTPromptTokenizingStrategy class"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastchat.conversation import Conversation, SeparatorStyle, register_conv_template
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="chatml",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message="You are a helpful assistant.",
|
||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||
sep_style=SeparatorStyle.CHATML,
|
||||
sep="<|im_end|>\n",
|
||||
)
|
||||
)
|
||||
from axolotl.prompters import PromptStyle, ShareGPTPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
),
|
||||
def load(tokenizer, cfg):
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg:
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
return strategy
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -50,7 +24,7 @@ def load_role(tokenizer, cfg):
|
||||
|
||||
def load_guanaco(tokenizer, cfg):
|
||||
return GuanacoShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(),
|
||||
ShareGPTPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
@@ -62,26 +36,8 @@ class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
_strict = True
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
return self._strict
|
||||
|
||||
@strict.setter
|
||||
def strict(self, strict):
|
||||
self._strict = strict
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
if self.strict:
|
||||
return conversations
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
||||
]
|
||||
return turns
|
||||
return prompt["conversations"]
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
@@ -2,15 +2,12 @@
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
from transformers import BatchEncoding, PreTrainedTokenizer
|
||||
|
||||
from axolotl.monkeypatch.fastchat_conversation_turns import (
|
||||
add_get_turns_to_conversation,
|
||||
)
|
||||
from axolotl.prompters import IGNORE_TOKEN_ID
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
@@ -21,8 +18,6 @@ LLAMA_DEFAULT_EOS_TOKEN = "</s>" # nosec
|
||||
LLAMA_DEFAULT_BOS_TOKEN = "<s>" # nosec
|
||||
LLAMA_DEFAULT_UNK_TOKEN = "<unk>" # nosec
|
||||
|
||||
add_get_turns_to_conversation()
|
||||
|
||||
|
||||
class InvalidDataException(Exception):
|
||||
"""
|
||||
@@ -45,8 +40,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
self.prompter = prompter
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.train_on_inputs = train_on_inputs
|
||||
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
|
||||
# TODO: Document how they are different.
|
||||
self.sequence_len = sequence_len
|
||||
self.max_length = sequence_len
|
||||
|
||||
@@ -58,34 +51,57 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def supports_batched(self):
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_user_token(self):
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_assistant_token(self):
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
if not prompt:
|
||||
result: BatchEncoding
|
||||
if not prompt.strip():
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
return empty
|
||||
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if len(result["input_ids"]) == 0:
|
||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||
return empty
|
||||
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.max_length
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
@@ -121,7 +137,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
@@ -245,7 +261,6 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
@@ -270,7 +285,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [
|
||||
IGNORE_INDEX
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
|
||||
return tokenized_full_prompt
|
||||
@@ -334,89 +349,56 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
return prompt["conversations"]
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# Initial values. We will append to these as we go through the conversation.
|
||||
result, current_len = tokenize_prompt_default()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
)
|
||||
|
||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||
role_remap = []
|
||||
if (
|
||||
conversation.name == "vicuna_v1.1"
|
||||
and "roles" in prompt
|
||||
and len(prompt["roles"]) >= 2
|
||||
):
|
||||
role_remap = [
|
||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
||||
]
|
||||
|
||||
user_token = self._get_user_token()
|
||||
assistant_token = self._get_assistant_token()
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if not isinstance(part, tuple):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
|
||||
user, assistant = conversation.roles
|
||||
role, content = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
if user in role:
|
||||
role = (
|
||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this is still the user query, we should
|
||||
if not content.strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif assistant in role:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
role = (
|
||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not content.strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
|
||||
elif role == "":
|
||||
turn = content
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
if isinstance(part, tuple):
|
||||
if part[0] == "USER:":
|
||||
turn = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn.strip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif part[0] == "ASSISTANT:":
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn.strip(),
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
elif part[0] == "SYSTEM:":
|
||||
part = part[1] # Ignore the system role from preamble
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
part.strip(), add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {part[0]}")
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
@@ -430,6 +412,38 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
if not prompt.strip():
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||
"""
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
"""Module containing prompters"""
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
from enum import Enum, auto
|
||||
from typing import Generator, List, Optional, Tuple, Union
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||
|
||||
|
||||
class PromptStyle(Enum):
|
||||
@@ -57,15 +54,20 @@ class AlpacaPrompter:
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
def _build_result(self, instruction, input_text, output):
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input_text:
|
||||
if input:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input_text)
|
||||
) + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
@@ -74,21 +76,7 @@ class AlpacaPrompter:
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
yield self._build_result(instruction, input, output)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
yield res
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
@@ -202,14 +190,14 @@ class ReflectAlpacaPrompter:
|
||||
)
|
||||
self.response_split = "ASSISTANT:"
|
||||
|
||||
def _build_result(
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
):
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
@@ -223,30 +211,54 @@ class ReflectAlpacaPrompter:
|
||||
corrected=corrected,
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
yield res
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# pylint: disable=duplicate-code
|
||||
yield self._build_result(
|
||||
instruction,
|
||||
input,
|
||||
output,
|
||||
reflection,
|
||||
corrected,
|
||||
class SeparatorStyle(Enum):
|
||||
"""Different separator style."""
|
||||
|
||||
SINGLE = auto()
|
||||
TWO = auto()
|
||||
DOLLY = auto()
|
||||
|
||||
|
||||
# TODO clean this 💩 up
|
||||
@dataclasses.dataclass
|
||||
class Conversation:
|
||||
"""A class that keeps all conversation history."""
|
||||
|
||||
system: str
|
||||
roles: List[str]
|
||||
messages: List[List[str]]
|
||||
offset: int
|
||||
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
||||
sep: str = "###"
|
||||
sep2: Optional[str] = None
|
||||
|
||||
def get_prompt(self) -> Generator[Tuple[str, str], None, None]:
|
||||
# seps = [self.sep, self.sep2]
|
||||
preamble = self.system + self.sep
|
||||
yield ("SYSTEM:", preamble)
|
||||
for _, (role, message) in enumerate(self.messages):
|
||||
if message:
|
||||
yield (role + ":", " " + message)
|
||||
else:
|
||||
LOG.warning(f"role with empty message: {role}")
|
||||
yield (role + ":", "")
|
||||
|
||||
def copy(self):
|
||||
return Conversation(
|
||||
system=self.system,
|
||||
roles=self.roles,
|
||||
messages=[[x, y] for x, y in self.messages],
|
||||
offset=self.offset,
|
||||
sep_style=self.sep_style,
|
||||
sep=self.sep,
|
||||
sep2=self.sep2,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
def append_message(self, role, message):
|
||||
self.messages.append([role, message])
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
@@ -259,29 +271,30 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
role_key_human = "human"
|
||||
role_key_model = "gpt"
|
||||
def __init__(self, prompt_style=None, system_prompt: Optional[str] = None):
|
||||
if prompt_style != PromptStyle.CHAT.value:
|
||||
raise ValueError(
|
||||
f"unsupported prompt_style for ShareGPTPrompter({prompt_style})"
|
||||
)
|
||||
system: str = (
|
||||
system_prompt
|
||||
if system_prompt
|
||||
else (
|
||||
"A chat between a curious user and an artificial intelligence assistant. "
|
||||
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
||||
)
|
||||
)
|
||||
self._conversation = Conversation(
|
||||
system=system,
|
||||
roles=["USER", "ASSISTANT"],
|
||||
messages=[],
|
||||
offset=0,
|
||||
sep_style=SeparatorStyle.TWO,
|
||||
sep=" ",
|
||||
sep2=" ",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prompt_style=None, # pylint: disable=unused-argument
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
):
|
||||
if conversation:
|
||||
if isinstance(conversation, Conversation):
|
||||
self._conversation = conversation
|
||||
else:
|
||||
self._conversation = get_conv_template(conversation)
|
||||
else:
|
||||
self._conversation = get_conv_template("vicuna_v1.1")
|
||||
if role_key_human:
|
||||
self.role_key_human = role_key_human
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
|
||||
def _build_result(self, source):
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
@@ -293,14 +306,17 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
|
||||
# Add the conversation system prompt if provided, otherwise use the default one
|
||||
if source[0]["from"] == "system":
|
||||
conv.set_system_message(source[0]["value"])
|
||||
conv.system = source[0]["value"]
|
||||
source.pop(0)
|
||||
|
||||
roles = {self.role_key_human: conv.roles[0], self.role_key_model: conv.roles[1]}
|
||||
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
||||
|
||||
try:
|
||||
# Apply prompt templates
|
||||
if source[0]["from"] not in roles:
|
||||
if (
|
||||
source[0]["from"] not in roles
|
||||
or roles[source[0]["from"]] != conv.roles[0]
|
||||
):
|
||||
# Skip the first one if it is not from human
|
||||
source = source[1:]
|
||||
except IndexError as err:
|
||||
@@ -308,54 +324,10 @@ class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for _, sentence in enumerate(source):
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
if len(conv.messages) > 0 and (
|
||||
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
||||
):
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
assert role == conv.roles[j % 2], SHAREGPT_ASSERTION_FAILED_ROLE
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
return conv.get_turns()
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
for part in turns:
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
for part in conv.get_prompt():
|
||||
yield part
|
||||
|
||||
def __repr__(self) -> str:
|
||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
A V2 prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Optional[Union[str, Conversation]] = None,
|
||||
role_key_human: Optional[str] = None,
|
||||
role_key_model: Optional[str] = None,
|
||||
):
|
||||
super().__init__(
|
||||
conversation=conversation,
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter:
|
||||
"""
|
||||
A dummy class for custom prompters
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
||||
|
||||
@@ -12,11 +12,9 @@ import torch
|
||||
import transformers.modelcard
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.monkeypatch import neft_embeddings
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
@@ -41,7 +39,10 @@ class TrainDatasetMeta:
|
||||
|
||||
|
||||
def train(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
dataset_meta: TrainDatasetMeta,
|
||||
):
|
||||
# load the tokenizer first
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
@@ -108,7 +109,6 @@ def train(
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
@@ -116,15 +116,9 @@ def train(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# post training
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
||||
@@ -140,22 +134,6 @@ def train(
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
trainer.save_model(cfg.output_dir)
|
||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
||||
|
||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
||||
# The model name saved is `pytorch_model.bin`
|
||||
unwrapped_model.save_pretrained(
|
||||
cfg.output_dir,
|
||||
is_main_process=trainer.accelerator.is_main_process,
|
||||
save_function=trainer.accelerator.save,
|
||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
@@ -166,23 +144,3 @@ def train(
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(cfg, trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
||||
|
||||
|
||||
def post_train_hooks(cfg, trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.post_train_hook(cfg, trainer)
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
"""Benchmarking and measurement utilities"""
|
||||
import functools
|
||||
import logging
|
||||
|
||||
import pynvml
|
||||
import torch
|
||||
from pynvml.nvml import NVMLError
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.bench")
|
||||
|
||||
|
||||
def check_cuda_device(default_value):
|
||||
"""
|
||||
@@ -65,14 +62,7 @@ def gpu_memory_usage_smi(device=0):
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
if not torch.cuda.is_available():
|
||||
return (0, 0, 0)
|
||||
|
||||
try:
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
except ValueError as exc:
|
||||
LOG.exception(exc)
|
||||
return (0, 0, 0)
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
if cache > 0:
|
||||
extras.append(f"+{cache:.03f}GB cache")
|
||||
|
||||
@@ -37,7 +37,7 @@ from axolotl.utils.distributed import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
IGNORE_INDEX = -100
|
||||
@@ -514,27 +514,3 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||
return control
|
||||
|
||||
return LogPredictionCallback
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to wandb"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
||||
artifact.add_file(local_path=self.axolotl_config_path)
|
||||
wandb.run.log_artifact(artifact)
|
||||
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
@@ -49,8 +49,6 @@ def normalize_config(cfg):
|
||||
cfg.batch_size = (
|
||||
cfg.batch_size or cfg.micro_batch_size * cfg.gradient_accumulation_steps
|
||||
)
|
||||
if cfg.eval_batch_size is None:
|
||||
cfg.eval_batch_size = cfg.micro_batch_size
|
||||
cfg.world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
cfg.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
cfg.eval_table_size = cfg.eval_table_size or 0
|
||||
@@ -77,11 +75,6 @@ def normalize_config(cfg):
|
||||
else:
|
||||
cfg.torch_dtype = torch.float32
|
||||
|
||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||
|
||||
if not cfg.base_model_config:
|
||||
cfg.base_model_config = cfg.base_model
|
||||
|
||||
model_config = load_model_config(cfg)
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
@@ -89,42 +82,10 @@ def normalize_config(cfg):
|
||||
cfg.is_llama_derived_model = (
|
||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||
or cfg.is_llama_derived_model
|
||||
or "llama" in cfg.base_model.lower()
|
||||
or "llama" in cfg.base_model
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
# figure out if the model is falcon
|
||||
cfg.is_falcon_derived_model = (
|
||||
(
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type
|
||||
in [
|
||||
"falcon",
|
||||
"RefinedWebModel",
|
||||
"RefinedWeb",
|
||||
]
|
||||
)
|
||||
or cfg.is_falcon_derived_model
|
||||
or "falcon" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
cfg.is_mistral_derived_model = (
|
||||
(
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type
|
||||
in [
|
||||
"mistral",
|
||||
]
|
||||
)
|
||||
or cfg.is_mistral_derived_model
|
||||
or "mistral" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
if isinstance(cfg.learning_rate, str):
|
||||
cfg.learning_rate = float(cfg.learning_rate)
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
|
||||
@@ -165,11 +126,6 @@ def validate_config(cfg):
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if cfg.eval_batch_size != cfg.micro_batch_size:
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
|
||||
if cfg.load_4bit:
|
||||
raise ValueError("cfg.load_4bit parameter has been deprecated")
|
||||
|
||||
@@ -195,15 +151,9 @@ def validate_config(cfg):
|
||||
if not cfg.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with QLoRA")
|
||||
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
raise ValueError("Fused modules are not supported with LoRA")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
@@ -217,9 +167,6 @@ def validate_config(cfg):
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
@@ -315,64 +262,6 @@ def validate_config(cfg):
|
||||
"`model_type: MixFormerSequentialForCausalLM` required for sample_packing"
|
||||
)
|
||||
|
||||
if cfg.datasets:
|
||||
for idx, ds_cfg in enumerate(cfg.datasets):
|
||||
if not ds_cfg.type:
|
||||
continue
|
||||
if ds_cfg.type == "sharegpt:chat":
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt:chat` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
cfg.datasets[idx].type = "sharegpt"
|
||||
if "sharegpt_simple" in ds_cfg.type:
|
||||
LOG.warning(
|
||||
PendingDeprecationWarning(
|
||||
"`type: sharegpt_simple` will soon be deprecated. simply use `type: sharegpt` instead."
|
||||
)
|
||||
)
|
||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||
"sharegpt_simple", "sharegpt"
|
||||
)
|
||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||
raise ValueError(
|
||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.evaluation_strategy
|
||||
and cfg.eval_steps
|
||||
and cfg.evaluation_strategy != "steps"
|
||||
):
|
||||
raise ValueError(
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.sample_packing
|
||||
and cfg.eval_table_size
|
||||
and cfg.eval_sample_packing is not False
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||
)
|
||||
|
||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.tensor_parallel and cfg.gradient_checkpointing:
|
||||
raise ValueError(
|
||||
"TensorParallelPreTrainedModel does not support gradient checkpointing"
|
||||
)
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
@@ -3,7 +3,7 @@ import functools
|
||||
import hashlib
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
from datasets import (
|
||||
@@ -16,7 +16,6 @@ from datasets import (
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
@@ -26,6 +25,7 @@ from axolotl.prompt_tokenizers import (
|
||||
GPTeacherPromptTokenizingStrategy,
|
||||
JeopardyPromptTokenizingStrategy,
|
||||
OpenAssistantPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
SummarizeTLDRPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import (
|
||||
@@ -35,8 +35,8 @@ from axolotl.prompters import (
|
||||
MultipleChoiceConcisePrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
ReflectAlpacaPrompter,
|
||||
ShareGPTPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
@@ -46,6 +46,7 @@ from axolotl.utils.trainer import (
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
|
||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
@@ -56,10 +57,9 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
|
||||
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
@@ -72,11 +72,11 @@ def prepare_dataset(cfg, tokenizer):
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||
return train_dataset, eval_dataset, cfg.max_steps
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset, tokenizer
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
@@ -85,7 +85,7 @@ def prepare_dataset(cfg, tokenizer):
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||
return train_dataset, eval_dataset, total_num_steps
|
||||
|
||||
|
||||
def load_tokenized_prepared_datasets(
|
||||
@@ -111,13 +111,12 @@ def load_tokenized_prepared_datasets(
|
||||
else Path(default_dataset_prepared_path) / ds_hash
|
||||
)
|
||||
dataset = None
|
||||
prompters = []
|
||||
use_auth_token = cfg.hf_use_auth_token
|
||||
try:
|
||||
if cfg.push_dataset_to_hub:
|
||||
dataset = load_dataset(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
token=use_auth_token,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
@@ -125,7 +124,7 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
if dataset:
|
||||
...
|
||||
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
|
||||
elif any(prepared_ds_path.glob("*")):
|
||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
LOG.info("Prepared dataset loaded from disk...")
|
||||
@@ -150,48 +149,48 @@ def load_tokenized_prepared_datasets(
|
||||
yield dataset
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
for config_dataset in for_d_in_datasets(cfg.datasets):
|
||||
for d in for_d_in_datasets(cfg.datasets):
|
||||
ds: Union[Dataset, DatasetDict] = None
|
||||
ds_from_hub = False
|
||||
try:
|
||||
load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
d.path,
|
||||
name=d.name,
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(config_dataset.path)
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = "json"
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
if d.ds_type:
|
||||
ds_type = d.ds_type
|
||||
elif ".parquet" in d.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in config_dataset.path:
|
||||
elif ".arrow" in d.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in config_dataset.path:
|
||||
elif ".csv" in d.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in config_dataset.path:
|
||||
elif ".txt" in d.path:
|
||||
ds_type = "text"
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
name=d.name,
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
@@ -201,83 +200,143 @@ def load_tokenized_prepared_datasets(
|
||||
)
|
||||
elif ds_from_hub:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
d.path,
|
||||
name=d.name,
|
||||
streaming=False,
|
||||
data_files=config_dataset.data_files,
|
||||
token=use_auth_token,
|
||||
data_files=d.data_files,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=config_dataset.data_files,
|
||||
)
|
||||
elif isinstance(config_dataset.data_files, list):
|
||||
fp = []
|
||||
for file in config_dataset.data_files:
|
||||
fp.append(
|
||||
hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"data_files must be either a string or list of strings"
|
||||
)
|
||||
fp = hf_hub_download(
|
||||
repo_id=d.path,
|
||||
repo_type="dataset",
|
||||
filename=d.data_files,
|
||||
)
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=False,
|
||||
split=None,
|
||||
"json", name=d.name, data_files=fp, streaming=False, split=None
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if config_dataset.shards:
|
||||
if d.shards:
|
||||
if "train" in ds:
|
||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||
num_shards=config_dataset.shards, index=0
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds = ds.shuffle(seed=seed).shard(
|
||||
num_shards=config_dataset.shards, index=0
|
||||
)
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
|
||||
d_base_type = d_prompt_style = None
|
||||
d_type = config_dataset.type
|
||||
d_type = d.type
|
||||
if isinstance(d_type, str):
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||
if "train" in ds:
|
||||
ds = ds["train"]
|
||||
elif (
|
||||
isinstance(ds, DatasetDict)
|
||||
and config_dataset.train_on_split
|
||||
and config_dataset.train_on_split in ds
|
||||
if (
|
||||
"input_ids" in ds.features
|
||||
and "attention_mask" in ds.features
|
||||
and "labels" in ds.features
|
||||
):
|
||||
ds = ds[config_dataset.train_on_split]
|
||||
elif isinstance(ds, DatasetDict):
|
||||
raise ValueError(
|
||||
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
|
||||
# dataset is already tokenized, just drop it straight in
|
||||
datasets.append(ds)
|
||||
elif isinstance(d.type, DictDefault):
|
||||
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "alpaca":
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "explainchoice":
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
MultipleChoiceExplainPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "concisechoice":
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
MultipleChoiceConcisePrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "summarizetldr":
|
||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
SummarizeTLDRPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "jeopardy":
|
||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||
JeopardyPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "oasst":
|
||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
AlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "gpteacher":
|
||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
GPTeacherPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "reflection":
|
||||
ds_strategy = AlpacaReflectionPTStrategy(
|
||||
ReflectAlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "sharegpt":
|
||||
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
suffix = ""
|
||||
if ":load_" in d.type:
|
||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||
)
|
||||
|
||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||
config_dataset=config_dataset,
|
||||
dataset=ds,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
d_base_type=d_base_type,
|
||||
d_prompt_style=d_prompt_style,
|
||||
)
|
||||
datasets.append(dataset_wrapper)
|
||||
prompters.append(dataset_prompter)
|
||||
|
||||
LOG.info("merging datasets")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
|
||||
@@ -295,14 +354,14 @@ def load_tokenized_prepared_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
)
|
||||
|
||||
return dataset, prompters
|
||||
return dataset
|
||||
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
) -> Tuple[Dataset, Dataset, List[Any]]:
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
@@ -311,7 +370,6 @@ def load_prepare_datasets(
|
||||
) # make sure we don't accidentally set it larger than sequence_len
|
||||
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
prompters = []
|
||||
if cfg.max_packed_sequence_len is not None:
|
||||
# see if we can go ahead and load the stacked dataset
|
||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||
@@ -345,7 +403,7 @@ def load_prepare_datasets(
|
||||
)
|
||||
dataset = load_dataset(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}",
|
||||
token=use_auth_token,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
dataset = dataset["train"]
|
||||
except Exception: # pylint: disable=broad-except # nosec
|
||||
@@ -353,7 +411,7 @@ def load_prepare_datasets(
|
||||
|
||||
if dataset:
|
||||
...
|
||||
elif cfg.dataset_prepared_path and any(prepared_ds_path.glob("*")):
|
||||
elif any(prepared_ds_path.glob("*")):
|
||||
LOG.info(
|
||||
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
||||
)
|
||||
@@ -367,7 +425,7 @@ def load_prepare_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
)
|
||||
else:
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
@@ -409,7 +467,7 @@ def load_prepare_datasets(
|
||||
private=True,
|
||||
)
|
||||
else:
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
@@ -460,124 +518,7 @@ def load_prepare_datasets(
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset, prompters
|
||||
|
||||
|
||||
def get_dataset_wrapper(
|
||||
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
||||
):
|
||||
dataset_wrapper = None
|
||||
dataset_prompter = None
|
||||
|
||||
if (
|
||||
"input_ids" in dataset.features
|
||||
and "attention_mask" in dataset.features
|
||||
and "labels" in dataset.features
|
||||
):
|
||||
# dataset is already tokenized, just drop it straight in
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = dataset
|
||||
elif isinstance(config_dataset.type, DictDefault):
|
||||
ds_strategy = load(
|
||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
elif d_base_type == "alpaca":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "explainchoice":
|
||||
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "concisechoice":
|
||||
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "summarizetldr":
|
||||
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "jeopardy":
|
||||
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "oasst":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "gpteacher":
|
||||
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "reflection":
|
||||
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaReflectionPTStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, dataset)
|
||||
dataset_wrapper = ds_wrapper
|
||||
else:
|
||||
suffix = ""
|
||||
if ":load_" in config_dataset.type:
|
||||
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
|
||||
LOG.error(
|
||||
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
|
||||
)
|
||||
|
||||
return dataset_wrapper, dataset_prompter
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def encode_pretraining(
|
||||
|
||||
@@ -3,9 +3,6 @@ import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from queue import Queue
|
||||
from threading import Thread
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
@@ -152,8 +149,6 @@ class MultipackDistributedDataloader:
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
prefetch_max: int = 1000,
|
||||
num_epochs: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
@@ -172,7 +167,6 @@ class MultipackDistributedDataloader:
|
||||
self.seq_max_length = seq_max_length
|
||||
self.batch_max_length = batch_size * seq_max_length
|
||||
self.collate_fn = collate_fn
|
||||
self.num_epochs = num_epochs
|
||||
|
||||
self.num_replicas = 1
|
||||
self.rank = 0
|
||||
@@ -183,44 +177,6 @@ class MultipackDistributedDataloader:
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
# maxsize is maximum number of samples in queue
|
||||
self.prefetch_max = prefetch_max
|
||||
self.queue: Queue = Queue(maxsize=prefetch_max)
|
||||
self.thread = None
|
||||
|
||||
def _worker(self):
|
||||
LOG.info(
|
||||
f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}"
|
||||
)
|
||||
for epoch in range(self.num_epochs):
|
||||
for sample in self._internal_batch_generator():
|
||||
while True:
|
||||
if self.queue.full():
|
||||
time.sleep(1)
|
||||
else:
|
||||
break
|
||||
self.queue.put(sample)
|
||||
|
||||
# stop the queue when epoch is done
|
||||
self.queue.put(None)
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
|
||||
if self.thread is None:
|
||||
self.thread = Thread(target=self._worker, daemon=True)
|
||||
self.thread.start()
|
||||
|
||||
while True:
|
||||
item = self.queue.get()
|
||||
|
||||
if item is None:
|
||||
break
|
||||
yield item
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
if self.sampler:
|
||||
@@ -250,7 +206,11 @@ class MultipackDistributedDataloader:
|
||||
|
||||
return batches, totseqs
|
||||
|
||||
def _internal_batch_generator(self):
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Module for models and model loading"""
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -7,12 +8,10 @@ from typing import Optional, Tuple # noqa: F401
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
import transformers.utils.bitsandbytes
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import PeftConfig, prepare_model_for_kbit_training
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
AutoTokenizer,
|
||||
@@ -32,7 +31,7 @@ LOG = logging.getLogger("axolotl")
|
||||
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
trust_remote_code: bool = False or cfg.trust_remote_code
|
||||
return AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
@@ -73,32 +72,21 @@ def load_tokenizer(cfg):
|
||||
# set a pad_token, but use eos_token so we don't add a new token
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
# Mistral's official FA implementation requires left padding
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, val in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens(
|
||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||
)
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(
|
||||
[
|
||||
AddedToken(token, rstrip=False, lstrip=False, normalized=False)
|
||||
for token in cfg.tokens
|
||||
]
|
||||
)
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, val in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens({k: val})
|
||||
if cfg.tokens:
|
||||
tokenizer.add_tokens(list(cfg.tokens))
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@@ -126,29 +114,26 @@ def load_model(
|
||||
|
||||
replace_btlm_attn_with_flash_attn(cfg.base_model)
|
||||
|
||||
if (
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type == "stablelm_epoch"
|
||||
):
|
||||
if cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.stablelm_attn_hijack_flash import (
|
||||
replace_stablelm_attn_with_flash_attn,
|
||||
if hasattr(model_config, "model_type") and model_config.model_type in [
|
||||
"falcon",
|
||||
"RefinedWebModel",
|
||||
"RefinedWeb",
|
||||
]:
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.falcon_attn_hijack_flash import (
|
||||
replace_falcon_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
replace_stablelm_attn_with_flash_attn(cfg.base_model)
|
||||
replace_falcon_attn_with_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||
if cfg.device not in ["mps", "cpu"] and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention for sample packing")
|
||||
replace_llama_attn_with_flash_attn(
|
||||
packed=cfg.sample_packing,
|
||||
cross_entropy=cfg.flash_attn_cross_entropy,
|
||||
rms_norm=cfg.flash_attn_rms_norm,
|
||||
)
|
||||
LOG.info("patching with flash attention")
|
||||
replace_llama_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
elif cfg.is_llama_derived_model and cfg.xformers_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
|
||||
hijack_llama_attention,
|
||||
@@ -173,14 +158,6 @@ def load_model(
|
||||
# Note: This might overwrite previous additional_special_tokens
|
||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
||||
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||
replace_mistral_attn_with_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
||||
replace_llama_rope_with_xpos_rope,
|
||||
@@ -199,11 +176,21 @@ def load_model(
|
||||
LOG.info("patching _expand_mask")
|
||||
hijack_expand_mask()
|
||||
|
||||
# special handling b/c remote MixFormers code doesn't have _no_split_modules set
|
||||
if (
|
||||
"MixFormerSequentialConfig" in model_config.__class__.__name__
|
||||
and cfg.model_type == "AutoModelForCausalLM"
|
||||
):
|
||||
module_name = model_config.__class__.__module__.replace(
|
||||
".configuration_mixformer_sequential", ".modeling_mixformer_sequential"
|
||||
)
|
||||
modeling_phi = importlib.import_module(module_name)
|
||||
# pylint:disable=protected-access
|
||||
modeling_phi.MixFormerSequentialForCausalLM._no_split_modules = [
|
||||
"ParallelBlock"
|
||||
]
|
||||
|
||||
model_kwargs = {}
|
||||
|
||||
model_kwargs["device_map"] = cfg.device_map
|
||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||
|
||||
if cfg.model_revision:
|
||||
model_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.gptq:
|
||||
@@ -222,26 +209,12 @@ def load_model(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
)
|
||||
# sample packing uses custom FA2 patch
|
||||
if cfg.flash_attention and not cfg.sample_packing:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
or cfg.is_falcon_derived_model
|
||||
or cfg.is_mistral_derived_model
|
||||
):
|
||||
model_kwargs["use_flash_attention_2"] = True
|
||||
|
||||
try:
|
||||
if (
|
||||
cfg.is_llama_derived_model
|
||||
and not cfg.trust_remote_code
|
||||
and not cfg.gptq
|
||||
and not cfg.tensor_parallel
|
||||
):
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config_kwargs = {}
|
||||
@@ -254,24 +227,12 @@ def load_model(
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if cfg.flash_attention and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_mlp_with_swiglu,
|
||||
replace_llama_qkv_with_fused,
|
||||
)
|
||||
|
||||
if cfg.flash_attn_fuse_mlp:
|
||||
LOG.info("patching with SwiGLU")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
|
||||
if cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("patching with fused QKV")
|
||||
replace_llama_qkv_with_fused(model)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
# This is a WIP, still an issue with the backward pass
|
||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||
@@ -303,36 +264,31 @@ def load_model(
|
||||
|
||||
model = MixFormerSequentialForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif model_type and not cfg.trust_remote_code and not cfg.tensor_parallel:
|
||||
elif model_type and not cfg.trust_remote_code:
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
device_map=cfg.device_map,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
elif cfg.tensor_parallel:
|
||||
model_kwargs.pop("device_map")
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
low_cpu_mem_usage=True,
|
||||
offload_state_dict=True,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
@@ -358,6 +314,8 @@ def load_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
device_map=cfg.device_map,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -365,8 +323,10 @@ def load_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=config,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -377,24 +337,23 @@ def load_model(
|
||||
LOG.exception(err)
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
device_map=cfg.device_map,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=cfg.torch_dtype,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
try:
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
else len(tokenizer)
|
||||
)
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
else:
|
||||
model.tie_weights()
|
||||
except NotImplementedError:
|
||||
LOG.warning("`resize_token_embeddings` not implemented on model")
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
else len(tokenizer)
|
||||
)
|
||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
||||
model.resize_token_embeddings(embeddings_len)
|
||||
else:
|
||||
model.tie_weights()
|
||||
|
||||
if (
|
||||
hasattr(model.config, "max_position_embeddings")
|
||||
@@ -406,20 +365,6 @@ def load_model(
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if (
|
||||
hasattr(model.config, "bos_token_id")
|
||||
and model.config.bos_token_id
|
||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||
):
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
|
||||
if (
|
||||
hasattr(model.config, "eos_token_id")
|
||||
and model.config.eos_token_id
|
||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
@@ -497,12 +442,7 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
if adapter is None:
|
||||
return model, None
|
||||
if hasattr(model, "enable_input_require_grads"):
|
||||
try:
|
||||
model.enable_input_require_grads()
|
||||
except NotImplementedError:
|
||||
LOG.warning("enable_input_require_grads not implemented on model")
|
||||
if adapter == "qlora" and cfg.tensor_parallel:
|
||||
model, _ = load_tp_qlora(model)
|
||||
model.enable_input_require_grads()
|
||||
if adapter in ["lora", "qlora"]:
|
||||
return load_lora(model, cfg, inference=inference)
|
||||
if adapter == "llama-adapter":
|
||||
@@ -540,11 +480,7 @@ def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
isinstance(module, cls)
|
||||
or "Linear" in module.__class__.__name__
|
||||
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
||||
):
|
||||
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
|
||||
names = name.split(".")
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
@@ -554,25 +490,6 @@ def find_all_linear_names(model):
|
||||
return list(lora_module_names)
|
||||
|
||||
|
||||
def load_tp_qlora(model):
|
||||
from transformers.utils.bitsandbytes import replace_with_bnb_linear
|
||||
|
||||
model = replace_with_bnb_linear(
|
||||
model,
|
||||
quantization_config=BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
llm_int8_threshold=6.0,
|
||||
llm_int8_has_fp16_weight=False,
|
||||
bnb_4bit_compute_dtype=torch.float16,
|
||||
bnb_4bit_use_double_quant=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
),
|
||||
)
|
||||
model.is_loaded_in_4bit = True
|
||||
|
||||
return model, None
|
||||
|
||||
|
||||
def load_lora(model, cfg, inference=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
|
||||
|
||||
@@ -31,8 +31,7 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
)
|
||||
colored_tokens.append(colored_token)
|
||||
|
||||
delimiter = "" if text_only else " "
|
||||
LOG.info(delimiter.join(colored_tokens))
|
||||
LOG.info(" ".join(colored_tokens))
|
||||
LOG.info("\n\n\n")
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
@@ -1,19 +1,39 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
import torch.distributed as dist
|
||||
from datasets import set_caching_enabled
|
||||
from torch.utils.data import DistributedSampler, RandomSampler
|
||||
import transformers
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
DistributedSampler,
|
||||
RandomSampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.distributed import (
|
||||
@@ -22,6 +42,7 @@ from axolotl.utils.distributed import (
|
||||
reduce_and_broadcast,
|
||||
zero_first,
|
||||
)
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -88,6 +109,269 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||
return weighted_cross_entropy(logits, labels, weights)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
):
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if (
|
||||
self.args.world_size > 1
|
||||
and self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
):
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
pct_start=pct_start,
|
||||
div_factor=6,
|
||||
)
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
def add_position_ids(sample):
|
||||
sample_len = len(sample["input_ids"])
|
||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||
@@ -113,38 +397,23 @@ def disable_datasets_caching():
|
||||
set_caching_enabled(True)
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset = train_dataset.filter(drop_long, num_proc=cfg.dataset_processes)
|
||||
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.filter(
|
||||
drop_long, num_proc=cfg.dataset_processes
|
||||
)
|
||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||
|
||||
if cfg.group_by_length:
|
||||
train_dataset = train_dataset.map(
|
||||
add_length, num_proc=cfg.dataset_processes
|
||||
)
|
||||
train_dataset = train_dataset.map(add_length, num_proc=os.cpu_count())
|
||||
|
||||
if cfg.sample_packing:
|
||||
train_dataset = train_dataset.map(
|
||||
add_position_ids, num_proc=cfg.dataset_processes
|
||||
)
|
||||
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
||||
if cfg.eval_sample_packing is not False:
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.map(
|
||||
add_position_ids, num_proc=cfg.dataset_processes
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
|
||||
# Phi doesn't want the attention_mask feature when training
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||
):
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
@@ -216,7 +485,6 @@ def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
num_epochs=cfg.num_epochs,
|
||||
)
|
||||
data_loader_len = data_loader.len_w_stats()
|
||||
actual_eff = data_loader.efficiency()
|
||||
@@ -266,8 +534,251 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
||||
elif cfg.deepspeed:
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
trainer_builder.eval_dataset = eval_dataset
|
||||
warmup_steps = (
|
||||
cfg.warmup_steps
|
||||
if cfg.warmup_steps is not None
|
||||
else min(int(0.03 * total_num_steps), 100)
|
||||
)
|
||||
logging_steps = (
|
||||
cfg.logging_steps
|
||||
if cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
return trainer_builder.build(total_num_steps)
|
||||
training_arguments_kwargs = {}
|
||||
if cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
|
||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if cfg.seed:
|
||||
training_arguments_kwargs["seed"] = cfg.seed
|
||||
|
||||
if cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
||||
if cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
||||
if cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
||||
|
||||
# deepspeed
|
||||
if cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
|
||||
|
||||
if cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
|
||||
|
||||
if cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
||||
if cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
||||
if cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
||||
if cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
||||
|
||||
if cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
|
||||
if cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
|
||||
|
||||
if cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = cfg.sample_packing_eff_est
|
||||
|
||||
if cfg.eval_steps and cfg.evaluation_strategy:
|
||||
# assume if the user set both, they know what they're doing
|
||||
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
||||
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
||||
elif cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
elif cfg.evaluation_strategy and cfg.evaluation_strategy in ["epoch", "no"]:
|
||||
# if explicitly set for epoch, just set, and eval steps don't matter
|
||||
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
||||
elif cfg.eval_steps:
|
||||
# steps isn't used w/ epochs
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if cfg.save_steps:
|
||||
# save_steps implies save_strategy of steps
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
training_arguments_kwargs["save_steps"] = cfg.save_steps
|
||||
elif cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
||||
if cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||
if cfg.metric_for_best_model:
|
||||
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
|
||||
if cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
||||
|
||||
if cfg.torch_compile:
|
||||
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||
else:
|
||||
import torch._dynamo # pylint: disable=redefined-outer-name
|
||||
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
|
||||
if cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = cfg.torch_compile_backend
|
||||
|
||||
# DDP Config
|
||||
if cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
||||
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||
if cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
|
||||
if cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||
max_seq_length=cfg.sequence_len,
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
if cfg.eval_batch_size is not None
|
||||
else cfg.micro_batch_size,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||
load_best_model_at_end=(
|
||||
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
||||
and cfg.val_set_size > 0
|
||||
and cfg.save_steps
|
||||
and cfg.eval_steps
|
||||
and cfg.save_steps % cfg.eval_steps == 0
|
||||
)
|
||||
or False,
|
||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||
group_by_length=cfg.group_by_length,
|
||||
report_to="wandb" if cfg.use_wandb else None,
|
||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
|
||||
lr_scheduler_type=cfg.lr_scheduler
|
||||
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine",
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
eval_sample_packing=cfg.eval_sample_packing,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
relora_steps=cfg.relora_steps,
|
||||
relora_warmup_steps=cfg.relora_warmup_steps,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
if cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(cfg.torchdistx_path).exists():
|
||||
sys.path.append(cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
callbacks = []
|
||||
callbacks.append(GPUStatsCallback(cfg))
|
||||
callbacks.append(EvalFirstStepCallback)
|
||||
|
||||
if cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(cfg))
|
||||
|
||||
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
if cfg.pad_to_sequence_len:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||
cfg.sequence_len / 64
|
||||
)
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
|
||||
LOG.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [train_dataset, eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = AxolotlTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
|
||||
trainer_cls = OneCycleLRSchedulerTrainer
|
||||
elif cfg.relora_steps:
|
||||
trainer_cls = ReLoRATrainer
|
||||
trainer = trainer_cls(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=callbacks,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
if cfg.use_wandb and cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
||||
trainer.add_callback(LogPredictionCallback(cfg))
|
||||
|
||||
if cfg.do_bench_eval:
|
||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
cfg.early_stopping_patience,
|
||||
)
|
||||
trainer.add_callback(early_stop_cb)
|
||||
|
||||
return trainer
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestFusedLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using Fused layers
|
||||
"""
|
||||
|
||||
def test_fft_packing(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": True,
|
||||
"flash_attn_fuse_qkv": True,
|
||||
"flash_attn_fuse_mlp": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
@@ -29,6 +29,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
@@ -71,6 +72,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
@@ -115,6 +117,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_ft(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
@@ -1,116 +0,0 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestMistral(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
def test_lora_packing(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
def test_ft_packing(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
@@ -27,6 +27,7 @@ class TestPhi(unittest.TestCase):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model_config": "microsoft/phi-1_5",
|
||||
"trust_remote_code": True,
|
||||
"model_type": "MixFormerSequentialForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
@@ -70,6 +71,7 @@ class TestPhi(unittest.TestCase):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model_config": "microsoft/phi-1_5",
|
||||
"trust_remote_code": True,
|
||||
"model_type": "MixFormerSequentialForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
|
||||
2
tests/fixtures/conversation.tokenized.json
vendored
2
tests/fixtures/conversation.tokenized.json
vendored
File diff suppressed because one or more lines are too long
@@ -1,46 +0,0 @@
|
||||
"""
|
||||
Test classes for checking functionality of the cfg normalization
|
||||
"""
|
||||
import unittest
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"""
|
||||
test class for normalize_config checks
|
||||
"""
|
||||
|
||||
def _get_base_cfg(self):
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
def test_lr_as_float(self):
|
||||
cfg = (
|
||||
self._get_base_cfg()
|
||||
| DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"learning_rate": "5e-5",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
assert cfg.learning_rate == 0.00005
|
||||
|
||||
def test_base_model_config_set_when_empty(self):
|
||||
cfg = self._get_base_cfg()
|
||||
del cfg.base_model_config
|
||||
normalize_config(cfg)
|
||||
|
||||
assert cfg.base_model_config == cfg.base_model
|
||||
@@ -21,7 +21,7 @@ from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
ShareGPTPromptTokenizingStrategy,
|
||||
)
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompter
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
@@ -60,7 +60,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
tokenized_conversation = json.loads(data)
|
||||
prompter = ShareGPTPrompterV2()
|
||||
prompter = ShareGPTPrompter("chat")
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
@@ -79,7 +79,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
) as fin:
|
||||
data = fin.read()
|
||||
conversation = json.loads(data)
|
||||
prompter = ShareGPTPrompterV2()
|
||||
prompter = ShareGPTPrompter("chat")
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
@@ -90,73 +90,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert "assistant turn has empty text" in self._caplog.records[1].message
|
||||
|
||||
def test_sharegpt_warnings_turns(self):
|
||||
conversation = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert (
|
||||
"Role did not alternate between turns (gpt and human)"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_sharegpt_changes_roles(self):
|
||||
conversation = {
|
||||
"roles": ["USER", "CHARACTER"],
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
],
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
|
||||
|
||||
def test_sharegpt_assistant_label_ignore(self):
|
||||
conversation = {
|
||||
"roles": ["user", "assistant"],
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
],
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
idx = res["input_ids"].index(20255) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
|
||||
@@ -374,278 +374,3 @@ class ValidationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_sharegpt_deprecation(self):
|
||||
cfg = DictDefault(
|
||||
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt:chat"}]}
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"`type: sharegpt:chat` will soon be deprecated." in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
assert cfg.datasets[0].type == "sharegpt"
|
||||
|
||||
cfg = DictDefault(
|
||||
{"datasets": [{"path": "lorem/ipsum", "type": "sharegpt_simple:load_role"}]}
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"`type: sharegpt_simple` will soon be deprecated." in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
assert cfg.datasets[0].type == "sharegpt:load_role"
|
||||
|
||||
def test_no_conflict_save_strategy(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "epoch",
|
||||
"save_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "no",
|
||||
"save_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*save_strategy and save_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "steps",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "steps",
|
||||
"save_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"save_strategy": "no",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_no_conflict_eval_strategy(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*evaluation_strategy and eval_steps mismatch.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "steps",
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "no",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"eval_steps": 10,
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*eval_steps and evaluation_strategy are not supported with val_set_size == 0.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"val_set_size": 0,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"eval_steps": 10,
|
||||
"val_set_size": 0.01,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"evaluation_strategy": "epoch",
|
||||
"val_set_size": 0.01,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_eval_table_size_conflict_eval_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"eval_table_size": 100,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*Please set 'eval_sample_packing' to false.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": False,
|
||||
"eval_table_size": 100,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"eval_table_size": 100,
|
||||
"eval_sample_packing": False,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_load_in_x_bit_without_adapter(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
Reference in New Issue
Block a user