Compare commits
12 Commits
datasets-3
...
revert-mul
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e910e3e164 | ||
|
|
0f3587174d | ||
|
|
25e6c5f9bd | ||
|
|
32f51bca35 | ||
|
|
9daa04da90 | ||
|
|
0d71b0aa5f | ||
|
|
63aaccf85b | ||
|
|
ff0fe767c8 | ||
|
|
8e4158cc0b | ||
|
|
cd84325253 | ||
|
|
0b140fef83 | ||
|
|
e4cfebe995 |
87
.github/workflows/tests-nightly.yml
vendored
87
.github/workflows/tests-nightly.yml
vendored
@@ -18,9 +18,96 @@ jobs:
|
||||
env:
|
||||
SKIP: no-commit-to-branch
|
||||
|
||||
preload-cache:
|
||||
name: Preload HF cache
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
python_version: ["3.11"]
|
||||
pytorch_version: ["2.6.0"]
|
||||
timeout-minutes: 20
|
||||
|
||||
env:
|
||||
AXOLOTL_IS_CI_CACHE_PRELOAD: "1"
|
||||
|
||||
steps:
|
||||
- name: Check out repository code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Restore HF cache
|
||||
id: hf-cache-restore
|
||||
uses: actions/cache/restore@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||
|
||||
- name: Setup Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: ${{ matrix.python_version }}
|
||||
cache: 'pip' # caching pip dependencies
|
||||
|
||||
- name: upgrade pip
|
||||
run: |
|
||||
pip3 install --upgrade pip
|
||||
pip3 install --upgrade packaging==23.2 setuptools==75.8.0 wheel
|
||||
|
||||
- name: Install PyTorch
|
||||
run: |
|
||||
pip3 install torch==${{ matrix.pytorch_version }}
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip3 show torch
|
||||
pip3 install --no-build-isolation -U -e .
|
||||
python scripts/unsloth_install.py | sh
|
||||
python scripts/cutcrossentropy_install.py | sh
|
||||
pip3 install -r requirements-dev.txt -r requirements-tests.txt
|
||||
|
||||
- name: Make sure PyTorch version wasn't clobbered
|
||||
run: |
|
||||
python -c "import torch; assert '${{ matrix.pytorch_version }}' in torch.__version__"
|
||||
|
||||
- name: Ensure axolotl CLI was installed
|
||||
run: |
|
||||
axolotl --help
|
||||
|
||||
- name: Pre-Download dataset fixture
|
||||
run: |
|
||||
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
pytest -v tests/conftest.py
|
||||
|
||||
- name: Upload coverage to Codecov
|
||||
uses: codecov/codecov-action@v5
|
||||
with:
|
||||
token: ${{ secrets.CODECOV_TOKEN }}
|
||||
files: ./coverage.xml
|
||||
flags: unittests,pytorch-${{ matrix.pytorch_version }}
|
||||
fail_ci_if_error: false
|
||||
|
||||
- name: cleanup pip cache
|
||||
run: |
|
||||
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
|
||||
|
||||
- name: Save HF cache
|
||||
id: hf-cache
|
||||
uses: actions/cache/save@v4
|
||||
with:
|
||||
path: |
|
||||
/home/runner/.cache/huggingface/hub/datasets--*
|
||||
/home/runner/.cache/huggingface/hub/models--*
|
||||
key: ${{ steps.hf-cache-restore.outputs.cache-primary-key }}
|
||||
|
||||
pytest:
|
||||
name: PyTest
|
||||
runs-on: ubuntu-latest
|
||||
needs: [preload-cache]
|
||||
strategy:
|
||||
fail-fast: false
|
||||
max-parallel: 2
|
||||
|
||||
@@ -32,6 +32,8 @@ tokenizer_legacy:
|
||||
resize_token_embeddings_to_32x:
|
||||
# Optional[bool] Whether to shrink the embeddings to len(tokenizer). By default, we won't shrink.
|
||||
shrink_embeddings:
|
||||
# Optional[bool] Don't upcast the embeddings to float32 when using PEFT. Useful for low-VRAM GPUs
|
||||
embeddings_skip_upcast:
|
||||
# Whether to load the model with randomly initialized weights. Useful for
|
||||
# pre-training a model from scratch or debugging purposes.
|
||||
random_init_weights:
|
||||
@@ -73,11 +75,12 @@ load_in_8bit: true
|
||||
load_in_4bit:
|
||||
|
||||
# Use CUDA bf16
|
||||
bf16: true # bool or 'full' for `bf16_full_eval`. require >=ampere
|
||||
bf16: true # bool or 'full' for `bf16_full_eval`, or 'auto' for automatic detection. require >=ampere
|
||||
# Use CUDA fp16
|
||||
fp16: true
|
||||
# Use CUDA tf32
|
||||
tf32: true # require >=ampere
|
||||
# Note: if bf16 is set to 'auto', and fp16 is set to true, we will prefer the explict fp16 setting
|
||||
|
||||
# No AMP (automatic mixed precision)
|
||||
bfloat16: true # require >=ampere
|
||||
@@ -184,8 +187,8 @@ datasets:
|
||||
# adding a system turn with empty content.
|
||||
drop_system_message:
|
||||
|
||||
# Optional[bool]. Whether to split the assistant turn based on a reasoning trace inside delimited tags
|
||||
# defaults to False
|
||||
# Optional[bool]. (for Qwen3 template only) Whether to split the assistant content based on a reasoning trace inside delimited tags
|
||||
# See example at `docs/dataset-formats/conversation.qmd`
|
||||
split_thinking:
|
||||
|
||||
# IMPORTANT: The following fields determine which parts of the conversation to train on.
|
||||
@@ -609,6 +612,7 @@ lr_div_factor: # Learning rate div factor
|
||||
# - optimi_adamw
|
||||
# - ao_adamw_8bit
|
||||
# - ao_adamw_fp8
|
||||
# - came_pytorch
|
||||
optimizer:
|
||||
# Dictionary of arguments to pass to the optimizer
|
||||
optim_args:
|
||||
|
||||
@@ -196,6 +196,34 @@ datasets:
|
||||
It is not necessary to set both `message_field_training` and `message_field_training_detail` at once.
|
||||
:::
|
||||
|
||||
8. (For Qwen3 template only) Enable reasoning split, where the reasoning is split from the content and passed as a separate field into the template.
|
||||
|
||||
```yaml
|
||||
datasets:
|
||||
- path: ...
|
||||
type: chat_template
|
||||
chat_template: qwen3
|
||||
split_thinking: true
|
||||
```
|
||||
|
||||
For example, a content can look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"content": "<think>Some thinking outputs</think>Output after thinking."
|
||||
}
|
||||
```
|
||||
|
||||
After split, it will look like:
|
||||
|
||||
```json
|
||||
{
|
||||
"reasoning_content": "Some thinking outputs",
|
||||
"content": "Output after thinking..."
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
## sharegpt
|
||||
|
||||
::: {.callout-important}
|
||||
|
||||
@@ -34,3 +34,5 @@ We provide a script to delinearize Llama 4 linearized models into regular Huggin
|
||||
```bash
|
||||
axolotl delinearize-llama4 --model path/to/model_dir --output path/to/output_dir
|
||||
```
|
||||
|
||||
Note: This only works with the non-quantized linearized model. If you have an adapter, merge it with the *non-quantized linearized* model before delinearizing.
|
||||
|
||||
@@ -6,11 +6,12 @@ triton>=3.0.0
|
||||
mamba-ssm==1.2.0.post1
|
||||
xformers>=0.0.23.post1
|
||||
autoawq==0.2.7.post3
|
||||
liger-kernel==0.5.8
|
||||
liger-kernel==0.5.9
|
||||
# END section
|
||||
|
||||
packaging==23.2
|
||||
|
||||
huggingface_hub==0.31.0
|
||||
peft==0.15.2
|
||||
transformers==4.51.3
|
||||
tokenizers>=0.21.1
|
||||
|
||||
5
setup.py
5
setup.py
@@ -67,13 +67,13 @@ def parse_requirements(extras_require_map):
|
||||
if (major, minor) >= (2, 7):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
# _install_requires.append("xformers==0.0.29.post3") # xformers seems to be hard pinned to 2.6.0
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 6):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
_install_requires.append(
|
||||
"xformers==0.0.29.post2"
|
||||
) # vllm needs post2 w torch 2.6
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5"]
|
||||
extras_require_map["vllm"] = ["vllm==0.8.5.post1"]
|
||||
elif (major, minor) >= (2, 5):
|
||||
_install_requires.pop(_install_requires.index(xformers_version))
|
||||
if patch == 0:
|
||||
@@ -142,6 +142,7 @@ extras_require = {
|
||||
"apollo-torch",
|
||||
"lomo-optim==0.1.1",
|
||||
"torch-optimi==0.2.1",
|
||||
"came_pytorch==0.1.3",
|
||||
],
|
||||
"ray": [
|
||||
"ray[train]",
|
||||
|
||||
@@ -18,6 +18,7 @@ from axolotl.cli.checks import check_accelerate_default_config, check_user_token
|
||||
from axolotl.cli.config import load_cfg
|
||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||
from axolotl.common.datasets import load_datasets, load_preference_datasets
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.trainer import disable_datasets_caching
|
||||
|
||||
@@ -47,7 +48,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
||||
cfg.dataset_prepared_path = DEFAULT_DATASET_PREPARED_PATH
|
||||
|
||||
with disable_datasets_caching():
|
||||
if cfg.rl:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
if plugin_manager.load_datasets(cfg, preprocess=True):
|
||||
pass
|
||||
elif cfg.rl:
|
||||
load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -43,10 +43,13 @@ def do_train(cfg: DictDefault, cli_args: TrainerCliArgs):
|
||||
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||
check_user_token()
|
||||
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
dataset_meta = plugin_manager.load_datasets(cfg, preprocess=False)
|
||||
if not dataset_meta:
|
||||
if cfg.rl:
|
||||
dataset_meta = load_preference_datasets(cfg=cfg, cli_args=cli_args)
|
||||
else:
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
model, tokenizer, trainer = train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ import importlib.util
|
||||
import inspect
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from abc import abstractmethod
|
||||
from pathlib import Path
|
||||
@@ -72,6 +73,7 @@ from axolotl.utils.callbacks import (
|
||||
SaveBetterTransformerModelCallback,
|
||||
bench_eval_callback_factory,
|
||||
causal_lm_bench_eval_callback_factory,
|
||||
colab_inference_post_train_callback,
|
||||
log_prediction_callback_factory,
|
||||
)
|
||||
from axolotl.utils.callbacks.lisa import lisa_callback_factory
|
||||
@@ -293,6 +295,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
if self.cfg.lisa_step_interval and self.cfg.lisa_n_layers:
|
||||
callbacks.append(lisa_callback_factory(trainer))
|
||||
|
||||
if any("COLAB_" in key for key in os.environ):
|
||||
ColabCallback = colab_inference_post_train_callback(trainer)
|
||||
callbacks.append(ColabCallback(self.cfg))
|
||||
|
||||
callbacks.extend(super().get_post_trainer_create_callbacks(trainer=trainer))
|
||||
return callbacks
|
||||
|
||||
@@ -702,6 +708,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
optimizer_cls = ADOPT
|
||||
adam_kwargs["decouple"] = True
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
elif self.cfg.optimizer == "came_pytorch":
|
||||
from came_pytorch import CAME
|
||||
|
||||
optimizer_cls = CAME
|
||||
|
||||
beta1 = training_arguments_kwargs.get("adam_beta1", 0.9)
|
||||
beta2 = training_arguments_kwargs.get("adam_beta2", 0.999)
|
||||
beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999)
|
||||
eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30)
|
||||
eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16)
|
||||
adam_kwargs["betas"] = (beta1, beta2, beta3)
|
||||
adam_kwargs["eps"] = (eps1, eps2)
|
||||
|
||||
optimizer_kwargs.update(adam_kwargs)
|
||||
|
||||
# Parse any additional optimizer args from config
|
||||
if self.cfg.optim_args:
|
||||
|
||||
@@ -26,6 +26,8 @@ from typing import OrderedDict
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class BasePlugin:
|
||||
"""
|
||||
@@ -36,11 +38,13 @@ class BasePlugin:
|
||||
|
||||
Methods:
|
||||
register(cfg): Registers the plugin with the given configuration.
|
||||
load_datasets(cfg): Loads and preprocesses the dataset for training.
|
||||
pre_model_load(cfg): Performs actions before the model is loaded.
|
||||
post_model_build(cfg, model): Performs actions after the model is loaded, but before LoRA adapters are applied.
|
||||
pre_lora_load(cfg, model): Performs actions before LoRA weights are loaded.
|
||||
post_lora_load(cfg, model): Performs actions after LoRA weights are loaded.
|
||||
post_model_load(cfg, model): Performs actions after the model is loaded, inclusive of any adapters.
|
||||
post_trainer_create(cfg, trainer): Performs actions after the trainer is created.
|
||||
create_optimizer(cfg, trainer): Creates and returns an optimizer for training.
|
||||
create_lr_scheduler(cfg, trainer, optimizer, num_training_steps): Creates and returns a learning rate scheduler.
|
||||
add_callbacks_pre_trainer(cfg, model): Adds callbacks to the trainer before training.
|
||||
@@ -63,20 +67,32 @@ class BasePlugin:
|
||||
None
|
||||
"""
|
||||
|
||||
def get_input_args(self):
|
||||
def get_input_args(self) -> str | None:
|
||||
"""
|
||||
Returns a pydantic model for the plugin's input arguments.
|
||||
"""
|
||||
|
||||
def load_datasets(self, cfg: DictDefault, preprocess: bool = False):
|
||||
"""
|
||||
Loads and preprocesses the dataset for training.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugin.
|
||||
preprocess: Whether this is the preprocess step of the datasets.
|
||||
|
||||
Returns:
|
||||
dataset_meta: The metadata for the training dataset.
|
||||
"""
|
||||
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before the model is loaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||
@@ -91,59 +107,71 @@ class BasePlugin:
|
||||
"""
|
||||
Performs actions after the model is loaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions before LoRA weights are loaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after LoRA weights are loaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
|
||||
"""
|
||||
Returns a custom class for the trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The global axolotl configuration.
|
||||
Args:
|
||||
cfg (dict): The global axolotl configuration.
|
||||
|
||||
Returns:
|
||||
class: The class for the trainer.
|
||||
class: The class for the trainer.
|
||||
"""
|
||||
|
||||
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after the trainer is created.
|
||||
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
"""
|
||||
Creates and returns an optimizer for training.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
object: The created optimizer.
|
||||
object: The created optimizer.
|
||||
"""
|
||||
|
||||
def create_lr_scheduler(
|
||||
@@ -152,26 +180,26 @@ class BasePlugin:
|
||||
"""
|
||||
Creates and returns a learning rate scheduler.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
optimizer (object): The optimizer for training.
|
||||
num_training_steps (int): Total number of training steps
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
optimizer (object): The optimizer for training.
|
||||
num_training_steps (int): Total number of training steps
|
||||
|
||||
Returns:
|
||||
object (LRScheduler): The created learning rate scheduler.
|
||||
object (LRScheduler): The created learning rate scheduler.
|
||||
"""
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
"""
|
||||
setup callbacks before creating the trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
List[callable]: A list of callback functions to be added to the TrainingArgs
|
||||
"""
|
||||
return []
|
||||
|
||||
@@ -182,12 +210,12 @@ class BasePlugin:
|
||||
Adds callbacks to the trainer after creating the trainer.
|
||||
This is useful for callbacks that require access to the model or trainer.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
List[callable]: A list of callback functions to be added
|
||||
List[callable]: A list of callback functions to be added
|
||||
"""
|
||||
return []
|
||||
|
||||
@@ -195,23 +223,23 @@ class BasePlugin:
|
||||
"""
|
||||
Performs actions after training is complete.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The axolotl configuration
|
||||
model (object): The loaded model.
|
||||
Args:
|
||||
cfg (dict): The axolotl configuration
|
||||
model (object): The loaded model.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
"""
|
||||
Performs actions after training is complete and the model is unloaded.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
Args:
|
||||
cfg (dict): The configuration for the plugin.
|
||||
|
||||
Returns:
|
||||
None
|
||||
None
|
||||
"""
|
||||
|
||||
|
||||
@@ -338,6 +366,27 @@ class PluginManager:
|
||||
input_args.append(input_args_from_plugin)
|
||||
return input_args
|
||||
|
||||
def load_datasets(self, cfg, preprocess: bool = False):
|
||||
"""
|
||||
Calls the load_datasets method of each registered plugin.
|
||||
|
||||
Args:
|
||||
cfg: The configuration for the plugins.
|
||||
preprocess : Whether this is preprocess step of the datasets.
|
||||
|
||||
Returns:
|
||||
dataset_meta: The dataset metadata loaded from all registered plugins.
|
||||
"""
|
||||
return_ds_meta = None
|
||||
for plugin in self.plugins.values():
|
||||
dataset_meta = plugin.load_datasets(cfg, preprocess)
|
||||
if dataset_meta is not None:
|
||||
if return_ds_meta is None:
|
||||
return_ds_meta = dataset_meta
|
||||
else:
|
||||
raise RuntimeError("Multiple plugins loaded datasets")
|
||||
return return_ds_meta
|
||||
|
||||
def pre_model_load(self, cfg):
|
||||
"""
|
||||
Calls the pre_model_load method of all registered plugins.
|
||||
@@ -422,6 +471,20 @@ class PluginManager:
|
||||
return trainer_cls
|
||||
return None
|
||||
|
||||
def post_trainer_create(self, cfg, trainer):
|
||||
"""
|
||||
Calls the post_trainer_create method of all registered plugins.
|
||||
|
||||
Parameters:
|
||||
cfg (dict): The configuration for the plugins.
|
||||
trainer (object): The trainer object for training.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
for plugin in self.plugins.values():
|
||||
plugin.post_trainer_create(cfg, trainer)
|
||||
|
||||
def create_optimizer(self, trainer):
|
||||
"""
|
||||
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
"""
|
||||
attention module for attention monkeypatches
|
||||
"""
|
||||
|
||||
from transformers.integrations.flash_attention import flash_attention_forward
|
||||
|
||||
|
||||
def patch_xformers_attn_over_fa2():
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
from .xformers import xformers_attention_forward
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = xformers_attention_forward
|
||||
|
||||
|
||||
def unpatch_xformers_attn_over_fa2():
|
||||
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
|
||||
|
||||
ALL_ATTENTION_FUNCTIONS["flash_attention_2"] = flash_attention_forward()
|
||||
|
||||
160
src/axolotl/monkeypatch/attention/xformers.py
Normal file
160
src/axolotl/monkeypatch/attention/xformers.py
Normal file
@@ -0,0 +1,160 @@
|
||||
"""
|
||||
xformers attention implementation for packing
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import xformers
|
||||
import xformers.ops.fmha
|
||||
from transformers.modeling_flash_attention_utils import (
|
||||
_upad_input,
|
||||
)
|
||||
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
|
||||
xformers_attention = xformers.ops.fmha.memory_efficient_attention
|
||||
|
||||
|
||||
def xformers_attention_forward(
|
||||
module: torch.nn.Module,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
dropout: float = 0.0, # pylint: disable=unused-argument
|
||||
scaling: Optional[float] = None, # pylint: disable=unused-argument
|
||||
sliding_window: Optional[int] = None, # pylint: disable=unused-argument
|
||||
softcap: Optional[float] = None, # pylint: disable=unused-argument
|
||||
cu_seq_lens_q: Optional[torch.LongTensor] = None,
|
||||
cu_seq_lens_k: Optional[torch.LongTensor] = None,
|
||||
max_length_q: Optional[int] = None,
|
||||
max_length_k: Optional[int] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
):
|
||||
# Get dimensions
|
||||
# query: [batch, heads, seq_len, hidden_dim]
|
||||
batch_size = query.size(0)
|
||||
query_length = query.shape[2]
|
||||
key_length = key.shape[2]
|
||||
|
||||
# Default causal mask
|
||||
attn_bias = xformers.ops.LowerTriangularMask()
|
||||
|
||||
# Check if we have sliding window attention
|
||||
has_sliding_window = sliding_window is not None and sliding_window < query_length
|
||||
|
||||
# Transpose dimensions for xformers (Q: [b, h, s, d] -> [b, s, h, d])
|
||||
query = query.transpose(1, 2)
|
||||
key = key.transpose(1, 2)
|
||||
value = value.transpose(1, 2)
|
||||
|
||||
# Get GQA parameters
|
||||
num_attention_heads = module.config.num_attention_heads
|
||||
num_key_value_heads = module.config.num_key_value_heads
|
||||
head_dim = query.size(-1)
|
||||
is_gqa = num_attention_heads != num_key_value_heads
|
||||
n_groups = num_attention_heads // num_key_value_heads if is_gqa else 1
|
||||
|
||||
# If position_ids is provided and check all examples do not contain only 1 sequence, If tensor in increasing
|
||||
# then we probably have one sequence, otherwise it is packed. Additionally check we are in pre-fill/training stage.
|
||||
# Use `flash_attn_varlen_func` to prevent cross-example attention and also allow padding free approach
|
||||
if position_ids is not None and (
|
||||
max_length_q is not None
|
||||
or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())
|
||||
):
|
||||
if cu_seq_lens_q is None or cu_seq_lens_k is None:
|
||||
cu_seq_lens_q = get_cu_seqlens_from_pos_ids(position_ids)[0]
|
||||
cu_seq_lens_q = cu_seq_lens_q.squeeze()
|
||||
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
|
||||
attn_bias = (
|
||||
xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
q_seqlen=seq_lengths.tolist(),
|
||||
)
|
||||
)
|
||||
else:
|
||||
query = query.reshape(-1, query.size(-2), query.size(-1))
|
||||
key = key.reshape(-1, key.size(-2), key.size(-1))
|
||||
value = value.reshape(-1, value.size(-2), value.size(-1))
|
||||
|
||||
# Handle GQA
|
||||
if is_gqa:
|
||||
key = key.repeat_interleave(n_groups, dim=2)
|
||||
value = value.repeat_interleave(n_groups, dim=2)
|
||||
|
||||
elif attention_mask is not None:
|
||||
query, key, value, _, cu_seq_lens, _ = _upad_input(
|
||||
query, key, value, attention_mask, query_length
|
||||
)
|
||||
cu_seq_lens_q, cu_seq_lens_k = cu_seq_lens
|
||||
seq_lengths = []
|
||||
for i in range(len(cu_seq_lens_q) - 1):
|
||||
seq_lengths.append(cu_seq_lens_q[i + 1] - cu_seq_lens_q[i])
|
||||
attn_bias = xformers.ops.fmha.attn_bias.BlockDiagonalCausalMask.from_seqlens(
|
||||
q_seqlen=seq_lengths,
|
||||
kv_seqlen=seq_lengths,
|
||||
)
|
||||
|
||||
# Handle GQA
|
||||
if is_gqa:
|
||||
key = key.repeat_interleave(n_groups, dim=2)
|
||||
value = value.repeat_interleave(n_groups, dim=2)
|
||||
else:
|
||||
# Handle Group Query Attention (GQA) using view/expand approach from reference
|
||||
key = key.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
|
||||
value = value.view(batch_size, key_length, num_key_value_heads, 1, head_dim)
|
||||
key = key.expand(
|
||||
batch_size, key_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
value = value.expand(
|
||||
batch_size, key_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
|
||||
if module.training:
|
||||
key = key.reshape(batch_size, key_length, num_attention_heads, head_dim)
|
||||
value = value.reshape(batch_size, key_length, num_attention_heads, head_dim)
|
||||
|
||||
if has_sliding_window:
|
||||
query = query.view(
|
||||
1, batch_size * query_length, num_attention_heads, head_dim
|
||||
)
|
||||
key = key.view(
|
||||
1, batch_size * key_length, num_attention_heads, head_dim
|
||||
)
|
||||
value = value.view(
|
||||
1, batch_size * key_length, num_attention_heads, head_dim
|
||||
)
|
||||
else:
|
||||
query = query.view(
|
||||
batch_size, query_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
|
||||
# If we need a sliding window attention
|
||||
if has_sliding_window:
|
||||
query = query.view(
|
||||
1,
|
||||
batch_size * query_length,
|
||||
num_key_value_heads,
|
||||
n_groups,
|
||||
head_dim,
|
||||
)
|
||||
key = key.view(
|
||||
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
value = value.view(
|
||||
1, batch_size * key_length, num_key_value_heads, n_groups, head_dim
|
||||
)
|
||||
|
||||
# Run the xformers attention
|
||||
attn_output = xformers_attention(
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
attn_bias=attn_bias,
|
||||
)
|
||||
|
||||
attn_output = attn_output.view(
|
||||
batch_size, -1, attn_output.size(-2), attn_output.size(-1)
|
||||
)
|
||||
return attn_output, None
|
||||
0
src/axolotl/monkeypatch/peft/__init__.py
Normal file
0
src/axolotl/monkeypatch/peft/__init__.py
Normal file
78
src/axolotl/monkeypatch/peft/utils.py
Normal file
78
src/axolotl/monkeypatch/peft/utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Patch prepare_model_for_kbit_training to not upcast everything
|
||||
"""
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import peft
|
||||
|
||||
import axolotl
|
||||
from axolotl.monkeypatch.utils import detab_code
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
ORIGINAL_PREPARE_CODE = """
|
||||
for param in model.parameters():
|
||||
if (
|
||||
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||
) and param.__class__.__name__ != "Params4bit":
|
||||
param.data = param.data.to(torch.float32)
|
||||
"""
|
||||
|
||||
PATCHED_PREPARE_CODE = """
|
||||
for name, param in model.named_parameters():
|
||||
if (
|
||||
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
|
||||
) and param.__class__.__name__ != "Params4bit" and all(embed_name not in name for embed_name in ["embed_tokens", "lm_head"]):
|
||||
param.data = param.data.to(torch.float32)
|
||||
"""
|
||||
|
||||
|
||||
def get_peft_prep_code() -> str:
|
||||
prepare = inspect.getsource(peft.utils.other.prepare_model_for_kbit_training)
|
||||
return prepare
|
||||
|
||||
|
||||
def check_peft_prep_code_is_patchable() -> bool:
|
||||
prep_code = get_peft_prep_code()
|
||||
prep_code, _ = detab_code(prep_code)
|
||||
return ORIGINAL_PREPARE_CODE in prep_code
|
||||
|
||||
|
||||
def patch_peft_prep_code():
|
||||
"""
|
||||
monkeypatch create_accelerator_and_postprocess so it checks for additional kwargs
|
||||
"""
|
||||
|
||||
try:
|
||||
prep_code = get_peft_prep_code()
|
||||
except OSError:
|
||||
return
|
||||
peft.utils.other._original_create_accelerator_and_postprocess = ( # pylint: disable=protected-access
|
||||
prep_code
|
||||
)
|
||||
prep_code, _ = detab_code(prep_code)
|
||||
if ORIGINAL_PREPARE_CODE not in prep_code:
|
||||
return
|
||||
|
||||
prep_code = prep_code.replace(ORIGINAL_PREPARE_CODE, PATCHED_PREPARE_CODE)
|
||||
prep_code = prep_code.replace(
|
||||
"def prepare_model_for_kbit_training(",
|
||||
"def fixed_prepare_model_for_kbit_training(",
|
||||
1,
|
||||
)
|
||||
|
||||
items_to_import = []
|
||||
for item in dir(peft.utils.other):
|
||||
if item in prep_code:
|
||||
items_to_import.append(item)
|
||||
|
||||
exec( # pylint: disable=exec-used # nosec B102
|
||||
"from peft.utils.other import (" + ", ".join(x for x in items_to_import) + ")",
|
||||
globals(),
|
||||
)
|
||||
exec(prep_code, globals()) # pylint: disable=exec-used # nosec B102
|
||||
LOG.info("patching prepare_model_for_kbit_training to allow for overrides")
|
||||
peft.utils.other.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
|
||||
axolotl.utils.models.prepare_model_for_kbit_training = fixed_prepare_model_for_kbit_training # pylint: disable=protected-access # pylint: disable=undefined-variable # noqa: F821
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import importlib
|
||||
import inspect
|
||||
import logging
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -12,7 +13,6 @@ from typing import Any, Dict
|
||||
|
||||
import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import save_fsdp_model
|
||||
from datasets import Dataset
|
||||
from huggingface_hub.errors import OfflineModeIsEnabled
|
||||
@@ -42,7 +42,7 @@ try:
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def setup_model_and_tokenizer(
|
||||
@@ -63,7 +63,6 @@ def setup_model_and_tokenizer(
|
||||
# Load tokenizer
|
||||
LOG.debug(
|
||||
f"loading tokenizer... {cfg.tokenizer_config or cfg.base_model_config}",
|
||||
main_process_only=True,
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
|
||||
@@ -528,6 +527,9 @@ def train(
|
||||
processor,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_trainer_create(cfg, trainer)
|
||||
|
||||
# Handle untrained tokens if configured
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
train_dataset = dataset_meta.train_dataset
|
||||
@@ -550,7 +552,6 @@ def train(
|
||||
if not cfg.use_ray:
|
||||
cleanup_distributed()
|
||||
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
plugin_manager.post_train(cfg, model)
|
||||
|
||||
return model, tokenizer, trainer
|
||||
|
||||
@@ -868,3 +868,28 @@ class GCCallback(TrainerCallback):
|
||||
):
|
||||
torch.cuda.empty_cache()
|
||||
gc.collect()
|
||||
|
||||
|
||||
def colab_inference_post_train_callback(trainer: Trainer):
|
||||
class ColabCallback(TrainerCallback):
|
||||
"""Callback to prep model for inference on Google Colab"""
|
||||
|
||||
def __init__(self, cfg):
|
||||
self.gpu_name = torch.cuda.get_device_name(0)
|
||||
self.cfg = cfg
|
||||
|
||||
def on_train_end(
|
||||
self, args, state, control, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
"""
|
||||
handle T4 gpu, we need to convert attention to eager for inference
|
||||
"""
|
||||
if "Tesla T4" in self.gpu_name and self.cfg.xformers_attention:
|
||||
trainer.model.config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
trainer.model.gradient_checkpointing_disable()
|
||||
trainer.model.config.use_cache = True
|
||||
trainer.model.eval()
|
||||
|
||||
return ColabCallback
|
||||
|
||||
@@ -70,6 +70,9 @@ def resolve_dtype(cfg):
|
||||
if cfg.fp16 is None and not cfg.float16:
|
||||
cfg.fp16 = True
|
||||
|
||||
if cfg.fp16 and cfg.bf16 == "auto":
|
||||
cfg.bf16 = False
|
||||
|
||||
if cfg.device == "mps":
|
||||
cfg.load_in_8bit = False
|
||||
cfg.tf32 = False
|
||||
|
||||
@@ -281,6 +281,10 @@ def load_dataset_w_config(
|
||||
**load_ds_kwargs,
|
||||
)
|
||||
if not ds:
|
||||
raise ValueError("unhandled dataset load")
|
||||
raise ValueError(
|
||||
"The dataset could not be loaded. This could be due to a misconfigured dataset path "
|
||||
f"({config_dataset.path}). Try double-check your path / name / data_files. "
|
||||
"This is not caused by the dataset type."
|
||||
)
|
||||
|
||||
return ds
|
||||
|
||||
@@ -1,15 +1,36 @@
|
||||
"""custom checkpointing utils"""
|
||||
|
||||
import importlib
|
||||
from functools import partial
|
||||
|
||||
from packaging import version
|
||||
|
||||
from axolotl.utils.gradient_checkpointing.unsloth import (
|
||||
Unsloth_Offloaded_Gradient_Checkpointer,
|
||||
)
|
||||
|
||||
transformers_version = version.parse(importlib.metadata.version("transformers"))
|
||||
if transformers_version > version.parse("4.51.3"):
|
||||
from transformers.modeling_layers import GradientCheckpointingLayer
|
||||
|
||||
def uses_gc_layers(decoder_layer):
|
||||
return isinstance(decoder_layer.func.__self__, GradientCheckpointingLayer)
|
||||
|
||||
else:
|
||||
|
||||
def uses_gc_layers(_):
|
||||
return False
|
||||
|
||||
|
||||
def hf_grad_checkpoint_offload_wrapper(
|
||||
decoder_layer, *args, use_reentrant=None
|
||||
): # pylint: disable=unused-argument
|
||||
if uses_gc_layers(decoder_layer):
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
decoder_layer,
|
||||
*args,
|
||||
)
|
||||
|
||||
return Unsloth_Offloaded_Gradient_Checkpointer.apply(
|
||||
(
|
||||
decoder_layer.func.__self__
|
||||
|
||||
@@ -556,11 +556,21 @@ class ModelLoader:
|
||||
self.auto_model_loader = AutoModelForCausalLM # pylint: disable=invalid-name
|
||||
|
||||
def apply_patches(self) -> None:
|
||||
if self.cfg.xformers_attention and self.cfg.sample_packing:
|
||||
from axolotl.monkeypatch.attention import patch_xformers_attn_over_fa2
|
||||
|
||||
patch_xformers_attn_over_fa2()
|
||||
self.cfg.flash_attention = True
|
||||
if self.cfg.fsdp_config and str(self.cfg.fsdp_config.fsdp_version) == "2":
|
||||
from axolotl.monkeypatch.accelerate.fsdp2 import patch_accelerate_fsdp_utils
|
||||
|
||||
patch_accelerate_fsdp_utils()
|
||||
|
||||
if self.cfg.adapter and self.cfg.embeddings_skip_upcast:
|
||||
from axolotl.monkeypatch.peft.utils import patch_peft_prep_code
|
||||
|
||||
patch_peft_prep_code()
|
||||
|
||||
if self.cfg.flex_attention:
|
||||
from axolotl.monkeypatch.attention.flex_attn import (
|
||||
patch_flex_make_mask,
|
||||
@@ -1180,7 +1190,7 @@ class ModelLoader:
|
||||
],
|
||||
)
|
||||
|
||||
def prepare_model(self, qlora_fsdp) -> None:
|
||||
def prepare_model(self, qlora_fsdp: bool) -> None:
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
if self.cfg.model_config_type == "qwen" and self.cfg.adapter == "lora":
|
||||
# Qwen doesn't play nicely with LoRA if this is enabled
|
||||
@@ -1310,7 +1320,10 @@ class ModelLoader:
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
embedding_modules = get_linear_embedding_layers(self.cfg.model_config_type)
|
||||
if not self.cfg.fsdp:
|
||||
# FSDP doesn't like mixed Float and BFloat16
|
||||
# we don't run this during FSDP because this will leave mixed
|
||||
# float and bfloat16 dtypes in the model which FSDP doesn't like
|
||||
if self.cfg.load_in_4bit and self.cfg.embeddings_skip_upcast:
|
||||
embedding_modules = []
|
||||
self.convert_embedding_modules_dtype(
|
||||
embedding_modules,
|
||||
dist_dtype=torch.float32,
|
||||
|
||||
@@ -82,6 +82,7 @@ class AxolotlInputConfig(
|
||||
mean_resizing_embeddings: bool | None = False
|
||||
# optionally shrink the embeddings when the tokenizer vocab size is smaller
|
||||
shrink_embeddings: bool | None = None
|
||||
embeddings_skip_upcast: bool | None = None
|
||||
|
||||
rl: RLType | None = None
|
||||
trl: TRLConfig | None = Field(
|
||||
@@ -435,16 +436,6 @@ class AxolotlInputConfig(
|
||||
)
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_sample_packing_w_xformers(cls, data):
|
||||
if data.get("sample_packing") and data.get("xformers_attention"):
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -471,9 +462,10 @@ class AxolotlInputConfig(
|
||||
and not data.get("flash_attention")
|
||||
and not data.get("sdp_attention")
|
||||
and not data.get("flex_attention")
|
||||
and not data.get("xformers_attention")
|
||||
):
|
||||
LOG.warning(
|
||||
"sample_packing without flash, sdp or flex attention does not handle cross sample decontamination."
|
||||
"sample_packing without flash, sdp, xformers or flex attention does not handle cross sample decontamination."
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@@ -53,4 +53,5 @@ class CustomSupportedOptimizers(str, Enum):
|
||||
ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name
|
||||
ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name
|
||||
adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name
|
||||
came_pytorch = "came_pytorch" # pylint: disable=invalid-name
|
||||
muon = "muon" # pylint: disable=invalid-name
|
||||
|
||||
@@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel):
|
||||
lr_groups: list[LrGroup] | None = None
|
||||
|
||||
adam_epsilon: float | None = None
|
||||
adam_epsilon2: float | None = None
|
||||
adam_beta1: float | None = None
|
||||
adam_beta2: float | None = None
|
||||
adam_beta3: float | None = None
|
||||
max_grad_norm: float | None = None
|
||||
num_epochs: float = Field(default=1.0)
|
||||
|
||||
|
||||
@@ -29,6 +29,12 @@ class LogHooksPlugin(BasePlugin):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_trainer_create\n")
|
||||
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
@@ -165,6 +171,7 @@ class TestPluginHooks:
|
||||
) as f:
|
||||
file_contents = f.readlines()
|
||||
file_contents = "\n".join(file_contents)
|
||||
assert "post_trainer_create" in file_contents
|
||||
assert "pre_model_load" in file_contents
|
||||
assert "post_model_build" in file_contents
|
||||
assert "pre_lora_load" in file_contents
|
||||
|
||||
@@ -479,7 +479,7 @@ class TestMultiGPULlama:
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"val_set_size": 0.05,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
|
||||
@@ -29,12 +29,12 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
MODEL_CONFIGS = [
|
||||
{
|
||||
"name": "openaccess-ai-collective/tiny-mistral",
|
||||
"name": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
{
|
||||
"name": "Qwen/Qwen2-7B",
|
||||
"name": "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
|
||||
"expected_activation": apply_lora_mlp_swiglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
@@ -44,7 +44,7 @@ MODEL_CONFIGS = [
|
||||
"dtype": torch.float32,
|
||||
},
|
||||
{
|
||||
"name": "mhenrichsen/gemma-2b",
|
||||
"name": "trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
"expected_activation": apply_lora_mlp_geglu,
|
||||
"dtype": torch.float16,
|
||||
},
|
||||
@@ -156,7 +156,9 @@ def test_swiglu_mlp_integration(small_llama_model):
|
||||
def test_geglu_model_integration():
|
||||
"""Test GeGLU activation with Gemma model."""
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
"mhenrichsen/gemma-2b", torch_dtype=torch.float16, device_map="cuda:0"
|
||||
"trl-internal-testing/tiny-Gemma2ForCausalLM",
|
||||
torch_dtype=torch.float16,
|
||||
device_map="cuda:0",
|
||||
)
|
||||
peft_config = get_peft_config(
|
||||
{
|
||||
|
||||
@@ -6,6 +6,8 @@ import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
@@ -23,6 +25,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
Test case for Falcon models
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -71,6 +74,7 @@ class TestFalconPatched(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
@@ -28,7 +28,7 @@ class TestMistral(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
@@ -76,7 +76,7 @@ class TestMistral(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 1024,
|
||||
|
||||
@@ -56,7 +56,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sample_packing": True,
|
||||
"sequence_len": 2048,
|
||||
|
||||
63
tests/e2e/patched/test_peft_embeddings.py
Normal file
63
tests/e2e/patched/test_peft_embeddings.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""
|
||||
Test case for handling embeddings when using peft
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.train import setup_model_and_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
class TestLlamaPeftEmbeddings:
|
||||
"""
|
||||
test class for handling embeddings when using peft
|
||||
"""
|
||||
|
||||
def test_peft_embeddings_upcast(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"load_in_4bit": True,
|
||||
"adapter": "qlora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_target_linear": True,
|
||||
"trust_remote_code": True,
|
||||
"sequence_len": 512,
|
||||
"val_set_size": 0.01,
|
||||
"special_tokens": {
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"max_steps": 2,
|
||||
"micro_batch_size": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_8bit",
|
||||
"lr_scheduler": "cosine",
|
||||
"flash_attention": True,
|
||||
"sample_packing": False,
|
||||
"bf16": "auto",
|
||||
"save_safetensors": True,
|
||||
"embeddings_skip_upcast": True,
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
model, _, _, _ = setup_model_and_tokenizer(cfg)
|
||||
|
||||
# Check if the embeddings are upcast correctly
|
||||
# only embed_tokens is a parameter that may be upcast
|
||||
assert model.base_model.model.model.embed_tokens.weight.dtype == torch.bfloat16
|
||||
assert model.base_model.model.lm_head.weight.dtype == torch.bfloat16
|
||||
@@ -15,7 +15,7 @@ from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists, most_recent_subdir
|
||||
from ..utils import check_model_output_exists, most_recent_subdir, require_torch_2_6_0
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
@@ -26,6 +26,7 @@ class TestResumeLlama:
|
||||
Test case for resuming training of llama models
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_resume_lora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
@@ -62,6 +63,7 @@ class TestResumeLlama:
|
||||
"save_total_limit": 5,
|
||||
"max_steps": 15,
|
||||
"use_tensorboard": True,
|
||||
"save_safetensors": True,
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
|
||||
@@ -19,14 +19,11 @@ class TestE2eEvaluate:
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.02,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
"pad_token": "<|endoftext|>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
|
||||
@@ -6,6 +6,8 @@ import logging
|
||||
import os
|
||||
import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.cli.args import TrainerCliArgs
|
||||
from axolotl.common.datasets import load_datasets
|
||||
from axolotl.train import train
|
||||
@@ -23,6 +25,7 @@ class TestFalcon(unittest.TestCase):
|
||||
Test case for falcon
|
||||
"""
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -74,6 +77,7 @@ class TestFalcon(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora_added_vocab(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -129,6 +133,7 @@ class TestFalcon(unittest.TestCase):
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
@@ -30,7 +30,7 @@ class TestMistral(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
@@ -77,7 +77,7 @@ class TestMistral(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "openaccess-ai-collective/tiny-mistral",
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
"flash_attention": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.02,
|
||||
|
||||
@@ -199,3 +199,50 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@with_temp_dir
|
||||
def test_came_pytorch(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 8,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"special_tokens": {
|
||||
"unk_token": "<unk>",
|
||||
"bos_token": "<s>",
|
||||
"eos_token": "</s>",
|
||||
},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 1,
|
||||
"micro_batch_size": 8,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "came_pytorch",
|
||||
"adam_beta3": 0.9999,
|
||||
"adam_epsilon2": 1e-16,
|
||||
"max_steps": 5,
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
@@ -414,7 +414,6 @@ class TestDatasetPreparation:
|
||||
snapshot_path = snapshot_download(
|
||||
repo_id="mhenrichsen/alpaca_2k_test",
|
||||
repo_type="dataset",
|
||||
local_dir=tmp_ds_path,
|
||||
)
|
||||
shutil.copytree(snapshot_path, tmp_ds_path, dirs_exist_ok=True)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user