Compare commits

...

13 Commits

Author SHA1 Message Date
Wing Lian
5e50d1e8f0 batch flattening with xformers too 2025-05-08 18:23:25 -04:00
Wing Lian
7fb01f0461 also support xformers w/o packing 2025-05-08 15:22:48 -04:00
Wing Lian
0f3587174d swap tinymodels that have safetensors for some ci tests (#2641) 2025-05-07 15:06:07 -04:00
xzuyn
25e6c5f9bd Add CAME Optimizer (#2385) 2025-05-07 10:31:46 -04:00
NanoCode012
32f51bca35 fix(doc): clarify instruction to delinearize llama4 similar to cli doc (#2644) [skip ci] 2025-05-07 10:29:47 -04:00
NanoCode012
9daa04da90 Fix: improve error message on failed dataset load (#2637) [skip ci]
* fix(log): clarify error on dataset loading failed

* fix: add path for easy tracking of broken config

* fix: improve error message based on pr feedback
2025-05-07 10:29:05 -04:00
Wing Lian
0d71b0aa5f Configurable embeddings upcast (#2621)
* fsdp embeddings should be float32 per comment

* patch peft to not upcast everything

* add tabs back to code check

* fix import

* add configurable option and fix check

* add check for dtypes

* move embeddings test to patch dir

* fix test

* fix comment and logic
2025-05-06 23:40:44 -04:00
Eric Meier
63aaccf85b Fix cut_cross_entropy plugin install (#2642) [skip ci] 2025-05-06 22:56:00 -04:00
Wing Lian
ff0fe767c8 xformers attention with packing (#2619)
* xformers attention with packing

* wire up the patch

* fix xformers + packing validation

* fix warning

* reorder the packing check

* fix fp16 / bf16 reset when using fp16 with bf16 auto

* fix seq lens calc to drop hanging sequences

* handle xformers patch for inference too

* fix batch size setter

* fix xformers inference

* add colab callback to fix inference post train

* PR feedback
2025-05-06 22:49:22 -04:00
Wing Lian
8e4158cc0b Multipack parallel bin packing (#2631)
* improve readability of multipack sampler

* parallel bin packing
fix error with lambda and pickling

make sure things are in float instead of np.float

* annotations and comments update

* support for configurable group and bin size for sample packing

* fix missing map back to original indices
2025-05-06 20:08:08 -04:00
Wing Lian
cd84325253 allow plugins to return their own dataset (#2617) [skip ci]
* allow plugins to return their own dataset

* add post_trainer_create and wire up

* add hook check

* address PR feedback:

* remove annotation causing circular import
2025-05-06 20:05:51 -04:00
NanoCode012
0b140fef83 feat(doc): add split_thinking docs (#2613) [skip ci]
* feat(doc): add split_thinking docs

* fix: link config.qmd to conversation.qmd for split_thinking example

* update thinking => reasoning_content in messages format

---------

Co-authored-by: Wing Lian <wing@axolotl.ai>
2025-05-06 20:05:32 -04:00
Wing Lian
e4cfebe995 bump liger dep to 0.5.9 (#2640) [skip ci]
* bump liger dep to 0.5.9

* also upgrade vllm to post1, and datasets to 3.5.1
2025-05-06 20:05:19 -04:00
40 changed files with 1049 additions and 232 deletions

View File

@@ -18,9 +18,96 @@ jobs:
env:
SKIP: no-commit-to-branch
preload-cache:
name: Preload HF cache
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
python_version: ["3.11"]
pytorch_version: ["2.6.0"]
timeout-minutes: 20
env:
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
steps:
- name: Check out repository code
uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python_version }}
cache: 'pip' # caching pip dependencies
- name: upgrade pip
run: |
pip3 install --upgrade pip
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
- name: Install PyTorch
run: |
pip3 install torch==${{ matrix.pytorch_version }}
- name: Install dependencies
run: |
pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
pip3 install -r requirements-dev.txt -r requirements-tests.txt
- name: Make sure PyTorch version wasn't clobbered
run: |
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
- name: Ensure axolotl CLI was installed
run: |
axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests
run: |
pytest -v tests/conftest.py
- name: Upload coverage to Codecov
uses: codecov/codecov-action@v5
with:
token: ${{ secrets.CODECOV_TOKEN }}
files: ./coverage.xml
flags: unittests,pytorch-${{ matrix.pytorch_version }}
fail_ci_if_error: false
- name: cleanup pip cache
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
- name: Save HF cache
id: hf-cache
uses: actions/cache/save@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
pytest:
name: PyTest
runs-on: ubuntu-latest
needs: [preload-cache]
strategy:
fail-fast: false
max-parallel: 2

View File

@@ -32,6 +32,8 @@ tokenizer_legacy:
resize_token_embeddings_to_32x:
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
shrink_embeddings:
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
embeddings_skip_upcast:
# Whether to load the model with randomly initialized weights. Useful for
# pre-training a model from scratch or debugging purposes.
random_init_weights:
@@ -73,11 +75,12 @@ load_in_8bit: true
load_in_4bit:
# Use CUDA bf16
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
# Use CUDA fp16
fp16: true
# Use CUDA tf32
tf32: true # require >=ampere
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
# No AMP (automatic mixed precision)
bfloat16: true # require >=ampere
@@ -184,8 +187,8 @@ datasets:
# adding a system turn with empty content.
drop_system_message:
# Optional[bool]. Whether to split the assistant turn based on a reasoning trace inside delimited tags
# defaults to False
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
# See example at `docs/dataset-formats/conversation.qmd`
split_thinking:
# IMPORTANT: The following fields determine which parts of the conversation to train on.
@@ -609,6 +612,7 @@ lr_div_factor: # Learning rate div factor
# - optimi_adamw
# - ao_adamw_8bit
# - ao_adamw_fp8
# - came_pytorch
optimizer:
# Dictionary of arguments to pass to the optimizer
optim_args:

View File

@@ -196,6 +196,34 @@ datasets:
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
:::
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
```yaml
datasets:
- path: ...
type: chat_template
chat_template: qwen3
split_thinking: true
```
For example, a content can look like:
```json
{
"content": "<think>Some thinking outputs</think>Output after thinking."
}
```
After split, it will look like:
```json
{
"reasoning_content": "Some thinking outputs",
"content": "Output after thinking..."
}
```
## sharegpt
::: {.callout-important}

View File

@@ -34,3 +34,5 @@ We provide a script to delinearize Llama 4 linearized models into regular Huggin
```bash
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
```
Note: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.

View File

@@ -6,16 +6,17 @@ triton>=3.0.0
mamba-ssm==1.2.0.post1
xformers>=0.0.23.post1
autoawq==0.2.7.post3
liger-kernel==0.5.8
liger-kernel==0.5.9
# END section
packaging==23.2
huggingface_hub==0.31.0
peft==0.15.2
transformers==4.51.3
tokenizers>=0.21.1
accelerate==1.6.0
datasets==3.5.0
datasets==3.5.1
deepspeed>=0.15.4
trl==0.17.0
hf_xet==1.1.0

View File

@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
if (major, minor) >= (2, 7):
_install_requires.pop(_install_requires.index(xformers_version))
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
extras_require_map["vllm"] = ["vllm==0.8.5"]
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 6):
_install_requires.pop(_install_requires.index(xformers_version))
_install_requires.append(
"xformers==0.0.29.post2"
) # vllm needs post2 w torch 2.6
extras_require_map["vllm"] = ["vllm==0.8.5"]
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
elif (major, minor) >= (2, 5):
_install_requires.pop(_install_requires.index(xformers_version))
if patch == 0:
@@ -142,6 +142,7 @@ extras_require = {
"apollo-torch",
"lomo-optim==0.1.1",
"torch-optimi==0.2.1",
"came_pytorch==0.1.3",
],
"ray": [
"ray[train]",

View File

@@ -18,6 +18,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
from axolotl.cli.config import load_cfg
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
from axolotl.common.datasets import load_datasets, load_preference_datasets
from axolotl.integrations.base import PluginManager
from axolotl.utils.dict import DictDefault
from axolotl.utils.trainer import disable_datasets_caching
@@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
with disable_datasets_caching():
if cfg.rl:
plugin_manager = PluginManager.get_instance()
if plugin_manager.load_datasets(cfg, preprocess=True):
pass
elif cfg.rl:
load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -43,10 +43,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
if int(os.getenv("LOCAL_RANK", "0")) == 0:
check_user_token()
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
plugin_manager = PluginManager.get_instance()
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
if not dataset_meta:
if cfg.rl:
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)

View File

@@ -21,6 +21,7 @@ import importlib.util
import inspect
import logging
import math
import os
import sys
from abc import abstractmethod
from pathlib import Path
@@ -72,6 +73,7 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
bench_eval_callback_factory,
causal_lm_bench_eval_callback_factory,
colab_inference_post_train_callback,
log_prediction_callback_factory,
)
from axolotl.utils.callbacks.lisa import lisa_callback_factory
@@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
callbacks.append(lisa_callback_factory(trainer))
if any("COLAB_" in key for key in os.environ):
ColabCallback = colab_inference_post_train_callback(trainer)
callbacks.append(ColabCallback(self.cfg))
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
return callbacks
@@ -702,6 +708,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
optimizer_cls = ADOPT
adam_kwargs["decouple"] = True
optimizer_kwargs.update(adam_kwargs)
elif self.cfg.optimizer == "came_pytorch":
from came_pytorch import CAME
optimizer_cls = CAME
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
adam_kwargs["betas"] = (beta1, beta2, beta3)
adam_kwargs["eps"] = (eps1, eps2)
optimizer_kwargs.update(adam_kwargs)
# Parse any additional optimizer args from config
if self.cfg.optim_args:

View File

@@ -114,6 +114,8 @@ class AxolotlTrainer(
packing_efficiency_estimate=self.args.sample_packing_efficiency,
batch_max_len=batch_max_len,
batch_size=batch_size,
group_size=self.args.sample_packing_group_size,
bin_size=self.args.sample_packing_bin_size,
sequential=self.args.sample_packing_sequentially,
drop_last=True,
)

View File

@@ -26,6 +26,8 @@ from typing import OrderedDict
import torch
from torch.optim.lr_scheduler import LRScheduler
from axolotl.utils.dict import DictDefault
class BasePlugin:
"""
@@ -36,11 +38,13 @@ class BasePlugin:
Methods:
register(cfg): Registers the plugin with the given configuration.
load_datasets(cfg): Loads and preprocesses the dataset for training.
pre_model_load(cfg): Performs actions before the model is loaded.
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
@@ -63,20 +67,32 @@ class BasePlugin:
None
"""
def get_input_args(self):
def get_input_args(self) -> str | None:
"""
Returns a pydantic model for the plugin's input arguments.
"""
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
"""
Loads and preprocesses the dataset for training.
Args:
cfg: The configuration for the plugin.
preprocess: Whether this is the preprocess step of the datasets.
Returns:
dataset_meta: The metadata for the training dataset.
"""
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
None
"""
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
@@ -91,59 +107,71 @@ class BasePlugin:
"""
Performs actions after the model is loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
None
None
"""
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
"""
Returns a custom class for the trainer.
Parameters:
cfg (dict): The global axolotl configuration.
Args:
cfg (dict): The global axolotl configuration.
Returns:
class: The class for the trainer.
class: The class for the trainer.
"""
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
"""
Performs actions after the trainer is created.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
None
"""
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
object: The created optimizer.
object: The created optimizer.
"""
def create_lr_scheduler(
@@ -152,26 +180,26 @@ class BasePlugin:
"""
Creates and returns a learning rate scheduler.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
optimizer (object): The optimizer for training.
num_training_steps (int): Total number of training steps
Returns:
object (LRScheduler): The created learning rate scheduler.
object (LRScheduler): The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
setup callbacks before creating the trainer.
Parameters:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Args:
cfg (dict): The configuration for the plugin.
model (object): The loaded model.
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
@@ -182,12 +210,12 @@ class BasePlugin:
Adds callbacks to the trainer after creating the trainer.
This is useful for callbacks that require access to the model or trainer.
Parameters:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Args:
cfg (dict): The configuration for the plugin.
trainer (object): The trainer object for training.
Returns:
List[callable]: A list of callback functions to be added
List[callable]: A list of callback functions to be added
"""
return []
@@ -195,23 +223,23 @@ class BasePlugin:
"""
Performs actions after training is complete.
Parameters:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Args:
cfg (dict): The axolotl configuration
model (object): The loaded model.
Returns:
None
None
"""
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
Parameters:
cfg (dict): The configuration for the plugin.
Args:
cfg (dict): The configuration for the plugin.
Returns:
None
None
"""
@@ -338,6 +366,27 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
def load_datasets(self, cfg, preprocess: bool = False):
"""
Calls the load_datasets method of each registered plugin.
Args:
cfg: The configuration for the plugins.
preprocess : Whether this is preprocess step of the datasets.
Returns:
dataset_meta: The dataset metadata loaded from all registered plugins.
"""
return_ds_meta = None
for plugin in self.plugins.values():
dataset_meta = plugin.load_datasets(cfg, preprocess)
if dataset_meta is not None:
if return_ds_meta is None:
return_ds_meta = dataset_meta
else:
raise RuntimeError("Multiple plugins loaded datasets")
return return_ds_meta
def pre_model_load(self, cfg):
"""
Calls the pre_model_load method of all registered plugins.
@@ -422,6 +471,20 @@ class PluginManager:
return trainer_cls
return None
def post_trainer_create(self, cfg, trainer):
"""
Calls the post_trainer_create method of all registered plugins.
Parameters:
cfg (dict): The configuration for the plugins.
trainer (object): The trainer object for training.
Returns:
None
"""
for plugin in self.plugins.values():
plugin.post_trainer_create(cfg, trainer)
def create_optimizer(self, trainer):
"""
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.

View File

@@ -0,0 +1,19 @@
"""
attention module for attention monkeypatches
"""
from transformers.integrations.flash_attention import flash_attention_forward
def patch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
from .xformers import xformers_attention_forward
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
def unpatch_xformers_attn_over_fa2():
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()

View File

@@ -0,0 +1,160 @@
"""
xformers attention implementation for packing
"""
from typing import Optional
import torch
import xformers
import xformers.ops.fmha
from transformers.modeling_flash_attention_utils import (
_upad_input,
)
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
xformers_attention = xformers.ops.fmha.memory_efficient_attention
def xformers_attention_forward(
module: torch.nn.Module,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
dropout: float = 0.0, # pylint: disable=unused-argument
scaling: Optional[float] = None, # pylint: disable=unused-argument
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
softcap: Optional[float] = None, # pylint: disable=unused-argument
cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
):
# Get dimensions
# query: [batch, heads, seq_len, hidden_dim]
batch_size = query.size(0)
query_length = query.shape[2]
key_length = key.shape[2]
# Default causal mask
attn_bias = xformers.ops.LowerTriangularMask()
# Check if we have sliding window attention
has_sliding_window = sliding_window is not None and sliding_window < query_length
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
query = query.transpose(1, 2)
key = key.transpose(1, 2)
value = value.transpose(1, 2)
# Get GQA parameters
num_attention_heads = module.config.num_attention_heads
num_key_value_heads = module.config.num_key_value_heads
head_dim = query.size(-1)
is_gqa = num_attention_heads != num_key_value_heads
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
if position_ids is not None and (
max_length_q is not None
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
):
if cu_seq_lens_q is None or cu_seq_lens_k is None:
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
cu_seq_lens_q = cu_seq_lens_q.squeeze()
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
attn_bias = (
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths.tolist(),
)
)
else:
query = query.reshape(-1, query.size(-2), query.size(-1))
key = key.reshape(-1, key.size(-2), key.size(-1))
value = value.reshape(-1, value.size(-2), value.size(-1))
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
elif attention_mask is not None:
query, key, value, _, cu_seq_lens, _ = _upad_input(
query, key, value, attention_mask, query_length
)
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
seq_lengths = []
for i in range(len(cu_seq_lens_q) - 1):
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
q_seqlen=seq_lengths,
kv_seqlen=seq_lengths,
)
# Handle GQA
if is_gqa:
key = key.repeat_interleave(n_groups, dim=2)
value = value.repeat_interleave(n_groups, dim=2)
else:
# Handle Group Query Attention (GQA) using view/expand approach from reference
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
key = key.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
value = value.expand(
batch_size, key_length, num_key_value_heads, n_groups, head_dim
)
if module.training:
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
if has_sliding_window:
query = query.view(
1, batch_size * query_length, num_attention_heads, head_dim
)
key = key.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
value = value.view(
1, batch_size * key_length, num_attention_heads, head_dim
)
else:
query = query.view(
batch_size, query_length, num_key_value_heads, n_groups, head_dim
)
# If we need a sliding window attention
if has_sliding_window:
query = query.view(
1,
batch_size * query_length,
num_key_value_heads,
n_groups,
head_dim,
)
key = key.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
value = value.view(
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
)
# Run the xformers attention
attn_output = xformers_attention(
query,
key,
value,
attn_bias=attn_bias,
)
attn_output = attn_output.view(
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
)
return attn_output, None

View File

View File

@@ -0,0 +1,78 @@
"""
Patch prepare_model_for_kbit_training to not upcast everything
"""
import inspect
import logging
import peft
import axolotl
from axolotl.monkeypatch.utils import detab_code
LOG = logging.getLogger(__name__)
ORIGINAL_PREPARE_CODE = """
for param in model.parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)
"""
PATCHED_PREPARE_CODE = """
for name, param in model.named_parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit" and all(embed_name not in name for embed_name in ["embed_tokens", "lm_head"]):
param.data = param.data.to(torch.float32)
"""
def get_peft_prep_code() -> str:
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
return prepare
def check_peft_prep_code_is_patchable() -> bool:
prep_code = get_peft_prep_code()
prep_code, _ = detab_code(prep_code)
return ORIGINAL_PREPARE_CODE in prep_code
def patch_peft_prep_code():
"""
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
"""
try:
prep_code = get_peft_prep_code()
except OSError:
return
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
prep_code
)
prep_code, _ = detab_code(prep_code)
if ORIGINAL_PREPARE_CODE not in prep_code:
return
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
prep_code = prep_code.replace(
"def prepare_model_for_kbit_training(",
"def fixed_prepare_model_for_kbit_training(",
1,
)
items_to_import = []
for item in dir(peft.utils.other):
if item in prep_code:
items_to_import.append(item)
exec( # pylint: disable=exec-used # nosec B102
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
globals(),
)
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821

View File

@@ -2,6 +2,7 @@
import importlib
import inspect
import logging
import os
import signal
import sys
@@ -12,7 +13,6 @@ from typing import Any, Dict
import torch
import transformers.modelcard
from accelerate.logging import get_logger
from accelerate.utils import save_fsdp_model
from datasets import Dataset
from huggingface_hub.errors import OfflineModeIsEnabled
@@ -42,7 +42,7 @@ try:
except ImportError:
BetterTransformer = None
LOG = get_logger(__name__)
LOG = logging.getLogger(__name__)
def setup_model_and_tokenizer(
@@ -63,7 +63,6 @@ def setup_model_and_tokenizer(
# Load tokenizer
LOG.debug(
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
main_process_only=True,
)
tokenizer = load_tokenizer(cfg)
@@ -528,6 +527,9 @@ def train(
processor,
) = setup_model_and_trainer(cfg, dataset_meta)
plugin_manager = PluginManager.get_instance()
plugin_manager.post_trainer_create(cfg, trainer)
# Handle untrained tokens if configured
safe_serialization = cfg.save_safetensors is True
train_dataset = dataset_meta.train_dataset
@@ -550,7 +552,6 @@ def train(
if not cfg.use_ray:
cleanup_distributed()
plugin_manager = PluginManager.get_instance()
plugin_manager.post_train(cfg, model)
return model, tokenizer, trainer

View File

@@ -868,3 +868,28 @@ class GCCallback(TrainerCallback):
):
torch.cuda.empty_cache()
gc.collect()
def colab_inference_post_train_callback(trainer: Trainer):
class ColabCallback(TrainerCallback):
"""Callback to prep model for inference on Google Colab"""
def __init__(self, cfg):
self.gpu_name = torch.cuda.get_device_name(0)
self.cfg = cfg
def on_train_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
"""
handle T4 gpu, we need to convert attention to eager for inference
"""
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
trainer.model.gradient_checkpointing_disable()
trainer.model.config.use_cache = True
trainer.model.eval()
return ColabCallback

View File

@@ -70,6 +70,9 @@ def resolve_dtype(cfg):
if cfg.fp16 is None and not cfg.float16:
cfg.fp16 = True
if cfg.fp16 and cfg.bf16 == "auto":
cfg.bf16 = False
if cfg.device == "mps":
cfg.load_in_8bit = False
cfg.tf32 = False

View File

@@ -281,6 +281,10 @@ def load_dataset_w_config(
**load_ds_kwargs,
)
if not ds:
raise ValueError("unhandled dataset load")
raise ValueError(
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
f"({config_dataset.path}). Try double-check your path / name / data_files. "
"This is not caused by the dataset type."
)
return ds

View File

@@ -1,15 +1,36 @@
"""custom checkpointing utils"""
import importlib
from functools import partial
from packaging import version
from axolotl.utils.gradient_checkpointing.unsloth import (
Unsloth_Offloaded_Gradient_Checkpointer,
)
transformers_version = version.parse(importlib.metadata.version("transformers"))
if transformers_version > version.parse("4.51.3"):
from transformers.modeling_layers import GradientCheckpointingLayer
def uses_gc_layers(decoder_layer):
return isinstance(decoder_layer.func.__self__, GradientCheckpointingLayer)
else:
def uses_gc_layers(_):
return False
def hf_grad_checkpoint_offload_wrapper(
decoder_layer, *args, use_reentrant=None
): # pylint: disable=unused-argument
if uses_gc_layers(decoder_layer):
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
decoder_layer,
*args,
)
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
(
decoder_layer.func.__self__

View File

@@ -556,11 +556,21 @@ class ModelLoader:
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
def apply_patches(self) -> None:
if self.cfg.xformers_attention:
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
patch_xformers_attn_over_fa2()
self.cfg.flash_attention = True
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
patch_accelerate_fsdp_utils()
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
patch_peft_prep_code()
if self.cfg.flex_attention:
from axolotl.monkeypatch.attention.flex_attn import (
patch_flex_make_mask,
@@ -761,13 +771,6 @@ class ModelLoader:
cross_entropy=self.cfg.flash_attn_cross_entropy,
rms_norm=self.cfg.flash_attn_rms_norm,
)
elif self.cfg.xformers_attention:
from axolotl.monkeypatch.llama_attn_hijack_xformers import (
hijack_llama_attention,
)
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif self.cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
@@ -1180,7 +1183,7 @@ class ModelLoader:
],
)
def prepare_model(self, qlora_fsdp) -> None:
def prepare_model(self, qlora_fsdp: bool) -> None:
skip_prepare_model_for_kbit_training = False
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
# Qwen doesn't play nicely with LoRA if this is enabled
@@ -1310,7 +1313,10 @@ class ModelLoader:
# make sure these are fp32 per Ramesh et al. (2021)
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
if not self.cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
# we don't run this during FSDP because this will leave mixed
# float and bfloat16 dtypes in the model which FSDP doesn't like
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
embedding_modules = []
self.convert_embedding_modules_dtype(
embedding_modules,
dist_dtype=torch.float32,

View File

@@ -1,10 +1,13 @@
# pylint: skip-file
"""
Multipack Batch Sampler
Multipack Batch Sampler - An efficient batch sampler for packing variable-length sequences
into fixed-capacity batches to optimize memory usage and training throughput.
"""
import logging
import math
from typing import Any, Iterable, List, Union
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import cpu_count
from typing import Iterable, Union
import numba
import numpy as np
@@ -13,26 +16,39 @@ from torch.utils.data import BatchSampler, Sampler, SequentialSampler
from axolotl.utils.distributed import reduce_and_broadcast
LOG = logging.getLogger(__name__)
LOG.setLevel(logging.INFO)
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
def ffd_check(sequence_lengths: np.ndarray, bin_capacity: int, num_bins: int):
"""
First-fit-decreasing bin packing algorithm check
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
Checks if sequences with the given lengths could fit in the specified number of bins
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin
num_bins: Number of bins available
Returns:
True if all sequences can be packed, False otherwise
"""
# Sort sequence lengths in descending order for optimal packing
sequence_lengths = np.sort(sequence_lengths)[::-1]
# Initialize all bins with full capacity
bins = np.full((num_bins,), bin_capacity, dtype=sequence_lengths.dtype)
# Try to place each sequence in the first bin it fits
for size in sequence_lengths:
not_found = True
for idx in range(n):
for idx in range(num_bins):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
# If no bin could fit this sequence, packing failed
if not_found:
return False
@@ -40,86 +56,132 @@ def ffd_check(a: np.ndarray, c: int, n: int):
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
def pack_group(
sequence_lengths: np.ndarray,
group_offset: int,
bin_capacity: int,
max_bins: int,
bin_size: int,
safe_mode: bool = True,
):
"""
Pack a group of sequences into bins using First-Fit Decreasing algorithm
indices = np.argsort(a)[::-1]
a = a[indices]
Args:
sequence_lengths: Array of sequence lengths
group_offset: Offset to apply to indices when returning results
bin_capacity: Maximum capacity of each bin
max_bins: Maximum number of bins to use
bin_size: Maximum number of sequences per bin
safe_mode: If True, use a more conservative packing approach
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
# Get sorting indices and sort lengths in descending order
indices = np.argsort(sequence_lengths)[::-1]
sorted_lengths = sequence_lengths[indices]
bins_remaining_space: list = [] # Tracks remaining capacity in each bin
bins_assigned_sequences: list = [] # Tracks sequence indices assigned to each bin
for seq_id, size in enumerate(sorted_lengths):
global_idx = indices[seq_id] + group_offset
# Try to place sequence in existing bins
add_new_bin = True
for bin_idx, _ in enumerate(bins_remaining_space):
if (
bins_remaining_space[bin_idx] >= size
and len(bins_assigned_sequences[bin_idx]) < bin_size
):
bins_remaining_space[bin_idx] -= size
bins_assigned_sequences[bin_idx].append(global_idx)
add_new_bin = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
# Create a new bin if needed and if we haven't reached the limit
if add_new_bin:
if len(bins_remaining_space) >= max_bins and safe_mode:
# In safe mode, skip items that would exceed max_bins
continue
bins_remaining_space.append(bin_capacity - size)
bins_assigned_sequences.append([global_idx])
return bins_result
# Safety check to avoid infinite bins
if len(bins_remaining_space) > len(sequence_lengths):
break
return bins_assigned_sequences
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
# Define a standalone function for multiprocessing
def _process_group(args):
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode = args
return pack_group(
group_lengths, start_idx, bin_capacity, max_bins, bin_size, safe_mode
)
def pack_parallel(
sequence_lengths: np.ndarray,
bin_capacity: int,
group_size: int,
bin_size: int,
num_processes: int | None = None,
safe_mode: bool = True,
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
"""
Pack sequences into bins using parallel processing
s = 0
start_index = 0
result = []
Args:
sequence_lengths: Array of sequence lengths
bin_capacity: Maximum capacity of each bin as total number of tokens
group_size: Number of sequences to process in each group
bin_size: Maximum number of bins to use
num_processes: Number of parallel processes to use
safe_mode: If True, use a more conservative packing approach
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
Returns:
List of bins, where each bin contains indices of sequences assigned to it
"""
num_items = len(sequence_lengths)
if num_processes is None:
num_processes = max(1, min(num_items // group_size, cpu_count()))
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# Create tasks for parallel processing
tasks = []
for i in range(0, num_items, group_size):
group_lengths = sequence_lengths[i : i + group_size]
max_bins = len(group_lengths) # Allow as many bins as items in the group
tasks.append((group_lengths, i, bin_capacity, max_bins, bin_size, safe_mode))
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
# Process groups in parallel
all_bins = []
with ProcessPoolExecutor(max_workers=num_processes) as executor:
for group_bins in executor.map(_process_group, tasks):
all_bins.extend(group_bins)
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
return all_bins
@numba.njit
def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
def allocate_sequentially(
sequence_lengths: np.ndarray, rank: int, bin_capacity: int, num_ranks: int
):
"""
Sequential allocator that preserves example order
Parameters:
- lengths: The lengths of all examples
- rank: The current rank (for distributed training)
- c: The capacity of each bin (maximum sequence length)
- n: Number of ranks
Args:
sequence_lengths: The lengths of all examples
rank: The current rank (for distributed training)
bin_capacity: The capacity of each bin (maximum sequence length)
num_ranks: Number of ranks (processes/GPUs)
Returns:
- result: List of batches for the current rank
- total_used: Number of actual example tokens
- total_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
rank_batches: List of batches for the current rank
total_tokens_used: Number of actual example tokens
total_token_slots: Maximum theoretical number of example tokens (number of bins * bin capacity)
"""
result = []
total_used = 0
@@ -127,9 +189,9 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
# First, do sequential packing into bins
all_bins = []
current_bin = [0 for i in range(0)] # numba hint
remaining_capacity = c
remaining_capacity = bin_capacity
for idx, size in enumerate(lengths):
for idx, size in enumerate(sequence_lengths):
if size <= remaining_capacity:
# Example fits in current bin
current_bin.append(idx)
@@ -140,7 +202,7 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
if current_bin: # Add non-empty bin to all_bins
all_bins.append(current_bin)
current_bin = [idx]
remaining_capacity = c - size
remaining_capacity = bin_capacity - size
total_used += size
# Add the last bin if not empty
@@ -148,132 +210,227 @@ def allocate_sequentially(lengths: np.ndarray, rank: int, c: int, n: int):
all_bins.append(current_bin)
# Assign bins to ranks - each rank gets every n-th bin
for bin_idx in range(rank, len(all_bins), n):
for bin_idx in range(rank, len(all_bins), num_ranks):
result.append(all_bins[bin_idx])
return result, total_used, len(all_bins) * c
return result, total_used, len(all_bins) * bin_capacity
class MultipackBatchSampler(BatchSampler):
"""Batch sampler class for multipack"""
"""
Batch sampler class for efficient packing of variable-length sequences
This sampler packs sequences into fixed-capacity bins (batches) to maximize
GPU memory utilization and training throughput by reducing padding.
It supports both parallel packing (using FFD algorithm) and
sequential packing (preserving original sequence order).
"""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int,
batch_max_len: int,
lengths: np.ndarray,
packing_efficiency_estimate: float = 1.0,
drop_last: bool = False,
num_count_samples: int = 16,
sequential: bool = False,
**kwargs,
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 16, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
bin_size: int = 200, # The max number of samples that can be packed in a single bin
num_processes: int | None = None, # Number of processes for parallel packing
safe_mode: bool = True, # Conservative packing to prevent training instability
**kwargs, # pylint: disable=unused-argument
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.lengths = np.array(lengths, dtype=np.int32)
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.sequential = sequential
self.group_size = group_size
self.bin_size = bin_size
self.num_processes = num_processes
self.safe_mode = safe_mode
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
# Efficiency statistics tracking
self.total_tokens_used = 0
self.total_token_slots = 0
# The number of times to calculate the batches to determine the minimum packed dataset length for the local rank
# The number of times to calculate batches to determine minimum packed dataset length
self.num_count_samples = num_count_samples
# the minimum packed dataset length across all ranks determined by a gather/broadcast
# Minimum packed dataset length across all ranks (determined by gather/broadcast)
self.len_across_ranks = None
# Cache for batches
self._batches = None
if self.sequential and not isinstance(sampler, SequentialSampler):
LOG.warning(
"using sequential sample packing with non-sequential sampler, did you want to also enable curriculum_sampling?"
)
def set_epoch(self, epoch: int):
"""Set the epoch number, used for reproducible shuffling across epochs"""
self.epoch = epoch
self._batches = None # Invalidate batch cache
def generate_batches(self, set_stats=False):
indices = [idx for idx in self.sampler]
"""
Generate packed batches for training
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
Args:
set_stats: Whether to update efficiency statistics
if self.sequential:
batches, total_used, total_slots = allocate_sequentially(
lengths=lengths,
rank=0,
c=self.batch_max_len,
n=1,
)
else:
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=0,
c=self.batch_max_len,
n=1,
)
Returns:
List of batches, where each batch contains multiple bins,
and each bin contains multiple sequence indices
"""
if self._batches is not None:
return self._batches
batches = [
[
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
# Get indices from the sampler
indices = [ # pylint: disable=unnecessary-comprehension
idx for idx in self.sampler
]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
# Get lengths of the selected sequences
lengths = self.lengths[indices]
# Pack sequences into bins using either sequential or parallel packing
if self.sequential:
bins, total_used, total_slots = allocate_sequentially(
lengths,
rank=0,
bin_capacity=self.batch_max_len,
num_ranks=1,
)
# Map bin indices back to original indices
bins = [[indices[b_idx] for b_idx in bin_indices] for bin_indices in bins]
else:
# Use parallel packing
all_bins = pack_parallel(
lengths,
bin_capacity=self.batch_max_len,
group_size=self.group_size,
bin_size=self.bin_size,
num_processes=self.num_processes,
safe_mode=self.safe_mode,
)
# Map bin indices back to original indices
bins = [
[indices[b_idx] for b_idx in bin_indices] for bin_indices in all_bins
]
# Calculate efficiency statistics
total_used = lengths.sum()
total_slots = len(all_bins) * self.batch_max_len
# Group bins into batches (each batch contains batch_size bins)
batches = [
bins[i : i + self.batch_size] for i in range(0, len(bins), self.batch_size)
]
# Drop last batch if requested and it's incomplete
if self.drop_last and len(batches[-1]) < self.batch_size:
batches = batches[:-1]
# Adjust total_slots if we dropped a batch
if not self.sequential:
total_slots -= (self.batch_size - len(batches[-1])) * self.batch_max_len
# Update statistics if requested
if set_stats:
self.total_tokens_used += total_used
self.total_token_slots += total_slots
self._batches = batches
return batches
def __iter__(self):
"""
Return an iterator over batches
The batches are truncated to match the minimum number of batches across all ranks
to ensure distributed training balance
"""
batches = self.generate_batches(set_stats=True)
if self.len_across_ranks:
# make sure the batches we iterate over is truncated to the same min length across all ranks
# Truncate batches to ensure all ranks have the same number of batches
batches = batches[: self.len_across_ranks]
return iter(batches)
def num_batches(self):
batches = self.generate_batches(set_stats=True)
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots
"""
Calculate the packing efficiency (ratio of tokens used to total token slots)
Higher is better - 1.0 would mean perfect packing with no wasted space
"""
if self.total_token_slots == 0:
self.generate_batches(set_stats=True)
if self.total_token_slots == 0:
return 0.0
# Return a Python float instead of potentially a numpy float
return float(self.total_tokens_used / self.total_token_slots)
def gather_efficiency(self):
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
return math.floor(0.997 * max(estimates))
"""
Gather and synchronize packing efficiency estimates across all distributed ranks
Returns a conservative efficiency estimate based on the measurements
"""
def calc_sample_packing_eff_est(estimates: list[float]):
LOG.debug(f"sample_packing_eff_est across ranks: {repr(estimates)}")
# Use 99.7% of max observed efficiency as a safe estimate
max_eff = max(float(eff) for eff in estimates)
return math.floor(0.997 * max_eff)
# Gather efficiency from all ranks and apply the calculation function
sample_packing_actual_eff_all = reduce_and_broadcast(
lambda: self.efficiency(), # pylint: disable=unnecessary-lambda
lambda: float(self.efficiency()), # pylint: disable=unnecessary-lambda
calc_sample_packing_eff_est,
)
# Quantize to 0.5% intervals for stability
sample_packing_eff_est = (
math.ceil(sample_packing_actual_eff_all * 200.0) / 200.0
)
return sample_packing_eff_est
def gather_len_batches(self, num):
"""
Gather and synchronize batch counts across all distributed ranks
Returns the minimum number of batches available on any rank
"""
def calc_min_len(estimates: list[(int, float)]):
LOG.info(f"gather_len_batches: {repr(estimates)}")
return math.floor(min(estimates))
# Find minimum batch count across ranks to ensure balance
min_len_batches = reduce_and_broadcast(lambda: num, calc_min_len)
return min_len_batches
def __len__(self):
if not self.len_across_ranks:
len_batches = min(
[self.num_batches() for _ in range(self.num_count_samples)]
"""
Return the total number of batches that will be yielded by this sampler
This is calculated as the minimum number of batches available on any rank
to ensure balanced distributed training
"""
if self._batches is None:
self._batches = self.generate_batches(set_stats=True)
if self.len_across_ranks is None:
# Sample multiple times to get stable estimate
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
self.len_across_ranks = self.gather_len_batches(len_batches)
return self.len_across_ranks

View File

@@ -82,6 +82,7 @@ class AxolotlInputConfig(
mean_resizing_embeddings: bool | None = False
# optionally shrink the embeddings when the tokenizer vocab size is smaller
shrink_embeddings: bool | None = None
embeddings_skip_upcast: bool | None = None
rl: RLType | None = None
trl: TRLConfig | None = Field(
@@ -435,16 +436,6 @@ class AxolotlInputConfig(
)
return data
@model_validator(mode="before")
@classmethod
def check_sample_packing_w_xformers(cls, data):
if data.get("sample_packing") and data.get("xformers_attention"):
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
return data
@model_validator(mode="before")
@classmethod
# pylint: disable=duplicate-code
@@ -471,9 +462,10 @@ class AxolotlInputConfig(
and not data.get("flash_attention")
and not data.get("sdp_attention")
and not data.get("flex_attention")
and not data.get("xformers_attention")
):
LOG.warning(
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
)
return data
@@ -483,8 +475,14 @@ class AxolotlInputConfig(
def check_batch_flattening_fa(cls, data):
if data.get("batch_flattening"):
batch_flattening_auto = data.get("batch_flattening") == "auto"
if not data.get("flash_attention") and not batch_flattening_auto:
raise ValueError("batch_flattening requires flash attention")
if (
not data.get("flash_attention")
and not data.get("xformers_attention")
and not batch_flattening_auto
):
raise ValueError(
"batch_flattening requires flash attention or xformers"
)
if data.get("sample_packing") and not batch_flattening_auto:
raise ValueError("batch_flattening not compatible with sample_packing")
if data.get("micro_batch_size") == 1 and not batch_flattening_auto:

View File

@@ -53,4 +53,5 @@ class CustomSupportedOptimizers(str, Enum):
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
muon = "muon" # pylint: disable=invalid-name

View File

@@ -41,6 +41,7 @@ class WandbConfig(BaseModel):
use_wandb: bool | None = None
wandb_name: str | None = None
wandb_run_id: str | None = None
wandb_run_group: str | None = None
wandb_mode: str | None = None
wandb_project: str | None = None
wandb_entity: str | None = None

View File

@@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel):
lr_groups: list[LrGroup] | None = None
adam_epsilon: float | None = None
adam_epsilon2: float | None = None
adam_beta1: float | None = None
adam_beta2: float | None = None
adam_beta3: float | None = None
max_grad_norm: float | None = None
num_epochs: float = Field(default=1.0)

View File

@@ -29,6 +29,12 @@ class LogHooksPlugin(BasePlugin):
except FileNotFoundError:
pass
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_trainer_create\n")
def pre_model_load(self, cfg): # pylint: disable=unused-argument
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
@@ -165,6 +171,7 @@ class TestPluginHooks:
) as f:
file_contents = f.readlines()
file_contents = "\n".join(file_contents)
assert "post_trainer_create" in file_contents
assert "pre_model_load" in file_contents
assert "post_model_build" in file_contents
assert "pre_lora_load" in file_contents

View File

@@ -479,7 +479,7 @@ class TestMultiGPULlama:
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"val_set_size": 0.05,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},

View File

@@ -29,12 +29,12 @@ from axolotl.utils.dict import DictDefault
MODEL_CONFIGS = [
{
"name": "openaccess-ai-collective/tiny-mistral",
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
{
"name": "Qwen/Qwen2-7B",
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
"expected_activation": apply_lora_mlp_swiglu,
"dtype": torch.float16,
},
@@ -44,7 +44,7 @@ MODEL_CONFIGS = [
"dtype": torch.float32,
},
{
"name": "mhenrichsen/gemma-2b",
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
"expected_activation": apply_lora_mlp_geglu,
"dtype": torch.float16,
},
@@ -156,7 +156,9 @@ def test_swiglu_mlp_integration(small_llama_model):
def test_geglu_model_integration():
"""Test GeGLU activation with Gemma model."""
model = AutoModelForCausalLM.from_pretrained(
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
"trl-internal-testing/tiny-Gemma2ForCausalLM",
torch_dtype=torch.float16,
device_map="cuda:0",
)
peft_config = get_peft_config(
{

View File

@@ -6,6 +6,8 @@ import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalconPatched(unittest.TestCase):
Test case for Falcon models
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -71,6 +74,7 @@ class TestFalconPatched(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -28,7 +28,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,
@@ -76,7 +76,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 1024,

View File

@@ -56,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
def test_mistral_multipack(self, temp_dir):
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sample_packing": True,
"sequence_len": 2048,

View File

@@ -0,0 +1,63 @@
"""
Test case for handling embeddings when using peft
"""
import torch
from axolotl.train import setup_model_and_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
class TestLlamaPeftEmbeddings:
"""
test class for handling embeddings when using peft
"""
def test_peft_embeddings_upcast(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_target_linear": True,
"trust_remote_code": True,
"sequence_len": 512,
"val_set_size": 0.01,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"max_steps": 2,
"micro_batch_size": 1,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
"sample_packing": False,
"bf16": "auto",
"save_safetensors": True,
"embeddings_skip_upcast": True,
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
model, _, _, _ = setup_model_and_tokenizer(cfg)
# Check if the embeddings are upcast correctly
# only embed_tokens is a parameter that may be upcast
assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16
assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16

View File

@@ -15,7 +15,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, most_recent_subdir
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -26,6 +26,7 @@ class TestResumeLlama:
Test case for resuming training of llama models
"""
@require_torch_2_6_0
def test_resume_lora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
@@ -62,6 +63,7 @@ class TestResumeLlama:
"save_total_limit": 5,
"max_steps": 15,
"use_tensorboard": True,
"save_safetensors": True,
}
)
if is_torch_bf16_gpu_available():

View File

@@ -19,14 +19,11 @@ class TestE2eEvaluate:
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"base_model": "HuggingFaceTB/SmolLM2-135M",
"sequence_len": 1024,
"val_set_size": 0.02,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
"pad_token": "<|endoftext|>",
},
"datasets": [
{

View File

@@ -6,6 +6,8 @@ import logging
import os
import unittest
import pytest
from axolotl.cli.args import TrainerCliArgs
from axolotl.common.datasets import load_datasets
from axolotl.train import train
@@ -23,6 +25,7 @@ class TestFalcon(unittest.TestCase):
Test case for falcon
"""
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code
@@ -74,6 +77,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
# pylint: disable=duplicate-code
@@ -129,6 +133,7 @@ class TestFalcon(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -30,7 +30,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"load_in_8bit": True,
@@ -77,7 +77,7 @@ class TestMistral(unittest.TestCase):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "openaccess-ai-collective/tiny-mistral",
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
"flash_attention": True,
"sequence_len": 1024,
"val_set_size": 0.02,

View File

@@ -199,3 +199,50 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
@with_temp_dir
def test_came_pytorch(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "came_pytorch",
"adam_beta3": 0.9999,
"adam_epsilon2": 1e-16,
"max_steps": 5,
"lr_scheduler": "cosine",
}
)
cfg = validate_config(cfg)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)

View File

@@ -414,7 +414,6 @@ class TestDatasetPreparation:
snapshot_path = snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
)
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)