Compare commits

..

8 Commits

Author SHA1 Message Date
Wing Lian
d790371b64 bump peft to 3.5.1 2025-05-06 11:38:14 -04:00
mhenrichsen
a6cac5dd32 Update lr_scheduler options in config.qmd to include additional scheduling strategies for improved training flexibility. (#2636) [skip ci] 2025-05-06 11:24:07 -04:00
Wing Lian
b71c0e3447 Print axolotl art if train is called outside of cli: (#2627) [skip ci] 2025-05-06 11:18:45 -04:00
Wing Lian
ddaebf8309 fix dpo eval override to call grandparent instead of the broken super (#2628) [skip ci] 2025-05-06 11:18:25 -04:00
Wing Lian
679743087a make sure gc_steps is used for all trainers (#2638) 2025-05-06 11:18:00 -04:00
Wing Lian
f720b6e72d repop cache (#2639)
* repop cache

* pre-cache as a step

* fix the name

* add reason for pytest skipif

* restore pytorch matrix

* remove max-parallel now that we've optimized this a bit
2025-05-06 11:09:07 -04:00
mhenrichsen
a980618fd0 Adds example for training a TTS model on top of a LLM. (#2614)
* Adds example for training a TTS model on top of a LLM.

* Update examples/orpheus/finetune.yml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update examples/orpheus/finetune.yml

Co-authored-by: NanoCode012 <nano@axolotl.ai>

* Update README.md to clarify GPU requirements for finetuning Orpheus TTS model

* Update finetune.yml to use the new base model canopylabs/orpheus-3b-0.1-pretrained

* Update finetune.yml and README.md for consistency and clarity

---------

Co-authored-by: NanoCode012 <nano@axolotl.ai>
2025-05-06 10:11:06 +02:00
Emmanuel Ferdman
54960d4de0 Fix logging deprecation warnings (#2623)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-05-04 08:22:45 -04:00
22 changed files with 670 additions and 809 deletions

View File

@@ -44,12 +44,98 @@ jobs:
env:
SKIP: no-commit-to-branch
pytest:
name: PyTest
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
max-parallel: 2
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
@@ -121,21 +207,12 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest-sdist:
name: PyTest from Source Dist
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 1
matrix:
python_version: ["3.11"]
pytorch_version: ["2.5.1", "2.6.0", "2.7.0"]
@@ -199,15 +276,6 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
docker-e2e-tests-1st:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...

View File

@@ -547,7 +547,7 @@ gradient_checkpointing: false
early_stopping_patience: 3
# Specify a scheduler and kwargs to use with the optimizer
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | empty for cosine
lr_scheduler: # 'one_cycle' | 'rex' | 'log_sweep' | 'linear' | 'cosine_with_restarts' | 'polynomial' | 'constant' | 'constant_with_warmup' | 'inverse_sqrt' | 'reduce_lr_on_plateau' | 'cosine_with_min_lr' | 'warmup_stable_decay' | empty for cosine
lr_scheduler_kwargs:
cosine_min_lr_ratio: # decay lr to some percentage of the peak lr, e.g. cosine_min_lr_ratio=0.1 for 10% of peak lr
cosine_constant_lr_ratio: # freeze lr at some percentage of the step, e.g. cosine_constant_lr_ratio=0.8 means start cosine_min_lr at 80% of training step (https://arxiv.org/pdf/2308.04014.pdf)

341
examples/orpheus/README.md Normal file
View File

@@ -0,0 +1,341 @@
# Finetuning LLMs to output audio
In this example, we finetune Orpcanopylabs/orpheus-tts-0.1-pretrained (a LLaMA 3.2 3b model) to output audio.
The `finetune.yml` withe current settings will run on any Nvidia GPU with 45GB VRAM or more. If you adjust the batch size it can easily run on any GPU under 24GB.
## Dataset pre-processing for pre-training
If you are adding another voice in English, please jump ahead to finetuning pre-processing.
For this to work, we need to preprocess our dataset. Since we are expecting to output audio, we will need to add tokens to the tokenizer.
Using this code, it will download the SNAC model and add the correct tokens and upload the final dataset.
```python
import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
dsn = my_original_dataset_name
snapshot_download(
repo_id=dsn,
repo_type="dataset",
revision="main",
max_workers=64,
)
ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")
def tokenise_audio(waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(dtype=torch.float32)
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
waveform = resample_transform(waveform)
waveform = waveform.unsqueeze(0).to("cuda")
#generate the codes from snac
with torch.inference_mode():
codes = model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def add_codes(example):
# Always initialize codes_list to None
codes_list = None
try:
answer_audio = example.get("audio")
# If there's a valid audio array, tokenise it
if answer_audio and "array" in answer_audio:
audio_array = answer_audio["array"]
codes_list = tokenise_audio(audio_array)
except Exception as e:
print(f"Skipping row due to error: {e}")
# Keep codes_list as None if we fail
example["codes_list"] = codes_list
return example
ds = ds.map(add_codes, remove_columns=["audio"])
#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009
start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2
start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4
start_of_ai = tokeniser_length + 5
end_of_ai = tokeniser_length + 6
pad_token = tokeniser_length + 7
audio_tokens_start = tokeniser_length + 10
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2
ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
#@title Create Input Ids
def remove_duplicate_frames(example):
vals = example["codes_list"]
if len(vals) % 7 != 0:
raise ValueError("Input list length must be divisible by 7")
result = vals[:7]
removed_frames = 0
for i in range(7, len(vals), 7):
current_first = vals[i]
previous_first = result[-7]
if current_first != previous_first:
result.extend(vals[i:i+7])
else:
removed_frames += 1
example["codes_list"] = result
return example
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
def create_input_ids(example):
text_ids = tokenizer.encode({example['text']}, add_special_tokens=True)
text_ids.append(end_of_text)
example["text_tokens"] = text_ids
input_ids = (
[start_of_human]
+ example["text_tokens"]
+ [end_of_human]
+ [start_of_ai]
+ [start_of_speech]
+ example["codes_list"]
+ [end_of_speech]
+ [end_of_ai]
)
example["input_ids"] = input_ids
example["labels"] = input_ids
example["attention_mask"] = [1] * len(input_ids)
return example
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
ds = ds.remove_columns(columns_to_remove)
ds.push_to_hub(name_to_push_dataset_to)
```
## Finetune pre-processing
Use this code to add a new voice.
```python
import torch
from snac import SNAC
from datasets import load_dataset
from huggingface_hub import snapshot_download
from datasets import load_dataset
import random
import torchaudio.transforms as T
from transformers import AutoTokenizer
import os
my_original_dataset_name = "<huggingface-id-of-dataset-that-we-want-to-preprocess>"
name_to_push_dataset_to = "<huggingface-id-of-where-to-save-dataset>"
dsn = my_original_dataset_name
snapshot_download(
repo_id=dsn,
repo_type="dataset",
revision="main",
max_workers=64,
)
ds = load_dataset(dsn, split="train")
ds_sample_rate = ds[0]["audio"]["sampling_rate"]
model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz")
model = model.to("mps")
def tokenise_audio(waveform):
waveform = torch.from_numpy(waveform).unsqueeze(0)
waveform = waveform.to(dtype=torch.float32)
resample_transform = T.Resample(orig_freq=ds_sample_rate, new_freq=24000)
waveform = resample_transform(waveform)
waveform = waveform.unsqueeze(0).to("cuda")
#generate the codes from snac
with torch.inference_mode():
codes = model.encode(waveform)
all_codes = []
for i in range(codes[0].shape[1]):
all_codes.append(codes[0][0][i].item()+128266)
all_codes.append(codes[1][0][2*i].item()+128266+4096)
all_codes.append(codes[2][0][4*i].item()+128266+(2*4096))
all_codes.append(codes[2][0][(4*i)+1].item()+128266+(3*4096))
all_codes.append(codes[1][0][(2*i)+1].item()+128266+(4*4096))
all_codes.append(codes[2][0][(4*i)+2].item()+128266+(5*4096))
all_codes.append(codes[2][0][(4*i)+3].item()+128266+(6*4096))
return all_codes
def add_codes(example):
# Always initialize codes_list to None
codes_list = None
try:
answer_audio = example.get("audio")
# If there's a valid audio array, tokenise it
if answer_audio and "array" in answer_audio:
audio_array = answer_audio["array"]
codes_list = tokenise_audio(audio_array)
except Exception as e:
print(f"Skipping row due to error: {e}")
# Keep codes_list as None if we fail
example["codes_list"] = codes_list
return example
ds = ds.map(add_codes, remove_columns=["audio"])
#@title Load Tokenizer
tokeniser_length = 128256
start_of_text = 128000
end_of_text = 128009
start_of_speech = tokeniser_length + 1
end_of_speech = tokeniser_length + 2
start_of_human = tokeniser_length + 3
end_of_human = tokeniser_length + 4
start_of_ai = tokeniser_length + 5
end_of_ai = tokeniser_length + 6
pad_token = tokeniser_length + 7
audio_tokens_start = tokeniser_length + 10
tokenizer_name = "canopylabs/orpheus-3b-0.1-pretrained"
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
num_proc = os.cpu_count() - 2
ds = ds.filter(lambda x: x["codes_list"] is not None)
ds = ds.filter(lambda x: len(x["codes_list"]) > 0)
#@title Create Input Ids
def remove_duplicate_frames(example):
vals = example["codes_list"]
if len(vals) % 7 != 0:
raise ValueError("Input list length must be divisible by 7")
result = vals[:7]
removed_frames = 0
for i in range(7, len(vals), 7):
current_first = vals[i]
previous_first = result[-7]
if current_first != previous_first:
result.extend(vals[i:i+7])
else:
removed_frames += 1
example["codes_list"] = result
return example
ds = ds.map(remove_duplicate_frames, num_proc=num_proc)
tok_info = '''*** HERE you can modify the text prompt
i.e. if you wanted a multispeaker model like canopylabs/orpheus-3b-0.1-ft, you can pass:
f"{example["source"]}: {example["text"]}", as is passed.
'''
print(tok_info)
def create_input_ids(example):
text_ids = tokenizer.encode(f"{example['speaker_id']}: {example['text']}", add_special_tokens=True)
text_ids.append(end_of_text)
example["text_tokens"] = text_ids
input_ids = (
[start_of_human]
+ example["text_tokens"]
+ [end_of_human]
+ [start_of_ai]
+ [start_of_speech]
+ example["codes_list"]
+ [end_of_speech]
+ [end_of_ai]
)
example["input_ids"] = input_ids
example["labels"] = input_ids
example["attention_mask"] = [1] * len(input_ids)
return example
ds = ds.map(create_input_ids, num_proc=num_proc, remove_columns=["text", "codes_list"])
#@title Remove unnecessary columns
columns_to_keep = ["input_ids", "labels", "attention_mask"]
columns_to_remove = [col for col in ds.column_names if col not in columns_to_keep]
ds = ds.remove_columns(columns_to_remove)
ds.push_to_hub(name_to_push_dataset_to)
```
## Training
After preprocessing is done, fill out the blanks in finetune.yml and simply run `axolotl train finetune.yml`
## Inference
For inference, please refer to the original [orpheus github](https://github.com/canopyai/Orpheus-TTS/tree/main).

View File

@@ -0,0 +1,52 @@
base_model: canopylabs/orpheus-3b-0.1-pretrained
hub_model_id: <your-hub-model-id>
plugins:
- axolotl.integrations.liger.LigerPlugin
liger_rope: true
liger_rms_norm: true
liger_glu_activation: true
liger_fused_linear_cross_entropy: true
datasets:
- path: <your-hf-dataset-id>
type: # leave empty to load pre-tokenized
dataset_prepared_path: last_run_prepared
val_set_size: 0.01
output_dir: ./outputs/out
sequence_len: 8192
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 8
micro_batch_size: 4
num_epochs: 3
optimizer: adamw_torch_fused
lr_scheduler: cosine
learning_rate: 2e-5
bf16: auto
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
resume_from_checkpoint:
logging_steps: 1
flash_attention: true
warmup_steps: 20
evals_per_epoch: 5
saves_per_epoch: 5
weight_decay: 0.05
special_tokens:
pad_token: <custom_token_7>

View File

@@ -15,7 +15,7 @@ peft==0.15.2
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
datasets==3.5.1
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.1.0

View File

@@ -168,6 +168,9 @@ class TrainerBuilderBase(abc.ABC):
)
)
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -249,9 +252,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.loss_watchdog_threshold is not None:
callbacks.append(LossWatchDogCallback(self.cfg))
if self.cfg.gc_steps:
callbacks.append(GCCallback(gc_steps=self.cfg.gc_steps))
return callbacks
def get_post_trainer_create_callbacks(self, trainer):

View File

@@ -114,8 +114,6 @@ class AxolotlTrainer(
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -247,7 +247,9 @@ class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
)
# Base evaluation
initial_output = super().evaluation_loop(
initial_output = super( # pylint: disable=bad-super-call
DPOTrainer, self
).evaluation_loop(
dataloader,
description,
prediction_loss_only,

View File

@@ -72,7 +72,7 @@ class CutCrossEntropyPlugin(BasePlugin):
if cfg.cut_cross_entropy:
self._check_requirements()
from .monkeypatch.patch import (
from axolotl.integrations.cut_cross_entropy.monkeypatch.patch import (
cce_patch,
)

View File

@@ -1,19 +0,0 @@
"""
attention module for attention monkeypatches
"""
from transformers.integrations.flash_attention import flash_attention_forward
def patch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .xformers import xformers_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
def unpatch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()

View File

@@ -1,160 +0,0 @@
"""
xformers attention implementation for packing
"""
from typing import Optional
import torch
import xformers
import xformers.ops.fmha
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
xformers_attention = xformers.ops.fmha.memory_efficient_attention
def xformers_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
dropout: float = 0.0, # pylint: disable=unused-argument
scaling: Optional[float] = None, # pylint: disable=unused-argument
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
softcap: Optional[float] = None, # pylint: disable=unused-argument
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Get dimensions
# query: [batch, heads, seq_len, hidden_dim]
batch_size = query.size(0)
query_length = query.shape[2]
key_length = key.shape[2]
# Default causal mask
attn_bias = xformers.ops.LowerTriangularMask()
# Check if we have sliding window attention
has_sliding_window = sliding_window is not None and sliding_window < query_length
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Get GQA parameters
num_attention_heads = module.config.num_attention_heads
num_key_value_heads = module.config.num_key_value_heads
head_dim = query.size(-1)
is_gqa = num_attention_heads != num_key_value_heads
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
if cu_seq_lens_q is None or cu_seq_lens_k is None:
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
cu_seq_lens_q = cu_seq_lens_q.squeeze()
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
attn_bias = (
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths.tolist(),
)
)
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
elif attention_mask is not None:
query, key, value, _, cu_seq_lens, _ = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
seq_lengths = []
for i in range(len(cu_seq_lens_q) - 1):
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths,
kv_seqlen=seq_lengths,
)
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
else:
# Handle Group Query Attention (GQA) using view/expand approach from reference
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
key = key.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
value = value.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
if module.training:
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
if has_sliding_window:
query = query.view(
1, batch_size * query_length, num_attention_heads, head_dim
)
key = key.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
value = value.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
else:
query = query.view(
batch_size, query_length, num_key_value_heads, n_groups, head_dim
)
# If we need a sliding window attention
if has_sliding_window:
query = query.view(
1,
batch_size * query_length,
num_key_value_heads,
n_groups,
head_dim,
)
key = key.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
value = value.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
# Run the xformers attention
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = attn_output.view(
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
)
return attn_output, None

View File

@@ -1,134 +0,0 @@
"""
chunked ce loss
"""
from typing import List, Optional
import torch
import torch.nn.functional as F
# copied and modified from torchtune.modules.loss.CEWithChunkedOutputLoss
class CEWithChunkedOutputLoss(torch.nn.Module):
"""
Cross-entropy with chunked outputs that saves memory by only upcasting one chunk at a time.
For more details, please refer to: https://github.com/pytorch/torchtune/pull/1390
"""
def __init__(self, num_output_chunks: int = 8, ignore_index: int = -100):
super().__init__()
self.num_output_chunks = num_output_chunks
self.ignore_index = ignore_index
def compute_cross_entropy(
self,
logits: torch.Tensor,
labels: torch.Tensor,
normalize: bool = True, # pylint: disable=unused-argument
) -> torch.Tensor:
"""
Upcast logits to fp32 and compute cross entropy loss.
"""
return F.cross_entropy(
logits.float(), labels, ignore_index=self.ignore_index, reduction="sum"
)
def forward(
self, logits: List[torch.Tensor], labels: torch.Tensor, reduction="sum"
) -> torch.Tensor:
"""
Args:
logits (List[torch.Tensor]): List of chunked logits of length
``self.num_output_chunks``, where each chunk has shape
``(batch_size, num_tokens / num_output_chunks, vocab_size)``.
labels (torch.Tensor): Ground truth labels of shape ``(batch_size, num_tokens)``.
reduction (str): The reduction to apply to the output.
Returns:
torch.Tensor: Cross entropy loss of shape (1,).
"""
total_elements = (labels != self.ignore_index).sum()
# chunk and reshape labels (bsz, num_tokens, vocab) -> [(bsz*num_tokens/num_chunks, vocab)]
labels = [
target_chunk.reshape(-1)
for target_chunk in labels.chunk(self.num_output_chunks, dim=1)
]
# reshape logits [(bsz, num_tokens/num_chunks, vocab)] -> [(bsz*num_tokens/num_chunks, vocab)]
logits = [
logit_chunk.reshape(-1, logit_chunk.size(-1)) for logit_chunk in logits
]
# compute one chunk at a time
total_loss = 0.0
for logits_chunk, labels_chunk in zip(logits, labels):
total_loss += self.compute_cross_entropy(logits_chunk, labels_chunk)
if reduction == "sum":
return total_loss
return total_loss / total_elements
def _build_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = CEWithChunkedOutputLoss(num_output_chunks, ignore_index)
loss_fn_ce.compute_cross_entropy = torch.compile(
loss_fn_ce.compute_cross_entropy, backend="inductor"
)
return loss_fn_ce
def get_causal_lm_loss(num_output_chunks: int = 8, ignore_index: int = -100):
loss_fn_ce = _build_chunked_ce_loss_fn(num_output_chunks, ignore_index)
def chunked_fix_cross_entropy(
source,
target,
num_items_in_batch: int = None,
ignore_index: int = -100,
**kwargs,
): # pylint: disable=unused-argument
reduction = "sum" if num_items_in_batch is not None else "mean"
logit_chunks = [ # pylint: disable=unnecessary-comprehension
chunk for chunk in source.chunk(loss_fn_ce.num_output_chunks, dim=1)
]
loss = loss_fn_ce(logit_chunks, target, reduction=reduction)
if reduction == "sum":
loss = loss / num_items_in_batch
return loss
def for_causal_lm_chunked_loss(
logits,
labels,
vocab_size: int = None, # pylint: disable=unused-argument
num_items_in_batch: Optional[int] = None,
ignore_index: int = -100,
shift_labels: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
# skip the upcast to float since we handle that in the chunking loss
if shift_labels is None:
# Shift so that tokens < n predict n
labels = F.pad(labels, (0, 1), value=ignore_index)
shift_labels = labels[..., 1:].contiguous()
# Skip Flattening the tokens
# Enable model parallelism
shift_labels = shift_labels.to(logits.device)
loss = chunked_fix_cross_entropy(
logits, shift_labels, num_items_in_batch, ignore_index, **kwargs
)
return loss
return for_causal_lm_chunked_loss
def patch_chunked_ce_loss_fn(num_output_chunks: int = 8, ignore_index: int = -100):
import transformers.loss.loss_utils
for_causal_lm_chunked_loss = get_causal_lm_loss(num_output_chunks, ignore_index)
transformers.loss.loss_utils.ForCausalLMLoss = for_causal_lm_chunked_loss
transformers.loss.loss_utils.LOSS_MAPPING["ForCausalLM"] = (
for_causal_lm_chunked_loss
)

View File

@@ -1,78 +0,0 @@
"""
Patch prepare_model_for_kbit_training to not upcast everything
"""
import inspect
import logging
import peft
import axolotl
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_PREPARE_CODE = """
for param in model.parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)
"""
PATCHED_PREPARE_CODE = """
for name, param in model.named_parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit" and "norm" in name:
param.data = param.data.to(torch.float32)
"""
def get_peft_prep_code() -> str:
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
return prepare
def check_peft_prep_code_is_patchable() -> bool:
prep_code = get_peft_prep_code()
prep_code, _ = detab_code(prep_code)
return ORIGINAL_PREPARE_CODE in prep_code
def patch_peft_prep_code():
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
"""
try:
prep_code = get_peft_prep_code()
except OSError:
return
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
prep_code
)
prep_code, _ = detab_code(prep_code)
if ORIGINAL_PREPARE_CODE not in prep_code:
return
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
prep_code = prep_code.replace(
"def prepare_model_for_kbit_training(",
"def fixed_prepare_model_for_kbit_training(",
1,
)
items_to_import = []
for item in dir(peft.utils.other):
if item in prep_code:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
globals(),
)
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

View File

@@ -70,9 +70,6 @@ def resolve_dtype(cfg):
if cfg.fp16 is None and not cfg.float16:
cfg.fp16 = True
if cfg.fp16 and cfg.bf16 == "auto":
cfg.bf16 = False
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False

View File

@@ -556,30 +556,11 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.xformers_attention and self.cfg.sample_packing:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.chunked_cross_entropy:
from axolotl.monkeypatch.loss.chunked import patch_chunked_ce_loss_fn
if self.cfg.chunked_cross_entropy_num_chunks:
patch_chunked_ce_loss_fn(self.cfg.chunked_cross_entropy_num_chunks)
else:
patch_chunked_ce_loss_fn()
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
if self.cfg.adapter:
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
patch_peft_prep_code()
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
@@ -912,7 +893,7 @@ class ModelLoader:
"bnb_4bit_compute_dtype": self.cfg.torch_dtype,
"bnb_4bit_use_double_quant": True,
"bnb_4bit_quant_type": "nf4",
"bnb_4bit_quant_storage": torch.uint8,
"bnb_4bit_quant_storage": torch.bfloat16,
}
if self.cfg.model_config_type in ["jamba", "qwen2_moe"] and not (
self.cfg.deepspeed or self.cfg.fsdp
@@ -1199,7 +1180,7 @@ class ModelLoader:
],
)
def prepare_model(self, qlora_fsdp: bool) -> None:
def prepare_model(self, qlora_fsdp) -> None:
skip_prepare_model_for_kbit_training = False
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -1328,7 +1309,7 @@ class ModelLoader:
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
if self.cfg.fsdp:
if not self.cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
self.convert_embedding_modules_dtype(
embedding_modules,

View File

@@ -1,13 +1,10 @@
# pylint: skip-file
"""
Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences
into fixed-capacity batches to optimize memory usage and training throughput.
Multipack Batch Sampler
"""
import logging
import math
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import Iterable, List, Union
from typing import Any, Iterable, List, Union
import numba
import numpy as np
@@ -16,39 +13,26 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
"""
First-fit-decreasing bin packing algorithm check
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
Checks if sequences with the given lengths could fit in the specified number of bins
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
num_bins: Number of bins available
Returns:
True if all sequences can be packed, False otherwise
"""
# Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
# Try to place each sequence in the first bin it fits
for size in sequence_lengths:
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(num_bins):
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
# If no bin could fit this sequence, packing failed
if not_found:
return False
@@ -56,380 +40,240 @@ def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
@numba.njit
def pack_group(
sequence_lengths: np.ndarray,
group_offset: int,
bin_capacity: int,
max_bins: int,
bin_size: int,
safe_mode: bool = True,
):
"""
Pack a group of sequences into bins using First-Fit Decreasing algorithm
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
Args:
sequence_lengths: Array of sequence lengths
group_offset: Offset to apply to indices when returning results
bin_capacity: Maximum capacity of each bin
max_bins: Maximum number of bins to use
bin_size: Maximum number of sequences per bin
safe_mode: If True, use a more conservative packing approach
indices = np.argsort(a)[::-1]
a = a[indices]
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True
for bin_idx, _ in enumerate(bins_remaining_space):
if (
bins_remaining_space[bin_idx] >= size
and len(bins_assigned_sequences[bin_idx]) < bin_size
):
bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(global_idx)
add_new_bin = False
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
# Create a new bin if needed and if we haven't reached the limit
if add_new_bin:
if len(bins_remaining_space) >= max_bins and safe_mode:
# In safe mode, skip items that would exceed max_bins
continue
bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([global_idx])
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
# Safety check to avoid infinite bins
if len(bins_remaining_space) > len(sequence_lengths):
break
return bins_assigned_sequences
# Define a standalone function for multiprocessing
def _process_group(args):
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
return pack_group(
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
)
def pack_parallel(
sequence_lengths: np.ndarray,
bin_capacity: int,
group_size: int,
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
):
"""
Pack sequences into bins using parallel processing
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin as total number of tokens
group_size: Number of sequences to process in each group
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count()))
# Create tasks for parallel processing
tasks = []
for i in range(0, num_items, group_size):
group_lengths = sequence_lengths[i : i + group_size]
max_bins = len(group_lengths) # Allow as many bins as items in the group
tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode))
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
return all_bins
return bins_result
@numba.njit
def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
"""
Sequential allocator that preserves example order
Parameters:
sequence_lengths: The lengths of all examples
rank: The current rank (for distributed training)
bin_capacity: The capacity of each bin (maximum sequence length)
num_ranks: Number of ranks (processes/GPUs)
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Returns:
rank_batches: List of batches for the current rank
total_tokens_used: Number of actual example tokens
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
rank_batches = []
total_tokens_used = 0
result = []
total_used = 0
# First, do sequential packing into bins
all_bins = []
current_bin = []
remaining_capacity = bin_capacity
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
# Process each sequence in order
for idx, size in enumerate(sequence_lengths):
for idx, size in enumerate(lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
remaining_capacity -= size
total_tokens_used += size
total_used += size
else:
# Example doesn't fit, start a new bin
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = bin_capacity - size
total_tokens_used += size
remaining_capacity = c - size
total_used += size
# Add the last bin if not empty
if current_bin:
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every num_ranks-th bin
for bin_idx in range(rank, len(all_bins), num_ranks):
rank_batches.append(all_bins[bin_idx])
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
result.append(all_bins[bin_idx])
return rank_batches, total_tokens_used, len(all_bins) * bin_capacity
return result, total_used, len(all_bins) * c
class MultipackBatchSampler(BatchSampler):
"""
Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both parallel packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
"""Batch sampler class for multipack"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop incomplete batches
num_count_samples: int = 16, # Number of samples to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
**kwargs, # pylint: disable=unused-argument
batch_size: int,
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_max_len = batch_max_len
self.lengths = np.array(lengths, dtype=np.int32)
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
self.group_size = group_size
self.bin_size = bin_size
self.num_processes = num_processes
self.safe_mode = safe_mode
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# Efficiency statistics tracking
self.total_tokens_used = 0
self.total_token_slots = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
# The number of times to calculate batches to determine minimum packed dataset length
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
self.num_count_samples = num_count_samples
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
# the minimum packed dataset length across all ranks determined by a gather/broadcast
self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch
self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False):
"""
Generate packed batches for training
indices = [idx for idx in self.sampler]
Args:
set_stats: Whether to update efficiency statistics
Returns:
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
if self._batches is not None:
return self._batches
# Get indices from the sampler
indices = [ # pylint: disable=unnecessary-comprehension
idx for idx in self.sampler
]
# Get lengths of the selected sequences
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
# Pack sequences into bins using either sequential or parallel packing
if self.sequential:
bins, total_used, total_slots = allocate_sequentially(
lengths,
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
bin_capacity=self.batch_max_len,
num_ranks=1,
c=self.batch_max_len,
n=1,
)
else:
# Use parallel packing
all_bins = pack_parallel(
lengths,
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
num_processes=self.num_processes,
safe_mode=self.safe_mode,
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
# Map bin indices back to original indices
bins = [
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
]
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
# Group bins into batches (each batch contains batch_size bins)
batches = [
bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)
[
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
]
# Drop last batch if requested and it's incomplete
if self.drop_last and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
# Adjust total_slots if we dropped a batch
if not self.sequential:
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
# Update statistics if requested
# statistics
if set_stats:
self.total_tokens_used += total_used
self.total_token_slots += total_slots
self.eff_total_used += total_used
self.eff_total_slots += total_slots
self._batches = batches
return batches
def __iter__(self):
"""
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks
to ensure distributed training balance
"""
batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# Truncate batches to ensure all ranks have the same number of batches
# make sure the batches we iterate over is truncated to the same min length across all ranks
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
"""
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
if self.total_token_slots == 0:
self.generate_batches(set_stats=True)
if self.total_token_slots == 0:
return 0.0
# Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots)
return self.eff_total_used / self.eff_total_slots
def gather_efficiency(self):
"""
Gather and synchronize packing efficiency estimates across all distributed ranks
Returns a conservative efficiency estimate based on the measurements
"""
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate
max_eff = max(float(eff) for eff in estimates)
return math.floor(0.997 * max_eff)
return math.floor(0.997 * max(estimates))
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
# Quantize to 0.5% intervals for stability
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
"""
Gather and synchronize batch counts across all distributed ranks
Returns the minimum number of batches available on any rank
"""
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates))
# Find minimum batch count across ranks to ensure balance
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
"""
Return the total number of batches that will be yielded by this sampler
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if self._batches is None:
self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
if not self.len_across_ranks:
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks

View File

@@ -242,9 +242,6 @@ class AxolotlInputConfig(
unsloth_rms_norm: bool | None = None
unsloth_rope: bool | None = None
chunked_cross_entropy: bool | None = None
chunked_cross_entropy_num_chunks: int | None = None
lora_mlp_kernel: bool | None = None
lora_qkv_kernel: bool | None = None
lora_o_kernel: bool | None = None
@@ -438,6 +435,16 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before")
@classmethod
# pylint: disable=duplicate-code

View File

@@ -4,6 +4,7 @@ shared pytest fixtures
import functools
import importlib
import os
import shutil
import sys
import tempfile
@@ -529,31 +530,32 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
# download_llama_68m_random_model,
# download_qwen_2_5_half_billion_model,
# download_tatsu_lab_alpaca_dataset,
# download_mhenrichsen_alpaca_2k_dataset,
# download_mhenrichsen_alpaca_2k_w_revision_dataset,
# download_mlabonne_finetome_100k_dataset,
# download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
# download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset,
# download_fozzie_alpaca_dpo_dataset,
# download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
# download_argilla_dpo_pairs_dataset,
# download_tiny_shakespeare_dataset,
# download_deepseek_model_fixture,
# download_huggyllama_model_fixture,
# download_llama_1b_model_fixture,
# download_llama3_8b_model_fixture,
# download_llama3_8b_instruct_model_fixture,
# download_phi_35_mini_model_fixture,
# download_phi_3_medium_model_fixture,
# download_mistral_7b_model_fixture,
# download_gemma_2b_model_fixture,
# download_gemma2_9b_model_fixture,
# download_mlx_mistral_7b_model_fixture,
# download_llama2_model_fixture,
# ):
# pass
@pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
reason="Not running in CI cache preload",
)
def test_load_fixtures(
download_smollm2_135m_model,
download_qwen_2_5_half_billion_model,
download_tatsu_lab_alpaca_dataset,
download_mhenrichsen_alpaca_2k_dataset,
download_mhenrichsen_alpaca_2k_w_revision_dataset,
download_mlabonne_finetome_100k_dataset,
download_argilla_distilabel_capybara_dpo_7k_binarized_dataset,
download_arcee_ai_distilabel_intel_orca_dpo_pairs_dataset,
download_argilla_dpo_pairs_dataset,
download_tiny_shakespeare_dataset,
download_deepseek_model_fixture,
download_huggyllama_model_fixture,
download_llama_1b_model_fixture,
download_llama3_8b_model_fixture,
download_llama3_8b_instruct_model_fixture,
download_phi_35_mini_model_fixture,
download_phi_3_medium_model_fixture,
download_mistral_7b_model_fixture,
download_gemma_2b_model_fixture,
download_gemma2_9b_model_fixture,
download_mlx_mistral_7b_model_fixture,
download_llama2_model_fixture,
):
pass

View File

@@ -1,40 +0,0 @@
"""
test suite for chunked cross entropy
"""
import pytest
import torch
from torch import nn
from axolotl.monkeypatch.loss.chunked import get_causal_lm_loss
@pytest.fixture
def chunked_fixtures():
model_dim = 512
vocab_size = 1024 * 256
seq_len = 2048
batch_size = 1
lm_head = nn.Linear(model_dim, vocab_size)
hidden_state = torch.randn(batch_size, seq_len, model_dim)
labels = torch.randint(low=0, high=vocab_size, size=(batch_size, seq_len))
return lm_head, hidden_state, labels, vocab_size
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
lm_loss = get_causal_lm_loss()
logits = lm_head(hidden_state)
chunked_lm_loss = lm_loss(logits, labels)
logits_flattened = logits.view(-1, vocab_size)
labels_flattened = labels.view(-1)
loss = nn.functional.cross_entropy(
logits_flattened.float(), labels_flattened, reduction="mean"
)
assert torch.allclose(chunked_lm_loss, loss, atol=1e-2, rtol=1e-2)