Compare commits
1 Commits
unsloth_mo
...
sharegpt-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4d84d56d5 |
5
.github/workflows/base.yml
vendored
5
.github/workflows/base.yml
vendored
@@ -25,11 +25,6 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
- cuda: "118"
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v3
|
||||
|
||||
16
.github/workflows/main.yml
vendored
16
.github/workflows/main.yml
vendored
@@ -23,12 +23,6 @@ jobs:
|
||||
python_version: "3.10"
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
@@ -52,12 +46,9 @@ jobs:
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
tags: ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
@@ -77,11 +68,6 @@ jobs:
|
||||
pytorch: 2.0.1
|
||||
axolotl_extras:
|
||||
is_latest: true
|
||||
- cuda: 118
|
||||
cuda_version: 11.8.0
|
||||
python_version: "3.10"
|
||||
pytorch: 2.1.0
|
||||
axolotl_extras:
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
|
||||
1
.github/workflows/tests.yml
vendored
1
.github/workflows/tests.yml
vendored
@@ -71,7 +71,6 @@ jobs:
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||
pip3 uninstall -y transformers accelerate
|
||||
pip3 install -U -e .[flash-attn]
|
||||
pip3 install -r requirements-tests.txt
|
||||
|
||||
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
too-many-nested-blocks,
|
||||
|
||||
464
README.md
464
README.md
@@ -23,17 +23,15 @@ Features:
|
||||
- [Supported Features](#axolotl-supports)
|
||||
- [Quickstart](#quickstart-)
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Runpod](#runpod)
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Docker Installation](#environment)
|
||||
- [Conda/Pip venv Installation](#condapip-venv)
|
||||
- [LambdaLabs Installation](#lambdalabs)
|
||||
- [Dataset](#dataset)
|
||||
- [How to Add Custom Prompts](#how-to-add-custom-prompts)
|
||||
- [How to Use Custom Pretokenized Dataset](#how-to-use-your-custom-pretokenized-dataset)
|
||||
- [Config](#config)
|
||||
- [Train](#train)
|
||||
- [Training w/ Deepspeed](#training-with-deepspeed)
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Common Errors](#common-errors-)
|
||||
@@ -52,7 +50,7 @@ Features:
|
||||
<b>Axolotl provides a unified repository for fine-tuning <br />a variety of AI models with ease</b>
|
||||
</p>
|
||||
<p>
|
||||
Go ahead and Axolotl questions!!
|
||||
Go ahead and axolotl questions!!
|
||||
</p>
|
||||
<img src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/pre-commit.yml/badge.svg?branch=main" alt="pre-commit">
|
||||
<img alt="PyTest Status" src="https://github.com/OpenAccess-AI-Collective/axolotl/actions/workflows/tests.yml/badge.svg?branch=main">
|
||||
@@ -76,8 +74,6 @@ Features:
|
||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||
|
||||
|
||||
## Quickstart ⚡
|
||||
@@ -86,39 +82,31 @@ Get started with Axolotl in just a few steps! This quickstart guide will walk yo
|
||||
|
||||
**Requirements**: Python >=3.9 and Pytorch >=2.0.
|
||||
|
||||
`pip3 install "axolotl[flash-attn,deepspeed] @ git+https://github.com/OpenAccess-AI-Collective/axolotl"`
|
||||
|
||||
### For developers
|
||||
```bash
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||
cd axolotl
|
||||
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
```
|
||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
||||
|
||||
### Usage
|
||||
```bash
|
||||
# finetune lora
|
||||
accelerate launch -m axolotl.cli.train examples/openllama-3b/lora.yml
|
||||
|
||||
# inference
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out"
|
||||
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out" --gradio
|
||||
```
|
||||
|
||||
## Installation
|
||||
|
||||
### Environment
|
||||
|
||||
#### Docker
|
||||
- Docker
|
||||
```bash
|
||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
Or run on the current files for development:
|
||||
|
||||
@@ -126,33 +114,12 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
<details>
|
||||
|
||||
<summary>Docker advanced</summary>
|
||||
|
||||
A more powerful Docker command to run would be this:
|
||||
|
||||
```bash
|
||||
docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 --mount type=volume,src=axolotl,target=/workspace/axolotl -v ${HOME}/.cache/huggingface:/root/.cache/huggingface winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||
```
|
||||
|
||||
It additionally:
|
||||
* Prevents memory issues when running e.g. deepspeed (e.g. you could hit SIGBUS/signal 7 error) through `--ipc` and `--ulimit` args.
|
||||
* Persists the downloaded HF data (models etc.) and your modifications to axolotl code through `--mount`/`-v` args.
|
||||
* The `--name` argument simply makes it easier to refer to the container in vscode (`Dev Containers: Attach to Running Container...`) or in your terminal.
|
||||
* The `--privileged` flag gives all capabilities to the container.
|
||||
* The `--shm-size 10g` argument increases the shared memory size. Use this if you see `exitcode: -7` errors using deepspeed.
|
||||
|
||||
[More information on nvidia website](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html#setincshmem)
|
||||
|
||||
</details>
|
||||
|
||||
#### Conda/Pip venv
|
||||
- Conda/Pip venv
|
||||
1. Install python >=**3.9**
|
||||
|
||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||
|
||||
3. Install Axolotl along with python dependencies
|
||||
3. Install axolotl along with python dependencies
|
||||
```bash
|
||||
pip3 install packaging
|
||||
pip3 install -e '.[flash-attn,deepspeed]'
|
||||
@@ -163,11 +130,7 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
```
|
||||
Get the token at huggingface.co/settings/tokens
|
||||
|
||||
#### Runpod
|
||||
|
||||
Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### LambdaLabs
|
||||
- LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
@@ -211,30 +174,7 @@ Use `winglian/axolotl-runpod:main-latest` or use this [direct link](https://runp
|
||||
```
|
||||
</details>
|
||||
|
||||
#### Windows
|
||||
Please use WSL or Docker!
|
||||
|
||||
|
||||
#### Launching on public clouds via SkyPilot
|
||||
To launch on GPU instances (both on-demand and spot instances) on 7+ clouds (GCP, AWS, Azure, OCI, and more), you can use [SkyPilot](https://skypilot.readthedocs.io/en/latest/index.html):
|
||||
```bash
|
||||
pip install "skypilot-nightly[gcp,aws,azure,oci,lambda,kubernetes,ibm,scp]" # choose your clouds
|
||||
sky check
|
||||
```
|
||||
Get the [example YAMLs](https://github.com/skypilot-org/skypilot/tree/master/llm/axolotl) of using Axolotl to finetune `mistralai/Mistral-7B-v0.1`:
|
||||
```
|
||||
git clone https://github.com/skypilot-org/skypilot.git
|
||||
cd skypilot/llm/axolotl
|
||||
```
|
||||
Use one command to launch:
|
||||
```bash
|
||||
# On-demand
|
||||
HF_TOKEN=xx sky launch axolotl.yaml --env HF_TOKEN
|
||||
|
||||
# Managed spot (auto-recovery on preemption)
|
||||
HF_TOKEN=xx BUCKET=<unique-name> sky spot launch axolotl-spot.yaml --env HF_TOKEN --env BUCKET
|
||||
```
|
||||
|
||||
- Windows: Please use WSL or Docker!
|
||||
|
||||
### Dataset
|
||||
|
||||
@@ -355,24 +295,25 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
||||
|
||||
#### How to add custom prompts
|
||||
|
||||
For a dataset that is preprocessed for instruction purposes:
|
||||
|
||||
```json
|
||||
{"instruction": "...", "output": "..."}
|
||||
```
|
||||
|
||||
You can use this example in your YAML config:
|
||||
|
||||
Using yaml. Example:
|
||||
```yaml
|
||||
datasets:
|
||||
- path: repo
|
||||
type:
|
||||
system_prompt: ""
|
||||
field_system: system
|
||||
format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: "[INST] {instruction} [/INST]"
|
||||
no_input_format: |-
|
||||
User: {instruction}<|end_of_turn|>
|
||||
Assistant:
|
||||
format: |-
|
||||
User: {instruction}
|
||||
{input}<|end_of_turn|>
|
||||
Assistant:
|
||||
```
|
||||
|
||||
Using file:
|
||||
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
|
||||
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`.
|
||||
|
||||
#### How to use your custom pretokenized dataset
|
||||
|
||||
- Do not pass a `type:`
|
||||
@@ -414,13 +355,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- typescript
|
||||
type: ... # unimplemented custom format
|
||||
|
||||
# fastchat conversation
|
||||
# See 'conversation' options: https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
datasets:
|
||||
- path: ...
|
||||
type: sharegpt
|
||||
conversation: chatml
|
||||
|
||||
# local
|
||||
datasets:
|
||||
- path: data.jsonl # or json
|
||||
@@ -432,12 +366,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
- path: knowrohit07/know_sql
|
||||
type: context_qa.load_v2
|
||||
train_on_split: validation
|
||||
|
||||
# loading from s3 or gcs
|
||||
# s3 creds will be loaded from the system default and gcs only supports public access
|
||||
dataset:
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -465,18 +393,18 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
|
||||
<details>
|
||||
|
||||
<summary>All yaml options (click me)</summary>
|
||||
<summary>All yaml options</summary>
|
||||
|
||||
```yaml
|
||||
# This is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# This can also be a relative path to a model on disk
|
||||
# this is the huggingface model that contains *.pt, *.safetensors, or *.bin files
|
||||
# this can also be a relative path to a model on disk
|
||||
base_model: ./llama-7b-hf
|
||||
# You can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
# you can specify an ignore pattern if the model repo contains more than 1 model type (*.pt, etc)
|
||||
base_model_ignore_patterns:
|
||||
# If the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
# if the base_model repo on hf hub doesn't include configuration .json files,
|
||||
# you can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
# you can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
# Optional tokenizer configuration override in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
@@ -491,33 +419,23 @@ trust_remote_code:
|
||||
tokenizer_use_fast:
|
||||
# Whether to use the legacy tokenizer setting, defaults to True
|
||||
tokenizer_legacy:
|
||||
# Resize the model embeddings when new tokens are added to multiples of 32
|
||||
# This is reported to improve training speed on some models
|
||||
# resize the model embeddings when new tokens are added to multiples of 32
|
||||
# this is reported to improve training speed on some models
|
||||
resize_token_embeddings_to_32x:
|
||||
|
||||
# Used to identify which the model is based on
|
||||
# used to identify which the model is based on
|
||||
is_falcon_derived_model:
|
||||
is_llama_derived_model:
|
||||
# Please note that if you set this to true, `padding_side` will be set to "left" by default
|
||||
is_mistral_derived_model:
|
||||
is_qwen_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
|
||||
# Whether you are training a 4-bit GPTQ quantized model
|
||||
# whether you are training a 4-bit GPTQ quantized model
|
||||
gptq: true
|
||||
gptq_groupsize: 128 # group size
|
||||
gptq_model_v1: false # v1 or v2
|
||||
|
||||
# This will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
# this will attempt to quantize the model down to 8 bits and use adam 8 bit optimizer
|
||||
load_in_8bit: true
|
||||
# Use bitsandbytes 4 bit
|
||||
# use bitsandbytes 4 bit
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
@@ -531,9 +449,9 @@ tf32: true # require >=ampere
|
||||
bfloat16: true # require >=ampere
|
||||
float16: true
|
||||
|
||||
# A list of one or more datasets to finetune the model with
|
||||
# a list of one or more datasets to finetune the model with
|
||||
datasets:
|
||||
# HuggingFace dataset repo | s3://,gs:// path | "json" for local dataset, make sure to fill data_files
|
||||
# hf dataset repo | "json" for local dataset, make sure to fill data_files
|
||||
- path: vicgalle/alpaca-gpt4
|
||||
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
|
||||
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn>
|
||||
@@ -541,25 +459,19 @@ datasets:
|
||||
data_files: # Optional[str] path to source data files
|
||||
shards: # Optional[int] number of shards to split data into
|
||||
name: # Optional[str] name of dataset configuration to load
|
||||
train_on_split: train # Optional[str] name of dataset split to load from
|
||||
conversation: # Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
|
||||
# Optional[str] fastchat conversation type, only used with type: sharegpt
|
||||
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
field_human: # Optional[str]. Human key to use for conversation.
|
||||
field_model: # Optional[str]. Assistant key to use for conversation.
|
||||
|
||||
# Custom user prompt
|
||||
# custom user prompt
|
||||
- path: repo
|
||||
type:
|
||||
# The below are defaults. only set what's needed.
|
||||
# the below are defaults. only set what's needed.
|
||||
system_prompt: ""
|
||||
system_format: "{system}"
|
||||
field_system: system
|
||||
field_instruction: instruction
|
||||
field_input: input
|
||||
field_output: output
|
||||
field_output: input
|
||||
|
||||
# Customizable to be single line or multi-line
|
||||
# customizable to be single line or multi-line
|
||||
system_format: "{system}"
|
||||
# 'format' can include {input}
|
||||
format: |-
|
||||
User: {instruction} {input}
|
||||
@@ -567,13 +479,13 @@ datasets:
|
||||
# 'no_input_format' cannot include {input}
|
||||
no_input_format: "{instruction} "
|
||||
|
||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||
# for completions datsets, uses the provided field if not `text`
|
||||
field:
|
||||
|
||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||
# subsequent training attempts load faster, relative path
|
||||
dataset_prepared_path: data/last_run_prepared
|
||||
# Push prepared dataset to hub
|
||||
# push prepared dataset to hub
|
||||
push_dataset_to_hub: # repo path
|
||||
# The maximum number of processes to use while preprocessing your input dataset. This defaults to `os.cpu_count()`
|
||||
# if not set.
|
||||
@@ -583,8 +495,8 @@ hub_model_id: # repo path to push finetuned model
|
||||
# how to push checkpoints to hub
|
||||
# https://huggingface.co/docs/transformers/v4.31.0/en/main_classes/trainer#transformers.TrainingArguments.hub_strategy
|
||||
hub_strategy:
|
||||
# Whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# Required to be true when used in combination with `push_dataset_to_hub`
|
||||
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
|
||||
# required to be true when used in combination with `push_dataset_to_hub`
|
||||
hf_use_auth_token: # boolean
|
||||
# How much of the dataset to set aside as evaluation. 1 = 100%, 0.50 = 50%, etc. 0 for no eval.
|
||||
val_set_size: 0.04
|
||||
@@ -593,40 +505,30 @@ dataset_shard_num:
|
||||
# Index of shard to use for whole dataset
|
||||
dataset_shard_idx:
|
||||
|
||||
# The maximum length of an input to train with, this should typically be less than 2048
|
||||
# the maximum length of an input to train with, this should typically be less than 2048
|
||||
# as most models have a token/context limit of 2048
|
||||
sequence_len: 2048
|
||||
# Pad inputs so each step uses constant sized buffers
|
||||
# This will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
# pad inputs so each step uses constant sized buffers
|
||||
# this will reduce memory fragmentation and may prevent OOMs, by re-using memory more efficiently
|
||||
pad_to_sequence_len:
|
||||
# Max sequence length to concatenate training samples together up to
|
||||
# Inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# max sequence length to concatenate training samples together up to
|
||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||
# FutureWarning: This will soon be DEPRECATED
|
||||
max_packed_sequence_len: 1024
|
||||
# Use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
# use efficient multi-packing with block diagonal attention and per sequence position_ids. Recommend set to 'true'
|
||||
sample_packing:
|
||||
# Set to 'false' if getting errors during eval with sample_packing on.
|
||||
# set to 'false' if getting errors during eval with sample_packing on.
|
||||
eval_sample_packing:
|
||||
# You can set these packing optimizations AFTER starting a training at least once.
|
||||
# you can set these packing optimizations AFTER starting a training at least once.
|
||||
# The trainer will provide recommended values for these values.
|
||||
sample_packing_eff_est:
|
||||
total_num_tokens:
|
||||
|
||||
# Passed through to transformers when loading the model when launched without accelerate
|
||||
# Use `sequential` when training w/ model parallelism to limit memory
|
||||
device_map:
|
||||
# Defines the max memory usage per gpu on the system. Passed through to transformers when loading the model.
|
||||
max_memory:
|
||||
|
||||
# If you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||
adapter: lora
|
||||
# If you already have a lora model trained that you want to load, put that here.
|
||||
# This means after training, if you want to test the model, you should set this to the value of `lora_out_dir`.
|
||||
# if you already have a lora model trained that you want to load, put that here
|
||||
# lora hyperparameters
|
||||
lora_model_dir:
|
||||
|
||||
# LoRA hyperparameters
|
||||
# For more details about the following options, see:
|
||||
# https://www.anyscale.com/blog/fine-tuning-llms-lora-or-full-parameter-an-in-depth-analysis-with-llama-2
|
||||
lora_r: 8
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
@@ -638,101 +540,81 @@ lora_target_modules:
|
||||
# - gate_proj
|
||||
# - down_proj
|
||||
# - up_proj
|
||||
lora_target_linear: # If true, will target all linear layers
|
||||
|
||||
# If you added new tokens to the tokenizer, you may need to save some LoRA modules because they need to know the new tokens.
|
||||
# For LLaMA and Mistral, you need to save `embed_tokens` and `lm_head`. It may vary for other models.
|
||||
# `embed_tokens` converts tokens to embeddings, and `lm_head` converts embeddings to token probabilities.
|
||||
# https://github.com/huggingface/peft/issues/334#issuecomment-1561727994
|
||||
lora_target_linear: # if true, will target all linear layers
|
||||
lora_modules_to_save:
|
||||
# - embed_tokens
|
||||
# - lm_head
|
||||
|
||||
# Once you complete training, the model will be saved to the following directory.
|
||||
# If you merge the adapter to the base model, a subdirectory `merged` will be created under this directory.
|
||||
# Make sure `lora_model_dir` points to this directory if you want to use the trained model.
|
||||
lora_out_dir:
|
||||
lora_fan_in_fan_out: false
|
||||
|
||||
# ReLoRA configuration
|
||||
# Must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # Number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # Number of per-restart warmup steps
|
||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
# must use either 'lora' or 'qlora' adapter, and does not support fsdp or deepspeed
|
||||
relora_steps: # number of steps per ReLoRA restart
|
||||
relora_warmup_steps: # number of per-restart warmup steps
|
||||
relora_cpu_offload: # true to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||
|
||||
# wandb configuration if you're using it
|
||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||
wandb_project: # Your wandb project name
|
||||
wandb_entity: # A wandb Team name if using a Team
|
||||
wandb_project: # your wandb project name
|
||||
wandb_entity: # a wandb Team name if using a Team
|
||||
wandb_watch:
|
||||
wandb_name: # Set the name of your wandb run
|
||||
wandb_run_id: # Set the ID of your wandb run
|
||||
wandb_run_id: # set the name of your wandb run
|
||||
wandb_log_model: # "checkpoint" to log model to wandb Artifacts every `save_steps` or "end" to log only at the end of training
|
||||
|
||||
# Where to save the full-finetuned model to
|
||||
# where to save the finished model to
|
||||
output_dir: ./completed-model
|
||||
|
||||
# Whether to use torch.compile and which backend to use
|
||||
# whether to use torch.compile and which backend to use
|
||||
torch_compile: # bool
|
||||
torch_compile_backend: # Optional[str]
|
||||
|
||||
# Training hyperparameters
|
||||
|
||||
# If greater than 1, backpropagation will be skipped and the gradients will be accumulated for the given number of steps.
|
||||
# training hyperparameters
|
||||
gradient_accumulation_steps: 1
|
||||
# The number of samples to include in each batch. This is the number of samples sent to each GPU.
|
||||
micro_batch_size: 2
|
||||
eval_batch_size:
|
||||
num_epochs: 4
|
||||
warmup_steps: 100 # cannot use with warmup_ratio
|
||||
warmup_ratio: 0.05 # cannot use with warmup_steps
|
||||
num_epochs: 3
|
||||
warmup_steps: 100
|
||||
learning_rate: 0.00003
|
||||
lr_quadratic_warmup:
|
||||
logging_steps:
|
||||
save_strategy: # Set to `no` to skip checkpoint saves
|
||||
save_steps: # Leave empty to save at each epoch
|
||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||
save_total_limit: # Checkpoints saved at a time
|
||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||
# if both are set, num_epochs will not be guaranteed.
|
||||
# e.g., when 1 epoch is 1000 steps => `num_epochs: 2` and `max_steps: 100` will train for 100 steps
|
||||
save_strategy: # set to `no` to skip checkpoint saves
|
||||
save_steps: # leave empty to save at each epoch
|
||||
eval_steps: # leave empty to eval at each epoch
|
||||
save_total_limit: # checkpoints saved at a time
|
||||
max_steps:
|
||||
|
||||
eval_table_size: # Approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # Total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
eval_table_size: # approximate number of predictions sent to wandb depending on batch size. Enabled above 0. Default is 0
|
||||
eval_table_max_new_tokens: # total number of tokens generated for predictions sent to wandb. Default is 128
|
||||
|
||||
loss_watchdog_threshold: # High loss value, indicating the learning has broken down (a good estimate is ~2 times the loss at the start of training)
|
||||
loss_watchdog_patience: # Number of high-loss steps in a row before the trainer aborts (default: 3)
|
||||
|
||||
# Save model as safetensors (require safetensors package)
|
||||
# save model as safetensors (require safetensors package)
|
||||
save_safetensors:
|
||||
|
||||
# Whether to mask out or include the human's prompt from the training labels
|
||||
# whether to mask out or include the human's prompt from the training labels
|
||||
train_on_inputs: false
|
||||
# Group similarly sized data to minimize padding.
|
||||
# May be slower to start, as it must download and sort the entire dataset.
|
||||
# Note that training loss may have an oscillating pattern with this enabled.
|
||||
# group similarly sized data to minimize padding
|
||||
# may be slower to start, as it must download and sort the entire dataset
|
||||
# note that training loss may have an oscillating pattern with this enabled
|
||||
group_by_length: false
|
||||
|
||||
# Whether to use gradient checkpointing https://huggingface.co/docs/transformers/v4.18.0/en/performance#gradient-checkpointing
|
||||
gradient_checkpointing: false
|
||||
|
||||
# Stop training after this many evaluation losses have increased in a row
|
||||
# stop training after this many evaluation losses have increased in a row
|
||||
# https://huggingface.co/transformers/v4.2.2/_modules/transformers/trainer_callback.html#EarlyStoppingCallback
|
||||
early_stopping_patience: 3
|
||||
|
||||
# Specify a scheduler and kwargs to use with the optimizer
|
||||
# specify a scheduler and kwargs to use with the optimizer
|
||||
lr_scheduler: # 'one_cycle' | 'log_sweep' | empty for cosine
|
||||
lr_scheduler_kwargs:
|
||||
|
||||
# For one_cycle optim
|
||||
lr_div_factor: # Learning rate div factor
|
||||
# for one_cycle optim
|
||||
lr_div_factor: # learning rate div factor
|
||||
|
||||
# For log_sweep optim
|
||||
# for log_sweep optim
|
||||
log_sweep_min_lr:
|
||||
log_sweep_max_lr:
|
||||
|
||||
# Specify optimizer
|
||||
# specify optimizer
|
||||
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||
#
|
||||
@@ -758,7 +640,7 @@ log_sweep_max_lr:
|
||||
# - paged_lion_32bit
|
||||
# - paged_lion_8bit
|
||||
optimizer:
|
||||
# Specify weight decay
|
||||
# specify weight decay
|
||||
weight_decay:
|
||||
# adamw hyperparams
|
||||
adam_beta1:
|
||||
@@ -767,54 +649,49 @@ adam_epsilon:
|
||||
# Gradient clipping max norm
|
||||
max_grad_norm:
|
||||
|
||||
# Augmentation techniques
|
||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||
# currently only supported on Llama and Mistral
|
||||
noisy_embedding_alpha:
|
||||
|
||||
# Whether to bettertransformers
|
||||
# whether to bettertransformers
|
||||
flash_optimum:
|
||||
# Whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
# whether to use xformers attention patch https://github.com/facebookresearch/xformers:
|
||||
xformers_attention:
|
||||
# Whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
# whether to use flash attention patch https://github.com/Dao-AILab/flash-attention:
|
||||
flash_attention:
|
||||
flash_attn_cross_entropy: # Whether to use flash-attention cross entropy implementation - advanced use only
|
||||
flash_attn_rms_norm: # Whether to use flash-attention rms norm implementation - advanced use only
|
||||
flash_attn_fuse_qkv: # Whether to fuse QKV into a single operation
|
||||
flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
||||
# Whether to use scaled-dot-product attention
|
||||
# whether to use scaled-dot-product attention
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||
sdp_attention:
|
||||
# Landmark attention (only llama)
|
||||
landmark_attention:
|
||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
||||
# LLaMA only
|
||||
# llama only
|
||||
xpos_rope:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
factor: # float
|
||||
|
||||
# Resume from a specific checkpoint dir
|
||||
# resume from a specific checkpoint dir
|
||||
resume_from_checkpoint:
|
||||
# If resume_from_checkpoint isn't set and you simply want it to start where it left off.
|
||||
# Be careful with this being turned on between different models.
|
||||
# if resume_from_checkpoint isn't set and you simply want it to start where it left off
|
||||
# be careful with this being turned on between different models
|
||||
auto_resume_from_checkpoints: false
|
||||
|
||||
# Don't mess with this, it's here for accelerate and torchrun
|
||||
# don't mess with this, it's here for accelerate and torchrun
|
||||
local_rank:
|
||||
|
||||
# Add or change special tokens.
|
||||
# If you add tokens here, you don't need to add them to the `tokens` list.
|
||||
# add or change special tokens
|
||||
special_tokens:
|
||||
# bos_token: "<s>"
|
||||
# eos_token: "</s>"
|
||||
# unk_token: "<unk>"
|
||||
|
||||
# Add extra tokens.
|
||||
# add extra tokens
|
||||
tokens:
|
||||
|
||||
# FSDP
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
|
||||
# Deepspeed config path. e.g., deepspeed/zero3.json
|
||||
# Deepspeed config path
|
||||
deepspeed:
|
||||
|
||||
# Advanced DDP Arguments
|
||||
@@ -840,66 +717,6 @@ strict:
|
||||
|
||||
</details>
|
||||
|
||||
<details>
|
||||
<summary> Understanding of batch size and gradient accumulation steps </summary>
|
||||
<br/>
|
||||
Gradient accumulation means accumulating gradients over several mini-batches and updating the model weights afterward. When the samples in each batch are diverse, this technique doesn't significantly impact learning.
|
||||
|
||||
This method allows for effective training with larger effective batch sizes without needing proportionally larger memory. Here's why:
|
||||
|
||||
1. **Memory Consumption with Batch Size**: The primary reason increasing the batch size impacts memory is due to the storage requirements for intermediate activations. When you forward propagate a batch through a network, you have to store the activations at each layer for each sample in the batch, because these activations are used during backpropagation to compute gradients. Therefore, larger batches mean more activations, leading to greater GPU memory consumption.
|
||||
|
||||
2. **Gradient Accumulation**: With gradient accumulation, you're effectively simulating a larger batch size by accumulating gradients over several smaller batches (or micro-batches). However, at any given time, you're only forward and backward propagating a micro-batch. This means you only store activations for the micro-batch, not the full accumulated batch. As a result, you can simulate the effect of a larger batch size without the memory cost of storing activations for a large batch.
|
||||
|
||||
**Example 1:**
|
||||
Micro batch size: 3
|
||||
Gradient accumulation steps: 2
|
||||
Number of GPUs: 3
|
||||
Total batch size = 3 * 2 * 3 = 18
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|----------------|----------------|----------------|
|
||||
| S1, S2, S3 | S4, S5, S6 | S7, S8, S9 |
|
||||
| e1, e2, e3 | e4, e5, e6 | e7, e8, e9 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (accumulate) | → (accumulate) | → (accumulate) |
|
||||
|----------------|----------------|----------------|
|
||||
| S10, S11, S12 | S13, S14, S15 | S16, S17, S18 |
|
||||
| e10, e11, e12 | e13, e14, e15 | e16, e17, e18 |
|
||||
|----------------|----------------|----------------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 after the second iteration (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6 + e7 + e8 + e9 + e10 + e11 + e12 + e13 + e14 + e15 + e16 + e17 + e18
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate x (Total gradient for w1 / 18)
|
||||
```
|
||||
|
||||
**Example 2:**
|
||||
Micro batch size: 2
|
||||
Gradient accumulation steps: 1
|
||||
Number of GPUs: 3
|
||||
Total batch size = 2 * 1 * 3 = 6
|
||||
|
||||
```
|
||||
| GPU 1 | GPU 2 | GPU 3 |
|
||||
|-----------|-----------|-----------|
|
||||
| S1, S2 | S3, S4 | S5, S6 |
|
||||
| e1, e2 | e3, e4 | e5, e6 |
|
||||
|-----------|-----------|-----------|
|
||||
| → (apply) | → (apply) | → (apply) |
|
||||
|
||||
Accumulated gradient for the weight w1 (considering all GPUs):
|
||||
Total gradient for w1 = e1 + e2 + e3 + e4 + e5 + e6
|
||||
|
||||
Weight update for w1:
|
||||
w1_new = w1_old - learning rate × (Total gradient for w1 / 6)
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
### Train
|
||||
|
||||
Run
|
||||
@@ -907,41 +724,14 @@ Run
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
This is recommended for large datasets.
|
||||
|
||||
- Set `push_dataset_to_hub: hf_user/repo` to push it to Huggingface.
|
||||
- Use `--debug` to see preprocessed examples.
|
||||
|
||||
```bash
|
||||
python -m axolotl.cli.preprocess your_config.yml
|
||||
```
|
||||
|
||||
#### Multi-GPU
|
||||
|
||||
Below are the options available in axolotl for training with multiple GPUs. Note that DeepSpeed
|
||||
is the recommended multi-GPU option currently because FSDP may experience
|
||||
[loss instability](https://github.com/huggingface/transformers/issues/26498).
|
||||
|
||||
##### DeepSpeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
You can optionally pre-tokenize dataset with the following before finetuning:
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES="" accelerate launch -m axolotl.cli.train your_config.yml --prepare_ds_only
|
||||
```
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
##### FSDP
|
||||
##### Config
|
||||
|
||||
- llama FSDP
|
||||
```yaml
|
||||
@@ -962,10 +752,28 @@ wandb_mode:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
```
|
||||
|
||||
### Training with Deepspeed
|
||||
|
||||
Deepspeed is an optimization suite for multi-gpu systems allowing you to train much larger models than you
|
||||
might typically be able to fit into your GPU's VRAM. More information about the various optimization types
|
||||
for deepspeed is available at https://huggingface.co/docs/accelerate/main/en/usage_guides/deepspeed#what-is-integrated
|
||||
|
||||
We provide several default deepspeed JSON configurations for ZeRO stage 1, 2, and 3.
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/config.py --deepspeed deepspeed/zero1.json
|
||||
```
|
||||
|
||||
or
|
||||
|
||||
```yaml
|
||||
deepspeed: deepspeed/zero1.json
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Pass the appropriate flag to the train command:
|
||||
@@ -983,14 +791,6 @@ Pass the appropriate flag to the train command:
|
||||
cat /tmp/prompt.txt | python -m axolotl.cli.inference examples/your_config.yml \
|
||||
--base_model="./completed-model" --prompter=None --load_in_8bit=True
|
||||
```
|
||||
-- With gradio hosting
|
||||
```bash
|
||||
python -m axolotl.cli.inference examples/your_config.yml --gradio
|
||||
```
|
||||
|
||||
Please use `--sample_packing False` if you have it on and receive the error similar to below:
|
||||
|
||||
> RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 8, 128] at entry 1
|
||||
|
||||
### Merge LORA to base
|
||||
|
||||
@@ -1008,8 +808,6 @@ CUDA_VISIBLE_DEVICES="" python3 -m axolotl.cli.merge_lora ...
|
||||
|
||||
## Common Errors 🧰
|
||||
|
||||
See also the [FAQ's](./docs/faq.md).
|
||||
|
||||
> If you encounter a 'Cuda out of memory' error, it means your GPU ran out of memory during the training process. Here's how to resolve it:
|
||||
|
||||
Please reduce any below
|
||||
|
||||
@@ -24,6 +24,16 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -28,6 +28,16 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupDecayLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear",
|
||||
"total_num_steps": "auto"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -1,6 +1,14 @@
|
||||
{
|
||||
"zero_optimization": {
|
||||
"stage": 3,
|
||||
"offload_optimizer": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"offload_param": {
|
||||
"device": "cpu",
|
||||
"pin_memory": true
|
||||
},
|
||||
"overlap_comm": true,
|
||||
"contiguous_gradients": true,
|
||||
"sub_group_size": 0,
|
||||
@@ -32,6 +40,15 @@
|
||||
"weight_decay": "auto"
|
||||
}
|
||||
},
|
||||
"scheduler": {
|
||||
"type": "WarmupLR",
|
||||
"params": {
|
||||
"warmup_min_lr": "auto",
|
||||
"warmup_max_lr": "auto",
|
||||
"warmup_num_steps": "auto",
|
||||
"warmup_type": "linear"
|
||||
}
|
||||
},
|
||||
"gradient_accumulation_steps": "auto",
|
||||
"train_batch_size": "auto",
|
||||
"train_micro_batch_size_per_gpu": "auto",
|
||||
|
||||
@@ -5,9 +5,6 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX"
|
||||
ARG AXOLOTL_EXTRAS=""
|
||||
ARG CUDA="118"
|
||||
ENV BNB_CUDA_VERSION=$CUDA
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
|
||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||
|
||||
RUN apt-get update && \
|
||||
apt-get install -y vim curl
|
||||
@@ -19,11 +16,10 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
WORKDIR /workspace/axolotl
|
||||
|
||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||
else \
|
||||
pip install -e .[deepspeed,flash-attn]; \
|
||||
pip install -e .[flash-attn]; \
|
||||
fi
|
||||
|
||||
# fix so that git fetch/pull from remote works
|
||||
|
||||
@@ -10,13 +10,11 @@ ENV PATH="/root/miniconda3/bin:${PATH}"
|
||||
ARG PYTHON_VERSION="3.9"
|
||||
ARG PYTORCH_VERSION="2.0.1"
|
||||
ARG CUDA="118"
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
ENV PYTHON_VERSION=$PYTHON_VERSION
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/* \
|
||||
&& apt-get install -y wget git build-essential ninja-build git-lfs libaio-dev && rm -rf /var/lib/apt/lists/*
|
||||
&& wget \
|
||||
https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh \
|
||||
&& mkdir /root/.conda \
|
||||
@@ -29,9 +27,52 @@ ENV PATH="/root/miniconda3/envs/py${PYTHON_VERSION}/bin:${PATH}"
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN python3 -m pip install --upgrade pip && pip3 install packaging && \
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} deepspeed-kernels --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
python3 -m pip install --no-cache-dir -U torch==${PYTORCH_VERSION}+cu${CUDA} --extra-index-url https://download.pytorch.org/whl/cu$CUDA
|
||||
|
||||
RUN git lfs install --skip-repo && \
|
||||
pip3 install awscli && \
|
||||
FROM base-builder AS deepspeed-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN git clone https://github.com/microsoft/DeepSpeed.git && \
|
||||
cd DeepSpeed && \
|
||||
MAX_CONCURRENCY=8 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_OPS=1 DS_BUILD_EVOFORMER_ATTN=0 python3 setup.py bdist_wheel
|
||||
|
||||
FROM base-builder AS bnb-builder
|
||||
|
||||
WORKDIR /workspace
|
||||
ARG CUDA="118"
|
||||
ENV CUDA=$CUDA
|
||||
ARG MAX_JOBS="-1"
|
||||
ENV MAX_JOBS=$MAX_JOBS
|
||||
|
||||
RUN git clone https://github.com/TimDettmers/bitsandbytes.git && \
|
||||
cd bitsandbytes && \
|
||||
CUDA_VERSION=$CUDA make cuda11x && \
|
||||
python setup.py bdist_wheel
|
||||
|
||||
FROM base-builder
|
||||
|
||||
ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6 9.0+PTX"
|
||||
ENV TORCH_CUDA_ARCH_LIST=$TORCH_CUDA_ARCH_LIST
|
||||
|
||||
# recompile apex
|
||||
RUN python3 -m pip uninstall -y apex
|
||||
RUN git clone https://github.com/NVIDIA/apex
|
||||
RUN cd apex && python3 -m pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
||||
|
||||
RUN mkdir -p /workspace/builds
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes /workspace/builds/bitsandbytes
|
||||
|
||||
RUN mkdir -p /workspace/wheels/bitsandbytes
|
||||
COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/dist/bitsandbytes-*.whl wheels
|
||||
COPY --from=bnb-builder /workspace/bitsandbytes/bitsandbytes/libbitsandbytes*.so wheels/bitsandbytes
|
||||
|
||||
RUN pip3 install wheels/deepspeed-*.whl
|
||||
RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
|
||||
RUN git lfs install --skip-repo
|
||||
RUN pip3 install awscli && \
|
||||
# The base image ships with `pydantic==1.8.2` which is not working
|
||||
pip3 install -U --no-cache-dir pydantic==1.10.10
|
||||
|
||||
18
docs/faq.md
18
docs/faq.md
@@ -1,18 +0,0 @@
|
||||
# Axolotl FAQ's
|
||||
|
||||
|
||||
> The trainer stopped and hasn't progressed in several minutes.
|
||||
|
||||
Usually an issue with the GPU's communicating with each other. See the [NCCL doc](../docs/nccl.md)
|
||||
|
||||
> Exitcode -9
|
||||
|
||||
This usually happens when you run out of system RAM.
|
||||
|
||||
> Exitcode -7 while using deepspeed
|
||||
|
||||
Try upgrading deepspeed w: `pip install -U deepspeed`
|
||||
|
||||
> AttributeError: 'DummyOptim' object has no attribute 'step'
|
||||
|
||||
You may be using deepspeed with single gpu. Please don't set `deepspeed:` in yaml or cli.
|
||||
@@ -1,51 +0,0 @@
|
||||
# Multipack
|
||||
|
||||
4k context, bsz =4,
|
||||
each character represents 256 tokens
|
||||
X represents a padding token
|
||||
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B ]
|
||||
C C C C C C C ]
|
||||
D D D D ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F ]
|
||||
[ G G G ]
|
||||
[ H H H H ]]
|
||||
|
||||
[[ I I I ]
|
||||
[ J J J ]
|
||||
[ K K K K K]
|
||||
[ L L L ]]
|
||||
```
|
||||
|
||||
after padding to longest input in each step
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A ]
|
||||
B B B B B B X X X X X X ]
|
||||
C C C C C C C X X X X ]
|
||||
D D D D X X X X X X X ]]
|
||||
|
||||
[[ E E E E E E E E ]
|
||||
[ F F F F X X X X ]
|
||||
[ G G G X X X X X ]
|
||||
[ H H H H X X X X ]]
|
||||
|
||||
[[ I I I X X ]
|
||||
[ J J J X X ]
|
||||
[ K K K K K ]
|
||||
[ L L L X X ]]
|
||||
```
|
||||
|
||||
w packing ( note it's the same effective number of tokens per step, but a true bsz of 1)
|
||||
```
|
||||
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5
|
||||
[[ A A A A A A A A A A A B B B B B
|
||||
B C C C C C C C D D D D E E E E
|
||||
E E E E F F F F F G G G H H H H
|
||||
I I I J J J J K K K K K L L L X ]]
|
||||
```
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: cerebras/btlm-3b-8k-base
|
||||
base_model_config: cerebras/btlm-3b-8k-base
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: GPT2Tokenizer
|
||||
trust_remote_code: true
|
||||
@@ -14,7 +15,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_prepared_run
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
@@ -35,7 +36,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
output_dir: btlm-out
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: cerebras/Cerebras-GPT-1.3B
|
||||
base_model_config: cerebras/Cerebras-GPT-1.3B
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -7,7 +8,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -24,7 +25,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
batch_size: 4
|
||||
@@ -49,7 +50,7 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,12 +30,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +55,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-13b-hf
|
||||
base_model_config: codellama/CodeLlama-13b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -31,12 +32,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +57,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,12 +30,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +55,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-34b-hf
|
||||
base_model_config: codellama/CodeLlama-34b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -31,12 +32,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +57,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,12 +30,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +55,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: codellama/CodeLlama-7b-hf
|
||||
base_model_config: codellama/CodeLlama-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: CodeLlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -31,12 +32,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +57,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -12,7 +13,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -26,7 +27,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# 1b: tiiuae/falcon-rw-1b
|
||||
# 40b: tiiuae/falcon-40b
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
# required by falcon custom model code: https://huggingface.co/tiiuae/falcon-7b/tree/main
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
@@ -18,7 +19,7 @@ datasets:
|
||||
- Chain-of-Thought/formatted_cot_data/gsm8k_train.json
|
||||
type: "alpaca:chat"
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -40,7 +41,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -53,7 +54,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: tiiuae/falcon-7b
|
||||
base_model_config: tiiuae/falcon-7b
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -12,7 +13,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca:chat
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -26,7 +27,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./falcon-7b
|
||||
batch_size: 2
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: EleutherAI/gpt-j-6b
|
||||
base_model_config: EleutherAI/gpt-j-6b
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
@@ -7,7 +8,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 2048
|
||||
@@ -21,7 +22,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 2
|
||||
@@ -46,7 +47,7 @@ flash_attention:
|
||||
gptq_groupsize:
|
||||
gptq_model_v1:
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: huggyllama/llama-7b
|
||||
base_model_config: huggyllama/llama-7b
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -19,12 +20,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./jeopardy-bot-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -9,16 +9,12 @@ gradient_accumulation_steps: 2
|
||||
micro_batch_size: 1
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/qlora.yml
|
||||
accelerate launch scripts/finetune.py examples/llama-2/qlora.yml
|
||||
|
||||
```
|
||||
or
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/lora.yml
|
||||
```
|
||||
accelerate launch scripts/finetune.py examples/llama-2/lora.yml
|
||||
|
||||
To launch a full finetuning with 16-bit precision:
|
||||
|
||||
```shell
|
||||
accelerate launch -m axolotl.cli.train examples/llama-2/fft_optimized.yml
|
||||
```
|
||||
|
||||
@@ -1,72 +0,0 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
|
||||
adapter:
|
||||
lora_model_dir:
|
||||
lora_r:
|
||||
lora_alpha:
|
||||
lora_dropout:
|
||||
lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
flash_attn_cross_entropy: false
|
||||
flash_attn_rms_norm: true
|
||||
flash_attn_fuse_qkv: false
|
||||
flash_attn_fuse_mlp: true
|
||||
|
||||
warmup_steps: 100
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||
weight_decay: 0.1
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
bos_token: "<s>"
|
||||
eos_token: "</s>"
|
||||
unk_token: "<unk>"
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: TheBloke/Llama-2-7B-GPTQ
|
||||
base_model_config: TheBloke/Llama-2-7B-GPTQ
|
||||
is_llama_derived_model: false
|
||||
gptq: true
|
||||
gptq_disable_exllama: true
|
||||
@@ -15,7 +16,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
sequence_len: 4096
|
||||
@@ -32,12 +33,12 @@ lora_target_linear:
|
||||
lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./model-out
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_torch
|
||||
adam_beta2: 0.95
|
||||
adam_eps: 0.00001
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,12 +30,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +55,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -31,12 +32,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: paged_adamw_32bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -56,7 +57,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
debug:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: NousResearch/Llama-2-7b-hf
|
||||
base_model_config: NousResearch/Llama-2-7b-hf
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
@@ -11,7 +12,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./relora-out
|
||||
|
||||
adapter: qlora
|
||||
@@ -35,12 +36,12 @@ relora_cpu_offload: false
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 4
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -60,7 +61,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
save_steps: 50
|
||||
debug:
|
||||
deepspeed:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: PY007/TinyLlama-1.1B-intermediate-step-715k-1.5T
|
||||
base_model: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
base_model_config: PY007/TinyLlama-1.1B-step-50K-105b
|
||||
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
@@ -12,7 +13,7 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
@@ -29,12 +30,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
@@ -54,7 +55,7 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
eval_table_size:
|
||||
save_steps:
|
||||
debug:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
base_model_config: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
@@ -11,25 +12,25 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./out
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
sample_packing:
|
||||
pad_to_sequence_len:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.000005
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
@@ -46,8 +47,8 @@ xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: mistralai/Mistral-7B-v0.1
|
||||
base_model_config: mistralai/Mistral-7B-v0.1
|
||||
model_type: MistralForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_mistral_derived_model: true
|
||||
@@ -11,15 +12,15 @@ datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path: last_run_prepared
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
output_dir: ./qlora-out
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
|
||||
sequence_len: 8192
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
sample_packing: True
|
||||
pad_to_sequence_len: True
|
||||
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
@@ -38,11 +39,11 @@ lora_target_modules:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
micro_batch_size: 4
|
||||
num_epochs: 1
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
@@ -62,12 +63,9 @@ logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: true
|
||||
|
||||
loss_watchdog_threshold: 5.0
|
||||
loss_watchdog_patience: 3
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_steps: 20
|
||||
eval_table_size: 5
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: mosaicml/mpt-7b
|
||||
base_model_config: mosaicml/mpt-7b
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code: true # required for mpt as their model class is not merged into transformers yet
|
||||
load_in_8bit: false
|
||||
@@ -21,12 +22,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: mpt-alpaca-7b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./mpt-alpaca-7b
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -23,7 +24,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./openllama-out
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: true
|
||||
@@ -29,7 +30,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-out
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: openlm-research/open_llama_3b_v2
|
||||
base_model_config: openlm-research/open_llama_3b_v2
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
load_in_8bit: false
|
||||
@@ -9,7 +10,7 @@ datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
sequence_len: 1024
|
||||
@@ -23,7 +24,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
model_type: PhiForCausalLM
|
||||
base_model_config: microsoft/phi-1_5
|
||||
model_type: MixFormerSequentialForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
trust_remote_code: true
|
||||
@@ -31,7 +32,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: microsoft/phi-1_5
|
||||
base_model_config: microsoft/phi-1_5
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
is_llama_derived_model: false
|
||||
@@ -31,7 +32,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: EleutherAI/pythia-12b-deduped
|
||||
base_model_config: EleutherAI/pythia-12b-deduped
|
||||
base_model_ignore_patterns: pytorch* # prefer safetensors
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -24,7 +25,7 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./pythia-12b
|
||||
gradient_accumulation_steps: 1
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: EleutherAI/pythia-1.4b-deduped
|
||||
base_model_config: EleutherAI/pythia-1.4b-deduped
|
||||
load_in_8bit: true
|
||||
datasets:
|
||||
- path: teknium/GPT4-LLM-Cleaned
|
||||
@@ -18,20 +19,20 @@ lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-alpaca-pythia
|
||||
gradient_accumulation_steps: 1
|
||||
micro_batch_size: 4
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
learning_rate: 0.00001
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
tf32: true
|
||||
bf16: True
|
||||
tf32: True
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
weight_decay: 0.1
|
||||
eval_steps: 0.05
|
||||
eval_steps: 20
|
||||
logging_steps: 1
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
is_qwen_derived_model: true
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 2048 # supports up to 8192
|
||||
sample_packing: false
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,68 +0,0 @@
|
||||
base_model: Qwen/Qwen-7B
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
|
||||
is_qwen_derived_model: true
|
||||
trust_remote_code: true
|
||||
|
||||
load_in_8bit: false
|
||||
load_in_4bit: true
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 2048 # supports up to 8192
|
||||
sample_packing: false
|
||||
pad_to_sequence_len:
|
||||
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_bnb_8bit
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: true
|
||||
fp16: false
|
||||
tf32: false
|
||||
|
||||
gradient_checkpointing: false
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention:
|
||||
|
||||
warmup_steps: 10
|
||||
eval_steps: 0.05
|
||||
eval_table_size:
|
||||
eval_table_max_new_tokens: 128
|
||||
save_steps:
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
base_model_config: togethercomputer/RedPajama-INCITE-Chat-3B-v1
|
||||
model_type: GPTNeoXForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
trust_remote_code:
|
||||
@@ -22,12 +23,12 @@ lora_fan_in_fan_out: false
|
||||
wandb_project: redpajama-alpaca-3b
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./redpajama-alpaca-3b
|
||||
batch_size: 4
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer: adamw_bnb_8bit
|
||||
torchdistx_path:
|
||||
lr_scheduler: cosine
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
base_model: replit/replit-code-v1-3b
|
||||
base_model_config: replit/replit-code-v1-3b
|
||||
trust_remote_code: true
|
||||
load_in_8bit: false
|
||||
datasets:
|
||||
@@ -21,12 +22,12 @@ lora_fan_in_fan_out:
|
||||
wandb_project: lora-replit
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./lora-replit
|
||||
batch_size: 8
|
||||
micro_batch_size: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
optimizer:
|
||||
torchdistx_path:
|
||||
lr_scheduler:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# An example finetuning Saleforce's XGen-7b model with 8k context using qlora
|
||||
# on Tim Dettmer's Guanaco dataset.
|
||||
base_model: Salesforce/xgen-7b-8k-base
|
||||
base_model_config: Salesforce/xgen-7b-8k-base
|
||||
trust_remote_code: true
|
||||
model_type: AutoModelForCausalLM
|
||||
tokenizer_type: AutoTokenizer
|
||||
@@ -16,7 +17,7 @@ datasets:
|
||||
- openassistant_best_replies_train.jsonl
|
||||
type: "completion"
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.05
|
||||
val_set_size: 0.01
|
||||
# enable QLoRA
|
||||
adapter: qlora
|
||||
lora_model_dir:
|
||||
@@ -38,7 +39,7 @@ lora_fan_in_fan_out:
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_run_id:
|
||||
wandb_log_model:
|
||||
output_dir: ./qlora-out
|
||||
|
||||
@@ -51,7 +52,7 @@ output_dir: ./qlora-out
|
||||
# decrease if OOM, increase for max VRAM utilization
|
||||
micro_batch_size: 1
|
||||
gradient_accumulation_steps: 1
|
||||
num_epochs: 4
|
||||
num_epochs: 3
|
||||
# Optimizer for QLoRA
|
||||
optimizer: paged_adamw_32bit
|
||||
torchdistx_path:
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 370 KiB |
@@ -1,22 +1,23 @@
|
||||
--extra-index-url https://download.pytorch.org/whl/cu118
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
auto-gptq==0.5.1
|
||||
torch==2.0.1
|
||||
auto-gptq
|
||||
packaging
|
||||
peft==0.6.0
|
||||
transformers==4.35.2
|
||||
tokenizers==0.15.0
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@bd6205919aad4d3a2300a39a98a642f1cc3a5348
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate==0.24.1
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@80da9cfb09bb3cc9f1b385cb55d6b90d025a5fd9
|
||||
deepspeed
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
datasets
|
||||
flash-attn>=2.3.0
|
||||
sentencepiece
|
||||
wandb
|
||||
einops
|
||||
xformers==0.0.22
|
||||
optimum==1.13.2
|
||||
xformers
|
||||
optimum
|
||||
hf_transfer
|
||||
colorama
|
||||
numba
|
||||
@@ -30,10 +31,3 @@ scikit-learn==1.2.2
|
||||
pynvml
|
||||
art
|
||||
fschat==0.2.29
|
||||
gradio==3.50.2
|
||||
tensorboard
|
||||
|
||||
# remote filesystems
|
||||
s3fs
|
||||
gcsfs
|
||||
# adlfs
|
||||
|
||||
@@ -45,6 +45,8 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
shard(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
10
setup.py
10
setup.py
@@ -21,14 +21,6 @@ def parse_requirements():
|
||||
):
|
||||
# Handle standard packages
|
||||
_install_requires.append(line)
|
||||
|
||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
||||
if "torch==2.1.0" in _install_requires:
|
||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
||||
_install_requires.append(
|
||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
||||
)
|
||||
|
||||
return _install_requires, _dependency_links
|
||||
|
||||
|
||||
@@ -46,7 +38,7 @@ setup(
|
||||
dependency_links=dependency_links,
|
||||
extras_require={
|
||||
"flash-attn": [
|
||||
"flash-attn>=2.3.0",
|
||||
"flash-attn>=2.2.1",
|
||||
],
|
||||
"deepspeed": [
|
||||
"deepspeed",
|
||||
|
||||
@@ -6,10 +6,8 @@ import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import gradio as gr
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
@@ -18,7 +16,7 @@ from accelerate.commands.config import config_args
|
||||
from art import text2art
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
@@ -29,7 +27,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import prepare_optim_env
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
@@ -47,7 +44,7 @@ def print_axolotl_text_art(suffix=None):
|
||||
ascii_text = " axolotl"
|
||||
if suffix:
|
||||
ascii_text += f" x {suffix}"
|
||||
ascii_art = text2art(ascii_text, font=font)
|
||||
ascii_art = text2art(" axolotl", font=font)
|
||||
|
||||
if is_main_process():
|
||||
print(ascii_art)
|
||||
@@ -72,7 +69,7 @@ def do_merge_lora(
|
||||
|
||||
LOG.info("running merge of LoRA with base model")
|
||||
model = model.merge_and_unload()
|
||||
model.to(dtype=cfg.torch_dtype)
|
||||
model.to(dtype=torch.float16)
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
LOG.info(f"saving merged model to: {str(Path(cfg.output_dir) / 'merged')}")
|
||||
@@ -156,91 +153,6 @@ def do_inference(
|
||||
print(tokenizer.decode(generated["sequences"].cpu().tolist()[0]))
|
||||
|
||||
|
||||
def do_inference_gradio(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
):
|
||||
model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args)
|
||||
prompter = cli_args.prompter
|
||||
default_tokens = {"unk_token": "<unk>", "bos_token": "<s>", "eos_token": "</s>"}
|
||||
|
||||
for token, symbol in default_tokens.items():
|
||||
# If the token isn't already specified in the config, add it
|
||||
if not (cfg.special_tokens and token in cfg.special_tokens):
|
||||
tokenizer.add_special_tokens({token: symbol})
|
||||
|
||||
prompter_module = None
|
||||
if prompter:
|
||||
prompter_module = getattr(
|
||||
importlib.import_module("axolotl.prompters"), prompter
|
||||
)
|
||||
|
||||
if cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
model.set_mem_cache_args(
|
||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
||||
)
|
||||
|
||||
model = model.to(cfg.device)
|
||||
|
||||
def generate(instruction):
|
||||
if not instruction:
|
||||
return
|
||||
if prompter_module:
|
||||
# pylint: disable=stop-iteration-return
|
||||
prompt: str = next(
|
||||
prompter_module().build_prompt(instruction=instruction.strip("\n"))
|
||||
)
|
||||
else:
|
||||
prompt = instruction.strip()
|
||||
batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
generation_config = GenerationConfig(
|
||||
repetition_penalty=1.1,
|
||||
max_new_tokens=1024,
|
||||
temperature=0.9,
|
||||
top_p=0.95,
|
||||
top_k=40,
|
||||
bos_token_id=tokenizer.bos_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
do_sample=True,
|
||||
use_cache=True,
|
||||
return_dict_in_generate=True,
|
||||
output_attentions=False,
|
||||
output_hidden_states=False,
|
||||
output_scores=False,
|
||||
)
|
||||
streamer = TextIteratorStreamer(tokenizer)
|
||||
generation_kwargs = {
|
||||
"inputs": batch["input_ids"].to(cfg.device),
|
||||
"generation_config": generation_config,
|
||||
"streamer": streamer,
|
||||
}
|
||||
|
||||
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
||||
thread.start()
|
||||
|
||||
all_text = ""
|
||||
|
||||
for new_text in streamer:
|
||||
all_text += new_text
|
||||
yield all_text
|
||||
|
||||
demo = gr.Interface(
|
||||
fn=generate,
|
||||
inputs="textbox",
|
||||
outputs="text",
|
||||
title=cfg.get("gradio_title", "Axolotl Gradio Interface"),
|
||||
)
|
||||
demo.queue().launch(show_api=False, share=True)
|
||||
|
||||
|
||||
def choose_config(path: Path):
|
||||
yaml_files = list(path.glob("*.yml"))
|
||||
|
||||
@@ -282,7 +194,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
cfg.axolotl_config_path = config
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
@@ -297,8 +208,6 @@ def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
@@ -312,9 +221,7 @@ def load_datasets(
|
||||
) -> TrainDatasetMeta:
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||
cfg, tokenizer
|
||||
)
|
||||
train_dataset, eval_dataset, total_num_steps = prepare_dataset(cfg, tokenizer)
|
||||
|
||||
if cli_args.debug or cfg.debug:
|
||||
LOG.info("check_dataset_labels...")
|
||||
@@ -330,10 +237,6 @@ def load_datasets(
|
||||
text_only=cli_args.debug_text_only,
|
||||
)
|
||||
|
||||
LOG.info("printing prompters...")
|
||||
for prompter in prompters:
|
||||
LOG.info(prompter)
|
||||
|
||||
return TrainDatasetMeta(
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
@@ -6,30 +6,21 @@ from pathlib import Path
|
||||
import fire
|
||||
import transformers
|
||||
|
||||
from axolotl.cli import (
|
||||
do_inference,
|
||||
do_inference_gradio,
|
||||
load_cfg,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.cli import do_inference, load_cfg, print_axolotl_text_art
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), gradio=False, **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parsed_cfg.sample_packing = False
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
parsed_cli_args.inference = True
|
||||
|
||||
if gradio:
|
||||
do_inference_gradio(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
else:
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
do_inference(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,53 +0,0 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
from colorama import Fore
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
check_user_token,
|
||||
load_cfg,
|
||||
load_datasets,
|
||||
print_axolotl_text_art,
|
||||
)
|
||||
from axolotl.common.cli import PreprocessCliArgs
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
check_accelerate_default_config()
|
||||
check_user_token()
|
||||
parser = transformers.HfArgumentParser((PreprocessCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
if not parsed_cfg.dataset_prepared_path:
|
||||
msg = (
|
||||
Fore.RED
|
||||
+ "preprocess CLI called without dataset_prepared_path set, "
|
||||
+ f"using default path: {DEFAULT_DATASET_PREPARED_PATH}"
|
||||
+ Fore.RESET
|
||||
)
|
||||
LOG.warning(msg)
|
||||
parsed_cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
_ = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
LOG.info(
|
||||
Fore.GREEN
|
||||
+ f"Success! Preprocessed data path: `dataset_prepared_path: {parsed_cfg.dataset_prepared_path}`"
|
||||
+ Fore.RESET
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(do_cli)
|
||||
@@ -1,7 +1,6 @@
|
||||
"""
|
||||
CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
@@ -17,8 +16,6 @@ from axolotl.cli import (
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -30,7 +27,10 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
dataset_meta = load_datasets(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||
if parsed_cli_args.prepare_ds_only:
|
||||
return
|
||||
train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
|
||||
@@ -25,22 +25,11 @@ class TrainerCliArgs:
|
||||
debug_num_examples: int = field(default=5)
|
||||
inference: bool = field(default=False)
|
||||
merge_lora: bool = field(default=False)
|
||||
prepare_ds_only: bool = field(default=False)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
shard: bool = field(default=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessCliArgs:
|
||||
"""
|
||||
dataclass representing arguments for preprocessing only
|
||||
"""
|
||||
|
||||
debug: bool = field(default=False)
|
||||
debug_text_only: bool = field(default=False)
|
||||
debug_num_examples: int = field(default=1)
|
||||
prompter: Optional[str] = field(default=None)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
"""
|
||||
Various shared constants
|
||||
"""
|
||||
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
@@ -1,757 +0,0 @@
|
||||
"""
|
||||
Builder for the training args and trainer
|
||||
"""
|
||||
|
||||
import abc
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_utils import seed_worker
|
||||
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
LossWatchDogCallback,
|
||||
SaveAxolotlConfigtoWandBCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
dataloader_prefetch_factor: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||
self.num_epochs = num_epochs
|
||||
self.bench_data_collator = bench_data_collator
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
):
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing:
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
self.train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
self.args.per_device_eval_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
lengths=(
|
||||
eval_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing:
|
||||
train_dataset = self.train_dataset
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> DataLoader:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
pct_start=pct_start,
|
||||
div_factor=6,
|
||||
)
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
Base class for trainer builder
|
||||
"""
|
||||
|
||||
_train_dataset = None
|
||||
_eval_dataset = None
|
||||
|
||||
def __init__(self, cfg, model, tokenizer):
|
||||
self.cfg = cfg
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
@property
|
||||
def train_dataset(self):
|
||||
return self._train_dataset
|
||||
|
||||
@train_dataset.setter
|
||||
def train_dataset(self, dataset):
|
||||
self._train_dataset = dataset
|
||||
|
||||
@property
|
||||
def eval_dataset(self):
|
||||
return self._eval_dataset
|
||||
|
||||
@eval_dataset.setter
|
||||
def eval_dataset(self, dataset):
|
||||
self._eval_dataset = dataset
|
||||
|
||||
@abstractmethod
|
||||
def build(self, total_num_steps):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_callbacks(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
"""
|
||||
Callbacks added after the trainer is created, usually b/c these need access to the trainer
|
||||
"""
|
||||
|
||||
|
||||
class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
"""
|
||||
Build the HuggingFace training args/trainer for Causal models
|
||||
"""
|
||||
|
||||
def hook_pre_create_training_args(self, training_arguments_kwargs):
|
||||
# TODO
|
||||
return training_arguments_kwargs
|
||||
|
||||
def hook_post_create_training_args(self, training_arguments):
|
||||
# TODO
|
||||
return training_arguments
|
||||
|
||||
def hook_pre_create_trainer(self, trainer_kwargs, trainer_cls):
|
||||
# TODO
|
||||
return trainer_kwargs, trainer_cls
|
||||
|
||||
def hook_post_create_trainer(self, trainer):
|
||||
# TODO
|
||||
return trainer
|
||||
|
||||
def get_callbacks(self):
|
||||
callbacks = []
|
||||
callbacks.append(GPUStatsCallback(self.cfg))
|
||||
callbacks.append(EvalFirstStepCallback)
|
||||
|
||||
if self.cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(self.cfg))
|
||||
|
||||
if (
|
||||
hasattr(self.model, "use_bettertransformer")
|
||||
and self.model.use_bettertransformer is True
|
||||
):
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
if self.cfg.use_wandb:
|
||||
callbacks.append(
|
||||
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
|
||||
)
|
||||
|
||||
if self.cfg.loss_watchdog_threshold is not None:
|
||||
callbacks.append(LossWatchDogCallback(self.cfg))
|
||||
|
||||
return callbacks
|
||||
|
||||
def get_post_trainer_create_callbacks(self, trainer):
|
||||
callbacks = []
|
||||
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(
|
||||
trainer, self.tokenizer
|
||||
)
|
||||
callbacks.append(LogPredictionCallback(self.cfg))
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
callbacks.append(bench_eval_callback_factory(trainer, self.tokenizer))
|
||||
|
||||
if self.cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
self.cfg.early_stopping_patience,
|
||||
)
|
||||
callbacks.append(early_stop_cb)
|
||||
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self):
|
||||
if self.cfg.lr_scheduler == "one_cycle" and (
|
||||
self.cfg.fsdp or self.cfg.adapter == "qlora"
|
||||
):
|
||||
return OneCycleLRSchedulerTrainer
|
||||
if self.cfg.relora_steps:
|
||||
return ReLoRATrainer
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
warmup_steps = None
|
||||
if self.cfg.warmup_steps is not None:
|
||||
warmup_steps = self.cfg.warmup_steps
|
||||
elif self.cfg.warmup_ratio is not None:
|
||||
warmup_steps = max(int(self.cfg.warmup_ratio * total_num_steps), 0)
|
||||
else:
|
||||
warmup_steps = min(int(0.03 * total_num_steps), 100)
|
||||
|
||||
logging_steps = (
|
||||
self.cfg.logging_steps
|
||||
if self.cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
training_arguments_kwargs = {}
|
||||
if self.cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = self.cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = (
|
||||
self.cfg.fp16 and not self.cfg.bf16
|
||||
) or False
|
||||
training_arguments_kwargs["tf32"] = self.cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
if self.cfg.seed:
|
||||
training_arguments_kwargs["seed"] = self.cfg.seed
|
||||
|
||||
if self.cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs[
|
||||
"gradient_checkpointing"
|
||||
] = self.cfg.gradient_checkpointing
|
||||
if self.cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = self.cfg.fsdp
|
||||
if self.cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(self.cfg.fsdp_config)
|
||||
|
||||
# deepspeed
|
||||
if self.cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = self.cfg.deepspeed
|
||||
|
||||
if self.cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs[
|
||||
"lr_quadratic_warmup"
|
||||
] = self.cfg.lr_quadratic_warmup
|
||||
|
||||
if self.cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1
|
||||
if self.cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2
|
||||
if self.cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon
|
||||
if self.cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm
|
||||
|
||||
if self.cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = self.cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
|
||||
if self.cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||
|
||||
if self.cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||
|
||||
if self.cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = self.cfg.sample_packing_eff_est
|
||||
|
||||
if self.cfg.dataloader_pin_memory is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_pin_memory"
|
||||
] = self.cfg.dataloader_pin_memory
|
||||
if self.cfg.dataloader_num_workers is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_num_workers"
|
||||
] = self.cfg.dataloader_num_workers
|
||||
if self.cfg.dataloader_prefetch_factor is not None:
|
||||
training_arguments_kwargs[
|
||||
"dataloader_prefetch_factor"
|
||||
] = self.cfg.dataloader_prefetch_factor
|
||||
|
||||
if self.cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
elif self.cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = self.cfg.eval_steps
|
||||
elif self.cfg.evaluation_strategy:
|
||||
training_arguments_kwargs[
|
||||
"evaluation_strategy"
|
||||
] = self.cfg.evaluation_strategy
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
training_arguments_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if self.cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = self.cfg.do_bench_eval
|
||||
if self.cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = self.cfg.bench_dataset
|
||||
if self.cfg.metric_for_best_model:
|
||||
training_arguments_kwargs[
|
||||
"metric_for_best_model"
|
||||
] = self.cfg.metric_for_best_model
|
||||
if self.cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = self.cfg.greater_is_better
|
||||
|
||||
if self.cfg.torch_compile:
|
||||
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||
elif torch._dynamo: # pylint: disable=protected-access
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = self.cfg.torch_compile
|
||||
if self.cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = self.cfg.torch_compile_backend
|
||||
|
||||
# DDP Config
|
||||
if self.cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = self.cfg.ddp_timeout
|
||||
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||
if self.cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = self.cfg.ddp_bucket_cap_mb
|
||||
if self.cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs[
|
||||
"ddp_broadcast_buffers"
|
||||
] = self.cfg.ddp_broadcast_buffers
|
||||
|
||||
# these are all the "standard" kwargs that are def used
|
||||
training_arguments_kwargs["max_steps"] = (
|
||||
total_num_steps if self.cfg.max_steps else -1
|
||||
)
|
||||
training_arguments_kwargs["max_seq_length"] = self.cfg.sequence_len
|
||||
training_arguments_kwargs[
|
||||
"per_device_train_batch_size"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs[
|
||||
"per_device_eval_batch_size"
|
||||
] = self.cfg.eval_batch_size
|
||||
training_arguments_kwargs[
|
||||
"gradient_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs[
|
||||
"eval_accumulation_steps"
|
||||
] = self.cfg.gradient_accumulation_steps
|
||||
training_arguments_kwargs["num_train_epochs"] = self.cfg.num_epochs
|
||||
training_arguments_kwargs["learning_rate"] = self.cfg.learning_rate
|
||||
training_arguments_kwargs["output_dir"] = self.cfg.output_dir
|
||||
training_arguments_kwargs["save_total_limit"] = (
|
||||
self.cfg.save_total_limit if self.cfg.save_total_limit else 4
|
||||
)
|
||||
training_arguments_kwargs["load_best_model_at_end"] = (
|
||||
(
|
||||
self.cfg.load_best_model_at_end is not False
|
||||
or self.cfg.early_stopping_patience
|
||||
)
|
||||
and self.cfg.val_set_size > 0
|
||||
and self.cfg.save_steps
|
||||
and self.cfg.eval_steps
|
||||
and self.cfg.save_steps % self.cfg.eval_steps == 0
|
||||
) or False
|
||||
training_arguments_kwargs["ddp_find_unused_parameters"] = (
|
||||
False if self.cfg.ddp else None
|
||||
)
|
||||
training_arguments_kwargs["group_by_length"] = self.cfg.group_by_length
|
||||
training_arguments_kwargs["report_to"] = "wandb" if self.cfg.use_wandb else None
|
||||
training_arguments_kwargs["run_name"] = (
|
||||
self.cfg.wandb_name if self.cfg.use_wandb else None
|
||||
)
|
||||
training_arguments_kwargs["optim"] = (
|
||||
self.cfg.optimizer if self.cfg.optimizer else "adamw_hf"
|
||||
)
|
||||
training_arguments_kwargs["lr_scheduler_type"] = (
|
||||
self.cfg.lr_scheduler
|
||||
if self.cfg.lr_scheduler
|
||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine"
|
||||
)
|
||||
training_arguments_kwargs["weight_decay"] = (
|
||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||
)
|
||||
training_arguments_kwargs["sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing
|
||||
if self.cfg.eval_sample_packing is not False
|
||||
else False
|
||||
)
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_seq_len_multiplier"
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
training_args = (
|
||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
)
|
||||
training_args = self.hook_post_create_training_args(training_args)
|
||||
trainer_kwargs = {}
|
||||
|
||||
if self.cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(self.cfg.torchdistx_path).exists():
|
||||
sys.path.append(self.cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
if self.cfg.pad_to_sequence_len:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||
self.cfg.sequence_len / 64
|
||||
)
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
set_model_mem_id(self.model, self.tokenizer)
|
||||
|
||||
LOG.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [self.train_dataset, self.eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(
|
||||
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
|
||||
),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||
trainer_kwargs, trainer_cls
|
||||
)
|
||||
trainer = trainer_cls(
|
||||
model=self.model,
|
||||
train_dataset=self.train_dataset,
|
||||
eval_dataset=self.eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=self.get_callbacks(),
|
||||
num_epochs=self.cfg.num_epochs,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
trainer = self.hook_post_create_trainer(trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(trainer):
|
||||
trainer.add_callback(callback)
|
||||
|
||||
if self.cfg.deepspeed and self.cfg.sample_packing:
|
||||
trainer.accelerator.state.deepspeed_plugin.deepspeed_config[
|
||||
"train_micro_batch_size_per_gpu"
|
||||
] = self.cfg.micro_batch_size
|
||||
|
||||
return trainer
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
from datasets import Dataset, IterableDataset, Sequence, Value
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
@@ -30,29 +30,27 @@ class TokenizedPromptDataset(Dataset):
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: IterableDataset,
|
||||
process_count: Optional[int] = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.prompt_tokenizer = prompt_tokenizer
|
||||
self.process_count = process_count
|
||||
super().__init__(self.process(dataset).data, **kwargs)
|
||||
|
||||
def process(self, dataset):
|
||||
features = dataset.features.keys()
|
||||
num_proc = (
|
||||
min(64, self.process_count)
|
||||
if self.process_count
|
||||
else min(64, os.cpu_count())
|
||||
)
|
||||
num_proc = min(64, os.cpu_count())
|
||||
map_kwargs = {}
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 100
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
return (
|
||||
dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
|
||||
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,4 @@ MixFormers model architecture used for phi models
|
||||
"""
|
||||
|
||||
from .configuration_mixformer_sequential import MixFormerSequentialConfig # noqa
|
||||
from .configuration_phi import PhiConfig # noqa
|
||||
from .modeling_mixformer_sequential import MixFormerSequentialForCausalLM # noqa
|
||||
from .modeling_phi import PhiForCausalLM # noqa
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
# pylint: skip-file
|
||||
# Copyright (c) Microsoft Corporation.
|
||||
# Licensed under the MIT license.
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
from transformers import PretrainedConfig
|
||||
|
||||
|
||||
class PhiConfig(PretrainedConfig):
|
||||
"""Phi configuration."""
|
||||
|
||||
model_type = "phi"
|
||||
attribute_map = {
|
||||
"max_position_embeddings": "n_positions",
|
||||
"hidden_size": "n_embd",
|
||||
"num_attention_heads": "n_head",
|
||||
"num_hidden_layers": "n_layer",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int = 50304,
|
||||
n_positions: int = 2048,
|
||||
n_embd: int = 1024,
|
||||
n_layer: int = 20,
|
||||
n_inner: Optional[int] = None,
|
||||
n_head: int = 16,
|
||||
n_head_kv: Optional[int] = None,
|
||||
rotary_dim: Optional[int] = 32,
|
||||
activation_function: Optional[str] = "gelu_new",
|
||||
flash_attn: bool = False,
|
||||
flash_rotary: bool = False,
|
||||
fused_dense: bool = False,
|
||||
attn_pdrop: float = 0.0,
|
||||
embd_pdrop: float = 0.0,
|
||||
resid_pdrop: float = 0.0,
|
||||
layer_norm_epsilon: float = 1e-5,
|
||||
initializer_range: float = 0.02,
|
||||
tie_word_embeddings: bool = False,
|
||||
pad_vocab_size_multiple: int = 64,
|
||||
**kwargs
|
||||
) -> None:
|
||||
self.vocab_size = int(
|
||||
math.ceil(vocab_size / pad_vocab_size_multiple) * pad_vocab_size_multiple
|
||||
)
|
||||
self.n_positions = n_positions
|
||||
self.n_embd = n_embd
|
||||
self.n_layer = n_layer
|
||||
self.n_inner = n_inner
|
||||
self.n_head = n_head
|
||||
self.n_head_kv = n_head_kv
|
||||
self.rotary_dim = min(rotary_dim, n_embd // n_head)
|
||||
self.activation_function = activation_function
|
||||
self.flash_attn = flash_attn
|
||||
self.flash_rotary = flash_rotary
|
||||
self.fused_dense = fused_dense
|
||||
self.attn_pdrop = attn_pdrop
|
||||
self.embd_pdrop = embd_pdrop
|
||||
self.resid_pdrop = resid_pdrop
|
||||
self.layer_norm_epsilon = layer_norm_epsilon
|
||||
self.initializer_range = initializer_range
|
||||
|
||||
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,168 +0,0 @@
|
||||
# Adapted from Unsloth
|
||||
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
import torch
|
||||
|
||||
MAX_FUSED_SIZE = 65536
|
||||
|
||||
def calculate_settings(n):
|
||||
BLOCK_SIZE = triton.next_power_of_2(n)
|
||||
# CUDA only supports 65536 - 2^16 threads per block
|
||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
||||
num_warps = 4
|
||||
if BLOCK_SIZE >= 32768: num_warps = 32
|
||||
elif BLOCK_SIZE >= 8192: num_warps = 16
|
||||
elif BLOCK_SIZE >= 2048: num_warps = 8
|
||||
return BLOCK_SIZE, num_warps
|
||||
pass
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
||||
loss_ptr,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
||||
Pi = exp(xi) / sum(exp(xi))
|
||||
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
||||
= -y [ x - log[sum(exp(x))] ]
|
||||
= y * (log[sum(exp(x))] - x)
|
||||
If y == 0: CE_i = 0
|
||||
If y == 1: CE_i = logsumexp - x
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
loss_ptr += row_idx
|
||||
lse_ptr += row_idx
|
||||
labels_ptr += row_idx
|
||||
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
|
||||
# TODO: Fixup int32 locations to int64
|
||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
||||
max_logits = tl.max(logits, 0)
|
||||
# Maximum stops overflow
|
||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
||||
tl.store(lse_ptr, lse)
|
||||
|
||||
if label_idx != -100:
|
||||
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
|
||||
loss = lse - logits_label
|
||||
else:
|
||||
loss = 0.0
|
||||
tl.store(loss_ptr, loss)
|
||||
pass
|
||||
|
||||
|
||||
@triton.jit
|
||||
def _cross_entropy_backward(logits_ptr, logits_row_stride,
|
||||
dloss_ptr, dloss_row_stride,
|
||||
lse_ptr,
|
||||
labels_ptr,
|
||||
n_cols,
|
||||
BLOCK_SIZE: tl.constexpr,):
|
||||
"""
|
||||
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
||||
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
||||
|
||||
From https://en.wikipedia.org/wiki/LogSumExp
|
||||
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
||||
|
||||
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
||||
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
||||
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
||||
|
||||
If y == 0: dC/dx = 0
|
||||
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
||||
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
||||
"""
|
||||
row_idx = tl.program_id(0)
|
||||
logits_ptr += row_idx * logits_row_stride
|
||||
dloss_ptr += row_idx * dloss_row_stride
|
||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
||||
mask = col_offsets < n_cols
|
||||
# TODO: Fixup int32 locations to int64
|
||||
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
||||
|
||||
if label_idx != -100:
|
||||
dloss = tl.load(dloss_ptr)
|
||||
else:
|
||||
dloss = 0.0
|
||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
|
||||
lse = tl.load(lse_ptr + row_idx)
|
||||
probs = tl.exp(logits - lse)
|
||||
|
||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
||||
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
|
||||
|
||||
|
||||
|
||||
class CrossEntropyLoss(torch.autograd.Function):
|
||||
@staticmethod
|
||||
def forward(ctx, logits, labels):
|
||||
n_rows, n_cols = logits.shape
|
||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
||||
|
||||
_cross_entropy_forward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
losses,
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = BLOCK_SIZE,
|
||||
num_warps = num_warps,
|
||||
)
|
||||
|
||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
||||
ctx.num_warps = num_warps
|
||||
ctx.save_for_backward(logits, logsumexp, labels)
|
||||
return losses
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, dlosses):
|
||||
logits, logsumexp, labels = ctx.saved_tensors
|
||||
n_rows, n_cols = logits.shape
|
||||
|
||||
_cross_entropy_backward[(n_rows,)](
|
||||
logits, logits.stride(0),
|
||||
dlosses, dlosses.stride(0),
|
||||
logsumexp,
|
||||
labels,
|
||||
n_cols,
|
||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
||||
num_warps = ctx.num_warps,
|
||||
)
|
||||
return logits, None, None,
|
||||
pass
|
||||
pass
|
||||
|
||||
|
||||
def fast_cross_entropy_loss(logits, labels):
|
||||
"""
|
||||
Arguments:
|
||||
logits: (batch, seq_len, vocab_size)
|
||||
labels: (batch, seq_len,)
|
||||
Returns:
|
||||
losses: float
|
||||
"""
|
||||
batch, seq_len, d = logits.shape
|
||||
assert(labels.shape == (batch, seq_len))
|
||||
|
||||
loss = CrossEntropyLoss.apply(
|
||||
logits.view(batch*seq_len, d),
|
||||
labels.view(-1),
|
||||
)
|
||||
n_items = torch.count_nonzero(labels != -100)
|
||||
return loss.sum() / n_items
|
||||
pass
|
||||
@@ -13,18 +13,12 @@ import transformers
|
||||
from einops import rearrange
|
||||
from flash_attn.bert_padding import pad_input, unpad_input
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaDecoderLayer as OriginalLlamaDecoderLayer,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaMLP,
|
||||
apply_rotary_pos_emb,
|
||||
repeat_kv,
|
||||
)
|
||||
from xformers.ops import SwiGLU
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids, set_module_name
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-imports
|
||||
@@ -44,28 +38,6 @@ except ImportError:
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def replace_llama_mlp_with_swiglu(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaMLP):
|
||||
mlp = FusedMLP(
|
||||
module.config, module.gate_proj, module.up_proj, module.down_proj
|
||||
)
|
||||
set_module_name(model, name, mlp)
|
||||
|
||||
|
||||
def replace_llama_qkv_with_fused(model):
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, LlamaAttention):
|
||||
qkv = FusedAttention(
|
||||
module.config,
|
||||
module.q_proj,
|
||||
module.k_proj,
|
||||
module.v_proj,
|
||||
module.o_proj,
|
||||
)
|
||||
set_module_name(model, name, qkv)
|
||||
|
||||
|
||||
def replace_llama_attn_with_flash_attn(
|
||||
packed: Optional[bool] = False,
|
||||
cross_entropy: Optional[bool] = False,
|
||||
@@ -114,92 +86,6 @@ def replace_llama_attn_with_flash_attn(
|
||||
)
|
||||
|
||||
|
||||
class FusedAttention(LlamaAttention):
|
||||
"""
|
||||
Fused QKV Attention layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
q: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
k: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
v: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
o: torch.nn.Linear, # pylint: disable=invalid-name
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.init_device = next(iter(q.state_dict().values())).device
|
||||
|
||||
# define equivalent fused qkv projection
|
||||
self.out_features: List[int] = [q.out_features, k.out_features, v.out_features]
|
||||
self.qkv_proj = torch.nn.Linear(
|
||||
q.in_features, sum(self.out_features), device=self.init_device, bias=False
|
||||
)
|
||||
self.o_proj = o
|
||||
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.qkv_proj.weight.data = torch.cat(
|
||||
(q.weight.data, k.weight.data, v.weight.data), dim=0
|
||||
)
|
||||
|
||||
def _post_training(self, model, name):
|
||||
q_proj, k_proj, v_proj = torch.split(
|
||||
self.qkv_proj.weight.data, self.out_features, dim=0
|
||||
)
|
||||
|
||||
new_attn = LlamaAttention(self.config)
|
||||
new_attn.q_proj.weight.data = q_proj
|
||||
new_attn.k_proj.weight.data = k_proj
|
||||
new_attn.v_proj.weight.data = v_proj
|
||||
new_attn.o_proj.weight.data = self.o_proj.weight.data
|
||||
|
||||
set_module_name(model, name, new_attn)
|
||||
|
||||
|
||||
class FusedMLP(torch.nn.Module):
|
||||
"""
|
||||
Fused MLP layer for incrementally improved training efficiency
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config,
|
||||
gate_proj: torch.nn.Linear,
|
||||
up_proj: torch.nn.Linear,
|
||||
down_proj: torch.nn.Linear,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.swiglu = SwiGLU(
|
||||
in_features=config.hidden_size,
|
||||
hidden_features=config.intermediate_size,
|
||||
bias=False,
|
||||
_pack_weights=True,
|
||||
)
|
||||
# overwrite initialized weights with pretrained weights
|
||||
self.swiglu.w12.weight.data = torch.cat(
|
||||
(gate_proj.weight.data, up_proj.weight.data), dim=0
|
||||
)
|
||||
self.swiglu.w3.weight.data = down_proj.weight.data
|
||||
|
||||
def _post_training(self, model, name):
|
||||
w1, w2 = torch.split( # pylint: disable=invalid-name
|
||||
self.swiglu.w12.weight.data, self.config.intermediate_size, dim=0
|
||||
)
|
||||
|
||||
# Assign the split weights back to the original layers
|
||||
new_mlp = LlamaMLP(self.config)
|
||||
new_mlp.gate_proj.weight.data = w1
|
||||
new_mlp.up_proj.weight.data = w2
|
||||
new_mlp.down_proj.weight.data = self.swiglu.w3.weight.data
|
||||
|
||||
set_module_name(model, name, new_mlp)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name
|
||||
return self.swiglu(x)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
@@ -261,14 +147,9 @@ def flashattn_forward(
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
if isinstance(self, FusedAttention):
|
||||
query_states, key_states, value_states = self.qkv_proj(hidden_states).split(
|
||||
self.out_features, dim=-1
|
||||
)
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
@@ -321,8 +202,6 @@ def flashattn_forward(
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
@@ -332,12 +211,7 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -360,7 +234,7 @@ def flashattn_forward(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
dropout_p=dropout_rate,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
@@ -373,7 +247,6 @@ def flashattn_forward(
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
)
|
||||
else:
|
||||
@@ -406,7 +279,7 @@ def flashattn_forward(
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=dropout_rate,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
)
|
||||
|
||||
@@ -25,8 +25,6 @@ def sdp_attention_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -29,8 +29,6 @@ def xformers_forward(
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
@@ -13,20 +13,13 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
||||
flash_attn_varlen_kvpacked_func,
|
||||
flash_attn_varlen_qkvpacked_func,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralAttention as OriginalMistralAttention,
|
||||
)
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import (
|
||||
MistralForCausalLM as OriginalMistralForCausalLM,
|
||||
)
|
||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
|
||||
|
||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||
|
||||
@@ -40,9 +33,6 @@ def replace_mistral_attn_with_flash_attn(
|
||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||
flashattn_forward
|
||||
)
|
||||
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
|
||||
mistral_causallm_forward
|
||||
)
|
||||
if packed:
|
||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||
MistralDecoderLayer
|
||||
@@ -52,44 +42,6 @@ def replace_mistral_attn_with_flash_attn(
|
||||
)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def _make_sliding_window_causal_mask(
|
||||
bsz: int,
|
||||
tgt_len: int,
|
||||
dtype: torch.dtype,
|
||||
device: torch.device,
|
||||
past_key_values_length: int = 0,
|
||||
sliding_window: int = 4096,
|
||||
):
|
||||
"""
|
||||
Make causal mask used for sliding window attention
|
||||
"""
|
||||
tensor = torch.full(
|
||||
(tgt_len, tgt_len),
|
||||
fill_value=1,
|
||||
device=device,
|
||||
)
|
||||
mask = torch.tril(tensor, diagonal=0)
|
||||
# make the mask banded to account for sliding window
|
||||
# NOTE: HF implementation is wrong as of 14-10-2023 for torch.triu, needs +1
|
||||
mask = torch.triu(mask, diagonal=-sliding_window + 1)
|
||||
mask = torch.log(mask).to(dtype)
|
||||
|
||||
if past_key_values_length > 0:
|
||||
mask = torch.cat(
|
||||
[
|
||||
torch.zeros(
|
||||
tgt_len, past_key_values_length, dtype=dtype, device=device
|
||||
),
|
||||
mask,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
return mask[None, None, :, :].expand(
|
||||
bsz, 1, tgt_len, tgt_len + past_key_values_length
|
||||
)
|
||||
|
||||
|
||||
# Disable the transformation of the attention mask in LlamaModel as the flash attention
|
||||
# requires the attention mask to be the same as the key_padding_mask
|
||||
def _prepare_decoder_attention_mask(
|
||||
@@ -101,29 +53,11 @@ def _prepare_decoder_attention_mask(
|
||||
sliding_window,
|
||||
): # pylint: disable=unused-argument
|
||||
# [bsz, seq_len]
|
||||
if attention_mask is None:
|
||||
return attention_mask
|
||||
|
||||
# NOTE: attention mask and sliding masks are only broadcastable in certain scenarios.
|
||||
# Without attention_mask.shape[0] == 1, error will trigger after eval loss but only when wandb is enabled.
|
||||
if input_shape[-1] > 1 and attention_mask.shape[0] == 1:
|
||||
sliding_window_mask = _make_sliding_window_causal_mask(
|
||||
bsz=input_shape[0],
|
||||
tgt_len=input_shape[1],
|
||||
dtype=inputs_embeds.dtype,
|
||||
device=inputs_embeds.device,
|
||||
past_key_values_length=past_key_values_length,
|
||||
sliding_window=sliding_window,
|
||||
)
|
||||
attention_mask = attention_mask + sliding_window_mask
|
||||
else:
|
||||
LOG.info("skipping sliding window mask, not broadcastable with attention mask")
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
def flashattn_forward(
|
||||
self: OriginalMistralAttention,
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
@@ -157,41 +91,10 @@ def flashattn_forward(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
|
||||
use_sliding_windows = (
|
||||
hasattr(self.config, "sliding_window") is not None
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
)
|
||||
|
||||
if use_sliding_windows:
|
||||
window_size = (self.config.sliding_window, self.config.sliding_window)
|
||||
else:
|
||||
window_size = (-1, -1)
|
||||
|
||||
if past_key_value is not None:
|
||||
# Activate slicing cache only if the config has a value `sliding_windows` attribute
|
||||
if (
|
||||
hasattr(self.config, "sliding_window")
|
||||
and kv_seq_len > self.config.sliding_window
|
||||
):
|
||||
slicing_tokens = kv_seq_len - self.config.sliding_window
|
||||
|
||||
past_key = past_key_value[0]
|
||||
past_value = past_key_value[1]
|
||||
|
||||
past_key = past_key[:, :, slicing_tokens:, :].contiguous()
|
||||
past_value = past_value[:, :, slicing_tokens:, :].contiguous()
|
||||
|
||||
if past_key.shape[-2] != self.config.sliding_window - 1:
|
||||
raise ValueError(
|
||||
f"past key much have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
|
||||
f" {past_key.shape}"
|
||||
)
|
||||
|
||||
past_key_value = (past_key, past_value) if use_cache else None
|
||||
|
||||
if past_key_value is not None:
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
@@ -208,8 +111,6 @@ def flashattn_forward(
|
||||
# only on first autoregressive step q,k,v have same seqlen
|
||||
is_causal = key_states.shape == query_states.shape
|
||||
|
||||
dropout_rate = 0.0 if not self.training else getattr(self, "attention_dropout", 0.0)
|
||||
|
||||
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||
# special handling using sample packing
|
||||
qkv = torch.stack(
|
||||
@@ -219,13 +120,7 @@ def flashattn_forward(
|
||||
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||
|
||||
output = flash_attn_varlen_qkvpacked_func(
|
||||
qkv,
|
||||
cu_seqlens,
|
||||
max_seqlen,
|
||||
dropout_p=dropout_rate,
|
||||
softmax_scale=None,
|
||||
causal=True,
|
||||
window_size=window_size,
|
||||
qkv, cu_seqlens, max_seqlen, 0.0, softmax_scale=None, causal=True
|
||||
)
|
||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||
elif query_states.shape == key_states.shape:
|
||||
@@ -248,10 +143,9 @@ def flashattn_forward(
|
||||
qkv_unpad,
|
||||
cu_seqlens_q,
|
||||
max_seqlen_q,
|
||||
dropout_p=dropout_rate,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
else:
|
||||
@@ -262,9 +156,7 @@ def flashattn_forward(
|
||||
output = flash_attn_kvpacked_func(
|
||||
query_states,
|
||||
torch.stack([key_states, value_states], 2),
|
||||
dropout_p=dropout_rate,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
else:
|
||||
( # pylint: disable=unbalanced-tuple-unpacking
|
||||
@@ -296,10 +188,9 @@ def flashattn_forward(
|
||||
cu_seqlens_k,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
dropout_p=dropout_rate,
|
||||
0.0,
|
||||
softmax_scale=None,
|
||||
causal=is_causal,
|
||||
window_size=window_size,
|
||||
)
|
||||
output = output_pad_fn(output_unpad)
|
||||
|
||||
@@ -648,71 +539,3 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
||||
outputs += (present_key_value,)
|
||||
|
||||
return outputs
|
||||
|
||||
def mistral_causallm_forward(
|
||||
self: OriginalMistralForCausalLM,
|
||||
input_ids: torch.LongTensor = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||
labels: Optional[torch.LongTensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
output_attentions: Optional[bool] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
*args, **kwargs
|
||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
||||
r"""
|
||||
Args:
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
||||
```"""
|
||||
|
||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
||||
output_hidden_states = (
|
||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
||||
|
||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
||||
outputs = self.model(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
output_attentions=output_attentions,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
|
||||
hidden_states = outputs[0]
|
||||
logits = self.lm_head(hidden_states)
|
||||
|
||||
loss = None
|
||||
if labels is not None:
|
||||
shift_logits = logits
|
||||
if not hasattr(self, "extra_ignored_labels"):
|
||||
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
||||
|
||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
||||
shift_labels = shift_labels.to(shift_logits.device)
|
||||
|
||||
# FAST CROSS ENTROPY
|
||||
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
||||
|
||||
if not return_dict:
|
||||
output = (logits,) + outputs[1:]
|
||||
return (loss,) + output if loss is not None else output
|
||||
|
||||
return CausalLMOutputWithPast(
|
||||
loss=loss,
|
||||
logits=logits,
|
||||
past_key_values=outputs.past_key_values,
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
@@ -1,65 +0,0 @@
|
||||
"""
|
||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
||||
"""
|
||||
import torch
|
||||
from peft import PeftModel
|
||||
from transformers import PreTrainedModel
|
||||
|
||||
|
||||
def patch_neft(alpha, model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
embeddings.noisy_embedding_alpha = alpha
|
||||
old_forward = embeddings.forward
|
||||
|
||||
# This hack seems to be needed to properly use a custom forward pass
|
||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
||||
embeddings, embeddings.__class__
|
||||
)
|
||||
setattr(embeddings, "forward", bound_method)
|
||||
|
||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
||||
return model
|
||||
|
||||
|
||||
def unpatch_neft(model):
|
||||
embeddings = None
|
||||
if isinstance(model, PreTrainedModel):
|
||||
embeddings = model.get_input_embeddings()
|
||||
if isinstance(model, PeftModel):
|
||||
embeddings = model.base_model.get_input_embeddings()
|
||||
if not embeddings:
|
||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
||||
if hasattr(embeddings, "_old_forward"):
|
||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings._old_forward # pylint: disable=protected-access
|
||||
del embeddings.noisy_embedding_alpha
|
||||
|
||||
|
||||
def neft_forward(self, inputs: torch.Tensor):
|
||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
||||
|
||||
if self.training:
|
||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
||||
-mag_norm, mag_norm
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
|
||||
def pretrain_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
||||
|
||||
|
||||
def post_train_hook(cfg, trainer):
|
||||
if cfg.noisy_embedding_alpha:
|
||||
unpatch_neft(trainer.model)
|
||||
@@ -101,16 +101,3 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
if "." in name:
|
||||
parent_name = name.rsplit(".", 1)[0]
|
||||
child_name = name[len(parent_name) + 1 :]
|
||||
parent = model.get_submodule(parent_name)
|
||||
else:
|
||||
parent_name = ""
|
||||
parent = model
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Module for Alpaca prompt strategy classes"""
|
||||
"""Module containing the AlpacaQAPromptTokenizingStrategy class"""
|
||||
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Tuple
|
||||
|
||||
from axolotl.prompt_tokenizers import (
|
||||
AlpacaPromptTokenizingStrategy,
|
||||
@@ -9,13 +9,9 @@ from axolotl.prompt_tokenizers import (
|
||||
from axolotl.prompters import AlpacaPrompter, PromptStyle, UnpromptedPrompter
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
prompt_style = PromptStyle.CHAT.value
|
||||
if ds_cfg and "conversation" in ds_cfg:
|
||||
prompt_style = ds_cfg["conversation"]
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
return AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(prompt_style),
|
||||
AlpacaPrompter(PromptStyle.CHAT.value),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
|
||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
strategy = SimpleShareGPTPromptTokenizingStrategy(
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
@@ -34,9 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and "strict" in ds_cfg:
|
||||
strategy.strict = ds_cfg["strict"]
|
||||
return strategy
|
||||
if ds_cfg and ds_cfg["skip"]:
|
||||
strat.skip_invalid = True
|
||||
return strat
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
@@ -57,30 +57,37 @@ def load_guanaco(tokenizer, cfg):
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
return NousShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
basic sharegpt strategy used by nous/ldj for input/output keyed data
|
||||
"""
|
||||
|
||||
_strict = True
|
||||
def get_conversation_thread(self):
|
||||
return "conversation"
|
||||
|
||||
@property
|
||||
def strict(self):
|
||||
return self._strict
|
||||
|
||||
@strict.setter
|
||||
def strict(self, strict):
|
||||
self._strict = strict
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
if self.strict:
|
||||
return conversations
|
||||
# remap roles - allow for assistant turn
|
||||
role_map = {"human": "human", "assistant": "gpt", "gpt": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["from"]], "value": t["value"]} for t in conversations
|
||||
]
|
||||
def map_conversation_thread(self, conversation):
|
||||
turns = []
|
||||
for turn in conversation:
|
||||
turns.append({"from": "human", "value": turn["input"]})
|
||||
turns.append({"from": "gpt", "value": turn["output"]})
|
||||
return turns
|
||||
|
||||
|
||||
@@ -89,10 +96,11 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
def map_conversation_thread(self, conversation):
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
turns = [
|
||||
{"from": turn["role"], "value": turn["value"]} for turn in conversation
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
@@ -101,11 +109,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
def map_conversation_thread(self, conversation):
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
{"from": role_map[turn["role"]], "value": turn["text"]}
|
||||
for turn in conversation
|
||||
]
|
||||
return turns
|
||||
|
||||
@@ -2,7 +2,9 @@
|
||||
|
||||
import abc
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
@@ -45,8 +47,6 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
self.prompter = prompter
|
||||
self.tokenizer: PreTrainedTokenizer = tokenizer
|
||||
self.train_on_inputs = train_on_inputs
|
||||
# sequence_len and max_length can be different for CompletionPromptTokenizingStrategy.
|
||||
# TODO: Document how they are different.
|
||||
self.sequence_len = sequence_len
|
||||
self.max_length = sequence_len
|
||||
|
||||
@@ -58,34 +58,57 @@ class PromptTokenizingStrategy(abc.ABC):
|
||||
def supports_batched(self):
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_user_token(self):
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=128)
|
||||
def _get_assistant_token(self):
|
||||
try:
|
||||
id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>")
|
||||
if isinstance(id_or_ids, (int,)):
|
||||
return id_or_ids
|
||||
except KeyError:
|
||||
pass
|
||||
return False
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
empty = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
result: BatchEncoding
|
||||
if not prompt:
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
return empty
|
||||
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if len(result["input_ids"]) == 0:
|
||||
LOG.warning("Tokenizer result is empty. You may want to audit your dataset")
|
||||
return empty
|
||||
|
||||
if (
|
||||
result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.max_length
|
||||
and add_eos_token
|
||||
):
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
if result["input_ids"][0] == self.tokenizer.bos_token_id and strip_bos_token:
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
@@ -121,7 +144,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
if not self.train_on_inputs:
|
||||
user_prompt_len = len(tokenized_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_prompt["labels"] = [IGNORE_INDEX] * user_prompt_len
|
||||
tokenized_prompt["labels"] = [-100] * user_prompt_len
|
||||
tokenized_res_prompt = self._tokenize(
|
||||
response, strip_bos_token=True, add_eos_token=True
|
||||
)
|
||||
@@ -245,7 +268,6 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
raise NotImplementedError
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# pylint: disable=duplicate-code
|
||||
(
|
||||
instruction,
|
||||
input, # pylint: disable=redefined-builtin
|
||||
@@ -270,7 +292,7 @@ class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
||||
# TODO this could be sped up using numpy array slicing
|
||||
tokenized_full_prompt["labels"] = [
|
||||
IGNORE_INDEX
|
||||
-100
|
||||
] * user_prompt_len + tokenized_full_prompt["labels"][user_prompt_len:]
|
||||
|
||||
return tokenized_full_prompt
|
||||
@@ -330,105 +352,141 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
_skip_invalid = False
|
||||
|
||||
@property
|
||||
def supports_batched(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def skip_invalid(self):
|
||||
return self._skip_invalid
|
||||
|
||||
@skip_invalid.setter
|
||||
def skip_invalid(self, value):
|
||||
self._skip_invalid = value
|
||||
|
||||
def get_conversation_thread(self):
|
||||
return "conversations"
|
||||
|
||||
def map_conversation_thread(self, conversation):
|
||||
return conversation
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
# Initial values. We will append to these as we go through the conversation.
|
||||
result, current_len = tokenize_prompt_default()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation.copy() # pylint: disable=protected-access
|
||||
)
|
||||
tokenized_res = defaultdict(lambda: [])
|
||||
conv_field = self.get_conversation_thread()
|
||||
for prmpt in prompt[conv_field]:
|
||||
result, current_len = tokenize_prompt_default()
|
||||
user_token = self._get_user_token()
|
||||
assistant_token = self._get_assistant_token()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation # pylint: disable=protected-access
|
||||
)
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
if conversation.roles[0] in part[0]:
|
||||
turn = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
if not part[1].strip():
|
||||
err_msg = f"user turn has empty text: {prmpt}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif conversation.roles[1] in part[0]:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not part[1].strip():
|
||||
err_msg = f"assistant turn has empty text: {prmpt}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
elif part[0] == "":
|
||||
turn = part[1]
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
err_msg = f"unhandled role: {part[0]}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
continue
|
||||
|
||||
# support for custom roles from the dataset, only useful for vicuna style prompts/roles
|
||||
role_remap = []
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
for key, val in sorted(result.items(), key=lambda x: x[0]):
|
||||
tokenized_res[key].append(val)
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
except ValueError as err:
|
||||
LOG.warning("skipping prompt: %s", str(err))
|
||||
return tokenized_res
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
if not prompt.strip():
|
||||
LOG.warning("Empty text requested for tokenization.")
|
||||
result = BatchEncoding(data={"input_ids": [], "attention_mask": []})
|
||||
else:
|
||||
result = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.sequence_len,
|
||||
padding=False,
|
||||
return_tensors=None,
|
||||
)
|
||||
if (
|
||||
conversation.name == "vicuna_v1.1"
|
||||
and "roles" in prompt
|
||||
and len(prompt["roles"]) >= 2
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][-1] != self.tokenizer.eos_token_id
|
||||
and len(result["input_ids"]) < self.sequence_len
|
||||
and add_eos_token
|
||||
):
|
||||
role_remap = [
|
||||
{"from": conversation.roles[0], "to": prompt["roles"][0]},
|
||||
{"from": conversation.roles[1], "to": prompt["roles"][1]},
|
||||
]
|
||||
result["input_ids"].append(self.tokenizer.eos_token_id)
|
||||
result["attention_mask"].append(1)
|
||||
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if not isinstance(part, tuple):
|
||||
LOG.warning(f"expected tuple, got {part}")
|
||||
continue
|
||||
if (
|
||||
len(result["input_ids"]) > 0
|
||||
and result["input_ids"][0] == self.tokenizer.bos_token_id
|
||||
and strip_bos_token
|
||||
):
|
||||
result["input_ids"] = result["input_ids"][1:]
|
||||
result["attention_mask"] = result["attention_mask"][1:]
|
||||
|
||||
user, assistant = conversation.roles
|
||||
role, content = part
|
||||
|
||||
# Uses "in" because role contains extra characters
|
||||
if user in role:
|
||||
role = (
|
||||
role.replace(role_remap[0]["from"], role_remap[0]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this is still the user query, we should
|
||||
if not content.strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif assistant in role:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
role = (
|
||||
role.replace(role_remap[1]["from"], role_remap[1]["to"])
|
||||
if role_remap
|
||||
else role
|
||||
)
|
||||
turn = role + content
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not content.strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
role_res = self._tokenize(
|
||||
role.rstrip(),
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
len_role = len(role_res["input_ids"])
|
||||
labels[:len_role] = [IGNORE_TOKEN_ID] * min(len_role, len(labels))
|
||||
elif role == "":
|
||||
turn = content
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {role}")
|
||||
continue
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
result["labels"] = result["input_ids"].copy()
|
||||
return result
|
||||
|
||||
|
||||
def tokenize_prompt_default() -> Tuple[Dict[str, List[int]], int]:
|
||||
|
||||
@@ -4,12 +4,10 @@ import logging
|
||||
from enum import Enum
|
||||
from typing import Generator, Optional, Union
|
||||
|
||||
from colorama import Fore
|
||||
from fastchat.conversation import Conversation, get_conv_template
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
IGNORE_TOKEN_ID = -100
|
||||
REPR_TEMPLATE = "\n<start>\n" + Fore.CYAN + "{full_prompt}" + Fore.RESET + "\n<end>\n"
|
||||
|
||||
|
||||
class PromptStyle(Enum):
|
||||
@@ -22,13 +20,7 @@ class PromptStyle(Enum):
|
||||
CHATML = "chatml"
|
||||
|
||||
|
||||
class Prompter:
|
||||
"""
|
||||
Base prompter class for all prompters
|
||||
"""
|
||||
|
||||
|
||||
class AlpacaPrompter(Prompter):
|
||||
class AlpacaPrompter:
|
||||
"""
|
||||
Base class for alpaca prompters
|
||||
"""
|
||||
@@ -63,38 +55,29 @@ class AlpacaPrompter(Prompter):
|
||||
)
|
||||
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||
|
||||
def _build_result(self, instruction, input_text, output):
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input_text:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input_text)
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
if self.system_no_input_prompt
|
||||
else ""
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
yield self._build_result(instruction, input, output)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_format.format(instruction=instruction, input=input)
|
||||
else:
|
||||
res = (
|
||||
self.system_format.format(system=self.system_no_input_prompt)
|
||||
if self.system_prompt
|
||||
else ""
|
||||
) + self.turn_no_input_format.format(instruction=instruction)
|
||||
if output:
|
||||
res = f"{res}{output}"
|
||||
yield res
|
||||
|
||||
|
||||
class UnpromptedPrompter(AlpacaPrompter):
|
||||
@@ -165,7 +148,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter):
|
||||
"""
|
||||
|
||||
|
||||
class ReflectAlpacaPrompter(Prompter):
|
||||
class ReflectAlpacaPrompter:
|
||||
"""
|
||||
Prompter for ReflectAlpaca
|
||||
"""
|
||||
@@ -208,14 +191,14 @@ class ReflectAlpacaPrompter(Prompter):
|
||||
)
|
||||
self.response_split = "ASSISTANT:"
|
||||
|
||||
def _build_result(
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
):
|
||||
) -> Generator[str, None, None]:
|
||||
# returns the full prompt from instruction and optional input
|
||||
# if a label (=response, =output) is provided, it's also appended.
|
||||
if input:
|
||||
@@ -229,30 +212,7 @@ class ReflectAlpacaPrompter(Prompter):
|
||||
corrected=corrected,
|
||||
)
|
||||
res = f"{res}{label}"
|
||||
|
||||
return res
|
||||
|
||||
def build_prompt(
|
||||
self,
|
||||
instruction: str,
|
||||
input: Union[None, str] = None, # pylint: disable=redefined-builtin
|
||||
output: Union[None, str] = None,
|
||||
reflection: Union[None, str] = None,
|
||||
corrected: Union[None, str] = None,
|
||||
) -> Generator[str, None, None]:
|
||||
# pylint: disable=duplicate-code
|
||||
yield self._build_result(
|
||||
instruction,
|
||||
input,
|
||||
output,
|
||||
reflection,
|
||||
corrected,
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return REPR_TEMPLATE.format(
|
||||
full_prompt=self._build_result("{instruction}", "{input}", "{output}")
|
||||
)
|
||||
yield res
|
||||
|
||||
|
||||
SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
@@ -260,7 +220,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = (
|
||||
)
|
||||
|
||||
|
||||
class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
class ShareGPTPrompter: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
A prompter that generates prompts for the ShareGPT
|
||||
"""
|
||||
@@ -287,7 +247,7 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
if role_key_model:
|
||||
self.role_key_model = role_key_model
|
||||
|
||||
def _build_result(self, source):
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
if len(source) < 2:
|
||||
# If there isn't a back and forth conversation, ignore it
|
||||
# also happens on the data splitting leaving empty conversations
|
||||
@@ -314,28 +274,17 @@ class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods
|
||||
raise err
|
||||
|
||||
conv.messages = []
|
||||
for _, sentence in enumerate(source):
|
||||
for j, sentence in enumerate(source):
|
||||
role = roles[sentence["from"]]
|
||||
if len(conv.messages) > 0 and (
|
||||
(role == conv.messages[-1][0]) or (role not in conv.roles)
|
||||
):
|
||||
if role != conv.roles[j % 2]:
|
||||
LOG.warning(f"{SHAREGPT_ASSERTION_FAILED_ROLE}: {sentence}")
|
||||
conv.append_message(role, sentence["value"])
|
||||
|
||||
return conv.get_turns()
|
||||
|
||||
def build_prompt(self, source) -> Generator[str, None, None]:
|
||||
turns = self._build_result(source)
|
||||
|
||||
for part in turns:
|
||||
for part in conv.get_turns():
|
||||
if part[0] and not part[1]:
|
||||
LOG.warning(f"role with empty message: {part[0]}")
|
||||
yield part
|
||||
|
||||
def __repr__(self) -> str:
|
||||
turns = self._build_result([{"from": "{from}", "value": "{value}"}])
|
||||
return "\n".join([REPR_TEMPLATE.format(full_prompt=part) for part in turns])
|
||||
|
||||
|
||||
class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
"""
|
||||
@@ -353,15 +302,3 @@ class ShareGPTPrompterV2(ShareGPTPrompter):
|
||||
role_key_human=role_key_human,
|
||||
role_key_model=role_key_model,
|
||||
)
|
||||
|
||||
|
||||
class UnsupportedPrompter(Prompter):
|
||||
"""
|
||||
A dummy class for custom prompters
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def __repr__(self):
|
||||
return "Pre-tokenized or custom dataset types are unsupported for logging"
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -9,14 +10,11 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.monkeypatch import neft_embeddings
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
@@ -26,7 +24,7 @@ src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
|
||||
configure_logging()
|
||||
LOG = get_logger("axolotl.train")
|
||||
LOG = logging.getLogger("axolotl.train")
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -41,13 +39,13 @@ class TrainDatasetMeta:
|
||||
|
||||
|
||||
def train(
|
||||
*, cfg: DictDefault, cli_args: TrainerCliArgs, dataset_meta: TrainDatasetMeta
|
||||
*,
|
||||
cfg: DictDefault,
|
||||
cli_args: TrainerCliArgs,
|
||||
dataset_meta: TrainDatasetMeta,
|
||||
):
|
||||
# load the tokenizer first
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
LOG.info(f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}")
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -55,15 +53,14 @@ def train(
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
# Load the model and tokenizer
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
msg += " and peft_config..."
|
||||
LOG.debug(msg)
|
||||
LOG.info("loading model and (optionally) peft_config...")
|
||||
model, peft_config = load_model(cfg, tokenizer, inference=cli_args.inference)
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
if (
|
||||
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
|
||||
) or cfg.resume_from_checkpoint is True:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
@@ -76,7 +73,9 @@ def train(
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
resume_from_checkpoint = (
|
||||
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
|
||||
)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
@@ -114,7 +113,6 @@ def train(
|
||||
if cfg.group_by_length:
|
||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||
|
||||
pretrain_hooks(cfg, trainer)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
@@ -122,15 +120,9 @@ def train(
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
post_train_hooks(cfg, trainer)
|
||||
|
||||
LOG.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}")
|
||||
|
||||
# post training
|
||||
for name, module in model.named_modules():
|
||||
if hasattr(module, "_post_training"):
|
||||
module._post_training(model, name) # pylint: disable=protected-access
|
||||
|
||||
if trainer.is_fsdp_enabled:
|
||||
trainer.accelerator.state.fsdp_plugin.set_state_dict_type("FULL_STATE_DICT")
|
||||
LOG.info("Set FSDP state dict type to FULL_STATE_DICT for saving.")
|
||||
@@ -146,22 +138,6 @@ def train(
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
trainer.save_model(cfg.output_dir)
|
||||
elif cfg.deepspeed and is_deepspeed_zero3_enabled():
|
||||
# Copied over from: https://github.com/huggingface/accelerate/blob/5ae611118057232f441055f7ef9ba0b0f2b8d533/docs/source/usage_guides/deepspeed.md#saving-and-loading
|
||||
trainer.accelerator.wait_for_everyone()
|
||||
unwrapped_model = trainer.accelerator.unwrap_model(trainer.model_wrapped)
|
||||
|
||||
# Saves the whole/unpartitioned fp16 model when in ZeRO Stage-3 to the output directory if
|
||||
# `stage3_gather_16bit_weights_on_model_save` is True in DeepSpeed Config file or
|
||||
# `zero3_save_16bit_model` is True in DeepSpeed Plugin.
|
||||
# For Zero Stages 1 and 2, models are saved as usual in the output directory.
|
||||
# The model name saved is `pytorch_model.bin`
|
||||
unwrapped_model.save_pretrained(
|
||||
cfg.output_dir,
|
||||
is_main_process=trainer.accelerator.is_main_process,
|
||||
save_function=trainer.accelerator.save,
|
||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
@@ -172,23 +148,3 @@ def train(
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
|
||||
return model, tokenizer
|
||||
|
||||
|
||||
def pretrain_hooks(cfg, trainer):
|
||||
"""
|
||||
Run hooks right before kicking off the training
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
||||
|
||||
|
||||
def post_train_hooks(cfg, trainer):
|
||||
"""
|
||||
Run hooks right after training completes
|
||||
:param cfg:
|
||||
:param trainer:
|
||||
:return:
|
||||
"""
|
||||
neft_embeddings.post_train_hook(cfg, trainer)
|
||||
|
||||
@@ -37,7 +37,7 @@ from axolotl.utils.distributed import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from axolotl.core.trainer_builder import AxolotlTrainingArguments
|
||||
from axolotl.utils.trainer import AxolotlTrainingArguments
|
||||
|
||||
LOG = logging.getLogger("axolotl.callbacks")
|
||||
IGNORE_INDEX = -100
|
||||
@@ -124,36 +124,6 @@ class GPUStatsCallback(
|
||||
return control
|
||||
|
||||
|
||||
class LossWatchDogCallback(TrainerCallback):
|
||||
"""Callback to track loss and stop training if loss is too high"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.cfg = cfg
|
||||
self.logged = False
|
||||
self.violations = 0
|
||||
self.threshold = cfg.loss_watchdog_threshold
|
||||
self.patience = cfg.loss_watchdog_patience or 3
|
||||
|
||||
def on_step_end(
|
||||
self,
|
||||
_args: TrainingArguments,
|
||||
state: TrainerState,
|
||||
control: TrainerControl,
|
||||
**_kwargs,
|
||||
):
|
||||
if len(state.log_history) > 0 and "loss" in state.log_history[-1]:
|
||||
if state.log_history[-1]["loss"] > self.threshold:
|
||||
self.violations += 1
|
||||
if self.violations >= self.patience:
|
||||
LOG.warning(
|
||||
"Loss is too high, stopping training (loss_watchdog_threshold)"
|
||||
)
|
||||
control.should_training_stop = True
|
||||
else:
|
||||
self.violations = 0
|
||||
return control
|
||||
|
||||
|
||||
def bench_eval_callback_factory(trainer, tokenizer):
|
||||
accuracy = evaluate.load("accuracy")
|
||||
abcd_idx = [
|
||||
@@ -544,27 +514,3 @@ def log_prediction_callback_factory(trainer: Trainer, tokenizer):
|
||||
return control
|
||||
|
||||
return LogPredictionCallback
|
||||
|
||||
|
||||
class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
||||
"""Callback to save axolotl config to wandb"""
|
||||
|
||||
def __init__(self, axolotl_config_path):
|
||||
self.axolotl_config_path = axolotl_config_path
|
||||
|
||||
def on_train_begin(
|
||||
self,
|
||||
args: AxolotlTrainingArguments, # pylint: disable=unused-argument
|
||||
state: TrainerState, # pylint: disable=unused-argument
|
||||
control: TrainerControl,
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
if is_main_process():
|
||||
try:
|
||||
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
||||
artifact.add_file(local_path=self.axolotl_config_path)
|
||||
wandb.run.log_artifact(artifact)
|
||||
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
||||
except (FileNotFoundError, ConnectionError) as err:
|
||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||
return control
|
||||
|
||||
@@ -119,30 +119,3 @@ class DataCollatorForSeq2Seq:
|
||||
features["decoder_input_ids"] = decoder_input_ids
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
Collator for multipack specific to the using the BatchSampler
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for item in features
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
@@ -27,7 +27,7 @@ def choose_device(cfg):
|
||||
|
||||
cfg.device = get_device()
|
||||
if cfg.world_size == 1:
|
||||
cfg.device_map = cfg.device_map or "auto"
|
||||
cfg.device_map = "auto"
|
||||
else:
|
||||
if cfg.device.startswith("cuda"):
|
||||
cfg.device_map = {"": torch.cuda.current_device()}
|
||||
@@ -79,9 +79,6 @@ def normalize_config(cfg):
|
||||
|
||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||
|
||||
if not cfg.base_model_config:
|
||||
cfg.base_model_config = cfg.base_model
|
||||
|
||||
model_config = load_model_config(cfg)
|
||||
cfg.model_config_type = model_config.model_type
|
||||
|
||||
@@ -122,22 +119,6 @@ def normalize_config(cfg):
|
||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
cfg.is_qwen_derived_model = (
|
||||
(
|
||||
hasattr(model_config, "model_type")
|
||||
and model_config.model_type
|
||||
in [
|
||||
"qwen",
|
||||
]
|
||||
)
|
||||
or cfg.is_qwen_derived_model
|
||||
or "qwen" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "qwen" in cfg.model_type.lower())
|
||||
)
|
||||
|
||||
if isinstance(cfg.learning_rate, str):
|
||||
cfg.learning_rate = float(cfg.learning_rate)
|
||||
|
||||
log_gpu_memory_usage(LOG, "baseline", cfg.device)
|
||||
|
||||
|
||||
@@ -178,11 +159,7 @@ def validate_config(cfg):
|
||||
"batch_size is not recommended. Please use gradient_accumulation_steps instead.",
|
||||
"To calculate the equivalent gradient_accumulation_steps, divide batch_size / micro_batch_size / number of gpus.",
|
||||
)
|
||||
if (
|
||||
cfg.eval_batch_size
|
||||
and cfg.micro_batch_size
|
||||
and cfg.eval_batch_size != cfg.micro_batch_size
|
||||
):
|
||||
if cfg.eval_batch_size != cfg.micro_batch_size:
|
||||
LOG.warning(
|
||||
"eval_batch_size != micro_batch_size. This can lead to VRAM instability."
|
||||
)
|
||||
@@ -212,15 +189,9 @@ def validate_config(cfg):
|
||||
if not cfg.load_in_4bit:
|
||||
raise ValueError("Require cfg.load_in_4bit to be True for qlora")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with QLoRA")
|
||||
|
||||
if not cfg.load_in_8bit and cfg.adapter == "lora":
|
||||
LOG.warning("We recommend setting `load_in_8bit: true` for LORA finetuning")
|
||||
|
||||
if cfg.adapter == "lora" and (cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp):
|
||||
raise ValueError("Fused modules are not supported with LoRA")
|
||||
|
||||
if cfg.relora_steps:
|
||||
if cfg.adapter not in ("lora", "qlora"):
|
||||
raise ValueError("cfg.adapter must be lora or qlora to use ReLoRA")
|
||||
@@ -234,9 +205,6 @@ def validate_config(cfg):
|
||||
if cfg.lr_scheduler == "one_cycle":
|
||||
raise ValueError("ReLoRA is not compatible with the one_cycle scheduler")
|
||||
|
||||
if cfg.flash_attn_fuse_qkv or cfg.flash_attn_fuse_mlp:
|
||||
raise ValueError("Fused modules are not supported with ReLoRA")
|
||||
|
||||
if cfg.trust_remote_code:
|
||||
LOG.warning(
|
||||
"`trust_remote_code` is set to true. Please make sure that you reviewed the remote code/model."
|
||||
@@ -371,39 +339,6 @@ def validate_config(cfg):
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.sample_packing
|
||||
and cfg.eval_table_size
|
||||
and cfg.eval_sample_packing is not False
|
||||
):
|
||||
raise ValueError(
|
||||
"eval_table_size and eval_sample_packing are not supported together with sample_packing. Please set 'eval_sample_packing' to false."
|
||||
)
|
||||
|
||||
if not cfg.adapter and (cfg.load_in_8bit or cfg.load_in_4bit):
|
||||
raise ValueError(
|
||||
"load_in_8bit and load_in_4bit are not supported without setting an adapter."
|
||||
"If you want to full finetune, please turn off load_in_8bit and load_in_4bit."
|
||||
)
|
||||
|
||||
if cfg.rope_scaling:
|
||||
LOG.warning("`rope_scaling` should now be be a key under `model_config`")
|
||||
|
||||
if cfg.warmup_steps and cfg.warmup_ratio:
|
||||
raise ValueError("warmup_steps and warmup_ratio are mutually exclusive")
|
||||
|
||||
if cfg.is_qwen_derived_model and cfg.gradient_checkpointing:
|
||||
LOG.warning(
|
||||
"Gradient checkpointing is broken for Qwen models for transformers>=4.35.0, except main branch."
|
||||
)
|
||||
|
||||
if cfg.wandb_run_id and not cfg.wandb_name:
|
||||
cfg.wandb_name = cfg.wandb_run_id
|
||||
|
||||
LOG.warning(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
@@ -16,7 +16,6 @@ from datasets import (
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.datasets import ConstantLengthDataset, TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies import load
|
||||
from axolotl.prompt_tokenizers import (
|
||||
@@ -34,10 +33,8 @@ from axolotl.prompters import (
|
||||
JeopardyPrompter,
|
||||
MultipleChoiceConcisePrompter,
|
||||
MultipleChoiceExplainPrompter,
|
||||
Prompter,
|
||||
ReflectAlpacaPrompter,
|
||||
SummarizeTLDRPrompter,
|
||||
UnsupportedPrompter,
|
||||
)
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import is_main_process, zero_first
|
||||
@@ -47,6 +44,7 @@ from axolotl.utils.trainer import (
|
||||
)
|
||||
|
||||
LOG = logging.getLogger("axolotl")
|
||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||
|
||||
|
||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
@@ -57,10 +55,9 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
|
||||
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||
train_dataset, eval_dataset = load_prepare_datasets(
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
@@ -73,33 +70,25 @@ def prepare_dataset(cfg, tokenizer):
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
return train_dataset, eval_dataset, cfg.max_steps, prompters
|
||||
return train_dataset, eval_dataset, cfg.max_steps
|
||||
|
||||
with zero_first(is_main_process()):
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset, tokenizer
|
||||
)
|
||||
|
||||
if eval_dataset and cfg.sample_packing and cfg.eval_sample_packing is not False:
|
||||
total_eval_steps = calculate_total_num_steps(cfg, eval_dataset, update=False)
|
||||
if total_eval_steps == 0:
|
||||
raise ValueError(
|
||||
"eval dataset split is too small for sample_packing. You should set `eval_sample_packing: False`. "
|
||||
)
|
||||
|
||||
if cfg.max_steps:
|
||||
total_num_steps = min(
|
||||
calculate_total_num_steps(cfg, train_dataset), cfg.max_steps
|
||||
calculate_total_num_steps(cfg, train_dataset, tokenizer), cfg.max_steps
|
||||
)
|
||||
LOG.info(f"Maximum number of steps set at {total_num_steps}")
|
||||
else:
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset)
|
||||
return train_dataset, eval_dataset, total_num_steps, prompters
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
return train_dataset, eval_dataset, total_num_steps
|
||||
|
||||
|
||||
def load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||
) -> DatasetDict:
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
ds_hash = str(
|
||||
md5(
|
||||
@@ -107,12 +96,7 @@ def load_tokenized_prepared_datasets(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ "|".join(
|
||||
sorted(
|
||||
[
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
||||
for d in cfg.datasets
|
||||
]
|
||||
)
|
||||
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
|
||||
)
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
@@ -125,7 +109,6 @@ def load_tokenized_prepared_datasets(
|
||||
else Path(default_dataset_prepared_path) / ds_hash
|
||||
)
|
||||
dataset = None
|
||||
prompters = []
|
||||
use_auth_token = cfg.hf_use_auth_token
|
||||
try:
|
||||
if cfg.push_dataset_to_hub:
|
||||
@@ -164,99 +147,50 @@ def load_tokenized_prepared_datasets(
|
||||
yield dataset
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
for config_dataset in for_d_in_datasets(cfg.datasets):
|
||||
for d in for_d_in_datasets(cfg.datasets):
|
||||
ds: Union[Dataset, DatasetDict] = None
|
||||
ds_from_hub = False
|
||||
try:
|
||||
load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
d.path,
|
||||
name=d.name,
|
||||
streaming=True,
|
||||
token=use_auth_token,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
pass
|
||||
|
||||
ds_from_cloud = False
|
||||
storage_options = {}
|
||||
remote_file_system = None
|
||||
if config_dataset.path.startswith("s3://"):
|
||||
try:
|
||||
import aiobotocore.session # type: ignore
|
||||
import s3fs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"s3:// paths require aiobotocore and s3fs to be installed"
|
||||
) from exc
|
||||
|
||||
# Takes credentials from ~/.aws/credentials for default profile
|
||||
s3_session = aiobotocore.session.AioSession(profile="default")
|
||||
storage_options = {"session": s3_session}
|
||||
remote_file_system = s3fs.S3FileSystem(**storage_options)
|
||||
elif config_dataset.path.startswith(
|
||||
"gs://"
|
||||
) or config_dataset.path.startswith("gcs://"):
|
||||
try:
|
||||
import gcsfs # type: ignore
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"gs:// or gcs:// paths require gcsfs to be installed"
|
||||
) from exc
|
||||
|
||||
# gcsfs will use default credentials from the environment else anon
|
||||
# https://gcsfs.readthedocs.io/en/latest/#credentials
|
||||
storage_options = {"token": None}
|
||||
remote_file_system = gcsfs.GCSFileSystem(**storage_options)
|
||||
# TODO: Figure out how to get auth creds passed
|
||||
# elif config_dataset.path.startswith("adl://") or config_dataset.path.startswith("abfs://"):
|
||||
# try:
|
||||
# import adlfs
|
||||
# except ImportError as exc:
|
||||
# raise ImportError(
|
||||
# "adl:// or abfs:// paths require adlfs to be installed"
|
||||
# ) from exc
|
||||
|
||||
# # Gen 1
|
||||
# storage_options = {
|
||||
# "tenant_id": TENANT_ID,
|
||||
# "client_id": CLIENT_ID,
|
||||
# "client_secret": CLIENT_SECRET,
|
||||
# }
|
||||
# # Gen 2
|
||||
# storage_options = {
|
||||
# "account_name": ACCOUNT_NAME,
|
||||
# "account_key": ACCOUNT_KEY,
|
||||
# }
|
||||
|
||||
# remote_file_system = adlfs.AzureBlobFileSystem(**storage_options)
|
||||
try:
|
||||
if remote_file_system and remote_file_system.exists(
|
||||
config_dataset.path
|
||||
):
|
||||
ds_from_cloud = True
|
||||
except (FileNotFoundError, ConnectionError):
|
||||
except (FileNotFoundError, ValueError):
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(config_dataset.path)
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
if not d.type:
|
||||
ds = load_from_disk(d.path)
|
||||
else:
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
|
||||
ds_type = "json"
|
||||
if d.ds_type:
|
||||
ds_type = d.ds_type
|
||||
elif ".parquet" in d.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in d.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in d.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in d.path:
|
||||
ds_type = "text"
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
name=d.name,
|
||||
data_files=d.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
@@ -266,41 +200,25 @@ def load_tokenized_prepared_datasets(
|
||||
)
|
||||
elif ds_from_hub:
|
||||
ds = load_dataset(
|
||||
config_dataset.path,
|
||||
name=config_dataset.name,
|
||||
d.path,
|
||||
name=d.name,
|
||||
streaming=False,
|
||||
data_files=config_dataset.data_files,
|
||||
data_files=d.data_files,
|
||||
token=use_auth_token,
|
||||
)
|
||||
elif ds_from_cloud and remote_file_system:
|
||||
if remote_file_system.isdir(config_dataset.path):
|
||||
ds = load_from_disk(
|
||||
config_dataset.path,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif remote_file_system.isfile(config_dataset.path):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
if isinstance(d.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_id=d.path,
|
||||
repo_type="dataset",
|
||||
filename=config_dataset.data_files,
|
||||
filename=d.data_files,
|
||||
)
|
||||
elif isinstance(config_dataset.data_files, list):
|
||||
elif isinstance(d.data_files, list):
|
||||
fp = []
|
||||
for file in config_dataset.data_files:
|
||||
for file in d.data_files:
|
||||
fp.append(
|
||||
hf_hub_download(
|
||||
repo_id=config_dataset.path,
|
||||
repo_id=d.path,
|
||||
repo_type="dataset",
|
||||
filename=file,
|
||||
)
|
||||
@@ -310,27 +228,21 @@ def load_tokenized_prepared_datasets(
|
||||
"data_files must be either a string or list of strings"
|
||||
)
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
name=config_dataset.name,
|
||||
data_files=fp,
|
||||
streaming=False,
|
||||
split=None,
|
||||
"json", name=d.name, data_files=fp, streaming=False, split=None
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if config_dataset.shards:
|
||||
if d.shards:
|
||||
if "train" in ds:
|
||||
ds = ds.shuffle(seed=seed)["train"].shard(
|
||||
num_shards=config_dataset.shards, index=0
|
||||
num_shards=d.shards, index=0
|
||||
)
|
||||
else:
|
||||
ds = ds.shuffle(seed=seed).shard(
|
||||
num_shards=config_dataset.shards, index=0
|
||||
)
|
||||
ds = ds.shuffle(seed=seed).shard(num_shards=d.shards, index=0)
|
||||
|
||||
d_base_type = d_prompt_style = None
|
||||
d_type = config_dataset.type
|
||||
d_type = d.type
|
||||
if isinstance(d_type, str):
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
@@ -339,33 +251,115 @@ def load_tokenized_prepared_datasets(
|
||||
ds = ds["train"]
|
||||
elif (
|
||||
isinstance(ds, DatasetDict)
|
||||
and config_dataset.train_on_split
|
||||
and config_dataset.train_on_split in ds
|
||||
and d.train_on_split
|
||||
and d.train_on_split in ds
|
||||
):
|
||||
ds = ds[config_dataset.train_on_split]
|
||||
ds = ds[d.train_on_split]
|
||||
elif isinstance(ds, DatasetDict):
|
||||
raise ValueError(
|
||||
f"no train split found for dataset {config_dataset.path}, you may specify a split with 'train_on_split: `"
|
||||
f"no train split found for dataset {d.path}, you may specify a split with 'train_on_split: `"
|
||||
)
|
||||
if (
|
||||
"input_ids" in ds.features
|
||||
and "attention_mask" in ds.features
|
||||
and "labels" in ds.features
|
||||
):
|
||||
# dataset is already tokenized, just drop it straight in
|
||||
datasets.append(ds)
|
||||
elif isinstance(d.type, DictDefault):
|
||||
ds_strategy = load("user_defined", tokenizer, cfg, d.type.to_dict())
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif ds_strategy := load(d.type, tokenizer, cfg, d):
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "alpaca":
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
AlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "explainchoice":
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
MultipleChoiceExplainPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "concisechoice":
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
MultipleChoiceConcisePrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "summarizetldr":
|
||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
SummarizeTLDRPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "jeopardy":
|
||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||
JeopardyPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "oasst":
|
||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
AlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "gpteacher":
|
||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
GPTeacherPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "reflection":
|
||||
ds_strategy = AlpacaReflectionPTStrategy(
|
||||
ReflectAlpacaPrompter(d_prompt_style),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
suffix = ""
|
||||
if ":load_" in d.type:
|
||||
suffix = f" Did you mean {d.type.replace(':load_', '.load_')}?"
|
||||
LOG.error(f"unhandled prompt tokenization strategy: {d.type}. {suffix}")
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {d.type} {suffix}"
|
||||
)
|
||||
|
||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||
config_dataset=config_dataset,
|
||||
dataset=ds,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
d_base_type=d_base_type,
|
||||
d_prompt_style=d_prompt_style,
|
||||
)
|
||||
datasets.append(dataset_wrapper)
|
||||
prompters.append(dataset_prompter)
|
||||
|
||||
LOG.info("merging datasets")
|
||||
dataset = concatenate_datasets(datasets)
|
||||
|
||||
if len(datasets) > 1:
|
||||
LOG.info("shuffle merged datasets")
|
||||
dataset = dataset.shuffle(seed=seed)
|
||||
if cfg.local_rank == 0:
|
||||
if cfg.local_rank == 0 and cfg.dataset_prepared_path:
|
||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
if cfg.push_dataset_to_hub:
|
||||
@@ -376,32 +370,14 @@ def load_tokenized_prepared_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
)
|
||||
|
||||
return dataset, prompters
|
||||
|
||||
|
||||
def get_ds_type(config_dataset: DictDefault):
|
||||
"""
|
||||
Get the dataset type from the path if it's not specified
|
||||
"""
|
||||
ds_type = "json"
|
||||
if config_dataset.ds_type:
|
||||
ds_type = config_dataset.ds_type
|
||||
elif ".parquet" in config_dataset.path:
|
||||
ds_type = "parquet"
|
||||
elif ".arrow" in config_dataset.path:
|
||||
ds_type = "arrow"
|
||||
elif ".csv" in config_dataset.path:
|
||||
ds_type = "csv"
|
||||
elif ".txt" in config_dataset.path:
|
||||
ds_type = "text"
|
||||
return ds_type
|
||||
return dataset
|
||||
|
||||
|
||||
def load_prepare_datasets(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
cfg,
|
||||
default_dataset_prepared_path,
|
||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||
) -> Tuple[Dataset, Dataset]:
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
@@ -410,7 +386,6 @@ def load_prepare_datasets(
|
||||
) # make sure we don't accidentally set it larger than sequence_len
|
||||
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
prompters: List[Prompter] = []
|
||||
if cfg.max_packed_sequence_len is not None:
|
||||
# see if we can go ahead and load the stacked dataset
|
||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||
@@ -466,7 +441,7 @@ def load_prepare_datasets(
|
||||
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
|
||||
)
|
||||
else:
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
@@ -508,7 +483,7 @@ def load_prepare_datasets(
|
||||
private=True,
|
||||
)
|
||||
else:
|
||||
dataset, prompters = load_tokenized_prepared_datasets(
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
@@ -544,13 +519,14 @@ def load_prepare_datasets(
|
||||
train_fingerprint = md5(to_hash_train)
|
||||
test_fingerprint = md5(to_hash_test)
|
||||
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
with zero_first(is_main_process()):
|
||||
dataset = dataset.train_test_split(
|
||||
test_size=cfg.val_set_size,
|
||||
shuffle=False,
|
||||
seed=cfg.seed or 42,
|
||||
train_new_fingerprint=train_fingerprint,
|
||||
test_new_fingerprint=test_fingerprint,
|
||||
)
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
eval_dataset = dataset["test"]
|
||||
@@ -558,144 +534,7 @@ def load_prepare_datasets(
|
||||
train_dataset = dataset
|
||||
eval_dataset = None
|
||||
|
||||
return train_dataset, eval_dataset, prompters
|
||||
|
||||
|
||||
def get_dataset_wrapper(
|
||||
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
||||
):
|
||||
dataset_wrapper = None
|
||||
dataset_prompter = None
|
||||
|
||||
if (
|
||||
"input_ids" in dataset.features
|
||||
and "attention_mask" in dataset.features
|
||||
and "labels" in dataset.features
|
||||
):
|
||||
# dataset is already tokenized, just drop it straight in
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = dataset
|
||||
elif isinstance(config_dataset.type, DictDefault):
|
||||
ds_strategy = load(
|
||||
"user_defined", tokenizer, cfg, config_dataset.type.to_dict()
|
||||
)
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
elif ds_strategy := load(config_dataset.type, tokenizer, cfg, config_dataset):
|
||||
dataset_prompter = UnsupportedPrompter()
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
elif d_base_type == "alpaca":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "explainchoice":
|
||||
dataset_prompter = MultipleChoiceExplainPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "concisechoice":
|
||||
dataset_prompter = MultipleChoiceConcisePrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "summarizetldr":
|
||||
dataset_prompter = SummarizeTLDRPrompter(d_prompt_style)
|
||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "jeopardy":
|
||||
dataset_prompter = JeopardyPrompter(d_prompt_style)
|
||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "oasst":
|
||||
dataset_prompter = AlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "gpteacher":
|
||||
dataset_prompter = GPTeacherPrompter(d_prompt_style)
|
||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
elif d_base_type == "reflection":
|
||||
dataset_prompter = ReflectAlpacaPrompter(d_prompt_style)
|
||||
ds_strategy = AlpacaReflectionPTStrategy(
|
||||
dataset_prompter,
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(
|
||||
ds_strategy, dataset, process_count=cfg.dataset_processes
|
||||
)
|
||||
dataset_wrapper = ds_wrapper
|
||||
else:
|
||||
suffix = ""
|
||||
if ":load_" in config_dataset.type:
|
||||
suffix = f" Did you mean {config_dataset.type.replace(':load_', '.load_')}?"
|
||||
LOG.error(
|
||||
f"unhandled prompt tokenization strategy: {config_dataset.type}. {suffix}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"unhandled prompt tokenization strategy: {config_dataset.type} {suffix}"
|
||||
)
|
||||
|
||||
return dataset_wrapper, dataset_prompter
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def encode_pretraining(
|
||||
|
||||
302
src/axolotl/utils/dataloader.py
Normal file
302
src/axolotl/utils/dataloader.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# pylint: skip-file
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import DistributedSampler, Sampler
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[Any] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result, len(a)
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
"""
|
||||
:param lengths: array of lengths of each sample
|
||||
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||
:param rank: rank for this process
|
||||
:param c: length of tokens per batch
|
||||
:param n: number of ranks
|
||||
:return:
|
||||
"""
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
result_totseqs = []
|
||||
|
||||
while True:
|
||||
# binary search [left, right)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||
left = mid
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# use length left
|
||||
batch, tot_seqs = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
# add total seqs for all ranks
|
||||
result_totseqs.append(tot_seqs)
|
||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
return result, result_totseqs, s, len(result) * c * n
|
||||
|
||||
|
||||
def chunk(iterable, n):
|
||||
"""
|
||||
Chunk data into tuples of length n
|
||||
"""
|
||||
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||
if n < 1:
|
||||
raise ValueError("n must be at least one")
|
||||
it = iter(iterable)
|
||||
while batch := tuple(itertools.islice(it, n)):
|
||||
yield batch
|
||||
|
||||
|
||||
def hash_indices(lst: List[int]) -> str:
|
||||
# Convert the list of integers to a string representation
|
||||
concatenated = ",".join(map(str, lst))
|
||||
|
||||
# Generate the hash
|
||||
sha256 = hashlib.sha256()
|
||||
sha256.update(concatenated.encode())
|
||||
|
||||
return sha256.hexdigest()
|
||||
|
||||
|
||||
class MultipackDistributedDataloader:
|
||||
"""Unpadded data loading using Multipack.
|
||||
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset: Any,
|
||||
collate_fn: Callable,
|
||||
seq_max_length: int = 2048,
|
||||
batch_size: int = 1,
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
self.lengths = (
|
||||
dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
)
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||
assert batch_size >= sample_packing_seq_len_multiplier
|
||||
self.sampler = sampler
|
||||
self.batch_size = batch_size
|
||||
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
||||
self.seq_max_length = seq_max_length
|
||||
self.batch_max_length = batch_size * seq_max_length
|
||||
self.collate_fn = collate_fn
|
||||
|
||||
self.num_replicas = 1
|
||||
self.rank = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
if self.sampler:
|
||||
indices = [idx for idx in self.sampler]
|
||||
else:
|
||||
indices = range(0, len(self.dataset))
|
||||
|
||||
LOG.info(hash_indices(indices))
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, totseqs, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
# c=self.batch_max_length,
|
||||
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches, totseqs
|
||||
|
||||
def __iter__(self):
|
||||
if hasattr(self.sampler, "set_epoch"):
|
||||
new_epoch = self.sampler.epoch + 1
|
||||
self.sampler.set_epoch(new_epoch)
|
||||
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||
all_batches, _ = self.generate_batches(set_stats=True)
|
||||
features = self.dataset.features.keys()
|
||||
len_remaining = self._len_est()
|
||||
for batches in chunk(
|
||||
all_batches, self.batch_size // self.sample_packing_seq_len_multiplier
|
||||
):
|
||||
chunked_data = []
|
||||
attn_mask_cum_idx = 0
|
||||
for batch in batches:
|
||||
concatenated = {}
|
||||
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||
for feature in features:
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||
for idx, item in enumerate(batched_data)
|
||||
if feature in item
|
||||
]
|
||||
attn_mask_cum_idx += len(batched_data)
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature])
|
||||
for item in batched_data
|
||||
if feature in item
|
||||
]
|
||||
concatenated[feature] = np.concatenate(arrays)
|
||||
chunked_data.append(concatenated)
|
||||
yield self.collate_fn(chunked_data)
|
||||
len_remaining -= 1
|
||||
if not len_remaining:
|
||||
return
|
||||
# yield a no-op for cases where we don't have any data left to pack
|
||||
for i in range(0, len_remaining):
|
||||
yield self.collate_fn(
|
||||
[
|
||||
{
|
||||
"input_ids": [0],
|
||||
"labels": [-100],
|
||||
"attention_mask": [True],
|
||||
"position_ids": [0],
|
||||
}
|
||||
]
|
||||
)
|
||||
|
||||
def _len_est(self):
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // self.device_count
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return (
|
||||
math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.seq_max_length
|
||||
// self.batch_size
|
||||
)
|
||||
- 1
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||
# the same share of total tokens
|
||||
# if not self.eff_total_used:
|
||||
# batches, _ = self.generate_batches(set_stats=True)
|
||||
# LOG.info(
|
||||
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
# f"actual packing efficiency: {self.efficiency()}"
|
||||
# )
|
||||
return max(1, self._len_est())
|
||||
|
||||
def len_w_stats(self):
|
||||
if not self.eff_total_used:
|
||||
batches, _ = self.generate_batches(set_stats=True)
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"actual packing efficiency: {self.efficiency()}"
|
||||
)
|
||||
return max(1, self._len_est())
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
@@ -50,17 +50,6 @@ def get_world_size():
|
||||
return int(os.getenv("WORLD_SIZE", "1"))
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_only():
|
||||
"""
|
||||
Context manager that only runs the enclosed block on the main rank.
|
||||
"""
|
||||
if is_main_process():
|
||||
yield
|
||||
else:
|
||||
yield None
|
||||
|
||||
|
||||
@contextmanager
|
||||
def zero_first(is_main):
|
||||
"""
|
||||
|
||||
@@ -17,6 +17,7 @@ from transformers import ( # noqa: F401
|
||||
AutoTokenizer,
|
||||
BitsAndBytesConfig,
|
||||
GPTQConfig,
|
||||
LlamaConfig,
|
||||
PreTrainedModel,
|
||||
PreTrainedTokenizerBase,
|
||||
)
|
||||
@@ -28,40 +29,12 @@ from axolotl.utils.dict import DictDefault
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
||||
quant_config_exists = hasattr(model_config, "quantization_config")
|
||||
quant_config_method_is_gptq = (
|
||||
quant_config_exists
|
||||
and "quant_method" in model_config.quantization_config
|
||||
and model_config.quantization_config["quant_method"] == "gptq"
|
||||
)
|
||||
|
||||
if cfg.gptq and not quant_config_method_is_gptq:
|
||||
raise ValueError(
|
||||
"model_config.quantization_config is not set or quant_method is not set to gptq. "
|
||||
"Please make sure to point to a GPTQ model."
|
||||
)
|
||||
|
||||
if not cfg.gptq and quant_config_exists:
|
||||
raise ValueError(
|
||||
"model_config.quantization_config is set but `gptq` flag is not. "
|
||||
"Please use the `gptq` flag to train quantized model or point to a non-quantized model."
|
||||
)
|
||||
|
||||
|
||||
def load_model_config(cfg):
|
||||
model_config_name = cfg.base_model_config or cfg.base_model
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
trust_remote_code: bool = False or cfg.trust_remote_code
|
||||
return AutoConfig.from_pretrained(
|
||||
model_config_name, trust_remote_code=trust_remote_code
|
||||
)
|
||||
if cfg.model_config:
|
||||
for key, val in cfg.model_config.items():
|
||||
setattr(model_config, key, val)
|
||||
|
||||
check_model_config(cfg, model_config)
|
||||
|
||||
return model_config
|
||||
|
||||
|
||||
def load_tokenizer(cfg):
|
||||
@@ -78,7 +51,7 @@ def load_tokenizer(cfg):
|
||||
if cfg.tokenizer_type:
|
||||
tokenizer_cls = getattr(transformers, cfg.tokenizer_type)
|
||||
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config or cfg.base_model
|
||||
tokenizer_config = cfg.tokenizer_config or cfg.base_model_config
|
||||
tokenizer = tokenizer_cls.from_pretrained(
|
||||
tokenizer_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
@@ -99,6 +72,11 @@ def load_tokenizer(cfg):
|
||||
# set a pad_token, but use eos_token so we don't add a new token
|
||||
tokenizer.pad_token = LLAMA_DEFAULT_EOS_TOKEN
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast":
|
||||
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
@@ -107,18 +85,6 @@ def load_tokenizer(cfg):
|
||||
if cfg.is_mistral_derived_model and cfg.flash_attention and not cfg.sample_packing:
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
# Qwen base only has single token, so we need to set the special tokens
|
||||
if cfg.is_qwen_derived_model:
|
||||
token_ids = ["bos_token_id", "eos_token_id", "pad_token_id", "unk_token_id"]
|
||||
for attr_name in token_ids:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, tokenizer.eod_id)
|
||||
|
||||
token_names = ["bos_token", "eos_token", "pad_token", "unk_token"]
|
||||
for attr_name in token_names:
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
if cfg.special_tokens:
|
||||
for k, val in cfg.special_tokens.items():
|
||||
tokenizer.add_special_tokens(
|
||||
@@ -132,11 +98,6 @@ def load_tokenizer(cfg):
|
||||
]
|
||||
)
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
@@ -149,6 +110,7 @@ def load_model(
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
base_model_config = cfg.base_model_config
|
||||
model_type = cfg.model_type
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
@@ -239,7 +201,6 @@ def load_model(
|
||||
model_kwargs = {}
|
||||
|
||||
model_kwargs["device_map"] = cfg.device_map
|
||||
model_kwargs["max_memory"] = cfg.max_memory
|
||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||
|
||||
if cfg.model_revision:
|
||||
@@ -277,27 +238,20 @@ def load_model(
|
||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
config_kwargs = {}
|
||||
if cfg.rope_scaling:
|
||||
config_kwargs["rope_scaling"] = cfg.rope_scaling
|
||||
config = LlamaConfig.from_pretrained(
|
||||
base_model_config,
|
||||
**config_kwargs,
|
||||
)
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
if cfg.flash_attention and not inference:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_mlp_with_swiglu,
|
||||
replace_llama_qkv_with_fused,
|
||||
)
|
||||
|
||||
if cfg.flash_attn_fuse_mlp:
|
||||
LOG.info("patching with SwiGLU")
|
||||
replace_llama_mlp_with_swiglu(model)
|
||||
|
||||
if cfg.flash_attn_fuse_qkv:
|
||||
LOG.info("patching with fused QKV")
|
||||
replace_llama_qkv_with_fused(model)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
# This is a WIP, still an issue with the backward pass
|
||||
# RuntimeError: grad can be implicitly created only for scalar outputs
|
||||
@@ -324,10 +278,10 @@ def load_model(
|
||||
# device=cfg.device,
|
||||
# )
|
||||
# model.train() # sets to train instead of eval mode
|
||||
elif model_type == "PhiForCausalLM":
|
||||
from axolotl.models.phi import PhiForCausalLM
|
||||
elif model_type == "MixFormerSequentialForCausalLM":
|
||||
from axolotl.models.phi import MixFormerSequentialForCausalLM
|
||||
|
||||
model = PhiForCausalLM.from_pretrained(
|
||||
model = MixFormerSequentialForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
@@ -337,55 +291,66 @@ def load_model(
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
config = AutoConfig.from_pretrained(
|
||||
base_model,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
)
|
||||
# Shouldn't be a problem most of the time. will obviously error if the model doesn't support this
|
||||
# when training starts
|
||||
if (
|
||||
hasattr(model_config, "max_seq_len")
|
||||
and model_config.max_seq_len
|
||||
and cfg.sequence_len > model_config.max_seq_len
|
||||
hasattr(config, "max_seq_len")
|
||||
and config.max_seq_len
|
||||
and cfg.sequence_len > config.max_seq_len
|
||||
):
|
||||
model_config.max_seq_len = cfg.sequence_len
|
||||
config.max_seq_len = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
elif (
|
||||
hasattr(model_config, "max_sequence_length")
|
||||
and model_config.max_sequence_length
|
||||
and cfg.sequence_len > model_config.max_sequence_length
|
||||
hasattr(config, "max_sequence_length")
|
||||
and config.max_sequence_length
|
||||
and cfg.sequence_len > config.max_sequence_length
|
||||
):
|
||||
model_config.max_sequence_length = cfg.sequence_len
|
||||
config.max_sequence_length = cfg.sequence_len
|
||||
LOG.warning(f"increasing context length to {cfg.sequence_len}")
|
||||
if cfg.gptq:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
config=config,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
else:
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
config=config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
except Exception as err: # pylint: disable=broad-exception-caught
|
||||
LOG.error(
|
||||
"Exception raised attempting to load model, retrying with AutoModelForCausalLM"
|
||||
)
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
@@ -407,20 +372,6 @@ def load_model(
|
||||
)
|
||||
model.config.max_position_embeddings = cfg.sequence_len
|
||||
|
||||
if (
|
||||
hasattr(model.config, "bos_token_id")
|
||||
and model.config.bos_token_id
|
||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||
):
|
||||
model.config.bos_token_id = tokenizer.bos_token_id
|
||||
|
||||
if (
|
||||
hasattr(model.config, "eos_token_id")
|
||||
and model.config.eos_token_id
|
||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
@@ -431,27 +382,20 @@ def load_model(
|
||||
if model_config.model_type == "btlm":
|
||||
# don't upcast lm_head for btlm
|
||||
continue
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float32)
|
||||
|
||||
needs_fa2_dtype = cfg.adapter or cfg.fsdp
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
|
||||
if cfg.model_config_type == "qwen" and cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
skip_prepare_model_for_kbit_training = True
|
||||
|
||||
if (cfg.adapter == "lora" and load_in_8bit) or (
|
||||
cfg.adapter == "qlora" and cfg.load_in_4bit
|
||||
):
|
||||
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
|
||||
if cfg.gradient_checkpointing:
|
||||
model.gradient_checkpointing_enable()
|
||||
if not skip_prepare_model_for_kbit_training:
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
model = prepare_model_for_kbit_training(
|
||||
model, use_gradient_checkpointing=cfg.gradient_checkpointing
|
||||
)
|
||||
needs_fa2_dtype = True
|
||||
|
||||
# LlamaRMSNorm layers are in fp32 after kbit_training or full finetune, so we need to
|
||||
@@ -470,7 +414,14 @@ def load_model(
|
||||
if cfg.ddp and not load_in_8bit:
|
||||
model.to(f"cuda:{cfg.local_rank}")
|
||||
|
||||
if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) == 1:
|
||||
if (
|
||||
torch.cuda.device_count() > 1
|
||||
and int(os.getenv("WORLD_SIZE", "1")) > 1
|
||||
and (cfg.load_in_4bit)
|
||||
):
|
||||
# llama is PROBABLY model parallelizable, but the default isn't that it is
|
||||
# so let's only set it for the 4bit, see
|
||||
# https://github.com/johnsmith0031/alpaca_lora_4bit/blob/08b3fca4a4a9e0d3945be1bab4529f100a428636/finetune.py#L130-L133
|
||||
setattr(model, "is_parallelizable", True)
|
||||
setattr(model, "model_parallel", True)
|
||||
|
||||
@@ -536,11 +487,7 @@ def find_all_linear_names(model):
|
||||
cls = (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt, torch.nn.Linear, QuantLinear)
|
||||
lora_module_names = set()
|
||||
for name, module in model.named_modules():
|
||||
if (
|
||||
isinstance(module, cls)
|
||||
or "Linear" in module.__class__.__name__
|
||||
and module.__class__.__name__ not in ("LlamaLinearScalingRotaryEmbedding",)
|
||||
):
|
||||
if isinstance(module, cls) or "Linear" in module.__class__.__name__:
|
||||
names = name.split(".")
|
||||
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
|
||||
|
||||
|
||||
@@ -1,4 +0,0 @@
|
||||
"""
|
||||
axolotl samplers module
|
||||
"""
|
||||
from .multipack import MultipackBatchSampler # noqa: F401
|
||||
@@ -1,196 +0,0 @@
|
||||
# pylint: skip-file
|
||||
"""
|
||||
Multipack Batch Sampler
|
||||
"""
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Iterable, List, Union
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
from torch.utils.data import BatchSampler, Sampler
|
||||
|
||||
LOG = logging.getLogger("axolotl.utils.samplers.multipack")
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[Any] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
|
||||
while True:
|
||||
# binary search [l, r)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
mid = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||
left = mid
|
||||
else:
|
||||
right = mid
|
||||
|
||||
# use length l
|
||||
batch = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
assert len(batch) <= n
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
class MultipackBatchSampler(BatchSampler):
|
||||
"""
|
||||
Batch Sampler class for multipack
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sampler: Union[Sampler[int], Iterable[int]],
|
||||
batch_size: int,
|
||||
drop_last: bool,
|
||||
batch_max_len: int,
|
||||
lengths: np.ndarray,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = None
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
indices = [idx for idx in self.sampler]
|
||||
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=0,
|
||||
c=self.batch_max_len,
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return len(batches)
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
|
||||
def __len__(self):
|
||||
self.num_batches()
|
||||
return self._len_est()
|
||||
|
||||
def _len_est(self):
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
lengths_sum = np.sum(self.lengths)
|
||||
lengths_sum_per_device = lengths_sum // world_size
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
)
|
||||
|
||||
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||
return max(
|
||||
0,
|
||||
(
|
||||
world_size
|
||||
* math.floor(
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// self.batch_max_len
|
||||
)
|
||||
- 1
|
||||
),
|
||||
)
|
||||
@@ -34,5 +34,6 @@ def check_example_labels(example, tokenizer, text_only=False):
|
||||
delimiter = "" if text_only else " "
|
||||
LOG.info(delimiter.join(colored_tokens))
|
||||
LOG.info("\n\n\n")
|
||||
print(" ".join(colored_tokens))
|
||||
|
||||
return " ".join(colored_tokens)
|
||||
|
||||
@@ -1,22 +1,50 @@
|
||||
"""Module containing the Trainer class and related functions"""
|
||||
import importlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from typing import List
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.cuda
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import set_caching_enabled
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import (
|
||||
DataLoader,
|
||||
DistributedSampler,
|
||||
RandomSampler,
|
||||
SequentialSampler,
|
||||
)
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import SequentialDistributedSampler
|
||||
|
||||
from axolotl.core.trainer_builder import HFCausalTrainerBuilder
|
||||
from axolotl.utils.distributed import is_main_process, reduce_and_broadcast, zero_first
|
||||
from axolotl.utils.samplers import MultipackBatchSampler
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||
from axolotl.utils.distributed import (
|
||||
is_distributed,
|
||||
is_main_process,
|
||||
reduce_and_broadcast,
|
||||
zero_first,
|
||||
)
|
||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||
|
||||
LOG = get_logger("axolotl")
|
||||
LOG = logging.getLogger("axolotl")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -81,6 +109,269 @@ def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||
return weighted_cross_entropy(logits, labels, weights)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
"""
|
||||
|
||||
lr_quadratic_warmup: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
)
|
||||
sample_packing_efficiency: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
sample_packing_seq_len_multiplier: int = field(
|
||||
default=1,
|
||||
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||
)
|
||||
relora_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how often to reset for ReLoRA"},
|
||||
)
|
||||
relora_warmup_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
bench_dataset: Optional[str] = field(
|
||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||
metadata={
|
||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||
},
|
||||
)
|
||||
do_bench_eval: Optional[bool] = field(
|
||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||
)
|
||||
max_bench_samples: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||
},
|
||||
)
|
||||
bench_source_max_len: int = field(
|
||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def __init__(self, *args, bench_data_collator=None, **kwargs):
|
||||
self.bench_data_collator = bench_data_collator
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
|
||||
):
|
||||
"""
|
||||
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
|
||||
passed as an argument.
|
||||
|
||||
Args:
|
||||
num_training_steps (int): The number of training steps to do.
|
||||
optimizer (torch.optim.Optimizer): The training optimizer
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
|
||||
# fmt: on
|
||||
if (
|
||||
self.args.lr_scheduler_type == "cosine"
|
||||
and self.args.lr_quadratic_warmup is True
|
||||
):
|
||||
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
|
||||
optimizer,
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
|
||||
num_training_steps=num_training_steps,
|
||||
)
|
||||
else:
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if (
|
||||
self.args.world_size > 1
|
||||
and self.args.sample_packing
|
||||
and self.args.eval_sample_packing is not False
|
||||
):
|
||||
return SequentialDistributedSampler(
|
||||
eval_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
batch_size=self.args.per_device_eval_batch_size,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
batch_size=self._train_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(
|
||||
self, eval_dataset: Optional[Dataset] = None
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
batch_size=self.args.eval_batch_size,
|
||||
seq_max_length=self.args.max_seq_length,
|
||||
collate_fn=self.data_collator,
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
# labels = inputs.pop("labels")
|
||||
# outputs = model(**inputs)
|
||||
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||
# return (loss, outputs) if return_outputs else loss
|
||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps = self.args.get_warmup_steps(num_training_steps)
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
self.lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
pct_start=pct_start,
|
||||
div_factor=6,
|
||||
)
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class ReLoRATrainer(AxolotlTrainer):
|
||||
"""
|
||||
Trainer subclass that uses the OneCycleLR scheduler
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.lr_scheduler = None
|
||||
|
||||
def create_scheduler(
|
||||
self,
|
||||
num_training_steps: int,
|
||||
optimizer: Optional[torch.optim.Optimizer] = None,
|
||||
):
|
||||
optimizer = self.optimizer if optimizer is None else optimizer
|
||||
lr_scheduler = super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
if self.args.relora_steps:
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
self.lr_scheduler = lr_scheduler
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
def add_position_ids(sample):
|
||||
sample_len = len(sample["input_ids"])
|
||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||
@@ -131,9 +422,7 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
)
|
||||
|
||||
# Phi doesn't want the attention_mask feature when training
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
||||
):
|
||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__:
|
||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||
if eval_dataset:
|
||||
eval_dataset = eval_dataset.remove_columns("attention_mask")
|
||||
@@ -141,35 +430,30 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
if not cfg.total_num_tokens:
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.debug(f"total_num_tokens: {total_num_tokens}", main_process_only=True)
|
||||
if update:
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if not cfg.total_supervised_tokens:
|
||||
total_supervised_tokens = (
|
||||
train_dataset.data.column("labels")
|
||||
.to_pandas()
|
||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||
.sum()
|
||||
)
|
||||
LOG.debug(
|
||||
f"`total_supervised_tokens: {total_supervised_tokens}`",
|
||||
main_process_only=True,
|
||||
)
|
||||
if update:
|
||||
cfg.total_supervised_tokens = total_supervised_tokens
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
if cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
if not cfg.total_num_tokens:
|
||||
LOG.info("calculating total_num_tokens")
|
||||
total_num_tokens = np.sum(
|
||||
train_dataset.data.column("input_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||
.values
|
||||
)
|
||||
LOG.info(f"total_num_tokens: {total_num_tokens}")
|
||||
cfg.total_num_tokens = total_num_tokens
|
||||
|
||||
if not cfg.total_supervised_tokens:
|
||||
total_supervised_tokens = (
|
||||
train_dataset.data.column("labels")
|
||||
.to_pandas()
|
||||
.apply(lambda x: np.sum(np.array(x) != -100))
|
||||
.sum()
|
||||
)
|
||||
LOG.info(f"`total_supervised_tokens: {total_supervised_tokens}`")
|
||||
cfg.total_supervised_tokens = total_supervised_tokens
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_steps = (
|
||||
@@ -187,41 +471,40 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
)
|
||||
* cfg.num_epochs
|
||||
)
|
||||
LOG.debug(
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}",
|
||||
main_process_only=True,
|
||||
LOG.info(
|
||||
f"total_num_tokens: {cfg.total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||
)
|
||||
else:
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=cfg.micro_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=cfg.micro_batch_size
|
||||
* (cfg.max_packed_sequence_len or cfg.sequence_len),
|
||||
lengths=(
|
||||
train_dataset.data.column("position_ids")
|
||||
.to_pandas()
|
||||
.apply(lambda x: x[-1] + 1)
|
||||
.values
|
||||
),
|
||||
)
|
||||
if cfg.world_size > 1 and is_distributed():
|
||||
sampler = DistributedSampler(
|
||||
train_dataset,
|
||||
num_replicas=cfg.world_size,
|
||||
rank=dist.get_rank(),
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
else:
|
||||
sampler = RandomSampler(train_dataset)
|
||||
|
||||
data_loader = DataLoader(
|
||||
train_dataset.remove_columns(["length"]),
|
||||
batch_sampler=sampler,
|
||||
data_loader = MultipackDistributedDataloader(
|
||||
train_dataset,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||
collate_fn=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
),
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
actual_eff = sampler.efficiency()
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
data_loader_len = data_loader.len_w_stats()
|
||||
actual_eff = data_loader.efficiency()
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
# on the agreed on value for sample_packing_eff_est
|
||||
total_num_steps = int(
|
||||
math.floor(
|
||||
data_loader_len
|
||||
* cfg.num_epochs
|
||||
/ int(os.environ.get("WORLD_SIZE", 1))
|
||||
)
|
||||
)
|
||||
total_num_steps = int(math.floor(data_loader_len * cfg.num_epochs))
|
||||
|
||||
def calc_sample_packing_eff_est(estimates: List[float]):
|
||||
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
|
||||
@@ -234,22 +517,13 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
sample_packing_eff_est = (
|
||||
math.ceil(sample_packing_actual_eff_all * 100.0) / 100.0
|
||||
)
|
||||
if update:
|
||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||
LOG.debug(
|
||||
f"sample_packing_eff_est: {cfg.sample_packing_eff_est}",
|
||||
main_process_only=True,
|
||||
)
|
||||
cfg.sample_packing_eff_est = sample_packing_eff_est
|
||||
LOG.info(f"sample_packing_eff_est: {cfg.sample_packing_eff_est}")
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(
|
||||
len(train_dataset)
|
||||
* cfg.num_epochs
|
||||
/ int(os.environ.get("WORLD_SIZE", 1))
|
||||
/ cfg.batch_size
|
||||
)
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
LOG.debug(f"total_num_steps: {total_num_steps}", main_process_only=True)
|
||||
LOG.info(f"total_num_steps: {total_num_steps}")
|
||||
return total_num_steps
|
||||
|
||||
|
||||
@@ -267,16 +541,248 @@ def setup_fsdp_envs(cfg):
|
||||
] = cfg.fsdp_config.fsdp_transformer_layer_cls_to_wrap
|
||||
|
||||
|
||||
def prepare_optim_env(cfg):
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
if cfg.fsdp:
|
||||
setup_fsdp_envs(cfg)
|
||||
elif cfg.deepspeed:
|
||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||
|
||||
warmup_steps = (
|
||||
cfg.warmup_steps
|
||||
if cfg.warmup_steps is not None
|
||||
else min(int(0.03 * total_num_steps), 100)
|
||||
)
|
||||
logging_steps = (
|
||||
cfg.logging_steps
|
||||
if cfg.logging_steps is not None
|
||||
else max(min(int(0.005 * total_num_steps), 10), 1)
|
||||
)
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
trainer_builder = HFCausalTrainerBuilder(cfg, model, tokenizer)
|
||||
trainer_builder.train_dataset = train_dataset
|
||||
trainer_builder.eval_dataset = eval_dataset
|
||||
training_arguments_kwargs = {}
|
||||
if cfg.bf16 == "full":
|
||||
training_arguments_kwargs["bf16_full_eval"] = True
|
||||
else:
|
||||
training_arguments_kwargs["bf16"] = cfg.bf16
|
||||
training_arguments_kwargs["fp16"] = (cfg.fp16 and not cfg.bf16) or False
|
||||
training_arguments_kwargs["tf32"] = cfg.tf32
|
||||
training_arguments_kwargs["warmup_steps"] = warmup_steps
|
||||
training_arguments_kwargs["logging_steps"] = logging_steps
|
||||
|
||||
return trainer_builder.build(total_num_steps)
|
||||
if cfg.seed:
|
||||
training_arguments_kwargs["seed"] = cfg.seed
|
||||
|
||||
if cfg.gradient_checkpointing:
|
||||
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
|
||||
if cfg.fsdp:
|
||||
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
||||
if cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
||||
|
||||
# deepspeed
|
||||
if cfg.deepspeed:
|
||||
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
|
||||
|
||||
if cfg.lr_quadratic_warmup is not None:
|
||||
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
|
||||
|
||||
if cfg.adam_beta1:
|
||||
training_arguments_kwargs["adam_beta1"] = cfg.adam_beta1
|
||||
if cfg.adam_beta2:
|
||||
training_arguments_kwargs["adam_beta2"] = cfg.adam_beta2
|
||||
if cfg.adam_epsilon:
|
||||
training_arguments_kwargs["adam_epsilon"] = cfg.adam_epsilon
|
||||
if cfg.max_grad_norm:
|
||||
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
|
||||
|
||||
if cfg.hub_model_id:
|
||||
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id
|
||||
training_arguments_kwargs["push_to_hub"] = True
|
||||
training_arguments_kwargs["hub_private_repo"] = True
|
||||
|
||||
if cfg.hub_strategy:
|
||||
training_arguments_kwargs["hub_strategy"] = cfg.hub_strategy
|
||||
|
||||
if cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
training_arguments_kwargs[
|
||||
"sample_packing_efficiency"
|
||||
] = cfg.sample_packing_eff_est
|
||||
|
||||
if cfg.eval_steps:
|
||||
training_arguments_kwargs["evaluation_strategy"] = "steps"
|
||||
training_arguments_kwargs["eval_steps"] = cfg.eval_steps
|
||||
elif cfg.evaluation_strategy:
|
||||
training_arguments_kwargs["evaluation_strategy"] = cfg.evaluation_strategy
|
||||
elif cfg.val_set_size == 0:
|
||||
# no eval set, so don't eval
|
||||
training_arguments_kwargs["evaluation_strategy"] = "no"
|
||||
else:
|
||||
# we have an eval set, but no steps defined, default to use epoch
|
||||
training_arguments_kwargs["evaluation_strategy"] = "epoch"
|
||||
|
||||
if cfg.save_steps:
|
||||
training_arguments_kwargs["save_strategy"] = "steps"
|
||||
training_arguments_kwargs["save_steps"] = cfg.save_steps
|
||||
elif cfg.save_strategy:
|
||||
training_arguments_kwargs["save_strategy"] = cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_arguments_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
if cfg.do_bench_eval:
|
||||
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
||||
if cfg.bench_dataset:
|
||||
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||
if cfg.metric_for_best_model:
|
||||
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
|
||||
if cfg.greater_is_better:
|
||||
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
||||
|
||||
if cfg.torch_compile:
|
||||
if torch.__version__ < "2.1.0": # pylint: disable=protected-access
|
||||
LOG.warning("torch>=2.1.0 required for torch_compile to work properly")
|
||||
else:
|
||||
import torch._dynamo # pylint: disable=redefined-outer-name
|
||||
|
||||
torch._dynamo.config.suppress_errors = ( # pylint: disable=protected-access
|
||||
True
|
||||
)
|
||||
training_arguments_kwargs["torch_compile"] = cfg.torch_compile
|
||||
if cfg.torch_compile_backend:
|
||||
training_arguments_kwargs[
|
||||
"torch_compile_backend"
|
||||
] = cfg.torch_compile_backend
|
||||
|
||||
# DDP Config
|
||||
if cfg.ddp_timeout:
|
||||
training_arguments_kwargs["ddp_timeout"] = cfg.ddp_timeout
|
||||
# see https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html
|
||||
if cfg.ddp_bucket_cap_mb:
|
||||
training_arguments_kwargs["ddp_bucket_cap_mb"] = cfg.ddp_bucket_cap_mb
|
||||
if cfg.ddp_broadcast_buffers is not None:
|
||||
training_arguments_kwargs["ddp_broadcast_buffers"] = cfg.ddp_broadcast_buffers
|
||||
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
max_steps=total_num_steps if cfg.max_steps else -1,
|
||||
max_seq_length=cfg.sequence_len,
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size,
|
||||
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
eval_accumulation_steps=cfg.gradient_accumulation_steps,
|
||||
num_train_epochs=cfg.num_epochs,
|
||||
learning_rate=cfg.learning_rate,
|
||||
output_dir=cfg.output_dir,
|
||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||
load_best_model_at_end=(
|
||||
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
||||
and cfg.val_set_size > 0
|
||||
and cfg.save_steps
|
||||
and cfg.eval_steps
|
||||
and cfg.save_steps % cfg.eval_steps == 0
|
||||
)
|
||||
or False,
|
||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||
group_by_length=cfg.group_by_length,
|
||||
report_to="wandb" if cfg.use_wandb else None,
|
||||
run_name=cfg.wandb_run_id if cfg.use_wandb else None,
|
||||
optim=cfg.optimizer if cfg.optimizer else "adamw_hf",
|
||||
lr_scheduler_type=cfg.lr_scheduler
|
||||
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||
else "cosine",
|
||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||
eval_sample_packing=cfg.eval_sample_packing,
|
||||
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||
relora_steps=cfg.relora_steps,
|
||||
relora_warmup_steps=cfg.relora_warmup_steps,
|
||||
**training_arguments_kwargs,
|
||||
)
|
||||
|
||||
trainer_kwargs = {}
|
||||
|
||||
if cfg.optimizer == "adamw_anyprecision":
|
||||
if Path(cfg.torchdistx_path).exists():
|
||||
sys.path.append(cfg.torchdistx_path)
|
||||
importlib.import_module("torchdistx")
|
||||
|
||||
callbacks = []
|
||||
callbacks.append(GPUStatsCallback(cfg))
|
||||
callbacks.append(EvalFirstStepCallback)
|
||||
|
||||
if cfg.relora_steps:
|
||||
callbacks.append(ReLoRACallback(cfg))
|
||||
|
||||
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
|
||||
callbacks.append(SaveBetterTransformerModelCallback)
|
||||
|
||||
data_collator_kwargs = {
|
||||
"padding": True, # True/"longest" is the default
|
||||
}
|
||||
if cfg.pad_to_sequence_len:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64 * math.ceil(
|
||||
cfg.sequence_len / 64
|
||||
)
|
||||
else:
|
||||
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||
add_mem_tokens,
|
||||
get_mem_id,
|
||||
set_model_mem_id,
|
||||
)
|
||||
|
||||
set_model_mem_id(model, tokenizer)
|
||||
|
||||
LOG.info("Adding landmark attention tokens to dataset")
|
||||
|
||||
for dataset in [train_dataset, eval_dataset]:
|
||||
dataset = dataset.map(
|
||||
partial(add_mem_tokens, mem_freq=50, mem_id=get_mem_id(tokenizer)),
|
||||
batched=False,
|
||||
num_proc=32,
|
||||
)
|
||||
|
||||
trainer_cls = AxolotlTrainer
|
||||
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora"):
|
||||
trainer_cls = OneCycleLRSchedulerTrainer
|
||||
elif cfg.relora_steps:
|
||||
trainer_cls = ReLoRATrainer
|
||||
trainer = trainer_cls(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
args=training_args,
|
||||
data_collator=DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
**data_collator_kwargs,
|
||||
),
|
||||
callbacks=callbacks,
|
||||
**trainer_kwargs,
|
||||
)
|
||||
|
||||
if cfg.use_wandb and cfg.eval_table_size > 0:
|
||||
LogPredictionCallback = log_prediction_callback_factory(trainer, tokenizer)
|
||||
trainer.add_callback(LogPredictionCallback(cfg))
|
||||
|
||||
if cfg.do_bench_eval:
|
||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||
|
||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||
if cfg.early_stopping_patience:
|
||||
early_stop_cb = EarlyStoppingCallback(
|
||||
cfg.early_stopping_patience,
|
||||
)
|
||||
trainer.add_callback(early_stop_cb)
|
||||
|
||||
return trainer
|
||||
|
||||
@@ -2,20 +2,20 @@
|
||||
|
||||
import os
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def setup_wandb_env_vars(cfg: DictDefault):
|
||||
for key in cfg.keys():
|
||||
if key.startswith("wandb_"):
|
||||
value = cfg.get(key, "")
|
||||
|
||||
if value and isinstance(value, str) and len(value) > 0:
|
||||
os.environ[key.upper()] = value
|
||||
|
||||
# Enable wandb if project name is present
|
||||
if cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
def setup_wandb_env_vars(cfg):
|
||||
if cfg.wandb_mode and cfg.wandb_mode == "offline":
|
||||
os.environ["WANDB_MODE"] = cfg.wandb_mode
|
||||
elif cfg.wandb_project and len(cfg.wandb_project) > 0:
|
||||
os.environ["WANDB_PROJECT"] = cfg.wandb_project
|
||||
cfg.use_wandb = True
|
||||
os.environ.pop("WANDB_DISABLED", None) # Remove if present
|
||||
if cfg.wandb_entity and len(cfg.wandb_entity) > 0:
|
||||
os.environ["WANDB_ENTITY"] = cfg.wandb_entity
|
||||
if cfg.wandb_watch and len(cfg.wandb_watch) > 0:
|
||||
os.environ["WANDB_WATCH"] = cfg.wandb_watch
|
||||
if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0:
|
||||
os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model
|
||||
if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0:
|
||||
os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id
|
||||
else:
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
"""
|
||||
E2E tests for lora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestFusedLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using Fused layers
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": True,
|
||||
"flash_attn_fuse_qkv": True,
|
||||
"flash_attn_fuse_mlp": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
@@ -4,6 +4,7 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -13,8 +14,6 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -24,12 +23,13 @@ class TestLoraLlama(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
def test_lora(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
@@ -53,7 +53,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -64,14 +64,15 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
def test_lora_packing(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
@@ -97,11 +98,10 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"bf16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
@@ -109,14 +109,15 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_gptq(self, temp_dir):
|
||||
def test_lora_gptq(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||
"base_model_config": "TheBlokeAI/jackfram_llama-68m-GPTQ",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
@@ -146,7 +147,7 @@ class TestLoraLlama(unittest.TestCase):
|
||||
"save_steps": 0.5,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -157,4 +158,4 @@ class TestLoraLlama(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@@ -4,6 +4,7 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -15,8 +16,6 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -26,12 +25,13 @@ class TestMistral(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
def test_lora(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model_config": "openaccess-ai-collective/tiny-mistral",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
@@ -55,7 +55,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -69,14 +69,15 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
def test_ft(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model_config": "openaccess-ai-collective/tiny-mistral",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
@@ -94,7 +95,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -112,4 +113,4 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -4,6 +4,7 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
@@ -15,8 +16,6 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -26,12 +25,13 @@ class TestMistral(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
def test_lora_packing(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model_config": "openaccess-ai-collective/tiny-mistral",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -56,7 +56,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -70,14 +70,15 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
assert (Path(output_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packing(self, temp_dir):
|
||||
def test_ft_packing(self):
|
||||
# pylint: disable=duplicate-code
|
||||
output_dir = tempfile.mkdtemp()
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model_config": "openaccess-ai-collective/tiny-mistral",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -96,7 +97,7 @@ class TestMistral(unittest.TestCase):
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": output_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -114,4 +115,4 @@ class TestMistral(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
assert (Path(output_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -4,8 +4,8 @@ E2E tests for lora llama
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
@@ -13,8 +13,6 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
@@ -24,14 +22,14 @@ class TestPhi(unittest.TestCase):
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
def test_ft(self):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model_config": "microsoft/phi-1_5",
|
||||
"trust_remote_code": True,
|
||||
"model_type": "PhiForCausalLM",
|
||||
"model_type": "MixFormerSequentialForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"sample_packing": False,
|
||||
@@ -55,7 +53,7 @@ class TestPhi(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": tempfile.mkdtemp(),
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -67,16 +65,15 @@ class TestPhi(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packed(self, temp_dir):
|
||||
def test_ft_packed(self):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
"base_model_config": "microsoft/phi-1_5",
|
||||
"trust_remote_code": True,
|
||||
"model_type": "PhiForCausalLM",
|
||||
"model_type": "MixFormerSequentialForCausalLM",
|
||||
"tokenizer_type": "AutoTokenizer",
|
||||
"sequence_len": 512,
|
||||
"sample_packing": True,
|
||||
@@ -100,7 +97,7 @@ class TestPhi(unittest.TestCase):
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"output_dir": tempfile.mkdtemp(),
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_bnb_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
@@ -112,4 +109,3 @@ class TestPhi(unittest.TestCase):
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
E2E tests for resuming training
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import most_recent_subdir, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestResumeLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for resuming training of llama models
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_resume_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"sample_packing": True,
|
||||
"flash_attention": True,
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 64,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "vicgalle/alpaca-gpt4",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"save_steps": 10,
|
||||
"save_total_limit": 5,
|
||||
"max_steps": 40,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
|
||||
resume_cfg = cfg | DictDefault(
|
||||
{
|
||||
"resume_from_checkpoint": f"{temp_dir}/checkpoint-30/",
|
||||
}
|
||||
)
|
||||
normalize_config(resume_cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
|
||||
train(cfg=resume_cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
tb_log_path_1 = most_recent_subdir(temp_dir + "/runs")
|
||||
cmd = f"tensorboard --inspect --logdir {tb_log_path_1}"
|
||||
res = subprocess.run(
|
||||
cmd, shell=True, text=True, capture_output=True, check=True
|
||||
)
|
||||
pattern = r"first_step\s+(\d+)"
|
||||
first_steps = int(re.findall(pattern, res.stdout)[0])
|
||||
assert first_steps == 31
|
||||
@@ -1,33 +0,0 @@
|
||||
"""
|
||||
helper utils for tests
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from functools import wraps
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def with_temp_dir(test_func):
|
||||
@wraps(test_func)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Create a temporary directory
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
# Pass the temporary directory to the test function
|
||||
test_func(*args, temp_dir=temp_dir, **kwargs)
|
||||
finally:
|
||||
# Clean up the directory after the test
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def most_recent_subdir(path):
|
||||
base_path = Path(path)
|
||||
subdirectories = [d for d in base_path.iterdir() if d.is_dir()]
|
||||
if not subdirectories:
|
||||
return None
|
||||
subdir = max(subdirectories, key=os.path.getctime)
|
||||
|
||||
return subdir
|
||||
2
tests/fixtures/conversation.tokenized.json
vendored
2
tests/fixtures/conversation.tokenized.json
vendored
File diff suppressed because one or more lines are too long
@@ -1,46 +0,0 @@
|
||||
"""
|
||||
Test classes for checking functionality of the cfg normalization
|
||||
"""
|
||||
import unittest
|
||||
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class NormalizeConfigTestCase(unittest.TestCase):
|
||||
"""
|
||||
test class for normalize_config checks
|
||||
"""
|
||||
|
||||
def _get_base_cfg(self):
|
||||
return DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"base_model_config": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
}
|
||||
)
|
||||
|
||||
def test_lr_as_float(self):
|
||||
cfg = (
|
||||
self._get_base_cfg()
|
||||
| DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
{
|
||||
"learning_rate": "5e-5",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
normalize_config(cfg)
|
||||
|
||||
assert cfg.learning_rate == 0.00005
|
||||
|
||||
def test_base_model_config_set_when_empty(self):
|
||||
cfg = self._get_base_cfg()
|
||||
del cfg.base_model_config
|
||||
normalize_config(cfg)
|
||||
|
||||
assert cfg.base_model_config == cfg.base_model
|
||||
@@ -90,73 +90,6 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert "assistant turn has empty text" in self._caplog.records[1].message
|
||||
|
||||
def test_sharegpt_warnings_turns(self):
|
||||
conversation = {
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
]
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
strat.tokenize_prompt(conversation)
|
||||
assert (
|
||||
"Role did not alternate between turns (gpt and human)"
|
||||
in self._caplog.records[0].message
|
||||
)
|
||||
|
||||
def test_sharegpt_changes_roles(self):
|
||||
conversation = {
|
||||
"roles": ["USER", "CHARACTER"],
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
],
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
assert "CHARACTER" in self.tokenizer.decode(res["input_ids"])
|
||||
|
||||
def test_sharegpt_assistant_label_ignore(self):
|
||||
conversation = {
|
||||
"roles": ["user", "assistant"],
|
||||
"conversations": [
|
||||
{"from": "system", "value": "lorem"},
|
||||
{"from": "gpt", "value": "ipsum"},
|
||||
{"from": "human", "value": "dolor"},
|
||||
{"from": "gpt", "value": "sit"},
|
||||
],
|
||||
}
|
||||
prompter = ShareGPTPrompterV2()
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
prompter,
|
||||
self.tokenizer,
|
||||
False,
|
||||
2048,
|
||||
)
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
res = strat.tokenize_prompt(conversation)
|
||||
idx = res["input_ids"].index(20255) # assistant token
|
||||
assert res["labels"][idx] == -100
|
||||
|
||||
def test_no_sys_prompt(self):
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from typing import Optional
|
||||
|
||||
@@ -9,7 +8,6 @@ import pytest
|
||||
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
|
||||
class ValidationTest(unittest.TestCase):
|
||||
@@ -567,197 +565,3 @@ class ValidationTest(unittest.TestCase):
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_eval_table_size_conflict_eval_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"eval_table_size": 100,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match=r".*Please set 'eval_sample_packing' to false.*"
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"eval_sample_packing": False,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": False,
|
||||
"eval_table_size": 100,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"sample_packing": True,
|
||||
"eval_table_size": 100,
|
||||
"eval_sample_packing": False,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_load_in_x_bit_without_adapter(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*load_in_8bit and load_in_4bit are not supported without setting an adapter.*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
def test_warmup_step_no_conflict(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"warmup_steps": 10,
|
||||
"warmup_ratio": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match=r".*warmup_steps and warmup_ratio are mutually exclusive*",
|
||||
):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"warmup_steps": 10,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"warmup_ratio": 0.1,
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
|
||||
class ValidationWandbTest(ValidationTest):
|
||||
"""
|
||||
Validation test for wandb
|
||||
"""
|
||||
|
||||
def test_wandb_set_run_id_to_name(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_run_id": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
with self._caplog.at_level(logging.WARNING):
|
||||
validate_config(cfg)
|
||||
assert any(
|
||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||
in record.message
|
||||
for record in self._caplog.records
|
||||
)
|
||||
|
||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id == "foo"
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_name": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
assert cfg.wandb_name == "foo" and cfg.wandb_run_id is None
|
||||
|
||||
def test_wandb_sets_env(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_project": "foo",
|
||||
"wandb_name": "bar",
|
||||
"wandb_run_id": "bat",
|
||||
"wandb_entity": "baz",
|
||||
"wandb_mode": "online",
|
||||
"wandb_watch": "false",
|
||||
"wandb_log_model": "checkpoint",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_PROJECT", "") == "foo"
|
||||
assert os.environ.get("WANDB_NAME", "") == "bar"
|
||||
assert os.environ.get("WANDB_RUN_ID", "") == "bat"
|
||||
assert os.environ.get("WANDB_ENTITY", "") == "baz"
|
||||
assert os.environ.get("WANDB_MODE", "") == "online"
|
||||
assert os.environ.get("WANDB_WATCH", "") == "false"
|
||||
assert os.environ.get("WANDB_LOG_MODEL", "") == "checkpoint"
|
||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||
|
||||
def test_wandb_set_disabled(self):
|
||||
cfg = DictDefault({})
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_DISABLED", "") == "true"
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"wandb_project": "foo",
|
||||
}
|
||||
)
|
||||
|
||||
validate_config(cfg)
|
||||
|
||||
setup_wandb_env_vars(cfg)
|
||||
|
||||
assert os.environ.get("WANDB_DISABLED", "") != "true"
|
||||
|
||||
Reference in New Issue
Block a user