Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
efa1209a92 add smoke test training 2024-10-30 15:40:27 -04:00
Wing Lian
67b9e31bbc make sure to set alternate optimizer and set lr and eps from adam 2024-10-30 15:33:37 -04:00
Wing Lian
ad60916323 add soap optimizer support 2024-10-30 15:33:37 -04:00
16 changed files with 640 additions and 299 deletions

View File

@@ -40,7 +40,7 @@ jobs:
cuda_version: 12.4.1 cuda_version: 12.4.1
cudnn_version: "" cudnn_version: ""
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -82,6 +82,13 @@ jobs:
num_gpus: 1 num_gpus: 1
axolotl_extras: mamba-ssm axolotl_extras: mamba-ssm
nightly_build: "true" nightly_build: "true"
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
nightly_build: "true"
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"

View File

@@ -72,52 +72,12 @@ jobs:
run: | run: |
find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \; find "$(pip cache dir)/http-v2" -type f -mtime +14 -exec rm {} \;
docker-e2e-tests-1st:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
timeout-minutes: 90
needs: [pre-commit, pytest]
strategy:
fail-fast: false
matrix:
include:
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Install Python
uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install Modal
run: |
python -m pip install --upgrade pip
pip install modal==0.63.64 jinja2
- name: Update env vars
run: |
echo "BASE_TAG=main-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}" >> $GITHUB_ENV
echo "PYTORCH_VERSION=${{ matrix.pytorch}}" >> $GITHUB_ENV
echo "AXOLOTL_ARGS=${{ matrix.axolotl_args}}" >> $GITHUB_ENV
echo "AXOLOTL_EXTRAS=${{ matrix.axolotl_extras}}" >> $GITHUB_ENV
echo "CUDA=${{ matrix.cuda }}" >> $GITHUB_ENV
echo "N_GPUS=${{ matrix.num_gpus }}" >> $GITHUB_ENV
- name: Run tests job on Modal
run: |
modal run cicd.tests
docker-e2e-tests: docker-e2e-tests:
if: github.repository_owner == 'axolotl-ai-cloud' if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners... # this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal] runs-on: [self-hosted, modal]
timeout-minutes: 90 timeout-minutes: 90
needs: [pre-commit, pytest, docker-e2e-tests-1st] needs: [pre-commit, pytest]
strategy: strategy:
fail-fast: false fail-fast: false
@@ -129,6 +89,18 @@ jobs:
pytorch: 2.3.1 pytorch: 2.3.1
num_gpus: 1 num_gpus: 1
axolotl_extras: mamba-ssm axolotl_extras: mamba-ssm
- cuda: 121
cuda_version: 12.1.1
python_version: "3.11"
pytorch: 2.3.1
num_gpus: 1
axolotl_extras: mamba-ssm
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.4.1
num_gpus: 1
axolotl_extras:
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"

View File

@@ -183,8 +183,6 @@ test_datasets:
# use RL training: 'dpo', 'ipo', 'kto' # use RL training: 'dpo', 'ipo', 'kto'
rl: rl:
# whether to perform weighting if doing DPO training. Boolean.
dpo_use_weighting:
# The name of the chat template to use for training, following values are supported: # The name of the chat template to use for training, following values are supported:
# - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value.

View File

