Compare commits
57 Commits
unsloth_mo
...
yayi2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
272bced137 | ||
|
|
c371d6b546 | ||
|
|
d6273188f0 | ||
|
|
72797b04a5 | ||
|
|
de47bb5eb0 | ||
|
|
c04df54b4b | ||
|
|
e3716db386 | ||
|
|
97943d8fc4 | ||
|
|
9d3f80cd40 | ||
|
|
bfae79a634 | ||
|
|
5a85ee16eb | ||
|
|
3678a6c41d | ||
|
|
f8ae59b0a8 | ||
|
|
4f4d638b84 | ||
|
|
ba043a361e | ||
|
|
41353d2ea0 | ||
|
|
f6ecf14dd4 | ||
|
|
dec66d7c53 | ||
|
|
76357dc5da | ||
|
|
70b46ca4f4 | ||
|
|
85dd4d525b | ||
|
|
384b817dc0 | ||
|
|
db9094df0f | ||
|
|
6ef46f8dca | ||
|
|
628b754824 | ||
|
|
37820f6540 | ||
|
|
7d4185ffcb | ||
|
|
93ebec1ac5 | ||
|
|
2e61dc3180 | ||
|
|
1ffa3866f2 | ||
|
|
62ba1609b6 | ||
|
|
7bbaac98f7 | ||
|
|
161bcb6517 | ||
|
|
d25c34caa6 | ||
|
|
13e938149d | ||
|
|
85de004dd4 | ||
|
|
80ec7af358 | ||
|
|
f28e75513b | ||
|
|
5ada140ff0 | ||
|
|
712fd27b3f | ||
|
|
ef24342538 | ||
|
|
5ea3aa31f0 | ||
|
|
f1f60cb5b2 | ||
|
|
450e04d3c4 | ||
|
|
b0cf397ecb | ||
|
|
5f79b8242f | ||
|
|
f1de29dd1e | ||
|
|
7fabc4d95e | ||
|
|
9a5eb3990c | ||
|
|
86487c2e96 | ||
|
|
35f9b0f149 | ||
|
|
68b227a7d8 | ||
|
|
03c6318ba3 | ||
|
|
40a6362c92 | ||
|
|
d339beb9d9 | ||
|
|
fde091cb12 | ||
|
|
06ae39200b |
7
.github/workflows/base.yml
vendored
7
.github/workflows/base.yml
vendored
@@ -28,7 +28,12 @@ jobs:
|
|||||||
- cuda: "118"
|
- cuda: "118"
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
|
- cuda: "121"
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.1.1
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 9.0+PTX"
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
51
.github/workflows/main.yml
vendored
51
.github/workflows/main.yml
vendored
@@ -27,38 +27,56 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.1.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v3
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl
|
images: winglian/axolotl
|
||||||
|
- name: Set up Docker Buildx
|
||||||
|
uses: docker/setup-buildx-action@v3
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
# guidance for testing before pushing: https://docs.docker.com/build/ci/github-actions/test-before-push/
|
||||||
uses: docker/setup-buildx-action@v2
|
- name: Build and export to Docker
|
||||||
- name: Build
|
uses: docker/build-push-action@v5
|
||||||
uses: docker/build-push-action@v4
|
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
|
load: true
|
||||||
build-args: |
|
build-args: |
|
||||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||||
CUDA=${{ matrix.cuda }}
|
CUDA=${{ matrix.cuda }}
|
||||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||||
file: ./docker/Dockerfile
|
file: ./docker/Dockerfile
|
||||||
push: ${{ github.event_name != 'pull_request' }}
|
|
||||||
tags: |
|
tags: |
|
||||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
labels: ${{ steps.metadata.outputs.labels }}
|
labels: ${{ steps.metadata.outputs.labels }}
|
||||||
|
- name: Unit Tests
|
||||||
|
run: |
|
||||||
|
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||||
|
- name: Push to Docker Hub
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
run: |
|
||||||
|
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||||
|
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||||
|
if [ -n "$latest_tag" ]; then
|
||||||
|
docker push "$latest_tag"
|
||||||
|
fi
|
||||||
|
|
||||||
build-axolotl-runpod:
|
build-axolotl-runpod:
|
||||||
needs: build-axolotl
|
needs: build-axolotl
|
||||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||||
@@ -80,26 +98,31 @@ jobs:
|
|||||||
- cuda: 118
|
- cuda: 118
|
||||||
cuda_version: 11.8.0
|
cuda_version: 11.8.0
|
||||||
python_version: "3.10"
|
python_version: "3.10"
|
||||||
pytorch: 2.1.0
|
pytorch: 2.1.1
|
||||||
|
axolotl_extras:
|
||||||
|
- cuda: 121
|
||||||
|
cuda_version: 12.1.0
|
||||||
|
python_version: "3.10"
|
||||||
|
pytorch: 2.1.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
runs-on: [self-hosted, gpu, docker]
|
runs-on: [self-hosted, gpu, docker]
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
uses: actions/checkout@v3
|
uses: actions/checkout@v4
|
||||||
- name: Docker metadata
|
- name: Docker metadata
|
||||||
id: metadata
|
id: metadata
|
||||||
uses: docker/metadata-action@v3
|
uses: docker/metadata-action@v5
|
||||||
with:
|
with:
|
||||||
images: winglian/axolotl-runpod
|
images: winglian/axolotl-runpod
|
||||||
- name: Login to Docker Hub
|
- name: Login to Docker Hub
|
||||||
uses: docker/login-action@v2
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
username: ${{ secrets.DOCKERHUB_USERNAME }}
|
||||||
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
password: ${{ secrets.DOCKERHUB_TOKEN }}
|
||||||
- name: Set up Docker Buildx
|
- name: Set up Docker Buildx
|
||||||
uses: docker/setup-buildx-action@v2
|
uses: docker/setup-buildx-action@v2
|
||||||
- name: Build
|
- name: Build
|
||||||
uses: docker/build-push-action@v4
|
uses: docker/build-push-action@v5
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
build-args: |
|
build-args: |
|
||||||
|
|||||||
2
.github/workflows/tests.yml
vendored
2
.github/workflows/tests.yml
vendored
@@ -73,7 +73,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
pip3 install --extra-index-url https://download.pytorch.org/whl/cu118 -U torch==2.0.1
|
||||||
pip3 uninstall -y transformers accelerate
|
pip3 uninstall -y transformers accelerate
|
||||||
pip3 install -U -e .[flash-attn]
|
pip3 install -U -e .[flash-attn,mamba-ssm]
|
||||||
pip3 install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run e2e tests
|
- name: Run e2e tests
|
||||||
|
|||||||
@@ -8,6 +8,9 @@ ignore_missing_imports = True
|
|||||||
[mypy-axolotl.monkeypatch.*]
|
[mypy-axolotl.monkeypatch.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
[mypy-axolotl.models.mixtral.*]
|
||||||
|
ignore_errors = True
|
||||||
|
|
||||||
[mypy-axolotl.models.phi.*]
|
[mypy-axolotl.models.phi.*]
|
||||||
ignore_errors = True
|
ignore_errors = True
|
||||||
|
|
||||||
|
|||||||
107
README.md
107
README.md
@@ -36,7 +36,9 @@ Features:
|
|||||||
- [Train](#train)
|
- [Train](#train)
|
||||||
- [Inference](#inference)
|
- [Inference](#inference)
|
||||||
- [Merge LORA to Base](#merge-lora-to-base)
|
- [Merge LORA to Base](#merge-lora-to-base)
|
||||||
|
- [Special Tokens](#special-tokens)
|
||||||
- [Common Errors](#common-errors-)
|
- [Common Errors](#common-errors-)
|
||||||
|
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||||
- [Need Help?](#need-help-)
|
- [Need Help?](#need-help-)
|
||||||
- [Badge](#badge-)
|
- [Badge](#badge-)
|
||||||
- [Community Showcase](#community-showcase)
|
- [Community Showcase](#community-showcase)
|
||||||
@@ -65,19 +67,21 @@ Features:
|
|||||||
|
|
||||||
## Axolotl supports
|
## Axolotl supports
|
||||||
|
|
||||||
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
| | fp16/fp32 | lora | qlora | gptq | gptq w/flash attn | flash attn | xformers attn |
|
||||||
|----------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
|-------------|:----------|:-----|-------|------|-------------------|------------|--------------|
|
||||||
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| llama | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Mistral | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||||
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Mixtral-MoE | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| Pythia | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
| cerebras | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
| btlm | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
| mpt | ✅ | ❌ | ❓ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
| falcon | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ❓ |
|
||||||
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| gpt-j | ✅ | ✅ | ✅ | ❌ | ❌ | ❓ | ❓ |
|
||||||
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
| XGen | ✅ | ❓ | ✅ | ❓ | ❓ | ❓ | ✅ |
|
||||||
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
| phi | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| RWKV | ✅ | ❓ | ❓ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
| Qwen | ✅ | ✅ | ✅ | ❓ | ❓ | ❓ | ❓ |
|
||||||
|
|
||||||
|
|
||||||
## Quickstart ⚡
|
## Quickstart ⚡
|
||||||
@@ -245,10 +249,17 @@ Have dataset(s) in one of the following format (JSONL recommended):
|
|||||||
```json
|
```json
|
||||||
{"instruction": "...", "input": "...", "output": "..."}
|
{"instruction": "...", "input": "...", "output": "..."}
|
||||||
```
|
```
|
||||||
- `sharegpt`: conversations where `from` is `human`/`gpt`
|
- `sharegpt`: conversations where `from` is `human`/`gpt`. (optional: `system` to override default system prompt)
|
||||||
```json
|
```json
|
||||||
{"conversations": [{"from": "...", "value": "..."}]}
|
{"conversations": [{"from": "...", "value": "..."}]}
|
||||||
```
|
```
|
||||||
|
- `llama-2`: the json is the same format as `sharegpt` above, with the following config (see the [config section](#config) for more details)
|
||||||
|
```yml
|
||||||
|
datasets:
|
||||||
|
- path: <your-path>
|
||||||
|
type: sharegpt
|
||||||
|
conversation: llama-2
|
||||||
|
```
|
||||||
- `completion`: raw corpus
|
- `completion`: raw corpus
|
||||||
```json
|
```json
|
||||||
{"text": "..."}
|
{"text": "..."}
|
||||||
@@ -509,6 +520,14 @@ model_config:
|
|||||||
type: # linear | dynamic
|
type: # linear | dynamic
|
||||||
factor: # float
|
factor: # float
|
||||||
|
|
||||||
|
# optional overrides to the bnb 4bit quantization configuration
|
||||||
|
# https://huggingface.co/docs/transformers/main/main_classes/quantization#transformers.BitsAndBytesConfig
|
||||||
|
bnb_config_kwargs:
|
||||||
|
# These are default values
|
||||||
|
llm_int8_has_fp16_weight: false
|
||||||
|
bnb_4bit_quant_type: nf4
|
||||||
|
bnb_4bit_use_double_quant: true
|
||||||
|
|
||||||
|
|
||||||
# Whether you are training a 4-bit GPTQ quantized model
|
# Whether you are training a 4-bit GPTQ quantized model
|
||||||
gptq: true
|
gptq: true
|
||||||
@@ -570,6 +589,9 @@ datasets:
|
|||||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||||
field:
|
field:
|
||||||
|
|
||||||
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
|
# Currently supports chatml and inst (mistral/mixtral)
|
||||||
|
chat_template: chatml
|
||||||
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
# Axolotl attempts to save the dataset as an arrow after packing the data together so
|
||||||
# subsequent training attempts load faster, relative path
|
# subsequent training attempts load faster, relative path
|
||||||
dataset_prepared_path: data/last_run_prepared
|
dataset_prepared_path: data/last_run_prepared
|
||||||
@@ -661,6 +683,7 @@ relora_warmup_steps: # Number of per-restart warmup steps
|
|||||||
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
relora_cpu_offload: # True to perform lora weight merges on cpu during restarts, for modest gpu memory savings
|
||||||
|
|
||||||
# wandb configuration if you're using it
|
# wandb configuration if you're using it
|
||||||
|
# Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||||
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
wandb_mode: # "offline" to save run metadata locally and not sync to the server, "disabled" to turn off wandb
|
||||||
wandb_project: # Your wandb project name
|
wandb_project: # Your wandb project name
|
||||||
wandb_entity: # A wandb Team name if using a Team
|
wandb_entity: # A wandb Team name if using a Team
|
||||||
@@ -689,9 +712,11 @@ warmup_ratio: 0.05 # cannot use with warmup_steps
|
|||||||
learning_rate: 0.00003
|
learning_rate: 0.00003
|
||||||
lr_quadratic_warmup:
|
lr_quadratic_warmup:
|
||||||
logging_steps:
|
logging_steps:
|
||||||
|
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
||||||
|
evals_per_epoch: # number of times per epoch to run evals, mutually exclusive with eval_steps
|
||||||
save_strategy: # Set to `no` to skip checkpoint saves
|
save_strategy: # Set to `no` to skip checkpoint saves
|
||||||
save_steps: # Leave empty to save at each epoch
|
save_steps: # Leave empty to save at each epoch
|
||||||
eval_steps: # Leave empty to eval at each epoch, integers for every N steps. decimal for fraction of total steps
|
saves_per_epoch: # number of times per epoch to save a checkpoint, mutually exclusive with save_steps
|
||||||
save_total_limit: # Checkpoints saved at a time
|
save_total_limit: # Checkpoints saved at a time
|
||||||
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
# Maximum number of iterations to train for. It precedes num_epochs which means that
|
||||||
# if both are set, num_epochs will not be guaranteed.
|
# if both are set, num_epochs will not be guaranteed.
|
||||||
@@ -770,7 +795,7 @@ max_grad_norm:
|
|||||||
# Augmentation techniques
|
# Augmentation techniques
|
||||||
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
|
||||||
# currently only supported on Llama and Mistral
|
# currently only supported on Llama and Mistral
|
||||||
noisy_embedding_alpha:
|
neftune_noise_alpha:
|
||||||
|
|
||||||
# Whether to bettertransformers
|
# Whether to bettertransformers
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
@@ -785,11 +810,6 @@ flash_attn_fuse_mlp: # Whether to fuse part of the MLP into a single operation
|
|||||||
# Whether to use scaled-dot-product attention
|
# Whether to use scaled-dot-product attention
|
||||||
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
# https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
|
||||||
sdp_attention:
|
sdp_attention:
|
||||||
# Landmark attention (only llama)
|
|
||||||
landmark_attention:
|
|
||||||
# xpos RoPE see https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
|
||||||
# LLaMA only
|
|
||||||
xpos_rope:
|
|
||||||
|
|
||||||
# Resume from a specific checkpoint dir
|
# Resume from a specific checkpoint dir
|
||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
@@ -956,6 +976,8 @@ fsdp_config:
|
|||||||
|
|
||||||
##### Weights & Biases Logging
|
##### Weights & Biases Logging
|
||||||
|
|
||||||
|
Make sure your `WANDB_API_KEY` environment variable is set (recommended) or you login to wandb with `wandb login`.
|
||||||
|
|
||||||
- wandb options
|
- wandb options
|
||||||
```yaml
|
```yaml
|
||||||
wandb_mode:
|
wandb_mode:
|
||||||
@@ -966,9 +988,28 @@ wandb_name:
|
|||||||
wandb_log_model:
|
wandb_log_model:
|
||||||
```
|
```
|
||||||
|
|
||||||
### Inference
|
##### Special Tokens
|
||||||
|
|
||||||
Pass the appropriate flag to the train command:
|
It is important to have special tokens like delimiters, end-of-sequence, beginning-of-sequence in your tokenizer's vocabulary. This will help you avoid tokenization issues and help your model train better. You can do this in axolotl like this:
|
||||||
|
|
||||||
|
```yml
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
|
tokens: # these are delimiters
|
||||||
|
- "<|im_start|>"
|
||||||
|
- "<|im_end|>"
|
||||||
|
```
|
||||||
|
|
||||||
|
When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.
|
||||||
|
|
||||||
|
### Inference Playground
|
||||||
|
|
||||||
|
Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
|
||||||
|
The config file is the same config file used for training.
|
||||||
|
|
||||||
|
Pass the appropriate flag to the inference command, depending upon what kind of model was trained:
|
||||||
|
|
||||||
- Pretrained LORA:
|
- Pretrained LORA:
|
||||||
```bash
|
```bash
|
||||||
@@ -997,7 +1038,7 @@ Please use `--sample_packing False` if you have it on and receive the error simi
|
|||||||
Add below flag to train command above
|
Add below flag to train command above
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False
|
python3 -m axolotl.cli.merge_lora examples/your_config.yml --lora_model_dir="./completed-model"
|
||||||
```
|
```
|
||||||
|
|
||||||
If you run out of CUDA memory, you can try to merge in system RAM with
|
If you run out of CUDA memory, you can try to merge in system RAM with
|
||||||
@@ -1018,6 +1059,10 @@ Please reduce any below
|
|||||||
- `gradient_accumulation_steps`
|
- `gradient_accumulation_steps`
|
||||||
- `sequence_len`
|
- `sequence_len`
|
||||||
|
|
||||||
|
If it does not help, try running without deepspeed and without accelerate (replace "accelerate launch" with "python") in the command.
|
||||||
|
|
||||||
|
Using adamw_bnb_8bit might also save you some memory.
|
||||||
|
|
||||||
> `failed (exitcode: -9)`
|
> `failed (exitcode: -9)`
|
||||||
|
|
||||||
Usually means your system has run out of system memory.
|
Usually means your system has run out of system memory.
|
||||||
@@ -1040,6 +1085,20 @@ It's safe to ignore it.
|
|||||||
|
|
||||||
See the [NCCL](docs/nccl.md) guide.
|
See the [NCCL](docs/nccl.md) guide.
|
||||||
|
|
||||||
|
|
||||||
|
### Tokenization Mismatch b/w Inference & Training
|
||||||
|
|
||||||
|
For many formats, Axolotl constructs prompts by concatenating token ids _after_ tokenizing strings. The reason for concatenating token ids rather than operating on strings is to maintain precise accounting for attention masks.
|
||||||
|
|
||||||
|
If you decode a prompt constructed by axolotl, you might see spaces between tokens (or lack thereof) that you do not expect, especially around delimiters and special tokens. When you are starting out with a new format, you should always do the following:
|
||||||
|
|
||||||
|
1. Materialize some data using `python -m axolotl.cli.preprocess your_config.yml --debug`, and then decode the first few rows with your model's tokenizer.
|
||||||
|
2. During inference, right before you pass a tensor of token ids to your model, decode these tokens back into a string.
|
||||||
|
3. Make sure the inference string from #2 looks **exactly** like the data you fine tuned on from #1, including spaces and new lines. If they aren't the same adjust your inference server accordingly.
|
||||||
|
4. As an additional troubleshooting step, you can look look at the token ids between 1 and 2 to make sure they are identical.
|
||||||
|
|
||||||
|
Having misalignment between your prompts during training and inference can cause models to perform very poorly, so it is worth checking this. See [this blog post](https://hamel.dev/notes/llm/05_tokenizer_gotchas.html) for a concrete example.
|
||||||
|
|
||||||
## Need help? 🙋♂️
|
## Need help? 🙋♂️
|
||||||
|
|
||||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||||
|
|||||||
39
deepspeed/zero3_bf16.json
Normal file
39
deepspeed/zero3_bf16.json
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
{
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 0,
|
||||||
|
"reduce_bucket_size": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"stage3_max_live_parameters": 0,
|
||||||
|
"stage3_max_reuse_distance": 0,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": true
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
47
deepspeed/zero3_cpu.json
Normal file
47
deepspeed/zero3_cpu.json
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
{
|
||||||
|
"zero_optimization": {
|
||||||
|
"stage": 3,
|
||||||
|
"offload_optimizer": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"offload_param": {
|
||||||
|
"device": "cpu",
|
||||||
|
"pin_memory": true
|
||||||
|
},
|
||||||
|
"overlap_comm": true,
|
||||||
|
"contiguous_gradients": true,
|
||||||
|
"sub_group_size": 0,
|
||||||
|
"reduce_bucket_size": "auto",
|
||||||
|
"stage3_prefetch_bucket_size": "auto",
|
||||||
|
"stage3_param_persistence_threshold": "auto",
|
||||||
|
"stage3_max_live_parameters": 0,
|
||||||
|
"stage3_max_reuse_distance": 0,
|
||||||
|
"stage3_gather_16bit_weights_on_model_save": true
|
||||||
|
},
|
||||||
|
"bf16": {
|
||||||
|
"enabled": "auto"
|
||||||
|
},
|
||||||
|
"fp16": {
|
||||||
|
"enabled": "auto",
|
||||||
|
"auto_cast": false,
|
||||||
|
"loss_scale": 0,
|
||||||
|
"initial_scale_power": 32,
|
||||||
|
"loss_scale_window": 1000,
|
||||||
|
"hysteresis": 2,
|
||||||
|
"min_loss_scale": 1
|
||||||
|
},
|
||||||
|
"optimizer": {
|
||||||
|
"type": "AdamW",
|
||||||
|
"params": {
|
||||||
|
"lr": "auto",
|
||||||
|
"betas": "auto",
|
||||||
|
"eps": "auto",
|
||||||
|
"weight_decay": "auto"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"gradient_accumulation_steps": "auto",
|
||||||
|
"train_batch_size": "auto",
|
||||||
|
"train_micro_batch_size_per_gpu": "auto",
|
||||||
|
"wall_clock_breakdown": false
|
||||||
|
}
|
||||||
@@ -10,7 +10,7 @@ ARG PYTORCH_VERSION="2.0.1"
|
|||||||
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
ENV PYTORCH_VERSION=$PYTORCH_VERSION
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y vim curl
|
apt-get install -y --allow-change-held-packages vim curl nano libnccl2 libnccl-dev
|
||||||
|
|
||||||
WORKDIR /workspace
|
WORKDIR /workspace
|
||||||
|
|
||||||
@@ -19,13 +19,15 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|||||||
WORKDIR /workspace/axolotl
|
WORKDIR /workspace/axolotl
|
||||||
|
|
||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN sed -i "s/torch==.*/torch==$PYTORCH_VERSION/" requirements.txt
|
|
||||||
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
pip install -e .[deepspeed,flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[deepspeed,flash-attn]; \
|
pip install -e .[deepspeed,flash-attn]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# So we can test the Docker image
|
||||||
|
RUN pip install pytest
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
RUN git config remote.origin.fetch "+refs/heads/*:refs/remotes/origin/*" && \
|
||||||
git config --get remote.origin.fetch
|
git config --get remote.origin.fetch
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ FROM winglian/axolotl:$BASE_TAG
|
|||||||
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
ENV HF_DATASETS_CACHE="/workspace/data/huggingface-cache/datasets"
|
||||||
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV HUGGINGFACE_HUB_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
ENV TRANSFORMERS_CACHE="/workspace/data/huggingface-cache/hub"
|
||||||
|
ENV HF_HOME="/workspace/data/huggingface-cache/hub"
|
||||||
|
|
||||||
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
COPY scripts/runpod-entrypoint.sh /root/runpod-entrypoint.sh
|
||||||
|
|
||||||
|
|||||||
@@ -72,8 +72,8 @@ gptq_groupsize:
|
|||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
|
|
||||||
warmup_steps: 32
|
warmup_steps: 32
|
||||||
eval_steps:
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
save_total_limit:
|
save_total_limit:
|
||||||
|
|
||||||
debug:
|
debug:
|
||||||
|
|||||||
@@ -49,8 +49,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -56,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -56,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -56,8 +56,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 40
|
warmup_steps: 40
|
||||||
eval_steps: 5
|
evals_per_epoch: 4
|
||||||
save_steps: 43
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -80,8 +80,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 5
|
evals_per_epoch: 4
|
||||||
save_steps: 10
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.000001
|
weight_decay: 0.000001
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 40
|
warmup_steps: 40
|
||||||
eval_steps: 5
|
evals_per_epoch: 4
|
||||||
save_steps: 43
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -46,8 +46,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -42,8 +42,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 110
|
evals_per_epoch: 4
|
||||||
save_steps: 660
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -58,9 +58,9 @@ flash_attn_fuse_qkv: false
|
|||||||
flash_attn_fuse_mlp: true
|
flash_attn_fuse_mlp: true
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
deepspeed: #deepspeed/zero2.json # multi-gpu only
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -62,8 +62,8 @@ flash_attention:
|
|||||||
sdp_attention:
|
sdp_attention:
|
||||||
flash_optimum:
|
flash_optimum:
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
eval_steps:
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -54,10 +54,10 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -56,9 +56,9 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps: 50
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -54,9 +54,9 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
61
examples/mamba/config.yml
Normal file
61
examples/mamba/config.yml
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
base_model: state-spaces/mamba-2.8b
|
||||||
|
model_type: MambaLMHeadModel
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
tokenizer_config: EleutherAI/gpt-neox-20b
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 2
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 5e-5
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: true
|
||||||
|
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: true
|
||||||
|
|
||||||
|
gradient_checkpointing: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention:
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
tokens:
|
||||||
|
save_safetensors: False
|
||||||
@@ -17,6 +17,7 @@ output_dir: ./out
|
|||||||
sequence_len: 8192
|
sequence_len: 8192
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
pad_to_sequence_len: true
|
pad_to_sequence_len: true
|
||||||
|
eval_sample_packing: false
|
||||||
|
|
||||||
wandb_project:
|
wandb_project:
|
||||||
wandb_entity:
|
wandb_entity:
|
||||||
@@ -46,10 +47,10 @@ xformers_attention:
|
|||||||
flash_attention: true
|
flash_attention: true
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
91
examples/mistral/mixtral.yml
Normal file
91
examples/mistral/mixtral.yml
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
base_model: mistralai/Mixtral-8x7B-v0.1
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
trust_remote_code: true
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: tatsu-lab/alpaca
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.0
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
## You can optionally freeze the entire model and unfreeze a subset of parameters
|
||||||
|
unfrozen_parameters:
|
||||||
|
# - lm_head.*
|
||||||
|
# - model.embed_tokens.*
|
||||||
|
# - model.layers.2[0-9]+.block_sparse_moe.gate.*
|
||||||
|
# - model.layers.2[0-9]+.block_sparse_moe.experts.*
|
||||||
|
# - model.layers.3[0-9]+.block_sparse_moe.gate.*
|
||||||
|
# - model.layers.3[0-9]+.block_sparse_moe.experts.*
|
||||||
|
|
||||||
|
model_config:
|
||||||
|
output_router_logits: true
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 4096
|
||||||
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
#lora_target_modules:
|
||||||
|
# - gate
|
||||||
|
# - q_proj
|
||||||
|
# - k_proj
|
||||||
|
# - v_proj
|
||||||
|
# - o_proj
|
||||||
|
# - w1
|
||||||
|
# - w2
|
||||||
|
# - w3
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 2
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed/zero2.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
@@ -11,7 +11,7 @@ datasets:
|
|||||||
- path: mhenrichsen/alpaca_2k_test
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
type: alpaca
|
type: alpaca
|
||||||
dataset_prepared_path: last_run_prepared
|
dataset_prepared_path: last_run_prepared
|
||||||
val_set_size: 0.05
|
val_set_size: 0.1
|
||||||
output_dir: ./qlora-out
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
adapter: qlora
|
adapter: qlora
|
||||||
@@ -66,10 +66,10 @@ loss_watchdog_threshold: 5.0
|
|||||||
loss_watchdog_patience: 3
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -44,8 +44,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 110
|
evals_per_epoch: 4
|
||||||
save_steps: 660
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
|
|||||||
@@ -49,8 +49,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -54,8 +54,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -48,8 +48,8 @@ flash_attention: true
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -59,8 +59,8 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 100
|
warmup_steps: 100
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
|
|||||||
@@ -33,5 +33,5 @@ early_stopping_patience:
|
|||||||
resume_from_checkpoint:
|
resume_from_checkpoint:
|
||||||
local_rank:
|
local_rank:
|
||||||
weight_decay: 0.1
|
weight_decay: 0.1
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
logging_steps: 1
|
logging_steps: 1
|
||||||
|
|||||||
@@ -56,10 +56,10 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -56,10 +56,10 @@ xformers_attention:
|
|||||||
flash_attention:
|
flash_attention:
|
||||||
|
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 0.05
|
evals_per_epoch: 4
|
||||||
eval_table_size:
|
eval_table_size:
|
||||||
eval_table_max_new_tokens: 128
|
eval_table_max_new_tokens: 128
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
@@ -45,8 +45,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 110
|
evals_per_epoch: 4
|
||||||
save_steps: 660
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0001
|
weight_decay: 0.0001
|
||||||
|
|||||||
@@ -45,8 +45,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 20
|
warmup_steps: 20
|
||||||
eval_steps: 50
|
evals_per_epoch: 4
|
||||||
save_steps:
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0
|
weight_decay: 0
|
||||||
|
|||||||
@@ -78,8 +78,8 @@ flash_attention:
|
|||||||
gptq_groupsize:
|
gptq_groupsize:
|
||||||
gptq_model_v1:
|
gptq_model_v1:
|
||||||
warmup_steps: 10
|
warmup_steps: 10
|
||||||
eval_steps: 50
|
evals_per_epoch: 4
|
||||||
save_steps: 50
|
saves_per_epoch: 1
|
||||||
debug:
|
debug:
|
||||||
deepspeed:
|
deepspeed:
|
||||||
weight_decay: 0.0
|
weight_decay: 0.0
|
||||||
|
|||||||
64
examples/yayi2-30b/fft.yml
Normal file
64
examples/yayi2-30b/fft.yml
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
base_model: models/yayi2-30b
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
is_mistral_derived_model: false
|
||||||
|
trust_remote_code: true
|
||||||
|
model_revision: refs/pr/5
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: false
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path:
|
||||||
|
val_set_size: 0.05
|
||||||
|
output_dir: ./out
|
||||||
|
|
||||||
|
sequence_len: 2048
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
eval_sample_packing: false
|
||||||
|
|
||||||
|
wandb_project:
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.000005
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: true
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed: deepspeed/zero3_cpu.json
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
76
examples/yayi2-30b/qlora.yml
Normal file
76
examples/yayi2-30b/qlora.yml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: wenge-research/yayi2-30b
|
||||||
|
model_type: AutoModelForCausalLM
|
||||||
|
tokenizer_type: AutoTokenizer
|
||||||
|
is_mistral_derived_model: false
|
||||||
|
trust_remote_code: true
|
||||||
|
model_revision: refs/pr/5
|
||||||
|
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
dataset_prepared_path: last_run_prepared
|
||||||
|
val_set_size: 0.1
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
|
||||||
|
sequence_len: 2048 # Fits in 40gb VRAM. Can easily do 4096 in A100 80 or a A6000
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
|
||||||
|
wandb_project: yayi2
|
||||||
|
wandb_entity:
|
||||||
|
wandb_watch:
|
||||||
|
wandb_name:
|
||||||
|
wandb_log_model:
|
||||||
|
|
||||||
|
gradient_accumulation_steps: 1
|
||||||
|
micro_batch_size: 1
|
||||||
|
num_epochs: 1
|
||||||
|
optimizer: adamw_bnb_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0005
|
||||||
|
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
|
||||||
|
gradient_checkpointing: true
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
flash_attention: false
|
||||||
|
|
||||||
|
loss_watchdog_threshold: 5.0
|
||||||
|
loss_watchdog_patience: 3
|
||||||
|
|
||||||
|
warmup_steps: 10
|
||||||
|
evals_per_epoch: 4
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
saves_per_epoch: 1
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0.0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<s>"
|
||||||
|
eos_token: "</s>"
|
||||||
|
unk_token: "<unk>"
|
||||||
5
examples/yi-34B-chat/README.md
Normal file
5
examples/yi-34B-chat/README.md
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
# Overview
|
||||||
|
|
||||||
|
This is an example of a Yi-34B-Chat configuration. It demonstrates that it is possible to finetune a 34B model on a GPU with 24GB of VRAM.
|
||||||
|
|
||||||
|
Tested on an RTX 4090 with `python -m axolotl.cli.train examples/mistral/qlora.yml`, a single epoch of finetuning on the alpaca dataset using qlora runs in 47 mins, using 97% of available memory.
|
||||||
76
examples/yi-34B-chat/qlora.yml
Normal file
76
examples/yi-34B-chat/qlora.yml
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
base_model: 01-ai/Yi-34B-Chat
|
||||||
|
model_type: LlamaForCausalLM
|
||||||
|
tokenizer_type: LlamaTokenizer
|
||||||
|
is_mistral_derived_model: false
|
||||||
|
is_llama_derived_model: true
|
||||||
|
load_in_8bit: false
|
||||||
|
load_in_4bit: true
|
||||||
|
strict: false
|
||||||
|
sequence_len: 1024
|
||||||
|
bf16: true
|
||||||
|
fp16: false
|
||||||
|
tf32: false
|
||||||
|
flash_attention: true
|
||||||
|
special_tokens:
|
||||||
|
bos_token: "<|startoftext|>"
|
||||||
|
eos_token: "<|endoftext|>"
|
||||||
|
unk_token: "<unk>"
|
||||||
|
|
||||||
|
# Data
|
||||||
|
datasets:
|
||||||
|
- path: mhenrichsen/alpaca_2k_test
|
||||||
|
type: alpaca
|
||||||
|
warmup_steps: 10
|
||||||
|
|
||||||
|
# Iterations
|
||||||
|
num_epochs: 1
|
||||||
|
|
||||||
|
# Evaluation
|
||||||
|
val_set_size: 0.1
|
||||||
|
evals_per_epoch: 5
|
||||||
|
eval_table_size:
|
||||||
|
eval_table_max_new_tokens: 128
|
||||||
|
eval_sample_packing: false
|
||||||
|
eval_batch_size: 1
|
||||||
|
|
||||||
|
# LoRA
|
||||||
|
output_dir: ./qlora-out
|
||||||
|
adapter: qlora
|
||||||
|
lora_model_dir:
|
||||||
|
lora_r: 32
|
||||||
|
lora_alpha: 16
|
||||||
|
lora_dropout: 0.05
|
||||||
|
lora_target_linear: true
|
||||||
|
lora_fan_in_fan_out:
|
||||||
|
lora_target_modules:
|
||||||
|
|
||||||
|
# Sampling
|
||||||
|
sample_packing: false
|
||||||
|
pad_to_sequence_len: false
|
||||||
|
|
||||||
|
# Batching
|
||||||
|
gradient_accumulation_steps: 4
|
||||||
|
micro_batch_size: 1
|
||||||
|
gradient_checkpointing: true
|
||||||
|
|
||||||
|
# wandb
|
||||||
|
wandb_project:
|
||||||
|
|
||||||
|
# Optimizer
|
||||||
|
optimizer: paged_adamw_8bit
|
||||||
|
lr_scheduler: cosine
|
||||||
|
learning_rate: 0.0002
|
||||||
|
|
||||||
|
# Misc
|
||||||
|
train_on_inputs: false
|
||||||
|
group_by_length: false
|
||||||
|
early_stopping_patience:
|
||||||
|
resume_from_checkpoint:
|
||||||
|
local_rank:
|
||||||
|
logging_steps: 1
|
||||||
|
xformers_attention:
|
||||||
|
debug:
|
||||||
|
deepspeed:
|
||||||
|
weight_decay: 0
|
||||||
|
fsdp:
|
||||||
|
fsdp_config:
|
||||||
@@ -2,7 +2,7 @@
|
|||||||
auto-gptq==0.5.1
|
auto-gptq==0.5.1
|
||||||
packaging
|
packaging
|
||||||
peft==0.6.0
|
peft==0.6.0
|
||||||
transformers==4.35.2
|
transformers==4.36.2
|
||||||
tokenizers==0.15.0
|
tokenizers==0.15.0
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate==0.24.1
|
accelerate==0.24.1
|
||||||
@@ -29,7 +29,7 @@ scipy
|
|||||||
scikit-learn==1.2.2
|
scikit-learn==1.2.2
|
||||||
pynvml
|
pynvml
|
||||||
art
|
art
|
||||||
fschat==0.2.29
|
fschat==0.2.34
|
||||||
gradio==3.50.2
|
gradio==3.50.2
|
||||||
tensorboard
|
tensorboard
|
||||||
|
|
||||||
|
|||||||
20
setup.py
20
setup.py
@@ -1,5 +1,7 @@
|
|||||||
"""setup.py for axolotl"""
|
"""setup.py for axolotl"""
|
||||||
|
|
||||||
|
from importlib.metadata import PackageNotFoundError, version
|
||||||
|
|
||||||
from setuptools import find_packages, setup
|
from setuptools import find_packages, setup
|
||||||
|
|
||||||
|
|
||||||
@@ -22,12 +24,13 @@ def parse_requirements():
|
|||||||
# Handle standard packages
|
# Handle standard packages
|
||||||
_install_requires.append(line)
|
_install_requires.append(line)
|
||||||
|
|
||||||
# TODO(wing) remove once xformers release supports torch 2.1.0
|
try:
|
||||||
if "torch==2.1.0" in _install_requires:
|
torch_version = version("torch")
|
||||||
_install_requires.pop(_install_requires.index("xformers>=0.0.22"))
|
if torch_version.startswith("2.1.1"):
|
||||||
_install_requires.append(
|
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||||
"xformers @ git+https://github.com/facebookresearch/xformers.git@main"
|
_install_requires.append("xformers==0.0.23")
|
||||||
)
|
except PackageNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
return _install_requires, _dependency_links
|
return _install_requires, _dependency_links
|
||||||
|
|
||||||
@@ -46,10 +49,13 @@ setup(
|
|||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn>=2.3.0",
|
"flash-attn==2.3.3",
|
||||||
],
|
],
|
||||||
"deepspeed": [
|
"deepspeed": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
],
|
],
|
||||||
|
"mamba-ssm": [
|
||||||
|
"mamba-ssm==1.0.1",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -103,15 +103,7 @@ def do_inference(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.landmark_attention:
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
|
||||||
|
|
||||||
set_model_mem_id(model, tokenizer)
|
|
||||||
model.set_mem_cache_args(
|
|
||||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model.to(cfg.device)
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
print("=" * 80)
|
print("=" * 80)
|
||||||
@@ -176,15 +168,7 @@ def do_inference_gradio(
|
|||||||
importlib.import_module("axolotl.prompters"), prompter
|
importlib.import_module("axolotl.prompters"), prompter
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.landmark_attention:
|
model = model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import set_model_mem_id
|
|
||||||
|
|
||||||
set_model_mem_id(model, tokenizer)
|
|
||||||
model.set_mem_cache_args(
|
|
||||||
max_seq_len=255, mem_freq=50, top_k=5, max_cache_size=None
|
|
||||||
)
|
|
||||||
|
|
||||||
model = model.to(cfg.device)
|
|
||||||
|
|
||||||
def generate(instruction):
|
def generate(instruction):
|
||||||
if not instruction:
|
if not instruction:
|
||||||
|
|||||||
@@ -18,7 +18,15 @@ def do_cli(config: Path = Path("examples/"), **kwargs):
|
|||||||
return_remaining_strings=True
|
return_remaining_strings=True
|
||||||
)
|
)
|
||||||
parsed_cli_args.merge_lora = True
|
parsed_cli_args.merge_lora = True
|
||||||
parsed_cfg = load_cfg(config, merge_lora=True, **kwargs)
|
|
||||||
|
parsed_cfg = load_cfg(
|
||||||
|
config,
|
||||||
|
merge_lora=True,
|
||||||
|
load_in_8bit=False,
|
||||||
|
load_in_4bit=False,
|
||||||
|
flash_attention=False,
|
||||||
|
**kwargs
|
||||||
|
)
|
||||||
|
|
||||||
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
do_merge_lora(cfg=parsed_cfg, cli_args=parsed_cli_args)
|
||||||
|
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ LOG = logging.getLogger("axolotl.cli.train")
|
|||||||
|
|
||||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
print_axolotl_text_art()
|
|
||||||
parsed_cfg = load_cfg(config, **kwargs)
|
parsed_cfg = load_cfg(config, **kwargs)
|
||||||
|
print_axolotl_text_art()
|
||||||
check_accelerate_default_config()
|
check_accelerate_default_config()
|
||||||
check_user_token()
|
check_user_token()
|
||||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ import math
|
|||||||
import sys
|
import sys
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import partial
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -31,7 +31,10 @@ from axolotl.utils.callbacks import (
|
|||||||
bench_eval_callback_factory,
|
bench_eval_callback_factory,
|
||||||
log_prediction_callback_factory,
|
log_prediction_callback_factory,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.collators import (
|
||||||
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
|
MambaDataCollator,
|
||||||
|
)
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler
|
from axolotl.utils.samplers import MultipackBatchSampler
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
@@ -49,6 +52,9 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
Extend the base TrainingArguments for axolotl helpers
|
Extend the base TrainingArguments for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_type: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
|
)
|
||||||
lr_quadratic_warmup: bool = field(
|
lr_quadratic_warmup: bool = field(
|
||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
@@ -114,6 +120,7 @@ class AxolotlTrainer(Trainer):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
args = None # type: AxolotlTrainingArguments
|
args = None # type: AxolotlTrainingArguments
|
||||||
|
tag_names = ["axolotl"]
|
||||||
|
|
||||||
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
def __init__(self, *args, num_epochs=1, bench_data_collator=None, **kwargs):
|
||||||
self.num_epochs = num_epochs
|
self.num_epochs = num_epochs
|
||||||
@@ -284,12 +291,69 @@ class AxolotlTrainer(Trainer):
|
|||||||
# return (loss, outputs) if return_outputs else loss
|
# return (loss, outputs) if return_outputs else loss
|
||||||
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
def _sanitize_kwargs_for_tagging(self, tag_names, kwargs=None):
|
||||||
|
if isinstance(tag_names, str):
|
||||||
|
tag_names = [tag_names]
|
||||||
|
|
||||||
|
if kwargs is not None:
|
||||||
|
if "tags" not in kwargs:
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], list):
|
||||||
|
kwargs["tags"].extend(tag_names)
|
||||||
|
elif "tags" in kwargs and isinstance(kwargs["tags"], str):
|
||||||
|
tag_names.append(kwargs["tags"])
|
||||||
|
kwargs["tags"] = tag_names
|
||||||
|
|
||||||
|
return kwargs
|
||||||
|
|
||||||
|
@wraps(Trainer.push_to_hub)
|
||||||
|
def push_to_hub(self, *args, **kwargs) -> str:
|
||||||
|
"""
|
||||||
|
Overwrite the `push_to_hub` method in order to force-add the tags when pushing the
|
||||||
|
model on the Hub. Please refer to `~transformers.Trainer.push_to_hub` for more details.
|
||||||
|
"""
|
||||||
|
kwargs = self._sanitize_kwargs_for_tagging(
|
||||||
|
tag_names=self.tag_names, kwargs=kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
return super().push_to_hub(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
|
"""
|
||||||
|
Mamba specific trainer to handle loss calculation
|
||||||
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "mamba"]
|
||||||
|
|
||||||
|
def compute_loss(
|
||||||
|
self,
|
||||||
|
model,
|
||||||
|
inputs,
|
||||||
|
return_outputs=False, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
input_ids = inputs.pop("input_ids")
|
||||||
|
lm_logits = model(input_ids).logits
|
||||||
|
|
||||||
|
labels = input_ids.to(lm_logits.device)
|
||||||
|
shift_logits = lm_logits[:, :-1, :].contiguous()
|
||||||
|
labels = labels[:, 1:].contiguous()
|
||||||
|
|
||||||
|
loss_fct = torch.nn.CrossEntropyLoss()
|
||||||
|
lm_loss = loss_fct(
|
||||||
|
shift_logits.view(-1, shift_logits.size(-1)), labels.view(-1)
|
||||||
|
)
|
||||||
|
|
||||||
|
return lm_loss
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "onecycle"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
@@ -319,6 +383,8 @@ class ReLoRATrainer(AxolotlTrainer):
|
|||||||
Trainer subclass that uses the OneCycleLR scheduler
|
Trainer subclass that uses the OneCycleLR scheduler
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
tag_names = ["axolotl", "relora"]
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.lr_scheduler = None
|
self.lr_scheduler = None
|
||||||
@@ -462,6 +528,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return OneCycleLRSchedulerTrainer
|
return OneCycleLRSchedulerTrainer
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
return AxolotlMambaTrainer
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -529,7 +597,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.hub_strategy:
|
if self.cfg.hub_strategy:
|
||||||
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
training_arguments_kwargs["hub_strategy"] = self.cfg.hub_strategy
|
||||||
|
|
||||||
if self.cfg.save_safetensors:
|
if self.cfg.save_safetensors is not None:
|
||||||
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = self.cfg.save_safetensors
|
||||||
|
|
||||||
if self.cfg.sample_packing_eff_est:
|
if self.cfg.sample_packing_eff_est:
|
||||||
@@ -658,6 +726,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
and self.cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||||
else "cosine"
|
else "cosine"
|
||||||
)
|
)
|
||||||
|
training_arguments_kwargs["lr_scheduler_kwargs"] = (
|
||||||
|
self.cfg.lr_scheduler_kwargs if self.cfg.lr_scheduler_kwargs else {}
|
||||||
|
)
|
||||||
training_arguments_kwargs["weight_decay"] = (
|
training_arguments_kwargs["weight_decay"] = (
|
||||||
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
self.cfg.weight_decay if self.cfg.weight_decay is not None else 0.0
|
||||||
)
|
)
|
||||||
@@ -677,6 +748,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||||
training_arguments_kwargs
|
training_arguments_kwargs
|
||||||
)
|
)
|
||||||
|
training_arguments_kwargs["model_type"] = self.cfg.model_config_type
|
||||||
|
|
||||||
|
if self.cfg.neftune_noise_alpha is not None:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"neftune_noise_alpha"
|
||||||
|
] = self.cfg.neftune_noise_alpha
|
||||||
|
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
@@ -702,26 +780,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 64
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
|
||||||
if self.cfg.is_llama_derived_model and self.cfg.landmark_attention:
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
|
||||||
add_mem_tokens,
|
|
||||||
get_mem_id,
|
|
||||||
set_model_mem_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
set_model_mem_id(self.model, self.tokenizer)
|
|
||||||
|
|
||||||
LOG.info("Adding landmark attention tokens to dataset")
|
|
||||||
|
|
||||||
for dataset in [self.train_dataset, self.eval_dataset]:
|
|
||||||
dataset = dataset.map(
|
|
||||||
partial(
|
|
||||||
add_mem_tokens, mem_freq=50, mem_id=get_mem_id(self.tokenizer)
|
|
||||||
),
|
|
||||||
batched=False,
|
|
||||||
num_proc=32,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer_cls = self._get_trainer_cls()
|
trainer_cls = self._get_trainer_cls()
|
||||||
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
trainer_kwargs, trainer_cls = self.hook_pre_create_trainer(
|
||||||
trainer_kwargs, trainer_cls
|
trainer_kwargs, trainer_cls
|
||||||
@@ -731,11 +789,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
train_dataset=self.train_dataset,
|
train_dataset=self.train_dataset,
|
||||||
eval_dataset=self.eval_dataset,
|
eval_dataset=self.eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=BatchSamplerDataCollatorForSeq2Seq(
|
data_collator=self.build_collator(**data_collator_kwargs),
|
||||||
self.tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
**data_collator_kwargs,
|
|
||||||
),
|
|
||||||
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
bench_data_collator=transformers.DataCollatorForSeq2Seq(
|
||||||
self.tokenizer,
|
self.tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
@@ -755,3 +809,13 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.micro_batch_size
|
] = self.cfg.micro_batch_size
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
|
def build_collator(self, **kwargs):
|
||||||
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
return MambaDataCollator(tokenizer=self.tokenizer)
|
||||||
|
|
||||||
|
return BatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
self.tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|||||||
12
src/axolotl/models/mamba/__init__.py
Normal file
12
src/axolotl/models/mamba/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Modeling module for Mamba models
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def fix_mamba_attn_for_loss():
|
||||||
|
from mamba_ssm.models import mixer_seq_simple
|
||||||
|
|
||||||
|
from .modeling_mamba import MambaLMHeadModel as MambaLMHeadModelFixed
|
||||||
|
|
||||||
|
mixer_seq_simple.MambaLMHeadModel = MambaLMHeadModelFixed
|
||||||
|
return mixer_seq_simple.MambaLMHeadModel # pylint: disable=invalid-name
|
||||||
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
42
src/axolotl/models/mamba/configuration_mamba.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""
|
||||||
|
HF Transformers MambaConfig
|
||||||
|
"""
|
||||||
|
from transformers import PretrainedConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MambaConfig(PretrainedConfig):
|
||||||
|
"""
|
||||||
|
modeling configuration for state space model/mamba
|
||||||
|
"""
|
||||||
|
|
||||||
|
model_type = "mamba"
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab_size=50280,
|
||||||
|
d_model=2560,
|
||||||
|
n_layer=64,
|
||||||
|
rms_norm=True,
|
||||||
|
residual_in_fp32=True,
|
||||||
|
fused_add_norm=True,
|
||||||
|
pad_vocab_size_multiple=8,
|
||||||
|
pad_token_id=50277,
|
||||||
|
bos_token_id=0,
|
||||||
|
eos_token_id=0,
|
||||||
|
tie_word_embeddings=False,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
self.d_model = d_model
|
||||||
|
self.n_layer = n_layer
|
||||||
|
self.rms_norm = rms_norm
|
||||||
|
self.residual_in_fp32 = residual_in_fp32
|
||||||
|
self.fused_add_norm = fused_add_norm
|
||||||
|
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
||||||
|
super().__init__(
|
||||||
|
pad_token_id=pad_token_id,
|
||||||
|
bos_token_id=bos_token_id,
|
||||||
|
eos_token_id=eos_token_id,
|
||||||
|
tie_word_embeddings=tie_word_embeddings,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
128
src/axolotl/models/mamba/modeling_mamba.py
Normal file
@@ -0,0 +1,128 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
import os
|
||||||
|
from collections import namedtuple
|
||||||
|
from functools import partial
|
||||||
|
from typing import Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from mamba_ssm.models.mixer_seq_simple import MixerModel, _init_weights
|
||||||
|
from mamba_ssm.utils.generation import GenerationMixin
|
||||||
|
from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
|
||||||
|
from torch import nn
|
||||||
|
from torch.nn import CrossEntropyLoss
|
||||||
|
|
||||||
|
from axolotl.models.mamba.configuration_mamba import MambaConfig
|
||||||
|
|
||||||
|
|
||||||
|
class MambaLMHeadModel(nn.Module, GenerationMixin):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
d_model: int,
|
||||||
|
n_layer: int,
|
||||||
|
vocab_size: int,
|
||||||
|
initializer_cfg=None,
|
||||||
|
pad_vocab_size_multiple: int = 1,
|
||||||
|
device=None,
|
||||||
|
dtype=None,
|
||||||
|
**backbone_kwargs,
|
||||||
|
) -> None:
|
||||||
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
|
super().__init__()
|
||||||
|
if vocab_size % pad_vocab_size_multiple != 0:
|
||||||
|
vocab_size += pad_vocab_size_multiple - (
|
||||||
|
vocab_size % pad_vocab_size_multiple
|
||||||
|
)
|
||||||
|
self.config = MambaConfig(
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
d_model=d_model,
|
||||||
|
n_layer=n_layer,
|
||||||
|
pad_vocab_size_multiple=pad_vocab_size_multiple,
|
||||||
|
)
|
||||||
|
self.backbone = MixerModel(
|
||||||
|
d_model=d_model,
|
||||||
|
n_layer=n_layer,
|
||||||
|
vocab_size=vocab_size,
|
||||||
|
initializer_cfg=initializer_cfg,
|
||||||
|
**backbone_kwargs,
|
||||||
|
**factory_kwargs,
|
||||||
|
)
|
||||||
|
self.lm_head = nn.Linear(d_model, vocab_size, bias=False, **factory_kwargs)
|
||||||
|
|
||||||
|
# Initialize weights and apply final processing
|
||||||
|
self.apply(
|
||||||
|
partial(
|
||||||
|
_init_weights,
|
||||||
|
n_layer=n_layer,
|
||||||
|
**(initializer_cfg if initializer_cfg is not None else {}),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
self.tie_weights()
|
||||||
|
|
||||||
|
def tie_weights(self):
|
||||||
|
self.lm_head.weight = self.backbone.embedding.weight
|
||||||
|
|
||||||
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
|
||||||
|
return self.backbone.allocate_inference_cache(
|
||||||
|
batch_size, max_seqlen, dtype=dtype, **kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids,
|
||||||
|
position_ids=None,
|
||||||
|
inference_params=None,
|
||||||
|
num_last_tokens=0,
|
||||||
|
labels=None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
"position_ids" is just to be compatible with Transformer generation. We don't use it.
|
||||||
|
num_last_tokens: if > 0, only return the logits for the last n tokens
|
||||||
|
"""
|
||||||
|
hidden_states = self.backbone(input_ids, inference_params=inference_params)
|
||||||
|
if num_last_tokens > 0:
|
||||||
|
hidden_states = hidden_states[:, -num_last_tokens:]
|
||||||
|
lm_logits = self.lm_head(hidden_states)
|
||||||
|
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||||
|
return CausalLMOutput(logits=lm_logits)
|
||||||
|
|
||||||
|
loss = None
|
||||||
|
if labels is not None:
|
||||||
|
logits = lm_logits
|
||||||
|
# Shift so that tokens < n predict n
|
||||||
|
shift_logits = logits[..., :-1, :].contiguous()
|
||||||
|
shift_labels = labels[..., 1:].contiguous()
|
||||||
|
# Flatten the tokens
|
||||||
|
loss_fct = CrossEntropyLoss()
|
||||||
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
shift_labels = shift_labels.view(-1)
|
||||||
|
# Enable model parallelism
|
||||||
|
shift_labels = shift_labels.to(shift_logits.device)
|
||||||
|
loss = loss_fct(shift_logits, shift_labels)
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits", "loss"])
|
||||||
|
print(loss)
|
||||||
|
return CausalLMOutput(logits=lm_logits, loss=loss)
|
||||||
|
|
||||||
|
else:
|
||||||
|
CausalLMOutput = namedtuple("CausalLMOutput", ["logits"])
|
||||||
|
return CausalLMOutput(logits=lm_logits)
|
||||||
|
|
||||||
|
def save_pretrained(
|
||||||
|
self,
|
||||||
|
save_directory: Union[str, os.PathLike],
|
||||||
|
state_dict: Optional[dict] = None,
|
||||||
|
safe_serialization: Optional[bool] = None, # pylint: disable=unused-argument
|
||||||
|
):
|
||||||
|
if state_dict is None:
|
||||||
|
state_dict = self.state_dict()
|
||||||
|
torch.save(state_dict, os.path.join(save_directory, "pytorch_model.bin"))
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(cls, pretrained_model_name, device=None, dtype=None, **kwargs):
|
||||||
|
config = load_config_hf(pretrained_model_name)
|
||||||
|
model = cls(**config, device=device, dtype=dtype, **kwargs)
|
||||||
|
model.load_state_dict(
|
||||||
|
load_state_dict_hf(pretrained_model_name, device={"": device}, dtype=dtype)
|
||||||
|
)
|
||||||
|
return model
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
# Adapted from Unsloth
|
|
||||||
# https://github.com/unslothai/unsloth/blob/4b97a810b509c93f44be4c037c7aa18fb8922884/unsloth/kernels/cross_entropy_loss.py
|
|
||||||
|
|
||||||
import triton
|
|
||||||
import triton.language as tl
|
|
||||||
import torch
|
|
||||||
|
|
||||||
MAX_FUSED_SIZE = 65536
|
|
||||||
|
|
||||||
def calculate_settings(n):
|
|
||||||
BLOCK_SIZE = triton.next_power_of_2(n)
|
|
||||||
# CUDA only supports 65536 - 2^16 threads per block
|
|
||||||
if BLOCK_SIZE > MAX_FUSED_SIZE:
|
|
||||||
raise RuntimeError(f"Cannot launch Triton kernel since n = {n} exceeds "\
|
|
||||||
f"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.")
|
|
||||||
num_warps = 4
|
|
||||||
if BLOCK_SIZE >= 32768: num_warps = 32
|
|
||||||
elif BLOCK_SIZE >= 8192: num_warps = 16
|
|
||||||
elif BLOCK_SIZE >= 2048: num_warps = 8
|
|
||||||
return BLOCK_SIZE, num_warps
|
|
||||||
pass
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _cross_entropy_forward(logits_ptr, logits_row_stride,
|
|
||||||
loss_ptr,
|
|
||||||
lse_ptr,
|
|
||||||
labels_ptr,
|
|
||||||
n_cols,
|
|
||||||
BLOCK_SIZE: tl.constexpr,):
|
|
||||||
"""
|
|
||||||
Cross Entropy Loss = 1/n sum [ -yi log(Pi) ]
|
|
||||||
Pi = exp(xi) / sum(exp(xi))
|
|
||||||
CE_i = -y log(p) = -y log[ exp(x) / sum(exp(x)) ]
|
|
||||||
= -y [ x - log[sum(exp(x))] ]
|
|
||||||
= y * (log[sum(exp(x))] - x)
|
|
||||||
If y == 0: CE_i = 0
|
|
||||||
If y == 1: CE_i = logsumexp - x
|
|
||||||
"""
|
|
||||||
row_idx = tl.program_id(0)
|
|
||||||
logits_ptr += row_idx * logits_row_stride
|
|
||||||
loss_ptr += row_idx
|
|
||||||
lse_ptr += row_idx
|
|
||||||
labels_ptr += row_idx
|
|
||||||
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = col_offsets < n_cols
|
|
||||||
|
|
||||||
# TODO: Fixup int32 locations to int64
|
|
||||||
label_idx = tl.load(labels_ptr).to(tl.int32)
|
|
||||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = -float("inf")).to(tl.float32)
|
|
||||||
max_logits = tl.max(logits, 0)
|
|
||||||
# Maximum stops overflow
|
|
||||||
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits
|
|
||||||
tl.store(lse_ptr, lse)
|
|
||||||
|
|
||||||
if label_idx != -100:
|
|
||||||
logits_label = tl.load(logits_ptr + label_idx).to(tl.float32)
|
|
||||||
loss = lse - logits_label
|
|
||||||
else:
|
|
||||||
loss = 0.0
|
|
||||||
tl.store(loss_ptr, loss)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def _cross_entropy_backward(logits_ptr, logits_row_stride,
|
|
||||||
dloss_ptr, dloss_row_stride,
|
|
||||||
lse_ptr,
|
|
||||||
labels_ptr,
|
|
||||||
n_cols,
|
|
||||||
BLOCK_SIZE: tl.constexpr,):
|
|
||||||
"""
|
|
||||||
CE_i = -y log(P) = y * (log[sum(exp(x))] - x)
|
|
||||||
dC/dx = d/dx (y * log[sum(exp(x))] - x * y)
|
|
||||||
|
|
||||||
From https://en.wikipedia.org/wiki/LogSumExp
|
|
||||||
d/dx logsumexp = exp(x) / sum(exp(x)) = softmax(x)
|
|
||||||
|
|
||||||
dC/dx = y * exp(x) / sum(exp(x)) - d/dx (x * y)
|
|
||||||
dC/dx = y * exp[ log[exp(x) / sum(exp(x))] ] using x = exp(log(x)) trick
|
|
||||||
dC/dx = y * exp[x - logsumexp] - d/dx (x * y)
|
|
||||||
|
|
||||||
If y == 0: dC/dx = 0
|
|
||||||
If y == 1 and x == label: dC/dlabel = exp[x - logsumexp] - 1
|
|
||||||
If y == 1 and x != label: dC/dx = exp[x - logsumexp]
|
|
||||||
"""
|
|
||||||
row_idx = tl.program_id(0)
|
|
||||||
logits_ptr += row_idx * logits_row_stride
|
|
||||||
dloss_ptr += row_idx * dloss_row_stride
|
|
||||||
col_offsets = tl.arange(0, BLOCK_SIZE)
|
|
||||||
mask = col_offsets < n_cols
|
|
||||||
# TODO: Fixup int32 locations to int64
|
|
||||||
label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)
|
|
||||||
|
|
||||||
if label_idx != -100:
|
|
||||||
dloss = tl.load(dloss_ptr)
|
|
||||||
else:
|
|
||||||
dloss = 0.0
|
|
||||||
logits = tl.load(logits_ptr + col_offsets, mask = mask, other = 0).to(tl.float32)
|
|
||||||
lse = tl.load(lse_ptr + row_idx)
|
|
||||||
probs = tl.exp(logits - lse)
|
|
||||||
|
|
||||||
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)
|
|
||||||
tl.store(logits_ptr + col_offsets, dloss * probs, mask = mask)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropyLoss(torch.autograd.Function):
|
|
||||||
@staticmethod
|
|
||||||
def forward(ctx, logits, labels):
|
|
||||||
n_rows, n_cols = logits.shape
|
|
||||||
BLOCK_SIZE, num_warps = calculate_settings(n_cols)
|
|
||||||
losses = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
|
||||||
logsumexp = torch.empty(n_rows, dtype = torch.float32, device = "cuda")
|
|
||||||
|
|
||||||
_cross_entropy_forward[(n_rows,)](
|
|
||||||
logits, logits.stride(0),
|
|
||||||
losses,
|
|
||||||
logsumexp,
|
|
||||||
labels,
|
|
||||||
n_cols,
|
|
||||||
BLOCK_SIZE = BLOCK_SIZE,
|
|
||||||
num_warps = num_warps,
|
|
||||||
)
|
|
||||||
|
|
||||||
ctx.BLOCK_SIZE = BLOCK_SIZE
|
|
||||||
ctx.num_warps = num_warps
|
|
||||||
ctx.save_for_backward(logits, logsumexp, labels)
|
|
||||||
return losses
|
|
||||||
pass
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def backward(ctx, dlosses):
|
|
||||||
logits, logsumexp, labels = ctx.saved_tensors
|
|
||||||
n_rows, n_cols = logits.shape
|
|
||||||
|
|
||||||
_cross_entropy_backward[(n_rows,)](
|
|
||||||
logits, logits.stride(0),
|
|
||||||
dlosses, dlosses.stride(0),
|
|
||||||
logsumexp,
|
|
||||||
labels,
|
|
||||||
n_cols,
|
|
||||||
BLOCK_SIZE = ctx.BLOCK_SIZE,
|
|
||||||
num_warps = ctx.num_warps,
|
|
||||||
)
|
|
||||||
return logits, None, None,
|
|
||||||
pass
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def fast_cross_entropy_loss(logits, labels):
|
|
||||||
"""
|
|
||||||
Arguments:
|
|
||||||
logits: (batch, seq_len, vocab_size)
|
|
||||||
labels: (batch, seq_len,)
|
|
||||||
Returns:
|
|
||||||
losses: float
|
|
||||||
"""
|
|
||||||
batch, seq_len, d = logits.shape
|
|
||||||
assert(labels.shape == (batch, seq_len))
|
|
||||||
|
|
||||||
loss = CrossEntropyLoss.apply(
|
|
||||||
logits.view(batch*seq_len, d),
|
|
||||||
labels.view(-1),
|
|
||||||
)
|
|
||||||
n_items = torch.count_nonzero(labels != -100)
|
|
||||||
return loss.sum() / n_items
|
|
||||||
pass
|
|
||||||
@@ -82,15 +82,44 @@ def get_turns( # pylint: disable=too-many-return-statements
|
|||||||
else:
|
else:
|
||||||
yield role + ":", ""
|
yield role + ":", ""
|
||||||
return
|
return
|
||||||
if self.sep_style == SeparatorStyle.LLAMA2:
|
if self.sep_style == SeparatorStyle.LLAMA2 and self.name != "mistral":
|
||||||
seps = [self.sep, self.sep2]
|
|
||||||
if self.system_message:
|
if self.system_message:
|
||||||
|
if self.messages:
|
||||||
|
# For llama, the system message is incorporated into the first human instruction
|
||||||
|
first_role, first_msg = self.messages[0]
|
||||||
|
if first_role == self.roles[0]:
|
||||||
|
system_prompt += first_msg
|
||||||
|
self.messages.pop(0)
|
||||||
yield "", system_prompt
|
yield "", system_prompt
|
||||||
else:
|
for i, (role, message) in enumerate(self.messages):
|
||||||
yield "", "[INST] "
|
|
||||||
for i, (role, message) in enumerate(self.messages[1:]):
|
|
||||||
if message:
|
if message:
|
||||||
yield role + " ", message + seps[i % 2]
|
if (i % 2 == 0 and not self.system_message) or (
|
||||||
|
i % 2 != 0 and self.system_message
|
||||||
|
):
|
||||||
|
role = "<s> " + role
|
||||||
|
yield role + " ", message
|
||||||
|
else:
|
||||||
|
yield role, ""
|
||||||
|
return
|
||||||
|
if self.sep_style == SeparatorStyle.LLAMA2 and self.name == "mistral":
|
||||||
|
contains_sys_msg = False
|
||||||
|
if self.system_message:
|
||||||
|
contains_sys_msg = True
|
||||||
|
if self.messages:
|
||||||
|
# There is no clear guidance on how to handle system messages in Mistral so we just prepend it to the first human instruction seperated by a newline
|
||||||
|
first_role, first_msg = self.messages[0]
|
||||||
|
if first_role == self.roles[0]:
|
||||||
|
system_prompt = self.system_template.format(
|
||||||
|
system_message=" " + self.system_message
|
||||||
|
)
|
||||||
|
system_prompt += first_msg
|
||||||
|
self.messages.pop(0)
|
||||||
|
yield "", system_prompt
|
||||||
|
for i, (role, message) in enumerate(self.messages):
|
||||||
|
if message and i == 0 and not contains_sys_msg:
|
||||||
|
yield "", system_prompt.strip() + " " + message # if there is no system message, we need to make sure there is the a `<s> [INST]` at the beginning of the first instruction.
|
||||||
|
elif message:
|
||||||
|
yield role + " ", message
|
||||||
else:
|
else:
|
||||||
yield role, ""
|
yield role, ""
|
||||||
return
|
return
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -13,20 +13,16 @@ from flash_attn.flash_attn_interface import ( # pylint: disable=ungrouped-impor
|
|||||||
flash_attn_varlen_kvpacked_func,
|
flash_attn_varlen_kvpacked_func,
|
||||||
flash_attn_varlen_qkvpacked_func,
|
flash_attn_varlen_qkvpacked_func,
|
||||||
)
|
)
|
||||||
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
|
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralAttention as OriginalMistralAttention,
|
MistralAttention as OriginalMistralAttention,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
from transformers.models.mistral.modeling_mistral import (
|
||||||
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
MistralDecoderLayer as OriginalMistralDecoderLayer,
|
||||||
)
|
)
|
||||||
from transformers.models.mistral.modeling_mistral import (
|
|
||||||
MistralForCausalLM as OriginalMistralForCausalLM,
|
|
||||||
)
|
|
||||||
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
from transformers.models.mistral.modeling_mistral import apply_rotary_pos_emb, repeat_kv
|
||||||
|
|
||||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
from axolotl.monkeypatch.cross_entropy import fast_cross_entropy_loss
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
LOG = logging.getLogger("axolotl.monkeypatch.mistral")
|
||||||
|
|
||||||
@@ -40,9 +36,6 @@ def replace_mistral_attn_with_flash_attn(
|
|||||||
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
transformers.models.mistral.modeling_mistral.MistralAttention.forward = (
|
||||||
flashattn_forward
|
flashattn_forward
|
||||||
)
|
)
|
||||||
transformers.models.mistral.modeling_mistral.MistralForCausalLM.forward = (
|
|
||||||
mistral_causallm_forward
|
|
||||||
)
|
|
||||||
if packed:
|
if packed:
|
||||||
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
transformers.models.mistral.modeling_mistral.MistralDecoderLayer = (
|
||||||
MistralDecoderLayer
|
MistralDecoderLayer
|
||||||
@@ -648,71 +641,3 @@ class MistralDecoderLayer(OriginalMistralDecoderLayer):
|
|||||||
outputs += (present_key_value,)
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
|
||||||
def mistral_causallm_forward(
|
|
||||||
self: OriginalMistralForCausalLM,
|
|
||||||
input_ids: torch.LongTensor = None,
|
|
||||||
attention_mask: Optional[torch.Tensor] = None,
|
|
||||||
position_ids: Optional[torch.LongTensor] = None,
|
|
||||||
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
|
||||||
inputs_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
labels: Optional[torch.LongTensor] = None,
|
|
||||||
use_cache: Optional[bool] = None,
|
|
||||||
output_attentions: Optional[bool] = None,
|
|
||||||
output_hidden_states: Optional[bool] = None,
|
|
||||||
return_dict: Optional[bool] = None,
|
|
||||||
*args, **kwargs
|
|
||||||
) -> Union[Tuple, CausalLMOutputWithPast]:
|
|
||||||
r"""
|
|
||||||
Args:
|
|
||||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
||||||
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
||||||
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
||||||
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
||||||
```"""
|
|
||||||
|
|
||||||
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
||||||
output_hidden_states = (
|
|
||||||
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
||||||
)
|
|
||||||
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
|
||||||
|
|
||||||
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
|
|
||||||
outputs = self.model(
|
|
||||||
input_ids=input_ids,
|
|
||||||
attention_mask=attention_mask,
|
|
||||||
position_ids=position_ids,
|
|
||||||
past_key_values=past_key_values,
|
|
||||||
inputs_embeds=inputs_embeds,
|
|
||||||
use_cache=use_cache,
|
|
||||||
output_attentions=output_attentions,
|
|
||||||
output_hidden_states=output_hidden_states,
|
|
||||||
return_dict=return_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = outputs[0]
|
|
||||||
logits = self.lm_head(hidden_states)
|
|
||||||
|
|
||||||
loss = None
|
|
||||||
if labels is not None:
|
|
||||||
shift_logits = logits
|
|
||||||
if not hasattr(self, "extra_ignored_labels"):
|
|
||||||
self.extra_ignored_labels = torch.full((self.model.config.max_position_embeddings, 1), -100, device=shift_logits.device)
|
|
||||||
|
|
||||||
shift_labels = torch.hstack((labels[..., 1:], self.extra_ignored_labels[:labels.shape[0]]))
|
|
||||||
shift_labels = shift_labels.to(shift_logits.device)
|
|
||||||
|
|
||||||
# FAST CROSS ENTROPY
|
|
||||||
loss = fast_cross_entropy_loss(shift_logits, shift_labels)
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
output = (logits,) + outputs[1:]
|
|
||||||
return (loss,) + output if loss is not None else output
|
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
|
||||||
loss=loss,
|
|
||||||
logits=logits,
|
|
||||||
past_key_values=outputs.past_key_values,
|
|
||||||
hidden_states=outputs.hidden_states,
|
|
||||||
attentions=outputs.attentions,
|
|
||||||
)
|
|
||||||
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
22
src/axolotl/monkeypatch/mixtral/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
Patches to support multipack for mixtral
|
||||||
|
"""
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
|
||||||
|
def replace_mixtral_attn_with_multipack_flash_attn():
|
||||||
|
from .modeling_mixtral import (
|
||||||
|
MixtralMultipackFlashAttention2,
|
||||||
|
mixtral_decoder_layer_forward,
|
||||||
|
mixtral_model_forward,
|
||||||
|
)
|
||||||
|
|
||||||
|
transformers.models.mixtral.modeling_mixtral.MixtralDecoderLayer.forward = (
|
||||||
|
mixtral_decoder_layer_forward
|
||||||
|
)
|
||||||
|
transformers.models.mixtral.modeling_mixtral.MixtralModel.forward = (
|
||||||
|
mixtral_model_forward
|
||||||
|
)
|
||||||
|
transformers.models.mixtral.modeling_mixtral.MISTRAL_ATTENTION_CLASSES[
|
||||||
|
"flash_attention_2"
|
||||||
|
] = MixtralMultipackFlashAttention2
|
||||||
379
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
379
src/axolotl/monkeypatch/mixtral/modeling_mixtral.py
Normal file
@@ -0,0 +1,379 @@
|
|||||||
|
"""
|
||||||
|
Mixtral modeling for multipack
|
||||||
|
"""
|
||||||
|
# pylint: disable=missing-module-docstring,unused-argument,protected-access,pointless-string-statement,duplicate-code
|
||||||
|
import logging
|
||||||
|
import warnings
|
||||||
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from einops import rearrange
|
||||||
|
from flash_attn import flash_attn_varlen_qkvpacked_func
|
||||||
|
from transformers import Cache, DynamicCache
|
||||||
|
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
|
||||||
|
from transformers.modeling_outputs import MoeModelOutputWithPast
|
||||||
|
from transformers.models.mixtral.modeling_mixtral import (
|
||||||
|
MixtralFlashAttention2,
|
||||||
|
apply_rotary_pos_emb,
|
||||||
|
repeat_kv,
|
||||||
|
)
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.monkeypatch.mixtral")
|
||||||
|
|
||||||
|
|
||||||
|
class MixtralMultipackFlashAttention2(MixtralFlashAttention2):
|
||||||
|
"""
|
||||||
|
Custom multipack implementation w flash attention 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self._flash_attn_uses_top_left_mask = True
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Cache] = None,
|
||||||
|
output_attentions: bool = False,
|
||||||
|
use_cache: bool = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
bsz, q_len, _ = hidden_states.size()
|
||||||
|
|
||||||
|
query_states = self.q_proj(hidden_states)
|
||||||
|
key_states = self.k_proj(hidden_states)
|
||||||
|
value_states = self.v_proj(hidden_states)
|
||||||
|
|
||||||
|
query_states = query_states.view(
|
||||||
|
bsz, q_len, self.num_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
key_states = key_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
value_states = value_states.view(
|
||||||
|
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||||
|
).transpose(1, 2)
|
||||||
|
|
||||||
|
kv_seq_len = key_states.shape[-2]
|
||||||
|
if past_key_value is not None:
|
||||||
|
if self.layer_idx is None:
|
||||||
|
raise ValueError(
|
||||||
|
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
|
||||||
|
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
|
||||||
|
"with a layer index."
|
||||||
|
)
|
||||||
|
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
|
||||||
|
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||||
|
query_states, key_states = apply_rotary_pos_emb(
|
||||||
|
query_states, key_states, cos, sin, position_ids
|
||||||
|
)
|
||||||
|
|
||||||
|
if past_key_value is not None:
|
||||||
|
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
|
||||||
|
key_states, value_states = past_key_value.update(
|
||||||
|
key_states, value_states, self.layer_idx, cache_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
# repeat k/v heads if n_kv_heads < n_heads
|
||||||
|
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||||
|
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||||
|
|
||||||
|
if cu_seqlens is not None and max_seqlen is not None and cu_seqlens.dim() == 1:
|
||||||
|
# special handling using sample packing
|
||||||
|
qkv = torch.stack(
|
||||||
|
[query_states, key_states, value_states], dim=2
|
||||||
|
) # [bsz, nh, 3, q_len, hd]
|
||||||
|
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
|
||||||
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
|
||||||
|
attn_output = flash_attn_varlen_qkvpacked_func(
|
||||||
|
qkv,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
dropout_p=self.attention_dropout,
|
||||||
|
softmax_scale=None,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
attn_output = rearrange(attn_output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
|
||||||
|
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous()
|
||||||
|
attn_output = self.o_proj(attn_output)
|
||||||
|
|
||||||
|
if not output_attentions:
|
||||||
|
attn_weights = None
|
||||||
|
|
||||||
|
return attn_output, attn_weights, past_key_value
|
||||||
|
|
||||||
|
|
||||||
|
def mixtral_decoder_layer_forward(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||||
|
output_attentions: Optional[bool] = False,
|
||||||
|
output_router_logits: Optional[bool] = False,
|
||||||
|
use_cache: Optional[bool] = False,
|
||||||
|
cu_seqlens: Optional[torch.Tensor] = None,
|
||||||
|
max_seqlen: Optional[torch.Tensor] = None,
|
||||||
|
**kwargs,
|
||||||
|
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
|
||||||
|
if "padding_mask" in kwargs:
|
||||||
|
warnings.warn(
|
||||||
|
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||||
|
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
|
||||||
|
`(batch, sequence_length)` where padding elements are indicated by 0.
|
||||||
|
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
|
||||||
|
output_attentions (`bool`, *optional*):
|
||||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
|
||||||
|
returned tensors for more detail.
|
||||||
|
output_router_logits (`bool`, *optional*):
|
||||||
|
Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
|
||||||
|
should not be returned during inference.
|
||||||
|
use_cache (`bool`, *optional*):
|
||||||
|
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
|
||||||
|
(see `past_key_values`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
hidden_states = self.input_layernorm(hidden_states)
|
||||||
|
|
||||||
|
# Self Attention
|
||||||
|
hidden_states, self_attn_weights, present_key_value = self.self_attn(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_value,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
# Fully Connected
|
||||||
|
residual = hidden_states
|
||||||
|
hidden_states = self.post_attention_layernorm(hidden_states)
|
||||||
|
hidden_states, router_logits = self.block_sparse_moe(hidden_states)
|
||||||
|
hidden_states = residual + hidden_states
|
||||||
|
|
||||||
|
outputs = (hidden_states,)
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
outputs += (self_attn_weights,)
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
outputs += (present_key_value,)
|
||||||
|
|
||||||
|
if output_router_logits:
|
||||||
|
outputs += (router_logits,)
|
||||||
|
|
||||||
|
return outputs
|
||||||
|
|
||||||
|
|
||||||
|
def mixtral_model_forward(
|
||||||
|
self,
|
||||||
|
input_ids: torch.LongTensor = None,
|
||||||
|
attention_mask: Optional[torch.Tensor] = None,
|
||||||
|
position_ids: Optional[torch.LongTensor] = None,
|
||||||
|
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
||||||
|
inputs_embeds: Optional[torch.FloatTensor] = None,
|
||||||
|
use_cache: Optional[bool] = None,
|
||||||
|
output_attentions: Optional[bool] = None,
|
||||||
|
output_hidden_states: Optional[bool] = None,
|
||||||
|
output_router_logits: Optional[bool] = None,
|
||||||
|
return_dict: Optional[bool] = None,
|
||||||
|
) -> Union[Tuple, MoeModelOutputWithPast]:
|
||||||
|
output_attentions = (
|
||||||
|
output_attentions
|
||||||
|
if output_attentions is not None
|
||||||
|
else self.config.output_attentions
|
||||||
|
)
|
||||||
|
output_router_logits = (
|
||||||
|
output_router_logits
|
||||||
|
if output_router_logits is not None
|
||||||
|
else self.config.output_router_logits
|
||||||
|
)
|
||||||
|
output_hidden_states = (
|
||||||
|
output_hidden_states
|
||||||
|
if output_hidden_states is not None
|
||||||
|
else self.config.output_hidden_states
|
||||||
|
)
|
||||||
|
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
||||||
|
|
||||||
|
return_dict = (
|
||||||
|
return_dict if return_dict is not None else self.config.use_return_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# retrieve input_ids and inputs_embeds
|
||||||
|
if input_ids is not None and inputs_embeds is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
|
||||||
|
)
|
||||||
|
if input_ids is not None:
|
||||||
|
batch_size, seq_length = input_ids.shape
|
||||||
|
elif inputs_embeds is not None:
|
||||||
|
batch_size, seq_length, _ = inputs_embeds.shape
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
|
||||||
|
)
|
||||||
|
|
||||||
|
past_key_values_length = 0
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
use_legacy_cache = not isinstance(past_key_values, Cache)
|
||||||
|
if use_legacy_cache:
|
||||||
|
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
|
||||||
|
past_key_values_length = past_key_values.get_usable_length(seq_length)
|
||||||
|
|
||||||
|
cu_seqlens = None
|
||||||
|
max_seqlen = None
|
||||||
|
if position_ids is None:
|
||||||
|
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
||||||
|
position_ids = torch.arange(
|
||||||
|
past_key_values_length,
|
||||||
|
seq_length + past_key_values_length,
|
||||||
|
dtype=torch.long,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
|
||||||
|
else:
|
||||||
|
position_ids = position_ids.view(-1, seq_length).long()
|
||||||
|
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_seqlens = cu_seqlens.squeeze()
|
||||||
|
|
||||||
|
if inputs_embeds is None:
|
||||||
|
inputs_embeds = self.embed_tokens(input_ids)
|
||||||
|
|
||||||
|
if attention_mask is not None and self._use_flash_attention_2 and use_cache:
|
||||||
|
is_padding_right = attention_mask[:, -1].sum().item() != batch_size
|
||||||
|
if is_padding_right:
|
||||||
|
raise ValueError(
|
||||||
|
"You are attempting to perform batched generation with padding_side='right'"
|
||||||
|
" this may lead to unexpected behaviour for Flash Attention version of Mixtral. Make sure to "
|
||||||
|
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._use_flash_attention_2:
|
||||||
|
# 2d mask is passed through the layers
|
||||||
|
attention_mask = (
|
||||||
|
attention_mask
|
||||||
|
if (attention_mask is not None and 0 in attention_mask)
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# 4d mask is passed through the layers
|
||||||
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
attention_mask,
|
||||||
|
(batch_size, seq_length),
|
||||||
|
inputs_embeds,
|
||||||
|
past_key_values_length,
|
||||||
|
sliding_window=self.config.sliding_window,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
if use_cache:
|
||||||
|
LOG.warning_once(
|
||||||
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
||||||
|
)
|
||||||
|
use_cache = False
|
||||||
|
|
||||||
|
# decoder layers
|
||||||
|
all_hidden_states = () if output_hidden_states else None
|
||||||
|
all_self_attns = () if output_attentions else None
|
||||||
|
all_router_logits = () if output_router_logits else None
|
||||||
|
next_decoder_cache = None
|
||||||
|
|
||||||
|
for decoder_layer in self.layers:
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
if self.gradient_checkpointing and self.training:
|
||||||
|
layer_outputs = self._gradient_checkpointing_func(
|
||||||
|
decoder_layer.__call__,
|
||||||
|
hidden_states,
|
||||||
|
attention_mask,
|
||||||
|
position_ids,
|
||||||
|
past_key_values,
|
||||||
|
output_attentions,
|
||||||
|
output_router_logits,
|
||||||
|
use_cache,
|
||||||
|
cu_seqlens,
|
||||||
|
max_seqlen,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
layer_outputs = decoder_layer(
|
||||||
|
hidden_states,
|
||||||
|
attention_mask=attention_mask,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_value=past_key_values,
|
||||||
|
output_attentions=output_attentions,
|
||||||
|
output_router_logits=output_router_logits,
|
||||||
|
use_cache=use_cache,
|
||||||
|
cu_seqlens=cu_seqlens,
|
||||||
|
max_seqlen=max_seqlen,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = layer_outputs[0]
|
||||||
|
|
||||||
|
if use_cache:
|
||||||
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
|
||||||
|
|
||||||
|
if output_attentions:
|
||||||
|
all_self_attns += (layer_outputs[1],)
|
||||||
|
|
||||||
|
if output_router_logits:
|
||||||
|
all_router_logits += (layer_outputs[-1],)
|
||||||
|
|
||||||
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
|
# add hidden states from the last decoder layer
|
||||||
|
if output_hidden_states:
|
||||||
|
all_hidden_states += (hidden_states,)
|
||||||
|
|
||||||
|
next_cache = None
|
||||||
|
if use_cache:
|
||||||
|
next_cache = (
|
||||||
|
next_decoder_cache.to_legacy_cache()
|
||||||
|
if use_legacy_cache
|
||||||
|
else next_decoder_cache
|
||||||
|
)
|
||||||
|
|
||||||
|
if not return_dict:
|
||||||
|
return tuple(
|
||||||
|
v
|
||||||
|
for v in [
|
||||||
|
hidden_states,
|
||||||
|
next_cache,
|
||||||
|
all_hidden_states,
|
||||||
|
all_self_attns,
|
||||||
|
all_router_logits,
|
||||||
|
]
|
||||||
|
if v is not None
|
||||||
|
)
|
||||||
|
|
||||||
|
return MoeModelOutputWithPast(
|
||||||
|
last_hidden_state=hidden_states,
|
||||||
|
past_key_values=next_cache,
|
||||||
|
hidden_states=all_hidden_states,
|
||||||
|
attentions=all_self_attns,
|
||||||
|
router_logits=all_router_logits,
|
||||||
|
)
|
||||||
@@ -1,65 +0,0 @@
|
|||||||
"""
|
|
||||||
patches implemented through the trainer hooks to enable NEFT/noisy embeddings per https://arxiv.org/abs/2310.05914
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
from peft import PeftModel
|
|
||||||
from transformers import PreTrainedModel
|
|
||||||
|
|
||||||
|
|
||||||
def patch_neft(alpha, model):
|
|
||||||
embeddings = None
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
if not embeddings:
|
|
||||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
|
||||||
embeddings.noisy_embedding_alpha = alpha
|
|
||||||
old_forward = embeddings.forward
|
|
||||||
|
|
||||||
# This hack seems to be needed to properly use a custom forward pass
|
|
||||||
# all credits to: https://discuss.pytorch.org/t/how-can-i-replace-the-forward-method-of-a-predefined-torchvision-model-with-my-customized-forward-function/54224/11
|
|
||||||
bound_method = neft_forward.__get__( # pylint: disable=no-value-for-parameter
|
|
||||||
embeddings, embeddings.__class__
|
|
||||||
)
|
|
||||||
setattr(embeddings, "forward", bound_method)
|
|
||||||
|
|
||||||
embeddings._old_forward = old_forward # pylint: disable=protected-access
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def unpatch_neft(model):
|
|
||||||
embeddings = None
|
|
||||||
if isinstance(model, PreTrainedModel):
|
|
||||||
embeddings = model.get_input_embeddings()
|
|
||||||
if isinstance(model, PeftModel):
|
|
||||||
embeddings = model.base_model.get_input_embeddings()
|
|
||||||
if not embeddings:
|
|
||||||
raise ValueError(f"unhandled model class for neft: {model.__class__.__name__}")
|
|
||||||
if hasattr(embeddings, "_old_forward"):
|
|
||||||
embeddings.forward = embeddings._old_forward # pylint: disable=protected-access
|
|
||||||
del embeddings._old_forward # pylint: disable=protected-access
|
|
||||||
del embeddings.noisy_embedding_alpha
|
|
||||||
|
|
||||||
|
|
||||||
def neft_forward(self, inputs: torch.Tensor):
|
|
||||||
embeddings = self._old_forward(inputs) # pylint: disable=protected-access
|
|
||||||
|
|
||||||
if self.training:
|
|
||||||
dims = torch.tensor(embeddings.size(1) * embeddings.size(2))
|
|
||||||
mag_norm = self.noisy_embedding_alpha / torch.sqrt(dims)
|
|
||||||
embeddings = embeddings + torch.zeros_like(embeddings).uniform_(
|
|
||||||
-mag_norm, mag_norm
|
|
||||||
)
|
|
||||||
|
|
||||||
return embeddings
|
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hook(cfg, trainer):
|
|
||||||
if cfg.noisy_embedding_alpha:
|
|
||||||
trainer.model = patch_neft(cfg.noisy_embedding_alpha, trainer.model)
|
|
||||||
|
|
||||||
|
|
||||||
def post_train_hook(cfg, trainer):
|
|
||||||
if cfg.noisy_embedding_alpha:
|
|
||||||
unpatch_neft(trainer.model)
|
|
||||||
@@ -1,94 +0,0 @@
|
|||||||
# pylint: skip-file
|
|
||||||
"""
|
|
||||||
Copied from https://github.com/kaiokendev/cutoff-len-is-context-len/blob/main/util/xpos_rope_llama_monkey_patch.py
|
|
||||||
"""
|
|
||||||
import torch
|
|
||||||
import transformers
|
|
||||||
import transformers.models.llama.modeling_llama
|
|
||||||
from einops import rearrange
|
|
||||||
|
|
||||||
|
|
||||||
class XposRotaryEmbedding(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
dim,
|
|
||||||
max_position_embeddings=2048,
|
|
||||||
base=10000,
|
|
||||||
device=None,
|
|
||||||
scale_base=2048,
|
|
||||||
use_xpos=True,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.max_seq_len_cached = max_position_embeddings
|
|
||||||
self.scale_base = scale_base
|
|
||||||
|
|
||||||
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=device).type_as(inv_freq)
|
|
||||||
freqs = torch.einsum("i , j -> i j", t, inv_freq)
|
|
||||||
freqs = torch.cat((freqs, freqs), dim=-1)
|
|
||||||
|
|
||||||
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
|
||||||
self.register_buffer("freqs_cached", freqs, persistent=False)
|
|
||||||
|
|
||||||
if not use_xpos:
|
|
||||||
self.register_buffer("scale", None)
|
|
||||||
self.register_buffer("scale_cached", torch.ones(1))
|
|
||||||
return
|
|
||||||
|
|
||||||
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim)
|
|
||||||
power = (t - (self.max_seq_len_cached // 2)) / self.scale_base
|
|
||||||
scale_cached = scale ** rearrange(power, "n -> n 1")
|
|
||||||
scale_cached = torch.cat((scale_cached, scale_cached), dim=-1)
|
|
||||||
|
|
||||||
self.register_buffer("scale", scale, persistent=False)
|
|
||||||
self.register_buffer("scale_cached", scale_cached, persistent=False)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
x,
|
|
||||||
seq_len,
|
|
||||||
):
|
|
||||||
if seq_len > self.max_seq_len_cached:
|
|
||||||
self.max_seq_len_cached = seq_len
|
|
||||||
t = torch.arange(self.max_seq_len_cached, device=x.device).type_as(
|
|
||||||
self.inv_freq
|
|
||||||
)
|
|
||||||
freqs = torch.einsum("i , j -> i j", t, self.inv_freq)
|
|
||||||
freqs = torch.cat((freqs, freqs), dim=-1).to(dtype=x.dtype)
|
|
||||||
|
|
||||||
self.register_buffer("freqs_cached", freqs)
|
|
||||||
|
|
||||||
if self.scale is None:
|
|
||||||
self.register_buffer(
|
|
||||||
"scale_cached", torch.ones(1, device=x.device).to(dtype=x.dtype)
|
|
||||||
)
|
|
||||||
|
|
||||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached
|
|
||||||
|
|
||||||
power = (t - (seq_len // 2)) / self.scale_base
|
|
||||||
scale = self.scale ** rearrange(power, "n -> n 1")
|
|
||||||
scale = torch.cat((scale, scale), dim=-1).to(dtype=x.dtype)
|
|
||||||
self.register_buffer("scale_cached", scale)
|
|
||||||
|
|
||||||
return self.freqs_cached.to(dtype=x.dtype), self.scale_cached.to(dtype=x.dtype)
|
|
||||||
|
|
||||||
|
|
||||||
def rotate_half(x):
|
|
||||||
x1, x2 = x.chunk(2, dim=-1)
|
|
||||||
return torch.cat((-x2, x1), dim=-1)
|
|
||||||
|
|
||||||
|
|
||||||
def apply_rotary_pos_emb(q, k, freqs, scale=1, position_ids=None):
|
|
||||||
freqs = freqs[position_ids, :]
|
|
||||||
if scale.shape[-1] != 1:
|
|
||||||
scale = scale[position_ids, :]
|
|
||||||
|
|
||||||
q_embed = (q * freqs.cos() * scale) + (rotate_half(q) * freqs.sin() * scale)
|
|
||||||
k_embed = (k * freqs.cos() * 1 / scale) + (rotate_half(k) * freqs.sin() * 1 / scale)
|
|
||||||
|
|
||||||
return q_embed, k_embed
|
|
||||||
|
|
||||||
|
|
||||||
def replace_llama_rope_with_xpos_rope():
|
|
||||||
transformers.models.llama.modeling_llama.LlamaRotaryEmbedding = XposRotaryEmbedding
|
|
||||||
transformers.models.llama.modeling_llama.apply_rotary_pos_emb = apply_rotary_pos_emb
|
|
||||||
@@ -81,8 +81,9 @@ class LLama2ChatTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self.sequence_len = 4096
|
self.tokenizer.add_special_tokens(
|
||||||
self.tokenizer.add_special_tokens({"pad_token": "<pad>"})
|
{"pad_token": getattr(self.tokenizer, "pad_token", "<pad>")}
|
||||||
|
)
|
||||||
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
# https://huggingface.co/meta-llama/Llama-2-7b-chat-hf/blob/main/added_tokens.json
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ register_conv_template(
|
|||||||
system_message="You are a helpful assistant.",
|
system_message="You are a helpful assistant.",
|
||||||
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
roles=["<|im_start|>user", "<|im_start|>assistant"],
|
||||||
sep_style=SeparatorStyle.CHATML,
|
sep_style=SeparatorStyle.CHATML,
|
||||||
sep="<|im_end|>\n",
|
sep="<|im_end|>",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -39,6 +39,23 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
return strategy
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
|
def load_ultrachat(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
conversation = (
|
||||||
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||||
|
)
|
||||||
|
strategy = UltrachatShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation=conversation,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
if ds_cfg and "strict" in ds_cfg:
|
||||||
|
strategy.strict = ds_cfg["strict"]
|
||||||
|
return strategy
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
def load_role(tokenizer, cfg):
|
||||||
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
return SimpleRoleShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(),
|
ShareGPTPrompterV2(),
|
||||||
@@ -109,3 +126,17 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
|
class UltrachatShareGPTPromptTokenizingStrategy(SimpleShareGPTPromptTokenizingStrategy):
|
||||||
|
"""
|
||||||
|
sharegpt strategy that remaps ultrachat data to sharegpt format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_conversation_thread(self, prompt):
|
||||||
|
conversations = prompt["messages"]
|
||||||
|
role_map = {"user": "human", "assistant": "gpt"}
|
||||||
|
turns = [
|
||||||
|
{"from": role_map[t["role"]], "value": t["content"]} for t in conversations
|
||||||
|
]
|
||||||
|
return turns
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ class AlpacaPrompter(Prompter):
|
|||||||
Base class for alpaca prompters
|
Base class for alpaca prompters
|
||||||
"""
|
"""
|
||||||
|
|
||||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request."
|
||||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
|
||||||
system_format: str = "{system}"
|
system_format: str = "{system}"
|
||||||
turn_format: str
|
turn_format: str
|
||||||
turn_no_input_format: str
|
turn_no_input_format: str
|
||||||
|
|||||||
@@ -12,12 +12,13 @@ import transformers.modelcard
|
|||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
|
from pkg_resources import get_distribution # type: ignore
|
||||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from axolotl.common.cli import TrainerCliArgs
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.monkeypatch import neft_embeddings
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.freeze import freeze_parameters_except
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import setup_trainer
|
||||||
|
|
||||||
@@ -78,11 +79,15 @@ def train(
|
|||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||||
|
|
||||||
|
if cfg.unfrozen_parameters:
|
||||||
|
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
model.config.use_cache = False
|
if hasattr(model, "config"):
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
# go ahead and presave, so we have the adapter config available to inspect
|
# go ahead and presave, so we have the adapter config available to inspect
|
||||||
if peft_config:
|
if peft_config:
|
||||||
@@ -92,7 +97,8 @@ def train(
|
|||||||
if not Path(cfg.output_dir).is_dir():
|
if not Path(cfg.output_dir).is_dir():
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
if hasattr(model, "config"):
|
||||||
|
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
@@ -110,6 +116,12 @@ def train(
|
|||||||
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
badge_markdown = """[<img src="https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/image/axolotl-badge-web.png" alt="Built with Axolotl" width="200" height="32"/>](https://github.com/OpenAccess-AI-Collective/axolotl)"""
|
||||||
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n{badge_markdown}"
|
||||||
|
|
||||||
|
if getattr(cfg, "axolotl_config_path"):
|
||||||
|
raw_axolotl_cfg = Path(cfg.axolotl_config_path)
|
||||||
|
version = get_distribution("axolotl").version
|
||||||
|
if raw_axolotl_cfg.is_file():
|
||||||
|
transformers.modelcard.AUTOGENERATED_TRAINER_COMMENT += f"\n<details><summary>See axolotl config</summary>\n\naxolotl version: `{version}`\n```yaml\n{raw_axolotl_cfg.read_text(encoding='utf-8')}\n```\n\n</details><br>\n"
|
||||||
|
|
||||||
LOG.info("Starting trainer...")
|
LOG.info("Starting trainer...")
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
@@ -174,21 +186,19 @@ def train(
|
|||||||
return model, tokenizer
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def pretrain_hooks(cfg, trainer):
|
def pretrain_hooks(_cfg, _trainer):
|
||||||
"""
|
"""
|
||||||
Run hooks right before kicking off the training
|
Run hooks right before kicking off the training
|
||||||
:param cfg:
|
:param cfg:
|
||||||
:param trainer:
|
:param trainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
neft_embeddings.pretrain_hook(cfg, trainer)
|
|
||||||
|
|
||||||
|
|
||||||
def post_train_hooks(cfg, trainer):
|
def post_train_hooks(_cfg, _trainer):
|
||||||
"""
|
"""
|
||||||
Run hooks right after training completes
|
Run hooks right after training completes
|
||||||
:param cfg:
|
:param cfg:
|
||||||
:param trainer:
|
:param trainer:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
neft_embeddings.post_train_hook(cfg, trainer)
|
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from shutil import copyfile
|
||||||
|
from tempfile import NamedTemporaryFile
|
||||||
from typing import TYPE_CHECKING, Dict, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import evaluate
|
import evaluate
|
||||||
@@ -561,10 +563,15 @@ class SaveAxolotlConfigtoWandBCallback(TrainerCallback):
|
|||||||
):
|
):
|
||||||
if is_main_process():
|
if is_main_process():
|
||||||
try:
|
try:
|
||||||
artifact = wandb.Artifact(name="axolotl-config", type="config")
|
# sync config to top level in run, cannot delete file right away because wandb schedules it to be synced even w/policy = 'now', so let OS delete it later.
|
||||||
artifact.add_file(local_path=self.axolotl_config_path)
|
with NamedTemporaryFile(
|
||||||
wandb.run.log_artifact(artifact)
|
mode="w", delete=False, suffix=".yml", prefix="axolotl_config_"
|
||||||
LOG.info("Axolotl config has been saved to WandB as an artifact.")
|
) as temp_file:
|
||||||
|
copyfile(self.axolotl_config_path, temp_file.name)
|
||||||
|
wandb.save(temp_file.name)
|
||||||
|
LOG.info(
|
||||||
|
"The Axolotl config has been saved to the WandB run under files."
|
||||||
|
)
|
||||||
except (FileNotFoundError, ConnectionError) as err:
|
except (FileNotFoundError, ConnectionError) as err:
|
||||||
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
LOG.warning(f"Error while saving Axolotl config to WandB: {err}")
|
||||||
return control
|
return control
|
||||||
|
|||||||
29
src/axolotl/utils/chat_templates.py
Normal file
29
src/axolotl/utils/chat_templates.py
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
"""
|
||||||
|
This module provides functionality for selecting chat templates based on user choices.
|
||||||
|
These templates are used for formatting messages in a conversation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def chat_templates(user_choice: str):
|
||||||
|
"""
|
||||||
|
Finds the correct chat_template for the tokenizer_config.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_choice (str): The user's choice of template.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The chosen template string.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the user_choice is not found in the templates.
|
||||||
|
"""
|
||||||
|
|
||||||
|
templates = {
|
||||||
|
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||||
|
"chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||||
|
}
|
||||||
|
|
||||||
|
if user_choice in templates:
|
||||||
|
return templates[user_choice]
|
||||||
|
|
||||||
|
raise ValueError(f"Template '{user_choice}' not found.")
|
||||||
@@ -2,12 +2,16 @@
|
|||||||
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
"""
|
"""
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Dict, Optional, Sequence, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import transformers
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
from transformers.utils import PaddingStrategy
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
IGNORE_INDEX = -100
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataCollatorForSeq2Seq:
|
class DataCollatorForSeq2Seq:
|
||||||
@@ -146,3 +150,31 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
chunked_data[feature] = np.concatenate(arrays)
|
chunked_data[feature] = np.concatenate(arrays)
|
||||||
features = [chunked_data]
|
features = [chunked_data]
|
||||||
return super().__call__(features, return_tensors=return_tensors)
|
return super().__call__(features, return_tensors=return_tensors)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MambaDataCollator:
|
||||||
|
"""
|
||||||
|
Collator for State Space Models (Mamba)
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: transformers.PreTrainedTokenizer
|
||||||
|
|
||||||
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||||
|
input_ids, labels = tuple(
|
||||||
|
[torch.LongTensor(instance[key]) for instance in instances]
|
||||||
|
for key in ("input_ids", "labels")
|
||||||
|
)
|
||||||
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
input_ids,
|
||||||
|
batch_first=True,
|
||||||
|
padding_value=self.tokenizer.pad_token_id,
|
||||||
|
)
|
||||||
|
labels = torch.nn.utils.rnn.pad_sequence(
|
||||||
|
labels, batch_first=True, padding_value=IGNORE_INDEX
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"input_ids": input_ids,
|
||||||
|
"labels": labels,
|
||||||
|
}
|
||||||
|
|||||||
@@ -77,6 +77,15 @@ def normalize_config(cfg):
|
|||||||
else:
|
else:
|
||||||
cfg.torch_dtype = torch.float32
|
cfg.torch_dtype = torch.float32
|
||||||
|
|
||||||
|
if cfg.saves_per_epoch:
|
||||||
|
save_steps = 1.0 / (cfg.saves_per_epoch * cfg.num_epochs)
|
||||||
|
if save_steps < 1.0: # prevent saves on every step
|
||||||
|
cfg.save_steps = save_steps
|
||||||
|
if cfg.evals_per_epoch:
|
||||||
|
eval_steps = 1.0 / (cfg.evals_per_epoch * cfg.num_epochs)
|
||||||
|
if eval_steps < 1.0: # prevent evals on every step
|
||||||
|
cfg.eval_steps = eval_steps
|
||||||
|
|
||||||
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
cfg.dataset_processes = cfg.dataset_processes or os.cpu_count()
|
||||||
|
|
||||||
if not cfg.base_model_config:
|
if not cfg.base_model_config:
|
||||||
@@ -352,6 +361,27 @@ def validate_config(cfg):
|
|||||||
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
cfg.datasets[idx].type = cfg.datasets[idx].type.replace(
|
||||||
"sharegpt_simple", "sharegpt"
|
"sharegpt_simple", "sharegpt"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.saves_per_epoch and cfg.save_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"save_steps and saves_per_epoch are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
|
if cfg.saves_per_epoch and cfg.save_strategy and cfg.save_strategy != "steps":
|
||||||
|
raise ValueError(
|
||||||
|
"save_strategy must be empty or set to `steps` when used with saves_per_epoch."
|
||||||
|
)
|
||||||
|
if cfg.evals_per_epoch and cfg.eval_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"eval_steps and evals_per_epoch are mutually exclusive and cannot be used together."
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
cfg.evals_per_epoch
|
||||||
|
and cfg.evaluation_strategy
|
||||||
|
and cfg.evaluation_strategy != "steps"
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch."
|
||||||
|
)
|
||||||
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
if cfg.save_strategy and cfg.save_steps and cfg.save_strategy != "steps":
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
"save_strategy and save_steps mismatch. Please set save_strategy to 'steps' or remove save_steps."
|
||||||
@@ -404,6 +434,34 @@ def validate_config(cfg):
|
|||||||
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
"wandb_run_id sets the ID of the run. If you would like to set the name, please use wandb_name instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.noisy_embedding_alpha is not None:
|
||||||
|
# Deprecated, use neftune_noise_alpha
|
||||||
|
LOG.warning("noisy_embedding_alpha is deprecated, use neftune_noise_alpha")
|
||||||
|
if cfg.neftune_noise_alpha is None:
|
||||||
|
cfg.neftune_noise_alpha = cfg.noisy_embedding_alpha
|
||||||
|
else:
|
||||||
|
# User is providing both; bail and have them sort out their settings
|
||||||
|
raise ValueError(
|
||||||
|
"noisy_embedding_alpha is deprecated, use neftune_noise_alpha; both are set, please remove the deprecated noisy_embedding_alpha setting"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.neftune_noise_alpha is not None and cfg.neftune_noise_alpha <= 0.0:
|
||||||
|
raise ValueError("neftune_noise_alpha must be > 0.0")
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.adapter
|
||||||
|
and cfg.tokens
|
||||||
|
and (
|
||||||
|
not cfg.lora_modules_to_save
|
||||||
|
or not all(
|
||||||
|
x in cfg.lora_modules_to_save for x in ["embed_tokens", "lm_head"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"lora_modules_to_save not properly set yet adding new tokens. Please add `embed_tokens` and `lm_head` to `lora_modules_to_save`."
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
38
src/axolotl/utils/freeze.py
Normal file
38
src/axolotl/utils/freeze.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
"""
|
||||||
|
module to freeze/unfreeze parameters by name
|
||||||
|
"""
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from axolotl.utils.distributed import is_main_process
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.freeze")
|
||||||
|
|
||||||
|
|
||||||
|
def freeze_parameters_except(model, regex_patterns):
|
||||||
|
"""
|
||||||
|
Freezes all layers of the given model except for the layers that match given regex patterns.
|
||||||
|
Periods in the patterns are treated as literal periods, not as wildcard characters.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- model (nn.Module): The PyTorch model to be modified.
|
||||||
|
- regex_patterns (list of str): List of regex patterns to match layer names to keep unfrozen.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
None; the model is modified in place.
|
||||||
|
"""
|
||||||
|
# Escape periods and compile the regex patterns
|
||||||
|
compiled_patterns = [
|
||||||
|
re.compile(pattern.replace(".", "\\.")) for pattern in regex_patterns
|
||||||
|
]
|
||||||
|
|
||||||
|
# First, freeze all parameters in the model
|
||||||
|
for param in model.parameters():
|
||||||
|
param.requires_grad = False
|
||||||
|
|
||||||
|
# Unfreeze layers that match the regex patterns
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
if any(pattern.match(name) for pattern in compiled_patterns):
|
||||||
|
if is_main_process():
|
||||||
|
LOG.debug(f"unfreezing {name}")
|
||||||
|
param.requires_grad = True
|
||||||
@@ -4,6 +4,7 @@ import math
|
|||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple # noqa: F401
|
from typing import Optional, Tuple # noqa: F401
|
||||||
|
|
||||||
|
import addict
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -20,9 +21,12 @@ from transformers import ( # noqa: F401
|
|||||||
PreTrainedModel,
|
PreTrainedModel,
|
||||||
PreTrainedTokenizerBase,
|
PreTrainedTokenizerBase,
|
||||||
)
|
)
|
||||||
|
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
|
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
|
from axolotl.utils.chat_templates import chat_templates
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -52,9 +56,20 @@ def check_model_config(cfg: DictDefault, model_config: AutoConfig):
|
|||||||
def load_model_config(cfg):
|
def load_model_config(cfg):
|
||||||
model_config_name = cfg.base_model_config or cfg.base_model
|
model_config_name = cfg.base_model_config or cfg.base_model
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
model_config = AutoConfig.from_pretrained(
|
|
||||||
model_config_name, trust_remote_code=trust_remote_code
|
try:
|
||||||
)
|
model_config = AutoConfig.from_pretrained(
|
||||||
|
model_config_name, trust_remote_code=trust_remote_code
|
||||||
|
)
|
||||||
|
except ValueError as err:
|
||||||
|
if "mamba" in model_config_name:
|
||||||
|
return addict.Dict(
|
||||||
|
{
|
||||||
|
"model_type": "mamba",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
raise err
|
||||||
|
|
||||||
if cfg.model_config:
|
if cfg.model_config:
|
||||||
for key, val in cfg.model_config.items():
|
for key, val in cfg.model_config.items():
|
||||||
setattr(model_config, key, val)
|
setattr(model_config, key, val)
|
||||||
@@ -92,6 +107,7 @@ def load_tokenizer(cfg):
|
|||||||
"LlamaTokenizer",
|
"LlamaTokenizer",
|
||||||
"LlamaTokenizerFast",
|
"LlamaTokenizerFast",
|
||||||
"CodeLlamaTokenizer",
|
"CodeLlamaTokenizer",
|
||||||
|
"CodeLlamaTokenizerFast",
|
||||||
]
|
]
|
||||||
and hasattr(tokenizer, "pad_token")
|
and hasattr(tokenizer, "pad_token")
|
||||||
and not tokenizer.pad_token
|
and not tokenizer.pad_token
|
||||||
@@ -121,9 +137,43 @@ def load_tokenizer(cfg):
|
|||||||
|
|
||||||
if cfg.special_tokens:
|
if cfg.special_tokens:
|
||||||
for k, val in cfg.special_tokens.items():
|
for k, val in cfg.special_tokens.items():
|
||||||
|
# check if new special token is not already in tokenizer and
|
||||||
|
# is adapter training to make sure lora_modules_to_save is set
|
||||||
|
if (
|
||||||
|
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||||
|
and cfg.adapter
|
||||||
|
and (
|
||||||
|
not cfg.lora_modules_to_save
|
||||||
|
or not all(
|
||||||
|
x in cfg.lora_modules_to_save
|
||||||
|
for x in ["embed_tokens", "lm_head"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Please set lora_modules_to_save to ['embed_tokens', 'lm_head'] when using an adapter and changing the special tokens."
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer.add_special_tokens(
|
tokenizer.add_special_tokens(
|
||||||
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
{k: AddedToken(val, rstrip=False, lstrip=False, normalized=False)}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If we add bos_token and eos_token, we need to update the post processor to
|
||||||
|
# handle them correctly.
|
||||||
|
# https://github.com/huggingface/transformers/pull/24132
|
||||||
|
bos_or_eos_in_special_tokens = (
|
||||||
|
"bos_token" in cfg.special_tokens and "eos_token" in cfg.special_tokens
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
tokenizer.__class__.__name__
|
||||||
|
in (
|
||||||
|
"LlamaTokenizerFast",
|
||||||
|
"CodeLlamaTokenizerFast",
|
||||||
|
)
|
||||||
|
and bos_or_eos_in_special_tokens
|
||||||
|
):
|
||||||
|
tokenizer.update_post_processor()
|
||||||
|
|
||||||
if cfg.tokens:
|
if cfg.tokens:
|
||||||
tokenizer.add_tokens(
|
tokenizer.add_tokens(
|
||||||
[
|
[
|
||||||
@@ -137,6 +187,12 @@ def load_tokenizer(cfg):
|
|||||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||||
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
LOG.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}")
|
||||||
|
|
||||||
|
if cfg.chat_template:
|
||||||
|
tokenizer.chat_template = chat_templates(cfg.chat_template)
|
||||||
|
else:
|
||||||
|
LOG.info(
|
||||||
|
"No Chat template selected. Consider adding a chat template for easier inference."
|
||||||
|
)
|
||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
@@ -198,17 +254,6 @@ def load_model(
|
|||||||
|
|
||||||
LOG.info("patching with sdp attention")
|
LOG.info("patching with sdp attention")
|
||||||
hijack_llama_sdp_attention()
|
hijack_llama_sdp_attention()
|
||||||
elif cfg.is_llama_derived_model and cfg.landmark_attention:
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
|
||||||
MEM_TOKEN,
|
|
||||||
patch_llama_with_landmark_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
LOG.info("patching with landmark attention")
|
|
||||||
patch_llama_with_landmark_attn()
|
|
||||||
|
|
||||||
# Note: This might overwrite previous additional_special_tokens
|
|
||||||
tokenizer.add_special_tokens({"additional_special_tokens": [MEM_TOKEN]})
|
|
||||||
|
|
||||||
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
if cfg.is_mistral_derived_model and cfg.flash_attention and cfg.sample_packing:
|
||||||
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
from axolotl.monkeypatch.mistral_attn_hijack_flash import (
|
||||||
@@ -218,13 +263,17 @@ def load_model(
|
|||||||
LOG.info("patching with flash attention")
|
LOG.info("patching with flash attention")
|
||||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.xpos_rope:
|
if (
|
||||||
from axolotl.monkeypatch.xpos_rope_llama_monkey_patch import (
|
cfg.model_config_type == "mixtral"
|
||||||
replace_llama_rope_with_xpos_rope,
|
and cfg.flash_attention
|
||||||
|
and cfg.sample_packing
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.mixtral import (
|
||||||
|
replace_mixtral_attn_with_multipack_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.info("patching with xpos rope")
|
LOG.info("patching with flash attention")
|
||||||
replace_llama_rope_with_xpos_rope()
|
replace_mixtral_attn_with_multipack_flash_attn()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.is_llama_derived_model
|
cfg.is_llama_derived_model
|
||||||
@@ -242,6 +291,9 @@ def load_model(
|
|||||||
model_kwargs["max_memory"] = cfg.max_memory
|
model_kwargs["max_memory"] = cfg.max_memory
|
||||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||||
|
|
||||||
|
if is_deepspeed_zero3_enabled():
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.model_revision:
|
||||||
model_kwargs["revision"] = cfg.model_revision
|
model_kwargs["revision"] = cfg.model_revision
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
@@ -256,22 +308,42 @@ def load_model(
|
|||||||
**model_config.quantization_config
|
**model_config.quantization_config
|
||||||
)
|
)
|
||||||
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
if cfg.adapter == "qlora" and cfg.load_in_4bit:
|
||||||
|
bnb_config = {
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"llm_int8_threshold": 6.0,
|
||||||
|
"llm_int8_has_fp16_weight": False,
|
||||||
|
"bnb_4bit_compute_dtype": cfg.torch_dtype,
|
||||||
|
"bnb_4bit_use_double_quant": True,
|
||||||
|
"bnb_4bit_quant_type": "nf4",
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.bnb_config_kwargs:
|
||||||
|
bnb_config.update(cfg.bnb_config_kwargs)
|
||||||
|
|
||||||
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
model_kwargs["quantization_config"] = BitsAndBytesConfig(
|
||||||
load_in_4bit=True,
|
**bnb_config,
|
||||||
llm_int8_threshold=6.0,
|
|
||||||
llm_int8_has_fp16_weight=False,
|
|
||||||
bnb_4bit_compute_dtype=cfg.torch_dtype,
|
|
||||||
bnb_4bit_use_double_quant=True,
|
|
||||||
bnb_4bit_quant_type="nf4",
|
|
||||||
)
|
)
|
||||||
# sample packing uses custom FA2 patch
|
# sample packing uses custom FA2 patch
|
||||||
if cfg.flash_attention and not cfg.sample_packing:
|
if cfg.flash_attention:
|
||||||
if (
|
if not cfg.sample_packing:
|
||||||
cfg.is_llama_derived_model
|
if (
|
||||||
or cfg.is_falcon_derived_model
|
cfg.is_llama_derived_model
|
||||||
or cfg.is_mistral_derived_model
|
or cfg.is_falcon_derived_model
|
||||||
):
|
or cfg.is_mistral_derived_model
|
||||||
model_kwargs["use_flash_attention_2"] = True
|
or model_config.model_type == "mixtral"
|
||||||
|
):
|
||||||
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"flash_attention_2"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
if model_config.model_type == "mixtral":
|
||||||
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"flash_attention_2"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||||
|
"eager"
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
if cfg.is_llama_derived_model and not cfg.trust_remote_code and not cfg.gptq:
|
||||||
@@ -333,6 +405,20 @@ def load_model(
|
|||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
elif model_type == "MambaLMHeadModel":
|
||||||
|
# FIXME this is janky at best and hacked together to make it work
|
||||||
|
MambaLMHeadModel = fix_mamba_attn_for_loss() # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
model_kwargs["dtype"] = model_kwargs["torch_dtype"]
|
||||||
|
model_kwargs["device"] = torch.cuda.current_device()
|
||||||
|
del model_kwargs["torch_dtype"]
|
||||||
|
del model_kwargs["device_map"]
|
||||||
|
del model_kwargs["max_memory"]
|
||||||
|
|
||||||
|
model = MambaLMHeadModel.from_pretrained(
|
||||||
|
base_model,
|
||||||
|
**model_kwargs,
|
||||||
|
)
|
||||||
elif model_type and not cfg.trust_remote_code:
|
elif model_type and not cfg.trust_remote_code:
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(
|
||||||
@@ -392,13 +478,17 @@ def load_model(
|
|||||||
if cfg.resize_token_embeddings_to_32x
|
if cfg.resize_token_embeddings_to_32x
|
||||||
else len(tokenizer)
|
else len(tokenizer)
|
||||||
)
|
)
|
||||||
if model.get_input_embeddings().num_embeddings < embeddings_len:
|
if (
|
||||||
|
hasattr(model, "get_input_embeddings")
|
||||||
|
and model.get_input_embeddings().num_embeddings < embeddings_len
|
||||||
|
):
|
||||||
model.resize_token_embeddings(embeddings_len)
|
model.resize_token_embeddings(embeddings_len)
|
||||||
else:
|
else:
|
||||||
model.tie_weights()
|
model.tie_weights()
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "max_position_embeddings")
|
hasattr(model, "config")
|
||||||
|
and hasattr(model.config, "max_position_embeddings")
|
||||||
and model.config.max_position_embeddings
|
and model.config.max_position_embeddings
|
||||||
and cfg.sequence_len > model.config.max_position_embeddings
|
and cfg.sequence_len > model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
@@ -408,20 +498,22 @@ def load_model(
|
|||||||
model.config.max_position_embeddings = cfg.sequence_len
|
model.config.max_position_embeddings = cfg.sequence_len
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "bos_token_id")
|
hasattr(model, "config")
|
||||||
|
and hasattr(model.config, "bos_token_id")
|
||||||
and model.config.bos_token_id
|
and model.config.bos_token_id
|
||||||
and model.config.bos_token_id != tokenizer.bos_token_id
|
and model.config.bos_token_id != tokenizer.bos_token_id
|
||||||
):
|
):
|
||||||
model.config.bos_token_id = tokenizer.bos_token_id
|
model.config.bos_token_id = tokenizer.bos_token_id
|
||||||
|
|
||||||
if (
|
if (
|
||||||
hasattr(model.config, "eos_token_id")
|
hasattr(model, "config")
|
||||||
|
and hasattr(model.config, "eos_token_id")
|
||||||
and model.config.eos_token_id
|
and model.config.eos_token_id
|
||||||
and model.config.eos_token_id != tokenizer.eos_token_id
|
and model.config.eos_token_id != tokenizer.eos_token_id
|
||||||
):
|
):
|
||||||
model.config.eos_token_id = tokenizer.eos_token_id
|
model.config.eos_token_id = tokenizer.eos_token_id
|
||||||
|
|
||||||
if model.device.type == "cuda":
|
if hasattr(model, "device") and model.device.type == "cuda":
|
||||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||||
|
|
||||||
# make sure these are fp32 per Ramesh et al. (2021)
|
# make sure these are fp32 per Ramesh et al. (2021)
|
||||||
@@ -480,7 +572,8 @@ def load_model(
|
|||||||
requires_grad.append(f"{name}: {param.requires_grad}")
|
requires_grad.append(f"{name}: {param.requires_grad}")
|
||||||
if len(requires_grad) == 0:
|
if len(requires_grad) == 0:
|
||||||
LOG.warning("there are no parameters that require gradient updates")
|
LOG.warning("there are no parameters that require gradient updates")
|
||||||
model.config.use_cache = False
|
if hasattr(model, "config"):
|
||||||
|
model.config.use_cache = False
|
||||||
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.transform(model)
|
model = BetterTransformer.transform(model)
|
||||||
|
|||||||
@@ -131,8 +131,10 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset, tokenizer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Phi doesn't want the attention_mask feature when training
|
# Phi doesn't want the attention_mask feature when training
|
||||||
if "CodeGenTokenizer" in tokenizer.__class__.__name__ or (
|
if (
|
||||||
cfg.is_mistral_derived_model and cfg.flash_attention
|
"CodeGenTokenizer" in tokenizer.__class__.__name__
|
||||||
|
or (cfg.is_mistral_derived_model and cfg.flash_attention)
|
||||||
|
or cfg.model_config_type == "mamba"
|
||||||
):
|
):
|
||||||
train_dataset = train_dataset.remove_columns("attention_mask")
|
train_dataset = train_dataset.remove_columns("attention_mask")
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
@@ -153,7 +155,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
if update:
|
if update:
|
||||||
cfg.total_num_tokens = total_num_tokens
|
cfg.total_num_tokens = total_num_tokens
|
||||||
|
|
||||||
if not cfg.total_supervised_tokens:
|
skip_estimates = cfg.model_config_type == "mamba"
|
||||||
|
|
||||||
|
if not skip_estimates and not cfg.total_supervised_tokens:
|
||||||
total_supervised_tokens = (
|
total_supervised_tokens = (
|
||||||
train_dataset.data.column("labels")
|
train_dataset.data.column("labels")
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
@@ -167,7 +171,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
if update:
|
if update:
|
||||||
cfg.total_supervised_tokens = total_supervised_tokens
|
cfg.total_supervised_tokens = total_supervised_tokens
|
||||||
|
|
||||||
if cfg.sample_packing:
|
if not skip_estimates and cfg.sample_packing:
|
||||||
# we have to drop anything longer then sequence len otherwise
|
# we have to drop anything longer then sequence len otherwise
|
||||||
# flash attention with position ids fails
|
# flash attention with position ids fails
|
||||||
|
|
||||||
@@ -272,6 +276,7 @@ def prepare_optim_env(cfg):
|
|||||||
setup_fsdp_envs(cfg)
|
setup_fsdp_envs(cfg)
|
||||||
elif cfg.deepspeed:
|
elif cfg.deepspeed:
|
||||||
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
os.environ["ACCELERATE_USE_DEEPSPEED"] = "true"
|
||||||
|
os.environ["ACCELERATE_DEEPSPEED_CONFIG_FILE"] = cfg.deepspeed
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
|
|||||||
65
tests/e2e/test_mamba.py
Normal file
65
tests/e2e/test_mamba.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""
|
||||||
|
E2E tests for lora llama
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from axolotl.cli import load_datasets
|
||||||
|
from axolotl.common.cli import TrainerCliArgs
|
||||||
|
from axolotl.train import train
|
||||||
|
from axolotl.utils.config import normalize_config
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
from .utils import with_temp_dir
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||||
|
os.environ["WANDB_DISABLED"] = "true"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMistral(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test case for Llama models using LoRA
|
||||||
|
"""
|
||||||
|
|
||||||
|
@with_temp_dir
|
||||||
|
def test_fft(self, temp_dir):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"base_model": "state-spaces/mamba-130m",
|
||||||
|
"model_type": "MambaLMHeadModel",
|
||||||
|
"tokenizer_type": "AutoTokenizer",
|
||||||
|
"tokenizer_config": "EleutherAI/gpt-neox-20b",
|
||||||
|
"flash_attention": False,
|
||||||
|
"sequence_len": 1024,
|
||||||
|
"load_in_8bit": False,
|
||||||
|
"val_set_size": 0.0,
|
||||||
|
"datasets": [
|
||||||
|
{
|
||||||
|
"path": "mhenrichsen/alpaca_2k_test",
|
||||||
|
"type": "alpaca",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"gradient_checkpointing": False,
|
||||||
|
"num_epochs": 2,
|
||||||
|
"micro_batch_size": 2,
|
||||||
|
"gradient_accumulation_steps": 1,
|
||||||
|
"output_dir": temp_dir,
|
||||||
|
"learning_rate": 0.00001,
|
||||||
|
"optimizer": "adamw_torch",
|
||||||
|
"lr_scheduler": "cosine",
|
||||||
|
"max_steps": 20,
|
||||||
|
"save_steps": 10,
|
||||||
|
"eval_steps": None,
|
||||||
|
"save_safetensors": False,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
normalize_config(cfg)
|
||||||
|
cli_args = TrainerCliArgs()
|
||||||
|
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||||
|
|
||||||
|
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||||
|
assert (Path(temp_dir) / "pytorch_model.bin").exists()
|
||||||
File diff suppressed because one or more lines are too long
@@ -2,6 +2,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import unittest
|
import unittest
|
||||||
|
from copy import deepcopy
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
@@ -25,6 +26,50 @@ from axolotl.prompters import AlpacaPrompter, PromptStyle, ShareGPTPrompterV2
|
|||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
test_data = {
|
||||||
|
"multi_turn_sys": {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"single_turn_sys": {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "system", "value": "lorem"},
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"single_turn_no_sys": {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"multi_turn_no_sys": {
|
||||||
|
"conversations": [
|
||||||
|
{"from": "human", "value": "abc"},
|
||||||
|
{"from": "gpt", "value": "ipsum"},
|
||||||
|
{"from": "human", "value": "123"},
|
||||||
|
{"from": "gpt", "value": "sit"},
|
||||||
|
]
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def prompt_strat(conversation, tokenizer):
|
||||||
|
"Helper function to create a prompt strategy for testing."
|
||||||
|
prompter = ShareGPTPrompterV2(conversation=conversation)
|
||||||
|
return ShareGPTPromptTokenizingStrategy(
|
||||||
|
prompter,
|
||||||
|
tokenizer,
|
||||||
|
False,
|
||||||
|
2048,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
class TestPromptTokenizationStrategies(unittest.TestCase):
|
||||||
"""
|
"""
|
||||||
@@ -114,6 +159,70 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
in self._caplog.records[0].message
|
in self._caplog.records[0].message
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_sharegpt_llama(self):
|
||||||
|
"Make sure the sharegpt/llama is tokenized and formatted correctly."
|
||||||
|
strat = prompt_strat("llama-2", self.tokenizer)
|
||||||
|
|
||||||
|
def tokenize(conv):
|
||||||
|
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
||||||
|
|
||||||
|
def decode(ids):
|
||||||
|
return strat.tokenizer.decode(ids)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# System message, multi-turn conversations
|
||||||
|
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||||
|
assert decode(mt_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert mt_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
|
||||||
|
# System message, single-turn conversations
|
||||||
|
st_ids = tokenize(test_data['single_turn_sys'])
|
||||||
|
assert decode(st_ids) == '<s> [INST] <<SYS>>\nlorem\n<</SYS>>\n\nabc [/INST] ipsum</s>'
|
||||||
|
assert st_ids == [1, 518, 25580, 29962, 3532, 14816, 29903, 6778, 13, 29880, 3668, 13, 29966, 829, 14816, 29903, 6778, 13, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, single-turn
|
||||||
|
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||||
|
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||||
|
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, multi-turn
|
||||||
|
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
||||||
|
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s><s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 1, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
def test_sharegpt_mistral(self):
|
||||||
|
"Make sure the sharegpt/mistral is tokenized and formatted correctly."
|
||||||
|
strat = prompt_strat("mistral", self.tokenizer)
|
||||||
|
|
||||||
|
def tokenize(conv):
|
||||||
|
return strat.tokenize_prompt(deepcopy(conv))["input_ids"]
|
||||||
|
|
||||||
|
def decode(ids):
|
||||||
|
return strat.tokenizer.decode(ids)
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
# System message, multi-turn conversations
|
||||||
|
mt_ids = tokenize(test_data['multi_turn_sys'])
|
||||||
|
assert decode(mt_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert mt_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
|
||||||
|
# System message, single-turn conversations
|
||||||
|
st_ids = tokenize(test_data['single_turn_sys'])
|
||||||
|
assert decode(st_ids) == '<s> [INST] lorem\nabc [/INST] ipsum</s>'
|
||||||
|
assert st_ids == [1, 518, 25580, 29962, 301, 3668, 13, 10736, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, single-turn
|
||||||
|
ns_ids = tokenize(test_data['single_turn_no_sys'])
|
||||||
|
assert decode(ns_ids) == '<s> [INST] abc [/INST] ipsum</s>'
|
||||||
|
assert ns_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2]
|
||||||
|
|
||||||
|
# No system message, multi-turn
|
||||||
|
ns_mt_ids = tokenize(test_data['multi_turn_no_sys'])
|
||||||
|
assert decode(ns_mt_ids) == '<s> [INST] abc [/INST] ipsum</s> [INST] 123 [/INST] sit</s>'
|
||||||
|
assert ns_mt_ids == [1, 518, 25580, 29962, 25638, 518, 29914, 25580, 29962, 23421, 2, 518, 25580, 29962, 29871, 29896, 29906, 29941, 518, 29914, 25580, 29962, 7845, 2]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
def test_sharegpt_changes_roles(self):
|
def test_sharegpt_changes_roles(self):
|
||||||
conversation = {
|
conversation = {
|
||||||
"roles": ["USER", "CHARACTER"],
|
"roles": ["USER", "CHARACTER"],
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ Test cases for the tokenizer loading
|
|||||||
"""
|
"""
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_tokenizer
|
from axolotl.utils.models import load_tokenizer
|
||||||
|
|
||||||
@@ -31,6 +33,40 @@ class TestTokenizers(unittest.TestCase):
|
|||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
assert "Fast" not in tokenizer.__class__.__name__
|
assert "Fast" not in tokenizer.__class__.__name__
|
||||||
|
|
||||||
|
def test_special_tokens_modules_to_save(self):
|
||||||
|
# setting special_tokens to new token
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"adapter": "lora",
|
||||||
|
"special_tokens": {"bos_token": "[INST]"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*Please set lora_modules_to_save*",
|
||||||
|
):
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
# setting special_tokens but not changing from default
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"adapter": "lora",
|
||||||
|
"special_tokens": {"bos_token": "<s>"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
# non-adapter setting special_tokens
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"tokenizer_config": "huggyllama/llama-7b",
|
||||||
|
"special_tokens": {"bos_token": "[INST]"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
load_tokenizer(cfg)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -682,6 +682,43 @@ class ValidationTest(unittest.TestCase):
|
|||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_add_tokens_adapter(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{"adapter": "qlora", "load_in_4bit": True, "tokens": ["<|imstart|>"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "qlora",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"tokens": ["<|imstart|>"],
|
||||||
|
"lora_modules_to_save": ["embed_tokens"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
ValueError,
|
||||||
|
match=r".*lora_modules_to_save not properly set yet adding new tokens*",
|
||||||
|
):
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"adapter": "qlora",
|
||||||
|
"load_in_4bit": True,
|
||||||
|
"tokens": ["<|imstart|>"],
|
||||||
|
"lora_modules_to_save": ["embed_tokens", "lm_head"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
validate_config(cfg)
|
||||||
|
|
||||||
|
|
||||||
class ValidationWandbTest(ValidationTest):
|
class ValidationWandbTest(ValidationTest):
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user