Compare commits

..

19 Commits

Author SHA1 Message Date
Dan Saunders
e910e3e164 Revert "Multipack parallel bin packing (#2631)"
This reverts commit 8e4158cc0b.
2025-05-09 17:33:31 +00:00
Wing Lian
0f3587174d swap tinymodels that have safetensors for some ci tests (#2641) 2025-05-07 15:06:07 -04:00
xzuyn
25e6c5f9bd Add CAME Optimizer (#2385) 2025-05-07 10:31:46 -04:00
NanoCode012
32f51bca35 fix(doc): clarify instruction to delinearize llama4 similar to cli doc (#2644) [skip ci] 2025-05-07 10:29:47 -04:00
NanoCode012
9daa04da90 Fix: improve error message on failed dataset load (#2637) [skip ci]
* fix(log): clarify error on dataset loading failed

* fix: add path for easy tracking of broken config

* fix: improve error message based on pr feedback
2025-05-07 10:29:05 -04:00
Wing Lian
0d71b0aa5f Configurable embeddings upcast (#2621)
* fsdp embeddings should be float32 per comment

* patch peft to not upcast everything

* add tabs back to code check

* fix import

* add configurable option and fix check

* add check for dtypes

* move embeddings test to patch dir

* fix test

* fix comment and logic
2025-05-06 23:40:44 -04:00
Eric Meier
63aaccf85b Fix cut_cross_entropy plugin install (#2642) [skip ci] 2025-05-06 22:56:00 -04:00
Wing Lian
ff0fe767c8 xformers attention with packing (#2619)
* xformers attention with packing

* wire up the patch

* fix xformers + packing validation

* fix warning

* reorder the packing check

* fix fp16 / bf16 reset when using fp16 with bf16 auto

* fix seq lens calc to drop hanging sequences

* handle xformers patch for inference too

* fix batch size setter

* fix xformers inference

* add colab callback to fix inference post train

* PR feedback
2025-05-06 22:49:22 -04:00
Wing Lian
8e4158cc0b Multipack parallel bin packing (#2631)
* improve readability of multipack sampler

* parallel bin packing
fix error with lambda and pickling

make sure things are in float instead of np.float

* annotations and comments update

* support for configurable group and bin size for sample packing

* fix missing map back to original indices
2025-05-06 20:08:08 -04:00
Wing Lian
cd84325253 allow plugins to return their own dataset (#2617) [skip ci]
* allow plugins to return their own dataset

* add post_trainer_create and wire up

* add hook check

* address PR feedback:

* remove annotation causing circular import
2025-05-06 20:05:51 -04:00
NanoCode012
0b140fef83 feat(doc): add split_thinking docs (#2613) [skip ci]
* feat(doc): add split_thinking docs

* fix: link config.qmd to conversation.qmd for split_thinking example

* update thinking => reasoning_content in messages format

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-06 20:05:32 -04:00
Wing Lian
e4cfebe995 bump liger dep to 0.5.9 (#2640) [skip ci]
* bump liger dep to 0.5.9

* also upgrade vllm to post1, and datasets to 3.5.1
2025-05-06 20:05:19 -04:00
mhenrichsen
a6cac5dd32 Update lr_scheduler options in config.qmd to include additional scheduling strategies for improved training flexibility. (#2636) [skip ci] 2025-05-06 11:24:07 -04:00
Wing Lian
b71c0e3447 Print axolotl art if train is called outside of cli: (#2627) [skip ci] 2025-05-06 11:18:45 -04:00
Wing Lian
ddaebf8309 fix dpo eval override to call grandparent instead of the broken super (#2628) [skip ci] 2025-05-06 11:18:25 -04:00
Wing Lian
679743087a make sure gc_steps is used for all trainers (#2638) 2025-05-06 11:18:00 -04:00
Wing Lian
f720b6e72d repop cache (#2639)
* repop cache

* pre-cache as a step

* fix the name

* add reason for pytest skipif

* restore pytorch matrix

* remove max-parallel now that we've optimized this a bit
2025-05-06 11:09:07 -04:00
mhenrichsen
a980618fd0 Adds example for training a TTS model on top of a LLM. (#2614)
* Adds example for training a TTS model on top of a LLM.

* Update examples/orpheus/finetune.yml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/orpheus/finetune.yml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update README.md to clarify GPU requirements for finetuning Orpheus TTS model

* Update finetune.yml to use the new base model canopylabs/orpheus-3b-0.1-pretrained

* Update finetune.yml and README.md for consistency and clarity

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-06 10:11:06 +02:00
Emmanuel Ferdman
54960d4de0 Fix logging deprecation warnings (#2623)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-05-04 08:22:45 -04:00
43 changed files with 1138 additions and 619 deletions

View File

@@ -18,9 +18,96 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 2

View File

@@ -44,12 +44,98 @@ jobs:
env:
SKIP: no-commit-to-branch
pytest:
name: PyTest
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
@@ -121,21 +207,12 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
@@ -199,15 +276,6 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...

View File

@@ -32,6 +32,8 @@ tokenizer_legacy:
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
embeddings_skip_upcast:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init_weights:
@@ -73,11 +75,12 @@ load_in_8bit: true
load_in_4bit:
# Use CUDA bf16
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
# Use CUDA fp16
fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
@@ -184,8 +187,8 @@ datasets:
# adding a system turn with empty content.
drop_system_message:
# Optional[bool]. Whether to split the assistant turn based on a reasoning trace inside delimited tags
# defaults to False
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
# See example at `docs/dataset-formats/conversation.qmd`
split_thinking:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
@@ -547,7 +550,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)
@@ -609,6 +612,7 @@ lr_div_factor: # Learning rate div factor
# - optimi_adamw
# - ao_adamw_8bit
# - ao_adamw_fp8
# - came_pytorch
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:

View File

@@ -196,6 +196,34 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
```yaml
datasets:
- path: ...
type: chat_template
chat_template: qwen3
split_thinking: true
```
For example, a content can look like:
```json
{
"content": "<think>Some thinking outputs</think>Output after thinking."
}
```
After split, it will look like:
```json
{
"reasoning_content": "Some thinking outputs",
"content": "Output after thinking..."
}
```
## sharegpt
::: {.callout-important}

View File

@@ -34,3 +34,5 @@ We provide a script to delinearize Llama 4 linearized models into regular Huggin
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
Note: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.

341
examples/orpheus/README.md Normal file
View File

@@ -0,0 +1,341 @@
# Finetuning LLMs to output audio
In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.
The `finetune.yml` withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.
## Dataset pre-processing for pre-training
If you are adding another voice in English, please jump ahead to finetuning pre-processing.
For this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.
Using this code, it will download the SNAC model and add the correct tokens and upload the final dataset.
```python
import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
dsn = my_original_dataset_name
snapshot_download(
repo_id=dsn,
repo_type="dataset",
revision="main",
max_workers=64,
)
ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")
def tokenise_audio(waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(dtype=torch.float32)
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
waveform = resample_transform(waveform)
waveform = waveform.unsqueeze(0).to("cuda")
#generate the codes from snac
with torch.inference_mode():
codes = model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def add_codes(example):
# Always initialize codes_list to None
codes_list = None
try:
answer_audio = example.get("audio")
# If there's a valid audio array, tokenise it
if answer_audio and "array" in answer_audio:
audio_array = answer_audio["array"]
codes_list = tokenise_audio(audio_array)
except Exception as e:
print(f"Skipping row due to error: {e}")
# Keep codes_list as None if we fail
example["codes_list"] = codes_list
return example
ds = ds.map(add_codes, remove_columns=["audio"])
#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009
start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2
start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4
start_of_ai = tokeniser_length + 5
end_of_ai = tokeniser_length + 6
pad_token = tokeniser_length + 7
audio_tokens_start = tokeniser_length + 10
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2
ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
#@title Create Input Ids
def remove_duplicate_frames(example):
vals = example["codes_list"]
if len(vals) % 7 != 0:
raise ValueError("Input list length must be divisible by 7")
result = vals[:7]
removed_frames = 0
for i in range(7, len(vals), 7):
current_first = vals[i]
previous_first = result[-7]
if current_first != previous_first:
result.extend(vals[i:i+7])
else:
removed_frames += 1
example["codes_list"] = result
return example
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
def create_input_ids(example):
text_ids = tokenizer.encode({example['text']}, add_special_tokens=True)
text_ids.append(end_of_text)
example["text_tokens"] = text_ids
input_ids = (
[start_of_human]
+ example["text_tokens"]
+ [end_of_human]
+ [start_of_ai]
+ [start_of_speech]
+ example["codes_list"]
+ [end_of_speech]
+ [end_of_ai]
)
example["input_ids"] = input_ids
example["labels"] = input_ids
example["attention_mask"] = [1] * len(input_ids)
return example
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
ds = ds.remove_columns(columns_to_remove)
ds.push_to_hub(name_to_push_dataset_to)
```
## Finetune pre-processing
Use this code to add a new voice.
```python
import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
dsn = my_original_dataset_name
snapshot_download(
repo_id=dsn,
repo_type="dataset",
revision="main",
max_workers=64,
)
ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")
def tokenise_audio(waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(dtype=torch.float32)
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
waveform = resample_transform(waveform)
waveform = waveform.unsqueeze(0).to("cuda")
#generate the codes from snac
with torch.inference_mode():
codes = model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def add_codes(example):
# Always initialize codes_list to None
codes_list = None
try:
answer_audio = example.get("audio")
# If there's a valid audio array, tokenise it
if answer_audio and "array" in answer_audio:
audio_array = answer_audio["array"]
codes_list = tokenise_audio(audio_array)
except Exception as e:
print(f"Skipping row due to error: {e}")
# Keep codes_list as None if we fail
example["codes_list"] = codes_list
return example
ds = ds.map(add_codes, remove_columns=["audio"])
#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009
start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2
start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4
start_of_ai = tokeniser_length + 5
end_of_ai = tokeniser_length + 6
pad_token = tokeniser_length + 7
audio_tokens_start = tokeniser_length + 10
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2
ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
#@title Create Input Ids
def remove_duplicate_frames(example):
vals = example["codes_list"]
if len(vals) % 7 != 0:
raise ValueError("Input list length must be divisible by 7")
result = vals[:7]
removed_frames = 0
for i in range(7, len(vals), 7):
current_first = vals[i]
previous_first = result[-7]
if current_first != previous_first:
result.extend(vals[i:i+7])
else:
removed_frames += 1
example["codes_list"] = result
return example
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
tok_info = '''*** HERE you can modify the text prompt
i.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:
f"{example["source"]}: {example["text"]}", as is passed.
'''
print(tok_info)
def create_input_ids(example):
text_ids = tokenizer.encode(f"{example['speaker_id']}: {example['text']}", add_special_tokens=True)
text_ids.append(end_of_text)
example["text_tokens"] = text_ids
input_ids = (
[start_of_human]
+ example["text_tokens"]
+ [end_of_human]
+ [start_of_ai]
+ [start_of_speech]
+ example["codes_list"]
+ [end_of_speech]
+ [end_of_ai]
)
example["input_ids"] = input_ids
example["labels"] = input_ids
example["attention_mask"] = [1] * len(input_ids)
return example
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
ds = ds.remove_columns(columns_to_remove)
ds.push_to_hub(name_to_push_dataset_to)
```
## Training
After preprocessing is done, fill out the blanks in finetune.yml and simply run `axolotl train finetune.yml`
## Inference
For inference, please refer to the original [orpheus github](https://github.com/canopyai/Orpheus-TTS/tree/main).

View File

@@ -0,0 +1,52 @@
base_model: canopylabs/orpheus-3b-0.1-pretrained
hub_model_id: <your-hub-model-id>
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
datasets:
- path: <your-hf-dataset-id>
type: # leave empty to load pre-tokenized
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 20
evals_per_epoch: 5
saves_per_epoch: 5
weight_decay: 0.05
special_tokens:
pad_token: <custom_token_7>

View File

@@ -6,16 +6,17 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.8
liger-kernel==0.5.9
# END section
packaging==23.2
huggingface_hub==0.31.0
peft==0.15.2
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
datasets==3.5.1
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.1.0

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.5"]
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.5"]
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -142,6 +142,7 @@ extras_require = {
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]",

View File

@@ -18,6 +18,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
@@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if cfg.rl:
plugin_manager = PluginManager.get_instance()
if plugin_manager.load_datasets(cfg, preprocess=True):
pass
elif cfg.rl:
load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,10 +43,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -21,6 +21,7 @@ import importlib.util
import inspect
import logging
import math
import os
import sys
from abc import abstractmethod
from pathlib import Path
@@ -72,6 +73,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
@@ -168,6 +170,9 @@ class TrainerBuilderBase(abc.ABC):
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -249,9 +254,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
@@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
@@ -702,6 +708,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:

View File

@@ -114,8 +114,6 @@ class AxolotlTrainer(
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -247,7 +247,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
)
# Base evaluation
initial_output = super().evaluation_loop(
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,

View File

@@ -26,6 +26,8 @@ from typing import OrderedDict
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.utils.dict import DictDefault
class BasePlugin:
"""
@@ -36,11 +38,13 @@ class BasePlugin:
Methods:
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -63,20 +67,32 @@ class BasePlugin:
None
"""
def get_input_args(self):
def get_input_args(self) -> str | None:
"""
Returns a pydantic model for the plugin's input arguments.
"""
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
"""
Loads and preprocesses the dataset for training.
Args:
cfg: The configuration for the plugin.
preprocess: Whether this is the preprocess step of the datasets.
Returns:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
None
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
@@ -91,59 +107,71 @@ class BasePlugin:
"""
Performs actions after the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Parameters:
cfg (dict): The global axolotl configuration.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
class: The class for the trainer.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
object: The created optimizer.
"""
def create_lr_scheduler(
@@ -152,26 +180,26 @@ class BasePlugin:
"""
Creates and returns a learning rate scheduler.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns:
object (LRScheduler): The created learning rate scheduler.
object (LRScheduler): The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
setup callbacks before creating the trainer.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
@@ -182,12 +210,12 @@ class BasePlugin:
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added
List[callable]: A list of callback functions to be added
"""
return []
@@ -195,23 +223,23 @@ class BasePlugin:
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Args:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
None
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
None
"""
@@ -338,6 +366,27 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
def load_datasets(self, cfg, preprocess: bool = False):
"""
Calls the load_datasets method of each registered plugin.
Args:
cfg: The configuration for the plugins.
preprocess : Whether this is preprocess step of the datasets.
Returns:
dataset_meta: The dataset metadata loaded from all registered plugins.
"""
return_ds_meta = None
for plugin in self.plugins.values():
dataset_meta = plugin.load_datasets(cfg, preprocess)
if dataset_meta is not None:
if return_ds_meta is None:
return_ds_meta = dataset_meta
else:
raise RuntimeError("Multiple plugins loaded datasets")
return return_ds_meta
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
@@ -422,6 +471,20 @@ class PluginManager:
return trainer_cls
return None
def post_trainer_create(self, cfg, trainer):
"""
Calls the post_trainer_create method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_trainer_create(cfg, trainer)
def create_optimizer(self, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

View File

@@ -72,7 +72,7 @@ class CutCrossEntropyPlugin(BasePlugin):
if cfg.cut_cross_entropy:
self._check_requirements()
from .monkeypatch.patch import (
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
cce_patch,
)

View File

@@ -1,134 +0,0 @@
"""
chunked ce loss
"""
from typing import List, Optional
import torch
import torch.nn.functional as F
# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss
class CEWithChunkedOutputLoss(torch.nn.Module):
"""
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
"""
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
def compute_cross_entropy(
self,
logits: torch.Tensor,
labels: torch.Tensor,
normalize: bool = True, # pylint: disable=unused-argument
) -> torch.Tensor:
"""
Upcast logits to fp32 and compute cross entropy loss.
"""
return F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
)
def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor:
"""
Args:
logits (List[torch.Tensor]): List of chunked logits of length
``self.num_output_chunks``, where each chunk has shape
``(batch_size, num_tokens / num_output_chunks, vocab_size)``.
labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.
reduction (str): The reduction to apply to the output.
Returns:
torch.Tensor: Cross entropy loss of shape (1,).
"""
total_elements = (labels != self.ignore_index).sum()
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
labels = [
target_chunk.reshape(-1)
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
]
# reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]
logits = [
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
]
# compute one chunk at a time
total_loss = 0.0
for logits_chunk, labels_chunk in zip(logits, labels):
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)
if reduction == "sum":
return total_loss
return total_loss / total_elements
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor"
)
return loss_fn_ce
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
def chunked_fix_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
logit_chunks = [ # pylint: disable=unnecessary-comprehension
chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)
]
loss = loss_fn_ce(logit_chunks, target, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
def for_causal_lm_chunked_loss(
logits,
labels,
vocab_size: int = None, # pylint: disable=unused-argument
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# skip the upcast to float since we handle that in the chunking loss
if shift_labels is None:
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Skip Flattening the tokens
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = chunked_fix_cross_entropy(
logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
)
return loss
return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss
)

View File

@@ -24,7 +24,7 @@ PATCHED_PREPARE_CODE = """
for name, param in model.named_parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit" and "norm" in name:
) and param.__class__.__name__ != "Params4bit" and all(embed_name not in name for embed_name in ["embed_tokens", "lm_head"]):
param.data = param.data.to(torch.float32)
"""

View File

@@ -2,6 +2,7 @@
import importlib
import inspect
import logging
import os
import signal
import sys
@@ -12,7 +13,6 @@ from typing import Any, Dict
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
@@ -42,7 +42,7 @@ try:
except ImportError:
BetterTransformer = None
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def setup_model_and_tokenizer(
@@ -63,7 +63,6 @@ def setup_model_and_tokenizer(
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
@@ -528,6 +527,9 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
@@ -550,7 +552,6 @@ def train(
if not cfg.use_ray:
cleanup_distributed()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -868,3 +868,28 @@ class GCCallback(TrainerCallback):
):
torch.cuda.empty_cache()
gc.collect()
def colab_inference_post_train_callback(trainer: Trainer):
class ColabCallback(TrainerCallback):
"""Callback to prep model for inference on Google Colab"""
def __init__(self, cfg):
self.gpu_name = torch.cuda.get_device_name(0)
self.cfg = cfg
def on_train_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
"""
handle T4 gpu, we need to convert attention to eager for inference
"""
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
trainer.model.gradient_checkpointing_disable()
trainer.model.config.use_cache = True
trainer.model.eval()
return ColabCallback

View File

@@ -281,6 +281,10 @@ def load_dataset_w_config(
**load_ds_kwargs,
)
if not ds:
raise ValueError("unhandled dataset load")
raise ValueError(
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({config_dataset.path}). Try double-check your path / name / data_files. "
"This is not caused by the dataset type."
)
return ds

View File

@@ -1,15 +1,36 @@
"""custom checkpointing utils"""
import importlib
from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
if transformers_version > version.parse("4.51.3"):
from transformers.modeling_layers import GradientCheckpointingLayer
def uses_gc_layers(decoder_layer):
return isinstance(decoder_layer.func.__self__, GradientCheckpointingLayer)
else:
def uses_gc_layers(_):
return False
def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__

View File

@@ -561,21 +561,12 @@ class ModelLoader:
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.chunked_cross_entropy:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
else:
patch_chunked_ce_loss_fn()
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
if self.cfg.adapter:
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
patch_peft_prep_code()
@@ -912,7 +903,7 @@ class ModelLoader:
"bnb_4bit_compute_dtype": self.cfg.torch_dtype,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.uint8,
"bnb_4bit_quant_storage": torch.bfloat16,
}
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
self.cfg.deepspeed or self.cfg.fsdp
@@ -1328,8 +1319,11 @@ class ModelLoader:
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
if self.cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
if not self.cfg.fsdp:
# we don't run this during FSDP because this will leave mixed
# float and bfloat16 dtypes in the model which FSDP doesn't like
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
embedding_modules = []
self.convert_embedding_modules_dtype(
embedding_modules,
dist_dtype=torch.float32,

View File

@@ -1,13 +1,10 @@
# pylint: skip-file
"""
Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences
into fixed-capacity batches to optimize memory usage and training throughput.
Multipack Batch Sampler
"""
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import Iterable, List, Union
from typing import Any, Iterable, List, Union
import numba
import numpy as np
@@ -16,39 +13,26 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
"""
First-fit-decreasing bin packing algorithm check
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
Checks if sequences with the given lengths could fit in the specified number of bins
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
num_bins: Number of bins available
Returns:
True if all sequences can be packed, False otherwise
"""
# Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
# Try to place each sequence in the first bin it fits
for size in sequence_lengths:
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(num_bins):
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
# If no bin could fit this sequence, packing failed
if not_found:
return False
@@ -56,380 +40,240 @@ def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
@numba.njit
def pack_group(
sequence_lengths: np.ndarray,
group_offset: int,
bin_capacity: int,
max_bins: int,
bin_size: int,
safe_mode: bool = True,
):
"""
Pack a group of sequences into bins using First-Fit Decreasing algorithm
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
Args:
sequence_lengths: Array of sequence lengths
group_offset: Offset to apply to indices when returning results
bin_capacity: Maximum capacity of each bin
max_bins: Maximum number of bins to use
bin_size: Maximum number of sequences per bin
safe_mode: If True, use a more conservative packing approach
indices = np.argsort(a)[::-1]
a = a[indices]
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True
for bin_idx, _ in enumerate(bins_remaining_space):
if (
bins_remaining_space[bin_idx] >= size
and len(bins_assigned_sequences[bin_idx]) < bin_size
):
bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(global_idx)
add_new_bin = False
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
# Create a new bin if needed and if we haven't reached the limit
if add_new_bin:
if len(bins_remaining_space) >= max_bins and safe_mode:
# In safe mode, skip items that would exceed max_bins
continue
bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([global_idx])
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
# Safety check to avoid infinite bins
if len(bins_remaining_space) > len(sequence_lengths):
break
return bins_assigned_sequences
# Define a standalone function for multiprocessing
def _process_group(args):
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
return pack_group(
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
)
def pack_parallel(
sequence_lengths: np.ndarray,
bin_capacity: int,
group_size: int,
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
):
"""
Pack sequences into bins using parallel processing
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin as total number of tokens
group_size: Number of sequences to process in each group
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count()))
# Create tasks for parallel processing
tasks = []
for i in range(0, num_items, group_size):
group_lengths = sequence_lengths[i : i + group_size]
max_bins = len(group_lengths) # Allow as many bins as items in the group
tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode))
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
return all_bins
return bins_result
@numba.njit
def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
"""
Sequential allocator that preserves example order
Parameters:
sequence_lengths: The lengths of all examples
rank: The current rank (for distributed training)
bin_capacity: The capacity of each bin (maximum sequence length)
num_ranks: Number of ranks (processes/GPUs)
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Returns:
rank_batches: List of batches for the current rank
total_tokens_used: Number of actual example tokens
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
rank_batches = []
total_tokens_used = 0
result = []
total_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = []
remaining_capacity = bin_capacity
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
# Process each sequence in order
for idx, size in enumerate(sequence_lengths):
for idx, size in enumerate(lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_tokens_used += size
total_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = bin_capacity - size
total_tokens_used += size
remaining_capacity = c - size
total_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every num_ranks-th bin
for bin_idx in range(rank, len(all_bins), num_ranks):
rank_batches.append(all_bins[bin_idx])
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
return rank_batches, total_tokens_used, len(all_bins) * bin_capacity
return result, total_used, len(all_bins) * c
class MultipackBatchSampler(BatchSampler):
"""
Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both parallel packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
"""Batch sampler class for multipack"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop incomplete batches
num_count_samples: int = 16, # Number of samples to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
**kwargs, # pylint: disable=unused-argument
batch_size: int,
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_max_len = batch_max_len
self.lengths = np.array(lengths, dtype=np.int32)
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
self.group_size = group_size
self.bin_size = bin_size
self.num_processes = num_processes
self.safe_mode = safe_mode
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# Efficiency statistics tracking
self.total_tokens_used = 0
self.total_token_slots = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
self.num_count_samples = num_count_samples
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch
self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False):
"""
Generate packed batches for training
indices = [idx for idx in self.sampler]
Args:
set_stats: Whether to update efficiency statistics
Returns:
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
if self._batches is not None:
return self._batches
# Get indices from the sampler
indices = [ # pylint: disable=unnecessary-comprehension
idx for idx in self.sampler
]
# Get lengths of the selected sequences
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
# Pack sequences into bins using either sequential or parallel packing
if self.sequential:
bins, total_used, total_slots = allocate_sequentially(
lengths,
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
bin_capacity=self.batch_max_len,
num_ranks=1,
c=self.batch_max_len,
n=1,
)
else:
# Use parallel packing
all_bins = pack_parallel(
lengths,
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
num_processes=self.num_processes,
safe_mode=self.safe_mode,
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
# Map bin indices back to original indices
bins = [
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
]
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
# Group bins into batches (each batch contains batch_size bins)
batches = [
bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)
[
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
]
# Drop last batch if requested and it's incomplete
if self.drop_last and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
# Adjust total_slots if we dropped a batch
if not self.sequential:
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
# Update statistics if requested
# statistics
if set_stats:
self.total_tokens_used += total_used
self.total_token_slots += total_slots
self.eff_total_used += total_used
self.eff_total_slots += total_slots
self._batches = batches
return batches
def __iter__(self):
"""
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks
to ensure distributed training balance
"""
batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# Truncate batches to ensure all ranks have the same number of batches
# make sure the batches we iterate over is truncated to the same min length across all ranks
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
"""
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
if self.total_token_slots == 0:
self.generate_batches(set_stats=True)
if self.total_token_slots == 0:
return 0.0
# Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots)
return self.eff_total_used / self.eff_total_slots
def gather_efficiency(self):
"""
Gather and synchronize packing efficiency estimates across all distributed ranks
Returns a conservative efficiency estimate based on the measurements
"""
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate
max_eff = max(float(eff) for eff in estimates)
return math.floor(0.997 * max_eff)
return math.floor(0.997 * max(estimates))
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
# Quantize to 0.5% intervals for stability
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
"""
Gather and synchronize batch counts across all distributed ranks
Returns the minimum number of batches available on any rank
"""
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates))
# Find minimum batch count across ranks to ensure balance
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
"""
Return the total number of batches that will be yielded by this sampler
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if self._batches is None:
self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
if not self.len_across_ranks:
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks

View File

@@ -82,6 +82,7 @@ class AxolotlInputConfig(
mean_resizing_embeddings: bool | None = False
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = None
embeddings_skip_upcast: bool | None = None
rl: RLType | None = None
trl: TRLConfig | None = Field(
@@ -242,9 +243,6 @@ class AxolotlInputConfig(
unsloth_rms_norm: bool | None = None
unsloth_rope: bool | None = None
chunked_cross_entropy: bool | None = None
chunked_cross_entropy_num_chunks: int | None = None
lora_mlp_kernel: bool | None = None
lora_qkv_kernel: bool | None = None
lora_o_kernel: bool | None = None
@@ -464,9 +462,10 @@ class AxolotlInputConfig(
and not data.get("flash_attention")
and not data.get("sdp_attention")
and not data.get("flex_attention")
and not data.get("xformers_attention")
):
LOG.warning(
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
)
return data

View File

@@ -53,4 +53,5 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name

View File

@@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel):
lr_groups: list[LrGroup] | None = None
adam_epsilon: float | None = None
adam_epsilon2: float | None = None
adam_beta1: float | None = None
adam_beta2: float | None = None
adam_beta3: float | None = None
max_grad_norm: float | None = None
num_epochs: float = Field(default=1.0)

View File

@@ -4,6 +4,7 @@ shared pytest fixtures
import functools
import importlib
import os
import shutil
import sys
import tempfile
@@ -529,31 +530,32 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass
@pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
reason="Not running in CI cache preload",
)
def test_load_fixtures(
download_smollm2_135m_model,
download_qwen_2_5_half_billion_model,
download_tatsu_lab_alpaca_dataset,
download_mhenrichsen_alpaca_2k_dataset,
download_mhenrichsen_alpaca_2k_w_revision_dataset,
download_mlabonne_finetome_100k_dataset,
download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
download_argilla_dpo_pairs_dataset,
download_tiny_shakespeare_dataset,
download_deepseek_model_fixture,
download_huggyllama_model_fixture,
download_llama_1b_model_fixture,
download_llama3_8b_model_fixture,
download_llama3_8b_instruct_model_fixture,
download_phi_35_mini_model_fixture,
download_phi_3_medium_model_fixture,
download_mistral_7b_model_fixture,
download_gemma_2b_model_fixture,
download_gemma2_9b_model_fixture,
download_mlx_mistral_7b_model_fixture,
download_llama2_model_fixture,
):
pass

View File

@@ -29,6 +29,12 @@ class LogHooksPlugin(BasePlugin):
except FileNotFoundError:
pass
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_trainer_create\n")
def pre_model_load(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -165,6 +171,7 @@ class TestPluginHooks:
) as f:
file_contents = f.readlines()
file_contents = "\n".join(file_contents)
assert "post_trainer_create" in file_contents
assert "pre_model_load" in file_contents
assert "post_model_build" in file_contents
assert "pre_lora_load" in file_contents

View File

@@ -479,7 +479,7 @@ class TestMultiGPULlama:
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -29,12 +29,12 @@ from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "openaccess-ai-collective/tiny-mistral",
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "Qwen/Qwen2-7B",
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
@@ -44,7 +44,7 @@ MODEL_CONFIGS = [
"dtype": torch.float32,
},
{
"name": "mhenrichsen/gemma-2b",
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
@@ -156,7 +156,9 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
"trl-internal-testing/tiny-Gemma2ForCausalLM",
torch_dtype=torch.float16,
device_map="cuda:0",
)
peft_config = get_peft_config(
{

View File

@@ -6,6 +6,8 @@ import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalconPatched(unittest.TestCase):
Test case for Falcon models
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -71,6 +74,7 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -28,7 +28,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -76,7 +76,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,

View File

@@ -56,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
def test_mistral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -0,0 +1,63 @@
"""
Test case for handling embeddings when using peft
"""
import torch
from axolotl.train import setup_model_and_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
class TestLlamaPeftEmbeddings:
"""
test class for handling embeddings when using peft
"""
def test_peft_embeddings_upcast(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"trust_remote_code": True,
"sequence_len": 512,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": False,
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
model, _, _, _ = setup_model_and_tokenizer(cfg)
# Check if the embeddings are upcast correctly
# only embed_tokens is a parameter that may be upcast
assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16
assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16

View File

@@ -15,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -26,6 +26,7 @@ class TestResumeLlama:
Test case for resuming training of llama models
"""
@require_torch_2_6_0
def test_resume_lora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -62,6 +63,7 @@ class TestResumeLlama:
"save_total_limit": 5,
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
}
)
if is_torch_bf16_gpu_available():

View File

@@ -19,14 +19,11 @@ class TestE2eEvaluate:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{

View File

@@ -6,6 +6,8 @@ import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalcon(unittest.TestCase):
Test case for falcon
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -74,6 +77,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
# pylint: disable=duplicate-code
@@ -129,6 +133,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -30,7 +30,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
@@ -77,7 +77,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,

View File

@@ -199,3 +199,50 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_came_pytorch(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "came_pytorch",
"adam_beta3": 0.9999,
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -1,40 +0,0 @@
"""
test suite for chunked cross entropy
"""
import pytest
import torch
from torch import nn
from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss
@pytest.fixture
def chunked_fixtures():
model_dim = 512
vocab_size = 1024 * 256
seq_len = 2048
batch_size = 1
lm_head = nn.Linear(model_dim, vocab_size)
hidden_state = torch.randn(batch_size, seq_len, model_dim)
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
return lm_head, hidden_state, labels, vocab_size
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
lm_loss = get_causal_lm_loss()
logits = lm_head(hidden_state)
chunked_lm_loss = lm_loss(logits, labels)
logits_flattened = logits.view(-1, vocab_size)
labels_flattened = labels.view(-1)
loss = nn.functional.cross_entropy(
logits_flattened.float(), labels_flattened, reduction="mean"
)
assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)

View File

@@ -414,7 +414,6 @@ class TestDatasetPreparation:
snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
)
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)