@@ -43,7 +43,7 @@ s3fs>=2024.5.0
gcsfs>=2024.5.0 gcsfs>=2024.5.0
# adlfs # adlfs
trl @ git++https://github.com/huggingface/trl.git@5e90682836969310e16ed8aa711dd429f85863b7 trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -48,7 +48,6 @@ from trl import (
) )
from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length from trl.trainer.utils import RewardDataCollatorWithPadding, pad_to_length
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils import is_comet_available, is_mlflow_available from axolotl.utils import is_comet_available, is_mlflow_available
@@ -436,7 +435,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"] not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"soap",
]
): ):
return super().create_optimizer() return super().create_optimizer()
@@ -479,6 +484,25 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_embedding=loraplus_lr_embedding, loraplus_lr_embedding=loraplus_lr_embedding,
**optimizer_kwargs, **optimizer_kwargs,
) )
elif self.args.alternate_optimizer == "soap":
from axolotl.utils.optimizers.soap import SOAP
optim_args = {
"lr": optimizer_kwargs.pop("lr"),
"eps": optimizer_kwargs.pop("eps"),
}
if self.cfg.optim_args:
optim_args.update(self.cfg.optim_args)
optim_args["betas"] = (
self.args.optim_soap_beta1,
self.args.optim_soap_beta2,
)
self.optimizer = SOAP( # pylint: disable=attribute-defined-outside-init
optimizer_grouped_parameters,
**optim_args,
)
elif self.args.alternate_optimizer == "optimi_adamw": elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW from optimi import AdamW
@@ -1148,12 +1172,6 @@ class TrainerBuilderBase(abc.ABC):
def get_callbacks(self) -> List[TrainerCallback]: def get_callbacks(self) -> List[TrainerCallback]:
callbacks = [] callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_pre_trainer(cfg=self.cfg, model=self.model)
)
if self.cfg.use_wandb: if self.cfg.use_wandb:
callbacks.append( callbacks.append(
SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path) SaveAxolotlConfigtoWandBCallback(self.cfg.axolotl_config_path)
@@ -1180,17 +1198,11 @@ class TrainerBuilderBase(abc.ABC):
return callbacks return callbacks
@abstractmethod
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
""" """
Callbacks added after the trainer is created, usually b/c these need access to the trainer Callbacks added after the trainer is created, usually b/c these need access to the trainer
""" """
callbacks = []
plugin_manager = PluginManager.get_instance()
callbacks.extend(
plugin_manager.add_callbacks_post_trainer(cfg=self.cfg, trainer=trainer)
)
return callbacks
def hook_pre_create_training_args(self, training_arguments_kwargs): def hook_pre_create_training_args(self, training_arguments_kwargs):
# TODO # TODO
@@ -1236,7 +1248,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) callbacks = []
if self.cfg.use_wandb and self.cfg.eval_table_size > 0: if self.cfg.use_wandb and self.cfg.eval_table_size > 0:
LogPredictionCallback = log_prediction_callback_factory( LogPredictionCallback = log_prediction_callback_factory(
trainer, self.tokenizer, "wandb" trainer, self.tokenizer, "wandb"
@@ -1626,10 +1638,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs["max_length"] = self.cfg.sequence_len trainer_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.optimizer in [ if self.cfg.optimizer in [
# pylint: disable=duplicate-code
"optimi_adamw", "optimi_adamw",
"ao_adamw_4bit", "ao_adamw_4bit",
"ao_adamw_8bit", "ao_adamw_8bit",
"ao_adamw_fp8", "ao_adamw_fp8",
"soap",
]: ]:
# Set default so transformers doesn't throw # Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf" training_arguments_kwargs["optim"] = "adamw_hf"
@@ -1804,7 +1818,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) callbacks = []
return callbacks return callbacks
def build_training_arguments(self, total_num_steps): def build_training_arguments(self, total_num_steps):
@@ -1890,18 +1904,17 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
# default to saving each epoch if not defined # default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch" training_args_kwargs["save_strategy"] = "epoch"
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
if self.cfg.rl_beta: if self.cfg.rl_beta:
training_args_kwargs["beta"] = self.cfg.rl_beta training_args_kwargs["beta"] = self.cfg.rl_beta
if self.cfg.orpo_alpha: if self.cfg.orpo_alpha:
# trl does some odd mapping of alpha to beta to reuse the beta parameter ??? # trl does some odd mapping of alpha to beta to reuse the beta parameter ???
training_args_kwargs["beta"] = self.cfg.orpo_alpha training_args_kwargs["beta"] = self.cfg.orpo_alpha
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_cls = AxolotlDPOConfig
if self.cfg.rpo_alpha is not None: if self.cfg.rpo_alpha is not None:
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
training_args_cls = None
if self.cfg.rl == "simpo": if self.cfg.rl == "simpo":
training_args_cls = AxolotlCPOConfig training_args_cls = AxolotlCPOConfig
training_args_kwargs["loss_type"] = "simpo" training_args_kwargs["loss_type"] = "simpo"
@@ -1910,13 +1923,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.cpo_alpha is not None: if self.cfg.cpo_alpha is not None:
training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha
elif self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = AxolotlORPOConfig training_args_cls = AxolotlORPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len: if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
elif self.cfg.rl == "kto": if self.cfg.rl == "kto":
training_args_cls = AxolotlKTOConfig training_args_cls = AxolotlKTOConfig
training_args_kwargs["desirable_weight"] = ( training_args_kwargs["desirable_weight"] = (
@@ -1926,32 +1939,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.kto_undesirable_weight or 1.0 self.cfg.kto_undesirable_weight or 1.0
) )
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
training_args_kwargs["max_length"] = self.cfg.sequence_len training_args_kwargs["max_length"] = self.cfg.sequence_len
if self.cfg.max_prompt_len: if self.cfg.max_prompt_len:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
else:
training_args_cls = AxolotlDPOConfig
training_args_kwargs["max_length"] = self.cfg.sequence_len
training_args_kwargs["max_target_length"] = None
if self.cfg.max_prompt_len is not None:
training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len
if self.cfg.dpo_use_weighting is not None:
training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting
if self.cfg.rl == "ipo":
training_args_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
training_args_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.cfg.precompute_ref_log_probs is not None:
training_args_kwargs["precompute_ref_log_probs"] = self.cfg.precompute_ref_log_probs
training_args_kwargs["generate_during_eval"] = self.cfg.use_wandb
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
output_dir=self.cfg.output_dir, output_dir=self.cfg.output_dir,
per_device_train_batch_size=self.cfg.micro_batch_size, per_device_train_batch_size=self.cfg.micro_batch_size,
@@ -1971,16 +1963,27 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
def build(self, total_num_steps): def build(self, total_num_steps):
training_args = self.build_training_arguments(total_num_steps) training_args = self.build_training_arguments(total_num_steps)
dpo_trainer_kwargs = {} dpo_trainer_kwargs = {}
if self.cfg.rl == "ipo":
dpo_trainer_kwargs["loss_type"] = "ipo"
if self.cfg.dpo_label_smoothing:
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
if self.eval_dataset: if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
dpo_trainer_kwargs["peft_config"] = self.peft_config dpo_trainer_kwargs["peft_config"] = self.peft_config
if self.cfg.precompute_ref_log_probs is not None:
dpo_trainer_kwargs[
"precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo"]: if self.cfg.rl in ["dpo", "ipo"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = AxolotlDPOTrainer
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]
# these aren't used for the ORPO trainer
dpo_trainer_kwargs["max_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["max_target_length"] = None
dpo_trainer_kwargs["max_prompt_length"] = self.cfg.sequence_len
dpo_trainer_kwargs["generate_during_eval"] = self.cfg.use_wandb
elif self.cfg.rl == "orpo": elif self.cfg.rl == "orpo":
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
trainer_cls_args = [self.model] trainer_cls_args = [self.model]
@@ -2024,11 +2027,11 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
""" """
def get_callbacks(self): def get_callbacks(self):
callbacks = super().get_callbacks() callbacks = []
return callbacks return callbacks
def get_post_trainer_create_callbacks(self, trainer): def get_post_trainer_create_callbacks(self, trainer):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) callbacks = []
return callbacks return callbacks
def build(self, total_num_steps): def build(self, total_num_steps):

View File

@@ -18,10 +18,9 @@ Plugins can be used to integrate third-party models, modify the training process
To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods. To create a new plugin, you need to inherit from the BasePlugin class and implement the required methods.
""" """
import collections
import importlib import importlib
import logging import logging
from typing import OrderedDict from typing import List
class BasePlugin: class BasePlugin:
@@ -48,7 +47,7 @@ class BasePlugin:
Initializes the BasePlugin. Initializes the BasePlugin.
""" """
def register(self, cfg): # pylint: disable=unused-argument def register(self, cfg):
""" """
Registers the plugin with the given configuration. Registers the plugin with the given configuration.
@@ -64,7 +63,7 @@ class BasePlugin:
Returns a pydantic model for the plugin's input arguments. Returns a pydantic model for the plugin's input arguments.
""" """
def pre_model_load(self, cfg): # pylint: disable=unused-argument def pre_model_load(self, cfg):
""" """
Performs actions before the model is loaded. Performs actions before the model is loaded.
@@ -75,7 +74,7 @@ class BasePlugin:
None None
""" """
def post_model_load(self, cfg, model): # pylint: disable=unused-argument def post_model_load(self, cfg, model):
""" """
Performs actions after the model is loaded. Performs actions after the model is loaded.
@@ -87,7 +86,7 @@ class BasePlugin:
None None
""" """
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument def pre_lora_load(self, cfg, model):
""" """
Performs actions before LoRA weights are loaded. Performs actions before LoRA weights are loaded.
@@ -99,7 +98,7 @@ class BasePlugin:
None None
""" """
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument def post_lora_load(self, cfg, model):
""" """
Performs actions after LoRA weights are loaded. Performs actions after LoRA weights are loaded.
@@ -111,7 +110,7 @@ class BasePlugin:
None None
""" """
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument def create_optimizer(self, cfg, trainer):
""" """
Creates and returns an optimizer for training. Creates and returns an optimizer for training.
@@ -123,9 +122,7 @@ class BasePlugin:
object: The created optimizer. object: The created optimizer.
""" """
def create_lr_scheduler( def create_lr_scheduler(self, cfg, trainer, optimizer):
self, cfg, trainer, optimizer
): # pylint: disable=unused-argument
""" """
Creates and returns a learning rate scheduler. Creates and returns a learning rate scheduler.
@@ -138,7 +135,7 @@ class BasePlugin:
object: The created learning rate scheduler. object: The created learning rate scheduler.
""" """
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument def add_callbacks_pre_trainer(self, cfg, model):
""" """
Adds callbacks to the trainer before training. Adds callbacks to the trainer before training.
@@ -149,11 +146,8 @@ class BasePlugin:
Returns: Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs List[callable]: A list of callback functions to be added to the TrainingArgs
""" """
return []
def add_callbacks_post_trainer( def add_callbacks_post_trainer(self, cfg, trainer):
self, cfg, trainer
): # pylint: disable=unused-argument
""" """
Adds callbacks to the trainer after training. Adds callbacks to the trainer after training.
@@ -164,9 +158,8 @@ class BasePlugin:
Returns: Returns:
List[callable]: A list of callback functions to be added to the TrainingArgs List[callable]: A list of callback functions to be added to the TrainingArgs
""" """
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument def post_train(self, cfg, model):
""" """
Performs actions after training is complete. Performs actions after training is complete.
@@ -178,7 +171,7 @@ class BasePlugin:
None None
""" """
def post_train_unload(self, cfg): # pylint: disable=unused-argument def post_train_unload(self, cfg):
""" """
Performs actions after training is complete and the model is unloaded. Performs actions after training is complete and the model is unloaded.
@@ -234,7 +227,7 @@ class PluginManager:
pre_model_load(cfg): Calls the pre_model_load method of all registered plugins. pre_model_load(cfg): Calls the pre_model_load method of all registered plugins.
""" """
plugins: OrderedDict[str, BasePlugin] = collections.OrderedDict() plugins: List[BasePlugin] = []
_instance = None _instance = None
@@ -244,7 +237,7 @@ class PluginManager:
""" """
if cls._instance is None: if cls._instance is None:
cls._instance = super(PluginManager, cls).__new__(cls) cls._instance = super(PluginManager, cls).__new__(cls)
cls._instance.plugins = collections.OrderedDict() cls._instance.plugins: List[BasePlugin] = []
return cls._instance return cls._instance
@staticmethod @staticmethod
@@ -272,7 +265,7 @@ class PluginManager:
""" """
try: try:
plugin = load_plugin(plugin_name) plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin self.plugins.append(plugin)
except ImportError: except ImportError:
logging.error(f"Failed to load plugin: {plugin_name}") logging.error(f"Failed to load plugin: {plugin_name}")
@@ -284,7 +277,7 @@ class PluginManager:
list[str]: A list of Pydantic classes for all registered plugins' input arguments.' list[str]: A list of Pydantic classes for all registered plugins' input arguments.'
""" """
input_args = [] input_args = []
for plugin in self.plugins.values(): for plugin in self.plugins:
input_args_from_plugin = plugin.get_input_args() input_args_from_plugin = plugin.get_input_args()
if input_args_from_plugin is not None: if input_args_from_plugin is not None:
input_args.append(input_args_from_plugin) input_args.append(input_args_from_plugin)
@@ -300,7 +293,7 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.pre_model_load(cfg) plugin.pre_model_load(cfg)
def post_model_load(self, cfg, model): def post_model_load(self, cfg, model):
@@ -314,7 +307,7 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.post_model_load(cfg, model) plugin.post_model_load(cfg, model)
def pre_lora_load(self, cfg, model): def pre_lora_load(self, cfg, model):
@@ -328,7 +321,7 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.pre_lora_load(cfg, model) plugin.pre_lora_load(cfg, model)
def post_lora_load(self, cfg, model): def post_lora_load(self, cfg, model):
@@ -342,7 +335,7 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.post_lora_load(cfg, model) plugin.post_lora_load(cfg, model)
def create_optimizer(self, cfg, trainer): def create_optimizer(self, cfg, trainer):
@@ -356,7 +349,7 @@ class PluginManager:
Returns: Returns:
object: The created optimizer, or None if none was found. object: The created optimizer, or None if none was found.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
optimizer = plugin.create_optimizer(cfg, trainer) optimizer = plugin.create_optimizer(cfg, trainer)
if optimizer is not None: if optimizer is not None:
return optimizer return optimizer
@@ -374,7 +367,7 @@ class PluginManager:
Returns: Returns:
object: The created learning rate scheduler, or None if none was found. object: The created learning rate scheduler, or None if none was found.
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer) scheduler = plugin.create_lr_scheduler(cfg, trainer, optimizer)
if scheduler is not None: if scheduler is not None:
return scheduler return scheduler
@@ -392,7 +385,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs. List[callable]: A list of callback functions to be added to the TrainingArgs.
""" """
callbacks = [] callbacks = []
for plugin in self.plugins.values(): for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model)) callbacks.extend(plugin.add_callbacks_pre_trainer(cfg, model))
return callbacks return callbacks
@@ -408,7 +401,7 @@ class PluginManager:
List[callable]: A list of callback functions to be added to the TrainingArgs. List[callable]: A list of callback functions to be added to the TrainingArgs.
""" """
callbacks = [] callbacks = []
for plugin in self.plugins.values(): for plugin in self.plugins:
callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer))
return callbacks return callbacks
@@ -423,5 +416,5 @@ class PluginManager:
Returns: Returns:
None None
""" """
for plugin in self.plugins.values(): for plugin in self.plugins:
plugin.post_train_unload(cfg) plugin.post_train_unload(cfg)

