Compare commits

...

22 Commits

Author SHA1 Message Date
sunny
432b17eee1 test 2024-11-07 11:20:32 -05:00
sunny
58cca816f8 trl version requirement 2024-11-06 10:01:05 -05:00
sunny
28e134e6a8 commenting out 2024-11-05 14:57:35 -05:00
sunny
39af2a41a5 linting 2024-11-05 12:46:05 -05:00
sunny
41d10278bf test 2024-11-05 12:38:33 -05:00
sunny
d9b65f69fb test 2024-11-05 12:35:36 -05:00
sunny
bcb1205e39 test 2024-11-05 12:30:45 -05:00
sunny
04b532bd37 test 2024-11-05 12:20:00 -05:00
sunny
8ac149e317 test 2024-11-05 12:03:06 -05:00
sunny
98d819d3f7 trl 2024-11-05 11:59:10 -05:00
sunny
9da9916ff2 trl 2024-11-05 11:57:26 -05:00
sunny
027ccdab4d update trl version requirements 2024-11-05 11:53:49 -05:00
sunny
7a00dbc367 trlv0.12.0 integration 2024-11-05 11:44:46 -05:00
Wing Lian
052a9a79b4 only run the remainder of the gpu test suite if one case passes first (#2009) [skip ci]
* only run the remainder of the gpu test suite if one case passes first

* also reduce the test matrix
2024-10-31 13:45:01 -04:00
Wing Lian
3591bcfaf9 add torch 2.5.1 for base image (#2010) 2024-10-31 13:27:49 -04:00
Wing Lian
dc1de7d81b add retries for load datasets requests failures (#2007) 2024-10-31 13:26:14 -04:00
Chirag Jain
d4dbfa02fe Add plugin manager's callback hooks to training flow (#2006)
* Add plugin manager's callback hooks to training flow

* Use .values() instead of .items()
2024-10-31 12:13:46 -04:00
NanoCode012
5c7e89105d Fix: modelloader handling of model_kwargs load_in*bit (#1999)
* fix: load_in_*bit not properly read

* fix: load_*bit check

* fix: typo

* refactor: load * bit handling

* feat: add test dpo lora multi-gpu

* fix: turn off sample packing for dpo

* fix: missing warmup_steps

* fix: test to load in 8bit for lora

* skip 8bit lora on h100, add 4bit lora on h100 to multi gpu tests

* chore: reduce max_steps

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
2024-10-30 14:41:34 -04:00
Chirag Jain
74db2a1bae Fix get_chat_template call for trainer builder (#2003) 2024-10-30 14:27:00 -04:00
Geun, Lim
e62554c419 feat: add Exaone3 chat_template (#1995) 2024-10-30 12:30:12 -04:00
Wing Lian
32c60765ef remove skipped test (#2002)
* remove skipped test

* use mean_resizing_embeddings with qlora and added tokens

* use </s> as pad_token to prevent resize of embeddings

* make sure local hub test saves to a tmp dir

* use Path so concatenation works

* make sure to use tmp_ds_path for data files
2024-10-30 12:27:04 -04:00
NanoCode012
8c3a727f9d feat: update yml chat_template to specify dataset field (#2001) [skip ci]
* feat: update yml chat_template to specify dataset field

* feat: replace sharegpt references with chat_template
2024-10-29 10:26:03 -04:00
26 changed files with 474 additions and 149 deletions

View File

@@ -40,7 +40,7 @@ jobs:
cuda_version: 12.4.1
cudnn_version: ""
python_version: "3.11"
pytorch: 2.5.0
pytorch: 2.5.1
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps:
- name: Checkout

View File

@@ -82,13 +82,6 @@ jobs:
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -72,13 +72,53 @@ jobs:
run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests:
docker-e2e-tests-1st:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st]
strategy:
fail-fast: false
matrix:
@@ -89,18 +129,6 @@ jobs:
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"

View File

@@ -7,8 +7,8 @@ load_in_8bit: true
load_in_4bit: false
datasets:
- path: philschmid/guanaco-sharegpt-style
type: sharegpt
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
shards: 10
val_set_size: 0
output_dir: temp_debug/axolotl_outputs/model

View File

@@ -183,6 +183,8 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto'
rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.

View File

@@ -51,12 +51,12 @@ While debugging it's helpful to simplify your test scenario as much as possible.
### Background
The below example shows how to configure VSCode to debug data preprocessing of the `sharegpt` format. This is the format used when you have the following in your axolotl config:
The below example shows how to configure VSCode to debug data preprocessing of the `chat_template` format. This is the format used when you have the following in your axolotl config:
```yaml
datasets:
- path: <path to your sharegpt formatted dataset> # example on HF Hub: philschmid/guanaco-sharegpt-style
type: sharegpt
- path: <path to your chat_template formatted dataset> # example on HF Hub: fozziethebeat/alpaca_messages_2k_test
type: chat_template
```
>[!Important]
@@ -83,7 +83,7 @@ If you developing on a remote host, you can easily use VSCode to debug remotely.
The easiest way to get started is to modify the [.vscode/launch.json](../.vscode/launch.json) file in this project. This is just an example configuration, so you may need to modify or copy it to suit your needs.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_sharegpt.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 accelerate launch -m axolotl.cli.train dev_chat_template.yml`, you would use the below configuration[^1]. Note that we add additional flags that override the axolotl config and incorporate the tips above (see the comments). We also set the working directory to `devtools` and set the `env` variable `HF_HOME` to a temporary folder that is later partially deleted. This is because we want to delete the HF dataset cache before each run in order to ensure that the data preprocessing code is run from scratch.
```jsonc
// .vscode/launch.json
@@ -91,12 +91,12 @@ For example, to mimic the command `cd devtools && CUDA_VISIBLE_DEVICES=0 acceler
"version": "0.2.0",
"configurations": [
{
"name": "Debug axolotl prompt - sharegpt",
"name": "Debug axolotl prompt - chat_template",
"type": "python",
"module": "accelerate.commands.launch",
"request": "launch",
"args": [
"-m", "axolotl.cli.train", "dev_sharegpt.yml",
"-m", "axolotl.cli.train", "dev_chat_template.yml",
// The flags below simplify debugging by overriding the axolotl config
// with the debugging tips above. Modify as needed.
"--dataset_processes=1", // limits data preprocessing to one process
@@ -240,6 +240,6 @@ style="border-radius: 10px; display: block; margin: auto;" width="560" height="3
</div>
<br>
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/sharegpt.yml`, but this is the same thing.
[^1]: The config actually mimics the command `CUDA_VISIBLE_DEVICES=0 python -m accelerate.commands.launch -m axolotl.cli.train devtools/chat_template.yml`, but this is the same thing.
[^2]: Many of the below flags are recommended best practices by Nvidia when using nvidia-container-toolkit. You can read more about these flags [here](https://docs.nvidia.com/deeplearning/frameworks/user-guide/index.html).

View File

@@ -16,7 +16,10 @@ chat_template: deepseek_v2
datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0

View File

@@ -11,8 +11,11 @@ chat_template: gemma
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: gemma
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
val_set_size: 0.0
output_dir: ./outputs/out

View File

@@ -4,11 +4,15 @@ tokenizer_type: AutoTokenizer
load_in_4bit: true
strict: false
use_tensorboard: true
chat_template: jamba
datasets:
- path: cgato/SlimOrcaDedupCleaned
type: chat_template
chat_template: jamba
drop_system_message: true
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.0
output_dir: jamba-large-fsdp-qlora-ft

View File

@@ -14,6 +14,10 @@ datasets:
- path: mlabonne/FineTome-100k
type: chat_template
split: train[:20%]
field_messages: conversations
message_field_role: from
message_field_content: value
dataset_prepared_path: last_run_prepared
val_set_size: 0.02
output_dir: ./outputs/out

View File

@@ -10,7 +10,6 @@ chat_template: phi_3
datasets:
- path: fozziethebeat/alpaca_messages_2k_test
type: chat_template
chat_template: phi_3
field_messages: messages
message_field_role: role
message_field_content: content

View File

@@ -43,7 +43,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0
# adlfs
trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
trl==0.12.0
zstandard==0.22.0
fastcore

View File

@@ -272,7 +272,7 @@ def do_inference_gradio(
importlib.import_module("axolotl.prompters"), prompter
)
elif cfg.chat_template:
chat_template_str = get_chat_template(cfg.chat_template)
chat_template_str = get_chat_template(cfg.chat_template, tokenizer=tokenizer)
model = model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -48,6 +48,7 @@ from trl import (
)
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available
@@ -1147,6 +1148,12 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]:
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb:
callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -1173,11 +1180,17 @@ class TrainerBuilderBase(abc.ABC):
return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer):
"""
Callbacks added after the trainer is created, usually b/c these need access to the trainer
"""
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO
@@ -1223,7 +1236,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb"
@@ -1595,7 +1608,8 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["pretraining"] = bool(self.cfg.pretraining_dataset)
if self.cfg.chat_template:
training_arguments_kwargs["chat_template"] = get_chat_template(
self.cfg.chat_template
self.cfg.chat_template,
tokenizer=self.tokenizer,
)
if self.cfg.rl == "orpo":
@@ -1790,7 +1804,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build_training_arguments(self, total_num_steps):
@@ -1876,17 +1890,18 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo"
@@ -1895,13 +1910,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
if self.cfg.rl == "orpo":
elif self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
if self.cfg.rl == "kto":
elif self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = (
@@ -1916,6 +1931,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_cls = AxolotlDPOConfig
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1999,11 +2019,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
"""
def get_callbacks(self):
callbacks = []
callbacks = super().get_callbacks()
return callbacks
def get_post_trainer_create_callbacks(self, trainer):
callbacks = []
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks
def build(self, total_num_steps):

View File

@@ -18,9 +18,10 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
"""
import collections
import importlib
import logging
from typing import List
from typing import OrderedDict
class BasePlugin:
@@ -47,7 +48,7 @@ class BasePlugin:
Initializes the BasePlugin.
"""
def register(self, cfg):
def register(self, cfg): # pylint: disable=unused-argument
"""
Registers the plugin with the given configuration.
@@ -63,7 +64,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments.
"""
def pre_model_load(self, cfg):
def pre_model_load(self, cfg): # pylint: disable=unused-argument
"""
Performs actions before the model is loaded.
@@ -74,7 +75,7 @@ class BasePlugin:
None
"""
def post_model_load(self, cfg, model):
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after the model is loaded.
@@ -86,7 +87,7 @@ class BasePlugin:
None
"""
def pre_lora_load(self, cfg, model):
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions before LoRA weights are loaded.
@@ -98,7 +99,7 @@ class BasePlugin:
None
"""
def post_lora_load(self, cfg, model):
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after LoRA weights are loaded.
@@ -110,7 +111,7 @@ class BasePlugin:
None
"""
def create_optimizer(self, cfg, trainer):
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
"""
Creates and returns an optimizer for training.
@@ -122,7 +123,9 @@ class BasePlugin:
object: The created optimizer.
"""
def create_lr_scheduler(self, cfg, trainer, optimizer):
def create_lr_scheduler(
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
"""
Creates and returns a learning rate scheduler.
@@ -135,7 +138,7 @@ class BasePlugin:
object: The created learning rate scheduler.
"""
def add_callbacks_pre_trainer(self, cfg, model):
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer before training.
@@ -146,8 +149,11 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def add_callbacks_post_trainer(self, cfg, trainer):
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
"""
Adds callbacks to the trainer after training.
@@ -158,8 +164,9 @@ class BasePlugin:
Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs
"""
return []
def post_train(self, cfg, model):
def post_train(self, cfg, model): # pylint: disable=unused-argument
"""
Performs actions after training is complete.
@@ -171,7 +178,7 @@ class BasePlugin:
None
"""
def post_train_unload(self, cfg):
def post_train_unload(self, cfg): # pylint: disable=unused-argument
"""
Performs actions after training is complete and the model is unloaded.
@@ -227,7 +234,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
"""
plugins: List[BasePlugin] = []
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict()
_instance = None
@@ -237,7 +244,7 @@ class PluginManager:
"""
if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins: List[BasePlugin] = []
cls._instance.plugins = collections.OrderedDict()
return cls._instance
@staticmethod
@@ -265,7 +272,7 @@ class PluginManager:
"""
try:
plugin = load_plugin(plugin_name)
self.plugins.append(plugin)
self.plugins[plugin_name] = plugin
except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}")
@@ -277,7 +284,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
"""
input_args = []
for plugin in self.plugins:
for plugin in self.plugins.values():
input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin)
@@ -293,7 +300,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model):
@@ -307,7 +314,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model):
@@ -321,7 +328,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model):
@@ -335,7 +342,7 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer):
@@ -349,7 +356,7 @@ class PluginManager:
Returns:
object: The created optimizer, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None:
return optimizer
@@ -367,7 +374,7 @@ class PluginManager:
Returns:
object: The created learning rate scheduler, or None if none was found.
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None:
return scheduler
@@ -385,7 +392,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks
@@ -401,7 +408,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs.
"""
callbacks = []
for plugin in self.plugins:
for plugin in self.plugins.values():
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks
@@ -416,5 +423,5 @@ class PluginManager:
Returns:
None
"""
for plugin in self.plugins:
for plugin in self.plugins.values():
plugin.post_train_unload(cfg)

File diff suppressed because one or more lines are too long

View File

@@ -57,6 +57,7 @@ class ChatTemplate(str, Enum):
jinja = "jinja" # pylint: disable=invalid-name
qwen_25 = "qwen_25" # pylint: disable=invalid-name
tokenizer_default = "tokenizer_default" # pylint: disable=invalid-name
exaone = "exaone" # pylint: disable=invalid-name
class DeprecatedParameters(BaseModel):
@@ -587,6 +588,9 @@ class AxolotlInputConfig(
rl: Optional[RLType] = None
reward_model: Optional[bool] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore

View File

@@ -2,9 +2,11 @@
import functools
import logging
import time
from pathlib import Path
from typing import List, Optional, Tuple, Union
import requests
from datasets import (
Dataset,
DatasetDict,
@@ -53,6 +55,28 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None):
prompters = []
if not cfg.pretraining_dataset:

View File

@@ -640,9 +640,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**self.model_config.quantization_config
)
elif self.cfg.adapter == "qlora" and (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
):
elif self.cfg.adapter == "qlora" and self.model_kwargs["load_in_4bit"]:
bnb_config = {
"load_in_4bit": True,
"llm_int8_threshold": 6.0,
@@ -665,9 +663,7 @@ class ModelLoader:
self.model_kwargs["quantization_config"] = BitsAndBytesConfig(
**bnb_config,
)
elif self.cfg.adapter == "lora" and (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
):
elif self.cfg.adapter == "lora" and self.model_kwargs["load_in_8bit"]:
bnb_config = {
"load_in_8bit": True,
}
@@ -680,10 +676,8 @@ class ModelLoader:
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in self.model_kwargs or self.cfg.gptq:
if "load_in_8bit" in self.model_kwargs:
del self.model_kwargs["load_in_8bit"]
if "load_in_4bit" in self.model_kwargs:
del self.model_kwargs["load_in_4bit"]
self.model_kwargs.pop("load_in_8bit", None)
self.model_kwargs.pop("load_in_4bit", None)
def set_attention_config(self) -> None:
"""
@@ -968,17 +962,10 @@ class ModelLoader:
if is_deepspeed_zero3_enabled():
skip_prepare_model_for_kbit_training = True
is_load_in_8bit = (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
)
is_load_in_4bit = (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
)
if (
not skip_prepare_model_for_kbit_training
and self.cfg.adapter in ["lora", "qlora"]
and (is_load_in_8bit or is_load_in_4bit)
and (self.cfg.load_in_8bit or self.cfg.load_in_4bit)
):
LOG.info("converting PEFT model w/ prepare_model_for_kbit_training")
self.model = prepare_model_for_kbit_training(
@@ -1116,16 +1103,10 @@ class ModelLoader:
# ---------------------------------------------------------
# put model to accelerator
# ---------------------------------------------------------
is_load_in_8bit = (
"load_in_8bit" in self.model_kwargs and self.model_kwargs["load_in_8bit"]
)
is_load_in_4bit = (
"load_in_4bit" in self.model_kwargs and self.model_kwargs["load_in_4bit"]
)
if (
self.cfg.ddp
and not is_load_in_8bit
and not (self.cfg.rl and is_load_in_4bit)
and not self.cfg.load_in_8bit
and not (self.cfg.rl and self.cfg.load_in_4bit)
and not skip_move_to_device
):
# TODO revaldate this conditional

59
test.yml Normal file
View File

@@ -0,0 +1,59 @@
base_model: JackFram/llama-68m
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: tatsu-lab/alpaca
type: alpaca
dataset_prepared_path: last_run_prepared
val_set_size: 0.2
output_dir: ./outputs/out
sequence_len: 2048
sample_packing: true
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
rl: dpo
dpo_use_weighting: true
warmup_steps: 10
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -14,7 +14,7 @@ from huggingface_hub import snapshot_download
from axolotl.utils.dict import DictDefault
from ..utils import with_temp_dir
from ..utils import is_hopper, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e.multigpu")
os.environ["WANDB_DISABLED"] = "true"
@@ -59,7 +59,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -116,7 +116,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 50,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -144,6 +144,146 @@ class TestMultiGPULlama(unittest.TestCase):
]
)
@pytest.mark.skipif(is_hopper(), reason="h100 doesn't support 8-bit lora")
@with_temp_dir
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "TinyLlama/TinyLlama_v1.1",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"unk_token": "<unk>",
"bos_token": "<s>",
"eos_token": "</s>",
},
"rl": "dpo",
"chat_template": "llama3",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 2048,
"sample_packing": False,
"eval_sample_packing": False,
"pad_to_sequence_len": True,
"load_in_4bit": True,
"adapter": "qlora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"rl": "dpo",
"chat_template": "chatml",
"datasets": [
{
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
"type": "chat_template.default",
"field_messages": "conversation",
"field_chosen": "chosen",
"field_rejected": "rejected",
"message_field_role": "role",
"message_field_content": "content",
"roles": {
"system": ["system"],
"user": ["user"],
"assistant": ["assistant"],
},
},
],
"num_epochs": 1,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
"warmup_steps": 0,
"learning_rate": 0.00001,
"optimizer": "adamw_8bit",
"lr_scheduler": "cosine",
"flash_attention": True,
}
)
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
execute_subprocess_async(
[
"accelerate",
"launch",
"--num-processes",
"2",
"-m",
"axolotl.cli.train",
str(Path(temp_dir) / "config.yaml"),
]
)
@with_temp_dir
def test_fsdp(self, temp_dir):
# pylint: disable=duplicate-code
@@ -165,7 +305,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -231,7 +371,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -273,7 +413,6 @@ class TestMultiGPULlama(unittest.TestCase):
]
)
@pytest.mark.skip("disabled due to upstream issue")
@with_temp_dir
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
@@ -282,6 +421,7 @@ class TestMultiGPULlama(unittest.TestCase):
"base_model": "axolotl-ai-co/TinyLlama_v1.1-bnb-nf4-bf16",
"tokenizer_type": "AutoTokenizer",
"adapter": "qlora",
"mean_resizing_embeddings": True,
"load_in_4bit": True,
"lora_r": 8,
"lora_alpha": 16,
@@ -297,7 +437,7 @@ class TestMultiGPULlama(unittest.TestCase):
"sequence_len": 2048,
"val_set_size": 0.05,
"special_tokens": {
"pad_token": "<|end_of_text|>",
"pad_token": "</s>",
},
"datasets": [
{
@@ -307,7 +447,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -373,7 +513,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,
@@ -432,7 +572,7 @@ class TestMultiGPULlama(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"micro_batch_size": 4,
"gradient_accumulation_steps": 4,
"output_dir": temp_dir,

View File

@@ -47,7 +47,7 @@ class TestMultiGPUQwen2(unittest.TestCase):
},
],
"num_epochs": 1,
"max_steps": 100,
"max_steps": 15,
"warmup_steps": 20,
"micro_batch_size": 4,
"gradient_accumulation_steps": 2,

View File

@@ -13,7 +13,7 @@ from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_1_1, with_temp_dir
from ..utils import require_torch_2_3_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
@@ -24,7 +24,7 @@ class Test4dMultipackLlama(unittest.TestCase):
Test case for Llama models using 4d attention with multipack
"""
@require_torch_2_1_1
@require_torch_2_3_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code

View File

@@ -115,6 +115,51 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_dpo_use_weighting(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "dpo",
"dpo_use_weighting": True,
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
def test_kto_pair_lora(self, temp_dir):

View File

@@ -9,6 +9,8 @@ from functools import wraps
from importlib.metadata import version
from pathlib import Path
import torch
def with_temp_dir(test_func):
@wraps(test_func)
@@ -35,13 +37,18 @@ def most_recent_subdir(path):
return subdir
def require_torch_2_1_1(test_case):
def require_torch_2_3_1(test_case):
"""
Decorator marking a test that requires torch >= 2.1.1
Decorator marking a test that requires torch >= 2.3.1
"""
def is_min_2_1_1():
def is_min_2_3_1():
torch_version = version("torch")
return torch_version >= "2.1.1"
return torch_version >= "2.3.1"
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
return unittest.skipUnless(is_min_2_3_1(), "test torch 2.3.1")(test_case)
def is_hopper():
compute_capability = torch.cuda.get_device_capability()
return compute_capability == (9, 0)

View File

@@ -367,43 +367,44 @@ class TestDatasetPreparation(unittest.TestCase):
def test_load_local_hub_with_revision(self):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path("mhenrichsen/alpaca_2k_test")
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
with tempfile.TemporaryDirectory() as tmp_dir2:
tmp_ds_path = Path(tmp_dir2) / "mhenrichsen/alpaca_2k_test"
tmp_ds_path.mkdir(parents=True, exist_ok=True)
snapshot_download(
repo_id="mhenrichsen/alpaca_2k_test",
repo_type="dataset",
local_dir=tmp_ds_path,
revision="d05c1cb",
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
"mhenrichsen/alpaca_2k_test/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
prepared_path = Path(tmp_dir) / "prepared"
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"sequence_len": 1024,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"ds_type": "parquet",
"type": "alpaca",
"data_files": [
f"{tmp_ds_path}/alpaca_2000.parquet",
],
"revision": "d05c1cb",
},
],
}
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
dataset, _ = load_tokenized_prepared_datasets(
self.tokenizer, cfg, prepared_path
)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
if __name__ == "__main__":