View File

@@ -427,6 +427,7 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit", "ao_adamw_4bit",
"ao_adamw_8bit", "ao_adamw_8bit",
"ao_adamw_fp8", "ao_adamw_fp8",
"soap",
], ],
] ]
] = OptimizerNames.ADAMW_HF.value ] = OptimizerNames.ADAMW_HF.value
@@ -439,6 +440,10 @@ class HyperparametersConfig(BaseModel):
"help": "The target modules to optimize, i.e. the module names that you would like to train." "help": "The target modules to optimize, i.e. the module names that you would like to train."
}, },
) )
optim_soap_beta1: Optional[float] = None
optim_soap_beta2: Optional[float] = None
torchdistx_path: Optional[str] = None torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine" lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
@@ -588,9 +593,6 @@ class AxolotlInputConfig(
rl: Optional[RLType] = None rl: Optional[RLType] = None
reward_model: Optional[bool] = None reward_model: Optional[bool] = None
dpo_use_weighting: Optional[
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore

View File

@@ -2,11 +2,9 @@
import functools import functools
import logging import logging
import time
from pathlib import Path from pathlib import Path
from typing import List, Optional, Tuple, Union from typing import List, Optional, Tuple, Union
import requests
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
@@ -55,28 +53,6 @@ from axolotl.utils.trainer import (
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
def retry_on_request_exceptions(max_retries=3, delay=1):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
) as exc:
if attempt < max_retries - 1:
time.sleep(delay)
else:
raise exc
return wrapper
return decorator
@retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None): def prepare_dataset(cfg, tokenizer, processor=None):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:

View File

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2024 Nikhil Vyas
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,475 @@
# pylint: skip-file
# Copied from https://github.com/nikhilvyas/SOAP
from itertools import chain
import torch
import torch.optim as optim
# Parts of the code are modifications of Pytorch's AdamW optimizer
# Parts of the code are modifications of code from https://github.com/jiaweizzhao/GaLore/blob/master/galore_torch/galore_projector.py
class SOAP(optim.Optimizer):
"""
Implements SOAP algorithm (https://arxiv.org/abs/2409.11321).
Parameters:
params (`Iterable[nn.parameter.Parameter]`):
Iterable of parameters to optimize or dictionaries defining parameter groups.
lr (`float`, *optional*, defaults to 0.003):
The learning rate to use.
betas (`Tuple[float,float]`, *optional*, defaults to `(0.95, 0.95)`):
Adam's betas parameters (b1, b2).
shampoo_beta (`float`, *optional*, defaults to -1):
If >= 0, use this beta for the preconditioner (L and R in paper, state['GG'] below) moving average instead of betas[1].
eps (`float`, *optional*, defaults to 1e-08):
Adam's epsilon for numerical stability.
weight_decay (`float`, *optional*, defaults to 0.01): weight decay coefficient.
precondition_frequency (`int`, *optional*, defaults to 10):
How often to update the preconditioner.
max_precond_dim (`int`, *optional*, defaults to 10000):
Maximum dimension of the preconditioner.
Set to 10000, so that we exclude most common vocab sizes while including layers.
merge_dims (`bool`, *optional*, defaults to `False`):
Whether or not to merge dimensions of the preconditioner.
precondition_1d (`bool`, *optional*, defaults to `False`):
Whether or not to precondition 1D gradients.
normalize_grads (`bool`, *optional*, defaults to `False`):
Whether or not to normalize gradients per layer.
Helps at large precondition_frequency (~100 in our experiments),
but hurts performance at small precondition_frequency (~10 in our experiments).
data_format (`str`, *optional*, defaults to `channels_first`):
Data format of the input for convolutional layers.
Should be "channels_last" for data_format of NHWC and "channels_first" for NCHW.
correct_bias (`bool`, *optional*, defaults to `True`):
Whether or not to use bias correction in Adam.
"""
def __init__(
self,
params,
lr: float = 3e-3,
betas=(0.95, 0.95),
shampoo_beta: float = -1,
eps: float = 1e-8,
weight_decay: float = 0.01,
precondition_frequency: int = 10,
max_precond_dim: int = 10000, #
merge_dims: bool = False, # Merge dimensions till the product of the dimensions is less than or equal to max_precond_dim.
precondition_1d: bool = False,
normalize_grads: bool = False,
data_format: str = "channels_first",
correct_bias: bool = True,
):
defaults = {
"lr": lr,
"betas": betas,
"shampoo_beta": shampoo_beta,
"eps": eps,
"weight_decay": weight_decay,
"precondition_frequency": precondition_frequency,
"max_precond_dim": max_precond_dim,
"merge_dims": merge_dims,
"precondition_1d": precondition_1d,
"normalize_grads": normalize_grads,
"correct_bias": correct_bias,
}
super().__init__(params, defaults)
self._data_format = data_format
def merge_dims(self, grad, max_precond_dim):
"""
Merges dimensions of the gradient tensor till the product of the dimensions is less than or equal to max_precond_dim.
"""
assert self._data_format in ["channels_first", "channels_last"]
if self._data_format == "channels_last" and grad.dim() == 4:
grad = grad.permute(0, 3, 1, 2)
shape = grad.shape
new_shape = []
curr_shape = 1
for sh in shape:
temp_shape = curr_shape * sh
if temp_shape > max_precond_dim:
if curr_shape > 1:
new_shape.append(curr_shape)
curr_shape = sh
else:
new_shape.append(sh)
curr_shape = 1
else:
curr_shape = temp_shape
if curr_shape > 1 or len(new_shape) == 0:
new_shape.append(curr_shape)
new_grad = grad.reshape(new_shape)
return new_grad
@torch.no_grad()
def step(self):
"""
Performs a single optimization step.
Arguments:
closure (`Callable`, *optional*): A closure that reevaluates the model and returns the loss.
"""
loss = None
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad
state = self.state[p]
if "step" not in state:
state["step"] = 0
# State initialization
if "exp_avg" not in state:
# Exponential moving average of gradient values
state["exp_avg"] = torch.zeros_like(grad)
# Exponential moving average of squared gradient values
state["exp_avg_sq"] = torch.zeros_like(grad)
if "Q" not in state:
self.init_preconditioner(
grad,
state,
precondition_frequency=group["precondition_frequency"],
precondition_1d=group["precondition_1d"],
shampoo_beta=(
group["shampoo_beta"]
if group["shampoo_beta"] >= 0
else group["betas"][1]
),
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
)
self.update_preconditioner(
grad,
state,
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
precondition_1d=group["precondition_1d"],
)
continue # first step is skipped so that we never use the current gradients in the projection.
# Projecting gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
grad_projected = self.project(
grad,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"]
beta1, beta2 = group["betas"]
state["step"] += 1
# Decay the first and second moment running average coefficient
# In-place operations to update the averages at the same time
exp_avg.mul_(beta1).add_(grad, alpha=(1.0 - beta1))
exp_avg_sq.mul_(beta2).add_(
grad_projected.square(), alpha=(1.0 - beta2)
)
denom = exp_avg_sq.sqrt().add_(group["eps"])
# Projecting the exponential moving average of gradients to the eigenbases of Shampoo's preconditioner
# i.e. projecting to the eigenbases of matrices in state['GG']
exp_avg_projected = self.project(
exp_avg,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
step_size = group["lr"]
if group["correct_bias"]:
bias_correction1 = 1.0 - beta1 ** (state["step"])
bias_correction2 = 1.0 - beta2 ** (state["step"])
step_size = step_size * (bias_correction2**0.5) / bias_correction1
# Projecting back the preconditioned (by Adam) exponential moving average of gradients
# to the original space
norm_grad = self.project_back(
exp_avg_projected / denom,
state,
merge_dims=group["merge_dims"],
max_precond_dim=group["max_precond_dim"],
)
if group["normalize_grads"]:
norm_grad = norm_grad / (1e-30 + torch.mean(norm_grad**2) ** 0.5)
p.add_(norm_grad, alpha=-step_size)
# From AdamW code: Just adding the square of the weights to the loss function is *not*
# the correct way of using L2 regularization/weight decay with Adam,
# since that will interact with the m and v parameters in strange ways.
#
# Instead we want to decay the weights in a manner that doesn't interact
# with the m/v parameters. This is equivalent to adding the square
# of the weights to the loss with plain (non-momentum) SGD.
# Add weight decay at the end (fixed version)
if group["weight_decay"] > 0.0:
p.add_(p, alpha=(-group["lr"] * group["weight_decay"]))
# Update is done after the gradient step to avoid using current gradients in the projection.
self.update_preconditioner(
grad,
state,
max_precond_dim=group["max_precond_dim"],
merge_dims=group["merge_dims"],
precondition_1d=group["precondition_1d"],
)
return loss
def init_preconditioner(
self,
grad,
state,
precondition_frequency=10,
shampoo_beta=0.95,
max_precond_dim=10000,
precondition_1d=False,
merge_dims=False,
):
"""
Initializes the preconditioner matrices (L and R in the paper).
"""
state[
"GG"
] = [] # Will hold all the preconditioner matrices (L and R in the paper).
if grad.dim() == 1:
if not precondition_1d or grad.shape[0] > max_precond_dim:
state["GG"].append([])
else:
state["GG"].append(
torch.zeros(grad.shape[0], grad.shape[0], device=grad.device)
)
else:
if merge_dims:
grad = self.merge_dims(grad, max_precond_dim)
for sh in grad.shape:
if sh > max_precond_dim:
state["GG"].append([])
else:
state["GG"].append(torch.zeros(sh, sh, device=grad.device))
state["Q"] = None # Will hold all the eigenbases of the preconditioner.
state["precondition_frequency"] = precondition_frequency
state["shampoo_beta"] = shampoo_beta
def project(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient to the eigenbases of the preconditioner.
"""
original_shape = grad.shape
if merge_dims:
if grad.dim() == 4 and self._data_format == "channels_last":
permuted_shape = grad.permute(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state["Q"]:
if len(mat) > 0:
grad = torch.tensordot(
grad,
mat,
dims=[[0], [0]],
)
else:
permute_order = list(range(1, len(grad.shape))) + [0]
grad = grad.permute(permute_order)
if merge_dims:
if self._data_format == "channels_last" and len(original_shape) == 4:
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def update_preconditioner(
self,
grad,
state,
max_precond_dim=10000,
merge_dims=False,
precondition_1d=False,
):
"""
Updates the preconditioner matrices and the eigenbases (L, R, Q_L, Q_R in the paper).
"""
if grad.dim() == 1:
if precondition_1d and grad.shape[0] <= max_precond_dim:
state["GG"][0].lerp_(
grad.unsqueeze(1) @ grad.unsqueeze(0), 1 - state["shampoo_beta"]
)
else:
if merge_dims:
new_grad = self.merge_dims(grad, max_precond_dim)
for idx, sh in enumerate(new_grad.shape):
if sh <= max_precond_dim:
outer_product = torch.tensordot(
new_grad,
new_grad,
dims=[
[
*chain(
range(idx), range(idx + 1, len(new_grad.shape))
)
]
]
* 2,
)
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
else:
for idx, sh in enumerate(grad.shape):
if sh <= max_precond_dim:
outer_product = torch.tensordot(
grad,
grad,
# Contracts across all dimensions except for k.
dims=[[*chain(range(idx), range(idx + 1, len(grad.shape)))]]
* 2,
)
state["GG"][idx].lerp_(outer_product, 1 - state["shampoo_beta"])
if state["Q"] is None:
state["Q"] = self.get_orthogonal_matrix(state["GG"])
if state["step"] > 0 and state["step"] % state["precondition_frequency"] == 0:
state["Q"] = self.get_orthogonal_matrix_QR(
state, max_precond_dim, merge_dims
)
def project_back(self, grad, state, merge_dims=False, max_precond_dim=10000):
"""
Projects the gradient back to the original space.
"""
original_shape = grad.shape
if merge_dims:
if self._data_format == "channels_last" and grad.dim() == 4:
permuted_shape = grad.permute(0, 3, 1, 2).shape
grad = self.merge_dims(grad, max_precond_dim)
for mat in state["Q"]:
if len(mat) > 0:
grad = torch.tensordot(
grad,
mat,
dims=[[0], [1]],
)
else:
permute_order = list(range(1, len(grad.shape))) + [0]
grad = grad.permute(permute_order)
if merge_dims:
if self._data_format == "channels_last" and len(original_shape) == 4:
grad = grad.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
grad = grad.reshape(original_shape)
return grad
def get_orthogonal_matrix(self, mat):
"""
Computes the eigenbases of the preconditioner using torch.linalg.eigh decomposition.
"""
matrix = []
for m in mat:
if len(m) == 0:
matrix.append([])
continue
if m.data.dtype != torch.float:
float_data = False
original_type = m.data.dtype
original_device = m.data.device
matrix.append(m.data.float())
else:
float_data = True
matrix.append(m.data)
final = []
for m in matrix:
if len(m) == 0:
final.append([])
continue
try:
_, Q = torch.linalg.eigh(
m + 1e-30 * torch.eye(m.shape[0], device=m.device)
)
except: # pylint: disable=bare-except # noqa: E722
_, Q = torch.linalg.eigh(
m.to(torch.float64) + 1e-30 * torch.eye(m.shape[0], device=m.device)
)
Q = Q.to(m.dtype)
Q = torch.flip(Q, [1])
if not float_data:
Q = Q.to(original_device).type(original_type)
final.append(Q)
return final
def get_orthogonal_matrix_QR(self, state, max_precond_dim=10000, merge_dims=False):
"""
Computes the eigenbases of the preconditioner using one round of power iteration
followed by torch.linalg.qr decomposition.
"""
precond_list = state["GG"]
orth_list = state["Q"]
matrix = []
orth_matrix = []
for m, o in zip(precond_list, orth_list):
if len(m) == 0:
matrix.append([])
orth_matrix.append([])
continue
if m.data.dtype != torch.float:
float_data = False
original_type = m.data.dtype
original_device = m.data.device
matrix.append(m.data.float())
orth_matrix.append(o.data.float())
else:
float_data = True
matrix.append(m.data.float())
orth_matrix.append(o.data.float())
orig_shape = state["exp_avg_sq"].shape
if self._data_format == "channels_last" and len(orig_shape) == 4:
permuted_shape = state["exp_avg_sq"].permute(0, 3, 1, 2).shape
if merge_dims:
exp_avg_sq = self.merge_dims(state["exp_avg_sq"], max_precond_dim)
else:
exp_avg_sq = state["exp_avg_sq"]
final = []
for ind, (m, o) in enumerate(zip(matrix, orth_matrix)):
if len(m) == 0:
final.append([])
continue
est_eig = torch.diag(o.T @ m @ o)
sort_idx = torch.argsort(est_eig, descending=True)
exp_avg_sq = exp_avg_sq.index_select(ind, sort_idx)
o = o[:, sort_idx]
power_iter = m @ o
Q, _ = torch.linalg.qr(power_iter)
if not float_data:
Q = Q.to(original_device).type(original_type)
final.append(Q)
if merge_dims:
if self._data_format == "channels_last" and len(orig_shape) == 4:
exp_avg_sq = exp_avg_sq.reshape(permuted_shape).permute(0, 2, 3, 1)
else:
exp_avg_sq = exp_avg_sq.reshape(orig_shape)
state["exp_avg_sq"] = exp_avg_sq
return final

View File

@@ -1,59 +0,0 @@
base_model: JackFram/llama-68m
load_in_8bit: false
load_in_4bit: false
strict: false
datasets:
- path: arcee-ai/distilabel-intel-orca-dpo-pairs-binarized
type: chatml.ultra
split: train
dataset_prepared_path: last_run_prepared
val_set_size: 0.1
output_dir: ./outputs/out
sequence_len: 2048
pad_to_sequence_len: true
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 1
num_epochs: 1
optimizer: paged_adamw_8bit
lr_scheduler: cosine
learning_rate: 2e-5
train_on_inputs: false
group_by_length: false
bf16: auto
fp16:
tf32: false
gradient_checkpointing: true
gradient_checkpointing_kwargs:
use_reentrant: false
early_stopping_patience:
resume_from_checkpoint:
logging_steps: 1
xformers_attention:
flash_attention: true
rl: dpo
dpo_use_weighting: true
warmup_steps: 10
evals_per_epoch: 2
eval_table_size:
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
pad_token: <|end_of_text|>

View File

@@ -1,43 +0,0 @@
base_model: JackFram/llama-68m
load_in_8bit: true
datasets:
- path: arcee-ai/distilabel-intel-orca-dpo-pairs-binarized
type: chatml.ultra
split: train
output_dir: ./outputs/lora-out
sequence_len: 1024
adapter: lora
lora_r: 64
lora_alpha: 32
lora_dropout: 0.1
lora_target_linear: true
rl: dpo
dpo_use_weighting: true
wandb_project: check_dpotrainer
wandb_entity: axolotl-ai
wandb_watch:
wandb_name: baseline/dpo_base/dpo_use_weighting
wandb_log_model:
num_epochs: 1
micro_batch_size: 4
gradient_accumulation_steps: 1
learning_rate: 0.00001
optimizer: paged_adamw_8bit
lr_scheduler: cosine
max_steps": 20
save_steps: 10
warmup_steps: 5
gradient_checkpointing: True
gradient_checkpointing_kwargs:
use_reentrant: false
#special_tokens:
# pad_token: <|end_of_text|>

View File

@@ -115,51 +115,6 @@ class TestDPOLlamaLora(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@with_temp_dir
def test_dpo_use_weighting(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 64,
"lora_alpha": 32,
"lora_dropout": 0.1,
"lora_target_linear": True,
"special_tokens": {},
"rl": "dpo",
"dpo_use_weighting": True,
"datasets": [
{
"path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized",
"type": "chatml.ultra",
"split": "train",
},
],
"num_epochs": 1,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "paged_adamw_8bit",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"warmup_steps": 5,
"gradient_checkpointing": True,
"gradient_checkpointing_kwargs": {"use_reentrant": True},
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists()
@pytest.mark.skip("kto_pair no longer supported in trl") @pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir @with_temp_dir
def test_kto_pair_lora(self, temp_dir): def test_kto_pair_lora(self, temp_dir):

View File

@@ -65,3 +65,44 @@ class TestCustomOptimizers(unittest.TestCase):
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists() assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_soap(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM-135M",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 8,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"special_tokens": {
"pad_token": "<|endoftext|>",
},
"datasets": [
{
"path": "vicgalle/alpaca-gpt4",
"type": "alpaca",
},
],
"num_epochs": 1,
"micro_batch_size": 8,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "soap",
"optim_soap_beta1": 0.95,
"optim_soap_beta2": 0.95,
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()