diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index e2ad1f579..d6d19607f 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -149,6 +149,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlTrainer def build(self, total_num_steps): + from axolotl.core.training_args import ( + AxolotlPRMConfig, + AxolotlRewardConfig, + AxolotlTrainingArguments, + ) + training_arguments_kwargs, trainer_kwargs = self._set_base_training_args( total_num_steps ) @@ -317,12 +323,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["image_resize_algorithm"] = ( self.cfg.image_resize_algorithm ) - if self.cfg.kd_ce_alpha is not None: - training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha - if self.cfg.kd_alpha is not None: - training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha - if self.cfg.kd_temperature is not None: - training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature + + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + plugin_training_args = plugin_manager.get_training_args(self.cfg) + if plugin_training_args: + training_arguments_kwargs.update(plugin_training_args) if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig @@ -403,7 +409,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer def build_collator( - self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs + self, + training_args, # type: "AxolotlTrainingArguments" + is_eval=False, + **kwargs, ): if training_args.pretraining: if ( diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 14dbfa715..e7591131e 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -12,11 +12,6 @@ from axolotl.core.trainers import ( from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.grpo import GRPOStrategy -from axolotl.core.training_args import ( - AxolotlCPOConfig, - AxolotlKTOConfig, - AxolotlORPOConfig, -) from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import ensure_dtype from axolotl.utils.logging import get_logger @@ -79,6 +74,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase): """ Returns training_args and trainer_kwargs """ + from axolotl.core.training_args import ( + AxolotlCPOConfig, + AxolotlKTOConfig, + AxolotlORPOConfig, + ) + training_args_kwargs, trainer_kwargs = self._set_base_training_args( total_num_steps=total_num_steps ) @@ -165,6 +166,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if blocklist_key in training_args_kwargs: del training_args_kwargs[blocklist_key] + + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + plugin_training_args = plugin_manager.get_training_args(self.cfg) + if plugin_training_args: + training_args_kwargs.update(plugin_training_args) + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg logging_first_step=True, **training_args_kwargs, diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 03cad93e7..d5be9fc62 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -2,224 +2,17 @@ extra axolotl specific training args """ -from dataclasses import dataclass, field -from typing import Optional +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Type -from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig +from axolotl.integrations.config import merge_training_args -@dataclass -class AxolotlTrainingMixins: - """ - Mixin class for the Axolotl training args. - """ - - # pylint: disable=duplicate-code - model_type: Optional[str] = field( - default=None, metadata={"help": "HF model configuration model_type."} - ) - lr_quadratic_warmup: bool = field( - default=False, - metadata={"help": "Use quadratic warmup for cosine scheduling."}, - ) - pretraining: bool = field( - default=False, - metadata={ - "help": "Indicates to trainer whether we are doing continued pretraining." - }, - ) - sample_packing: bool = field( - default=False, - metadata={"help": "Use sample packing for efficient training."}, - ) - sample_packing_sequentially: bool = field( - default=False, - metadata={ - "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." - }, - ) - multipack_real_batches: bool = field( - default=False, - metadata={"help": "Use real batches for efficient training."}, - ) - eval_sample_packing: Optional[bool] = field( - default=None, - metadata={"help": "Use sample packing for efficient evals."}, - ) - sample_packing_efficiency: float = field( - default=1.0, - metadata={"help": "Sample packing efficiency for calculating batch length."}, - ) - sample_packing_bin_size: int = field( - default=200, - metadata={ - "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." - }, - ) - sample_packing_group_size: int = field( - default=100000, - metadata={ - "help": "The number of samples to group together for packing. Increase for better packing." - }, - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "The maximum sequence length the model can handle"}, - ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_anneal_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_prune_ratio: Optional[float] = field( - default=0.9, - metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, - ) - bench_split: Optional[str] = field( - default="eval", metadata={"help": "The benchmark split to run on"} - ) - bench_dataset: Optional[str] = field( - default="pharaouk/dharma-1/dharma_1_mini.json", - metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" - }, - ) - do_bench_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Benchmark evaluation."} - ) - do_causal_lm_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Causal LM evaluation."} - ) - max_bench_samples: Optional[int] = field( - default=None, - metadata={ - "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." - }, - ) - bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for bench."} - ) - dataloader_prefetch_factor: Optional[int] = field( - default=None, - metadata={"help": "prefetch_factor argument to the dataloader"}, - ) - cosine_min_lr_ratio: Optional[float] = field( - default=None, - metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, - ) - cosine_constant_lr_ratio: Optional[float] = field( - default=None, - metadata={ - "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" - }, - ) - loraplus_lr_ratio: Optional[float] = field( - default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} - ) - loraplus_lr_embedding: Optional[float] = field( - default=1e-6, - metadata={"help": "loraplus learning rate for lora embedding layers."}, - ) - embedding_lr_scale: Optional[float] = field( - default=None, - metadata={"help": "Scale the learning rate for the embedding layers."}, - ) - lr_groups: Optional[list[dict]] = field( - default=None, - metadata={"help": "Specify learning rate groups for with different LRs."}, - ) - embedding_lr: Optional[float] = field( - default=None, - metadata={"help": "absolute learning rate for the embedding layers."}, - ) - qlora: bool = field( - default=False, - metadata={"help": "whether this is a qlora training"}, - ) - orpo_alpha: Optional[float] = field( - default=None, - ) - lisa_n_layers: Optional[int] = field( - default=None, - metadata={"help": "the number of activate layers in LISA"}, - ) - lisa_step_interval: Optional[int] = field( - default=None, - metadata={"help": "how often to switch layers in LISA"}, - ) - lisa_layers_attribute: Optional[str] = field( - default=None, - metadata={"help": "path under the model to access the layers"}, - ) - curriculum_sampling: Optional[bool] = field( - default=None, - metadata={"help": "whether to use sequential sampling for curriculum learning"}, - ) - alternate_lr_scheduler_type: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate lr scheduler to the HF trainer" - }, - ) - chat_template: Optional[str] = field( - default=None, - metadata={"help": "Chat template converting chat messages to text"}, - ) - - kd_ce_alpha: Optional[float] = field( - default=None, - metadata={ - "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" - }, - ) - - kd_alpha: Optional[float] = field( - default=1.0, - metadata={"help": "The alpha scaling parameter for KD loss"}, - ) - - kd_temperature: Optional[float] = field( - default=1.0, - metadata={ - "help": "the temperature parameter for KL divergence loss when using KD" - }, - ) - - adam_beta3: Optional[float] = field( - default=None, - metadata={ - "help": "The beta3 hyperparameter used in some optimizers such as CAME" - }, - ) - adam_epsilon2: Optional[float] = field( - default=None, - metadata={ - "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" - }, - ) - - # multi-modal section - - image_size: int | tuple[int, int] | None = field( - default=None, - metadata={"help": "The size of the image to resize to"}, - ) - - image_resize_algorithm: Resampling | None = field( - default=None, - metadata={"help": "The algorithm to use for image resizing"}, - ) - - # end of multi-modal section +AxolotlTrainingMixins: Type = merge_training_args() @dataclass diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py new file mode 100644 index 000000000..fbd387e5f --- /dev/null +++ b/src/axolotl/core/training_args_base.py @@ -0,0 +1,222 @@ +""" +Base Axolotl Training Mixins shared across various trainer configs +""" + +from dataclasses import dataclass, field +from typing import Optional + +from PIL.Image import Resampling + +from axolotl.utils.schemas.enums import RingAttnFunc + + +@dataclass +class AxolotlTrainingMixins: + """ + Mixin class for the Axolotl training args. + """ + + # pylint: disable=duplicate-code + model_type: Optional[str] = field( + default=None, metadata={"help": "HF model configuration model_type."} + ) + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + sample_packing_sequentially: bool = field( + default=False, + metadata={ + "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." + }, + ) + multipack_real_batches: bool = field( + default=False, + metadata={"help": "Use real batches for efficient training."}, + ) + eval_sample_packing: Optional[bool] = field( + default=None, + metadata={"help": "Use sample packing for efficient evals."}, + ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) + sample_packing_bin_size: int = field( + default=200, + metadata={ + "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." + }, + ) + sample_packing_group_size: int = field( + default=100000, + metadata={ + "help": "The number of samples to group together for packing. Increase for better packing." + }, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) + relora_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for ReLoRA"}, + ) + relora_warmup_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_anneal_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_prune_ratio: Optional[float] = field( + default=0.9, + metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, + ) + bench_split: Optional[str] = field( + default="eval", metadata={"help": "The benchmark split to run on"} + ) + bench_dataset: Optional[str] = field( + default="pharaouk/dharma-1/dharma_1_mini.json", + metadata={ + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" + }, + ) + do_bench_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Benchmark evaluation."} + ) + do_causal_lm_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Causal LM evaluation."} + ) + max_bench_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." + }, + ) + bench_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for bench."} + ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={"help": "prefetch_factor argument to the dataloader"}, + ) + cosine_min_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, + ) + cosine_constant_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} + ) + loraplus_lr_embedding: Optional[float] = field( + default=1e-6, + metadata={"help": "loraplus learning rate for lora embedding layers."}, + ) + embedding_lr_scale: Optional[float] = field( + default=None, + metadata={"help": "Scale the learning rate for the embedding layers."}, + ) + lr_groups: Optional[list[dict]] = field( + default=None, + metadata={"help": "Specify learning rate groups for with different LRs."}, + ) + embedding_lr: Optional[float] = field( + default=None, + metadata={"help": "absolute learning rate for the embedding layers."}, + ) + qlora: bool = field( + default=False, + metadata={"help": "whether this is a qlora training"}, + ) + orpo_alpha: Optional[float] = field( + default=None, + ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) + alternate_lr_scheduler_type: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate lr scheduler to the HF trainer" + }, + ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) + + # kd_ce_alpha: Optional[float] = field( + # default=None, + # metadata={ + # "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" + # }, + # ) + # + # kd_alpha: Optional[float] = field( + # default=1.0, + # metadata={"help": "The alpha scaling parameter for KD loss"}, + # ) + # + # kd_temperature: Optional[float] = field( + # default=1.0, + # metadata={ + # "help": "the temperature parameter for KL divergence loss when using KD" + # }, + # ) + + adam_beta3: Optional[float] = field( + default=None, + metadata={ + "help": "The beta3 hyperparameter used in some optimizers such as CAME" + }, + ) + adam_epsilon2: Optional[float] = field( + default=None, + metadata={ + "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" + }, + ) + + # multi-modal section + + image_size: int | tuple[int, int] | None = field( + default=None, + metadata={"help": "The size of the image to resize to"}, + ) + + image_resize_algorithm: Resampling | None = field( + default=None, + metadata={"help": "The algorithm to use for image resizing"}, + ) + + # end of multi-modal section diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index f89dc5049..001436152 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -22,6 +22,8 @@ from __future__ import annotations import collections import importlib +import logging +import traceback from typing import TYPE_CHECKING, Callable, OrderedDict, Union from peft import PeftModel @@ -83,6 +85,11 @@ class BasePlugin: def get_input_args(self) -> str | None: """Returns a pydantic model for the plugin's input arguments.""" + def get_training_args_mixin(self) -> str | None: + """ + Returns a dataclass model for the plugin's training arguments. + """ + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -158,6 +165,32 @@ class BasePlugin: trainer: The trainer object for training. """ + + def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument): + """ + Returns custom training arguments to set on TrainingArgs. + + Args: + cfg: The global axolotl configuration. + + Returns: + object: dict containing the training arguments. + """ + + def get_collator_cls_and_kwargs( + self, cfg: DictDefault, is_eval: bool=False + ): # pylint: disable=unused-argument): + """ + Returns a custom class for the collator. + + Args: + cfg: The global axolotl configuration. + is_eval: Whether this is an eval split. + + Returns: + class: The class for the collator. + """ + # pylint: disable=unused-argument def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None: """Creates and returns an optimizer for training. @@ -167,84 +200,7 @@ class BasePlugin: trainer: The trainer object for training. Returns: -<<<<<<< HEAD The created optimizer. -======= - None - """ - - def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions before LoRA weights are loaded. - - Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. - - Returns: - None - """ - - def post_lora_load(self, cfg, model): # pylint: disable=unused-argument - """ - Performs actions after LoRA weights are loaded. - - Args: - cfg (dict): The configuration for the plugin. - model (object): The loaded model. - - Returns: - None - """ - - def get_trainer_cls(self, cfg): # pylint: disable=unused-argument): - """ - Returns a custom class for the trainer. - - Args: - cfg (dict): The global axolotl configuration. - - Returns: - class: The class for the trainer. - """ - - def get_collator_cls_and_kwargs( - self, cfg, is_eval=False - ): # pylint: disable=unused-argument): - """ - Returns a custom class for the collator. - - Args: - cfg (dict): The global axolotl configuration. - is_eval (bool): Whether this is an eval split. - - Returns: - class: The class for the collator. - """ - - def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument - """ - Performs actions after the trainer is created. - - Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - - Returns: - None - """ - - def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument - """ - Creates and returns an optimizer for training. - - Args: - cfg (dict): The configuration for the plugin. - trainer (object): The trainer object for training. - - Returns: - object: The created optimizer. ->>>>>>> f8df1563d (collator cls for plugins) """ # pylint: disable=unused-argument @@ -355,7 +311,7 @@ def load_plugin(plugin_name: str) -> BasePlugin: return plugin -class PluginManager: +class PluginManager: # pylint: disable=too-many-public-methods """The `PluginManager` class is responsible for loading and managing plugins. It should be a singleton so it can be accessed from anywhere in the codebase. @@ -414,8 +370,11 @@ class PluginManager: plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin LOG.info(f"Plugin loaded successfully: {plugin_name}") - except ImportError: + except ImportError as exc: LOG.error(f"Failed to load plugin: {plugin_name}") + # print stacktrace + traceback.print_exc() + print(f"Error: {exc}") def get_input_args(self) -> list[str]: """Returns a list of Pydantic classes for all registered plugins' input arguments.' @@ -430,6 +389,21 @@ class PluginManager: input_args.append(input_args_from_plugin) return input_args + def get_training_args_mixin(self): + """ + Returns a list of dataclasses for all registered plugins' training args mixins' + + Returns: + list[str]: A list of dataclsses + """ + training_args = [] + for plugin in self.plugins.values(): + training_args_from_plugin = plugin.get_training_args_mixin() + print(f"Training args from plugin: {plugin.__class__.__name__}") + if training_args_from_plugin is not None: + training_args.append(training_args_from_plugin) + return training_args + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -519,6 +493,24 @@ class PluginManager: return trainer_cls return None + def get_training_args(self, cfg): + """ + Calls the get_training_args method of all registered plugins and returns the combined training arguments. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + object: The training arguments + """ + training_args_kwargs = {} + for plugin in self.plugins.values(): + training_args = plugin.get_training_args(cfg) + if training_args is not None: + training_args_kwargs.update(training_args) + + return training_args_kwargs + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): """ Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class. @@ -531,9 +523,7 @@ class PluginManager: object: The collator class, or None if none was found. """ for plugin in self.plugins.values(): - collator = plugin.get_collator_cls_and_kwargs( - cfg, is_eval=is_eval - ) + collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval) if collator is not None: collator_cls, collator_kwargs = collator return collator_cls, collator_kwargs diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b443f228e..f5fc07e9e 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio This was moved here to prevent circular imports. """ -from typing import Any, Dict, List +from typing import Any, Dict, List, Type from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, @@ -61,3 +61,43 @@ def merge_input_args(): ] return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase + + +def merge_training_args() -> Type: + """ + Merges training arguments from registered plugins with the base TrainingArguments. + + This function retrieves the training arguments from registered plugins using the PluginManager. + It then dynamically creates new classes, AxolotlTrainingMixins, + that inherit from the base configurations and include the training arguments from the plugins. + + Returns: + tuple: A tuple containing the newly created classes, AxolotlTrainingMixins. + """ + # pylint: disable=duplicate-code + from axolotl.core.training_args_base import ( + AxolotlTrainingMixins as AxolotlTrainingMixinsBase, + ) + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + training_args_mixins: List[str] = plugin_manager.get_training_args_mixin() + mixin_classes = [] + dynamic_input = "" + for plugin_args in training_args_mixins: + plugin_module, plugin_cls = plugin_args.rsplit(".", 1) + dynamic_input += f"from {plugin_module} import {plugin_cls}\n" + mixin_classes.append(plugin_cls) + if dynamic_input: + dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n" + + namespace: Dict[Any, Any] = {} + local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase} + exec( # pylint: disable=exec-used # nosec B102 + dynamic_input, {**globals(), **local_vars}, namespace + ) + AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name + "AxolotlTrainingMixins" + ] + return AxolotlTrainingMixins + return AxolotlTrainingMixinsBase diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index e648bcd25..b53c62f8a 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -15,7 +15,12 @@ """ Plugin init to add KD support to Axolotl. """ +from typing import Any + +from transformers import Trainer + from axolotl.integrations.base import BasePlugin +from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback from .args import KDArgs # pylint: disable=unused-import. # noqa: F401 @@ -28,6 +33,9 @@ class KDPlugin(BasePlugin): def get_input_args(self): return "axolotl.integrations.kd.KDArgs" + def get_training_args_mixin(self): + return "axolotl.integrations.kd.args.KDTrainingArgsMixin" + def get_trainer_cls(self, cfg): if cfg.kd_trainer: from .trainer import AxolotlKDTrainer @@ -35,6 +43,14 @@ class KDPlugin(BasePlugin): return AxolotlKDTrainer return None + def get_training_args(self, cfg): + return { + "kd_ce_alpha": cfg.kd_ce_alpha, + "kd_alpha": cfg.kd_alpha, + "kd_temperature": cfg.kd_temperature, + "kd_beta": cfg.kd_beta, + } + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): if not cfg.kd_trainer: return None, None @@ -66,3 +82,24 @@ class KDPlugin(BasePlugin): from .kernels.models import apply_kernel apply_kernel(cfg.model_config_type) + + def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: + """ + Adds temp scheduler callback to the Trainer instance. + + Args: + cfg (Any): Configuration object containing the sparse recipe. + trainer (Trainer): Huggingface Trainer instance. + + Returns: + list: List containing the configured callback instances. + """ + if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url: + callback = KDTemperatureSchedulerCallback( + cfg.kd_temperature, + cfg.kd_temperature_min, + trainer, + ) + return [callback] + + return [] diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 5c97e7bdd..8b6d6b6f5 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -15,14 +15,19 @@ """ Plugin args for KD support. """ +from dataclasses import dataclass from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field class InferenceServerType(str, Enum): - vllm = "vllm" - sglang = "sglang" + """ + Online inferences server types to handle different request args + """ + + vllm = "vllm" # pylint: disable=invalid-name + sglang = "sglang" # pylint: disable=invalid-name class KDArgs(BaseModel): @@ -36,9 +41,29 @@ class KDArgs(BaseModel): ) kd_alpha: float | None = None # loss coefficient for KD loss kd_temperature: float | None = None # temperature for sampling during KD + kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL # TODO online kd kd_online_server_base_url: str | None = None kd_online_topk: int | None = None - kd_online_server: InferenceServerType | None = "vllm" + kd_online_server: InferenceServerType | None = Field( + default_factory=lambda: InferenceServerType.vllm + ) kd_online_timeout: int | None = 120 + kd_temperature_min: float | None = ( + None # kd temperature scheduling during online kd + ) + + +@dataclass +class KDTrainingArgsMixin: + """ + Additional args for KD training. + """ + + kd_ce_alpha: float | None = ( + None # loss coefficient for cross-entropy loss during KD + ) + kd_alpha: float | None = None # loss coefficient for KD loss + kd_temperature: float | None = None # temperature for sampling during KD + kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL diff --git a/src/axolotl/integrations/kd/callbacks.py b/src/axolotl/integrations/kd/callbacks.py new file mode 100644 index 000000000..b8a806f69 --- /dev/null +++ b/src/axolotl/integrations/kd/callbacks.py @@ -0,0 +1,36 @@ +""" +Transformers trainer callbacks to schedule the KD temperature during training +""" + +import math + +from transformers.trainer_callback import CallbackHandler + + +class KDTemperatureSchedulerCallback(CallbackHandler): + """ + KD temperature scheduler callback for the trainer. + """ + + def __init__(self, temperature_start, temperature_min, trainer): + self.temperature_start = temperature_start + self.temperature_min = temperature_min + self.temperature = temperature_start + + self.trainer = trainer + + def on_step_end( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + # cosine decay temperature over the max steps + + progress = state.global_step / state.max_steps + # Cosine decay factor: 0.5 * (1 + cos(pi * progress)) + # This factor goes from 1 (at progress=0) to 0 (at progress=1) + decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + self.temperature = self.temperature_start - ( + (self.temperature_start - self.temperature_min) * decay_factor + ) + + if hasattr(self.trainer.data_collator, "kd_temperature"): + self.trainer.data_collator.kd_temperature = self.temperature diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py index 39f6d16fb..6a6a4e2ee 100644 --- a/src/axolotl/integrations/kd/collator_online_teacher.py +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -2,12 +2,14 @@ Packed data loader for online teacher training supporting vllm and sglang. """ +import hashlib +import hmac import logging from typing import Any, Dict, List, Optional -import pandas as pd import requests import torch +from orjson import orjson from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq from axolotl.utils.data.utils import retry_on_request_exceptions @@ -15,6 +17,31 @@ from axolotl.utils.data.utils import retry_on_request_exceptions LOG = logging.getLogger(__name__) +def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256): + """ + Create HMAC-SHA hash from a list of integers + + Args: + int_list: List of integers + key: Secret key (string or bytes) + hash_func: Hash function (default: sha256) + + Returns: + HMAC digest as hex string + """ + # Convert key to bytes if it's a string + if isinstance(key, str): + key = key.encode("utf-8") + + # Convert list of ints to bytes + # Method 1: Convert each int to bytes and concatenate + data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list) + + # Create HMAC + h = hmac.new(key, data, hash_func) + return h.hexdigest() + + class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): """ Collator for online teacher training. @@ -30,6 +57,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): kd_temperature: Optional[float] = 1.0, kd_online_server: Optional[str] = "vllm", kd_online_timeout: Optional[int] = 120, + kd_cache_dir: Optional[str] = None, **kwargs: Any, ): super().__init__(*args, **kwargs) @@ -49,6 +77,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): self.kd_online_server = kd_online_server self.http_session = requests.Session() self.kd_online_timeout = kd_online_timeout + self.kd_cache_dir = kd_cache_dir def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: """ @@ -109,7 +138,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): return final_logprobs_tensor.tolist() - except Exception as e: + except Exception as e: # pylint: disable=broad-exception-caught LOG.error( f"Error during online logprob scaling: {e}. Returning raw logprobs.", exc_info=True, @@ -142,11 +171,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): } # Initialize with empty lists, so if API call fails, these are returned. - ret_logprobs_data = { - "target_token_ids": [], - "target_logprobs": [], - "target_mask": [], - } + ret_data_target_token_ids: List[List[List[int]]] = [] + ret_data_target_logprobs: List[List[List[float]]] = [] + ret_data_target_mask: List[List[List[int]]] = [] try: response = self.http_session.post( @@ -162,7 +189,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." ) # Return empty data; items processed later will get default empty KD fields - return ret_logprobs_data + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } for sequence_data, seq_input_ids, seq_labels in zip( api_data, batch_input_ids, labels @@ -185,7 +216,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # basic check that the logprob data len matches the input len, so no need to handle padding assert len(seq_input_ids) == len(input_top_logprobs) - for i, input_id, label in zip( + for i, _, label in zip( range(len(seq_input_ids)), seq_input_ids, seq_labels ): if i < len(input_top_logprobs) and input_top_logprobs[i] is None: @@ -254,9 +285,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): current_target_token_ids.append([0] * self.kd_online_topk) current_target_mask.append([0] * self.kd_online_topk) - ret_logprobs_data["target_token_ids"].append(current_target_token_ids) - ret_logprobs_data["target_logprobs"].append(current_target_logprobs) - ret_logprobs_data["target_mask"].append(current_target_mask) + ret_data_target_token_ids.append(current_target_token_ids) + ret_data_target_logprobs.append(current_target_logprobs) + ret_data_target_mask.append(current_target_mask) except requests.exceptions.RequestException as e: LOG.error(f"Error fetching logprobs from online teacher: {e}") @@ -269,7 +300,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): ) raise e - return ret_logprobs_data + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } @retry_on_request_exceptions(max_retries=10, delay=5) def fetch_online_logprobs_vllm( @@ -296,18 +331,19 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): } # Initialize with empty lists, so if API call fails, these are returned. - ret_logprobs_data = { - "target_token_ids": [], - "target_logprobs": [], - "target_mask": [], - } + ret_data_target_token_ids: List[List[List[int]]] = [] + ret_data_target_logprobs: List[List[List[float]]] = [] + ret_data_target_mask: List[List[List[int]]] = [] try: response = self.http_session.post( - api_endpoint, json=payload, timeout=self.kd_online_timeout + api_endpoint, + json=payload, + timeout=self.kd_online_timeout, + # json_decoder=orjson.loads, ) response.raise_for_status() - api_data: dict = response.json() + api_data: dict = orjson.loads(response.content) choices: list[dict] = api_data["choices"] # Ensure api_data is a list, and its length matches batch_input_ids @@ -317,7 +353,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." ) # Return empty data; items processed later will get default empty KD fields - return ret_logprobs_data + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } for sequence_data, seq_input_ids, seq_labels in zip( choices, batch_input_ids, labels @@ -330,29 +370,10 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): current_target_mask = [] # Ensure input_top_logprobs is a list - input_top_logprobs: Optional[list[None | list[tuple]]] = ( + input_top_logprobs: Optional[list[None | dict[str, dict]]] = ( sequence_data.pop("prompt_logprobs", []) ) - """ - vllm api data for prompt logprobs looks like: - "prompt_logprobs": [ - null, # first token is always null - { # second token logprobs - "8948": { # token ID - "logprob": -2.3841830625315197e-06, - "rank": 1, - "decoded_token": "system" - }, - "1849": { # token ID - "logprob": -13.187501907348633, - "rank": 2, - "decoded_token": "Ġsystem" - }, - ... # rest of the top-k tokens/logprobs - }, - ... # more tokens - } - """ + if not isinstance(input_top_logprobs, list): LOG.warning( f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." @@ -369,11 +390,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): # this is always the case for the first token. # there is never logprob data for the first token since that's a true input continue - elif ( + if ( i < len(input_top_logprobs) and input_top_logprobs[i] is not None ): - pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] + pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment] # Ensure pos_top_logprobs_data is a list of lists as expected if not ( isinstance(pos_top_logprobs_data, dict) @@ -396,9 +417,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): continue # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids - pos_token_ids = pos_top_logprobs_data.keys() + pos_token_ids_str = list(pos_top_logprobs_data.keys()) pos_logprobs_dict = pos_top_logprobs_data.values() - pos_token_ids = [int(token_id) for token_id in pos_token_ids] + pos_token_ids = [ + int(token_id) for token_id in pos_token_ids_str + ] pos_logprobs_raw = [ float(logprob.get("logprob", -float("inf"))) for logprob in pos_logprobs_dict @@ -446,17 +469,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): current_target_token_ids.append(list(range(self.kd_online_topk))) current_target_mask.append([0] * self.kd_online_topk) - ret_logprobs_data["target_token_ids"].append(current_target_token_ids) - ret_logprobs_data["target_logprobs"].append(current_target_logprobs) - ret_logprobs_data["target_mask"].append(current_target_mask) + ret_data_target_token_ids.append(current_target_token_ids) + ret_data_target_logprobs.append(current_target_logprobs) + ret_data_target_mask.append(current_target_mask) # TODO save and load targets to disk for caching for next epoch - # generate a hash over seq_input_ids and convert it to an int - # hash_input_ids: int = hash(tuple(seq_input_ids)) - # with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f: - # pd.DataFrame(current_target_logprobs).to_parquet(f, index=False) - # with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f: - # pd.DataFrame(current_target_token_ids).to_parquet(f, index=False) + # generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int + # if self.kd_cache_dir: + # hash_input_ids = hmac_sha_from_int_list( + # seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}" + # ) + # with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f: + # pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False) except requests.exceptions.RequestException as e: LOG.error(f"Error fetching logprobs from online teacher: {e}") @@ -469,7 +493,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): ) raise e - return ret_logprobs_data + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } def __call__( self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py index f30904d5a..0050ffe33 100644 --- a/src/axolotl/integrations/kd/kernels/liger.py +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -20,6 +20,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k] target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs target_mask_chunk: torch.Tensor, # [chunk_size, top_k] + beta: float = 0.0, ) -> torch.Tensor: """ Compute Top-K KL divergence loss for a chunk. @@ -28,6 +29,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). + beta: Controls the type of KL divergence. + 0.0 for Forward KL (P_teacher || P_student). + 1.0 for Reverse KL (P_student || P_teacher). + 0.5 for Symmetric KL (average of Forward and Reverse). Returns: Sum of KL divergence losses for the chunk. """ @@ -59,6 +64,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): # Teacher probabilities P(y|x_teacher) from logprobs # target_logprobs_valid are already normalized (log(softmax(teacher_logits/T))) teacher_probs_valid = target_logprobs_valid.exp() + # Student probabilities P_student from log P_student + student_probs_topk_valid = student_logprobs_topk_valid.exp() + + kd_loss_per_token = torch.zeros_like(target_logprobs_valid) # KL divergence: sum(P_teacher * (log P_teacher - log P_student)) # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student) @@ -66,9 +75,17 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student)) # Here, target_logprobs_valid are log_softmax_teacher. # student_logprobs_topk_valid are log_softmax_student (for the selected K indices). - kd_loss_per_token = teacher_probs_valid * ( - target_logprobs_valid - student_logprobs_topk_valid - ) + if beta < 1.0: # Contribution from Forward KL + fwd_kl_per_token = teacher_probs_valid * ( + target_logprobs_valid - student_logprobs_topk_valid + ) + kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token + if beta > 0.0: # Contribution from Reverse KL + rev_kl_per_token = student_probs_topk_valid * ( + student_logprobs_topk_valid - target_logprobs_valid + ) + kd_loss_per_token += beta * rev_kl_per_token + kd_loss = kd_loss_per_token.sum() return kd_loss @@ -91,6 +108,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): weight_soft_loss: float = 0.5, compute_ce_loss: bool = True, temperature: float = 1.0, + beta: float = 0.0, ): # Compute student logits for the chunk from hidden states and LM head # student_input_chunk: [chunk_size, hidden_dim] @@ -125,10 +143,9 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): target_token_ids_chunk, target_logprobs_chunk, target_mask_chunk, + beta=beta, ) - loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss - return soft_loss, ce_loss @classmethod @@ -146,6 +163,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): weight_soft_loss: float = 0.5, ignore_index: int = -100, temperature: float = 1.0, + beta: float = 0.0, compiled: bool = False, chunk_size: int = 1024, compute_ce_loss: bool = True, @@ -192,6 +210,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): weight_soft_loss=weight_soft_loss, compute_ce_loss=compute_ce_loss, temperature=temperature, + beta=beta, ) def accumulate_chunk_grads( @@ -288,13 +307,13 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) # For matching None returns in backward for non-tensor/non-grad_requiring inputs - ctx.hyperparams_count = 7 # Corresponds to number of hyperparams after main tensors in fwd signature + ctx.hyperparams_count = 8 # Corresponds to number of hyperparams after main tensors in fwd signature ctx.bias_was_none = student_lm_head_bias is None ctx.orig_dims = (B, N, D, K) - # since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulatedsum - # we still need to scale the kd_loss by the temp - kd_loss_acc = kd_loss_acc * (temperature ** 2) + # since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum + # we still need to scale the kd_loss by the temp^2 + kd_loss_acc = kd_loss_acc * (temperature**2) final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc return final_loss @@ -373,6 +392,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): weight_hard_loss: float = 0.5, weight_soft_loss: float = 0.5, temperature: float = 1.0, # This is the kd_temperature + beta: float = 1.0, ignore_index: int = -100, compiled: bool = True, chunk_size: int = 1024, @@ -387,6 +407,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): self.weight_hard_loss = weight_hard_loss self.weight_soft_loss = weight_soft_loss self.temperature = temperature + self.beta = beta self.ignore_index = ignore_index self.compiled = compiled self.chunk_size = chunk_size @@ -424,6 +445,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): self.weight_soft_loss, self.ignore_index, self.temperature, + self.beta, self.compiled, self.chunk_size, self.compute_ce_loss, diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py index f346ad21a..5a7c286bc 100644 --- a/src/axolotl/integrations/kd/kernels/models.py +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -75,8 +75,8 @@ def kldiv_forward_llama_like( target_mask, true_labels=labels, ) - num_items_in_batch = kwargs.pop("num_items_in_batch", None) - if num_items_in_batch is not None: + num_items_in_batch = kwargs.pop("num_items_in_batch", -1) + if num_items_in_batch is not None and num_items_in_batch > 0: loss = loss / num_items_in_batch return CausalLMOutputWithPast( diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index bb7daa092..131e1695d 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -33,6 +33,7 @@ class AxolotlKDTrainer(AxolotlTrainer): self.args.kd_ce_alpha, # hard label loss self.args.kd_alpha, # kd loss self.args.kd_temperature, + self.args.kd_beta, compute_ce_loss=bool(self.args.kd_ce_alpha), ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 866a9c454..59d9ec595 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,10 +1,13 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +from __future__ import annotations + import importlib import inspect import os import signal import sys +import typing import weakref from contextlib import ExitStack from pathlib import Path @@ -25,7 +28,6 @@ from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.integrations.base import PluginManager from axolotl.loaders import ( ModelLoader, @@ -45,6 +47,9 @@ try: except ImportError: BetterTransformer = None +if typing.TYPE_CHECKING: + from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder + LOG = get_logger(__name__) @@ -472,7 +477,7 @@ def handle_untrained_tokens_fix( def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[ - HFRLTrainerBuilder | HFCausalTrainerBuilder, + "HFRLTrainerBuilder" | "HFCausalTrainerBuilder", PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index bf496d2c5..09bfb5576 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -36,7 +36,7 @@ _CHAT_TEMPLATES = { "deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}", "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}", + "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- else %}\n {{- '\\n\\n' }}\n {%- endif %}\n{%- endif %}", "exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}", "metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}", "pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message["content"] + eos_token }}\n{%- endif %}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}', diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 67f590a37..275b1f414 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,6 @@ from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support @@ -629,6 +628,8 @@ def setup_trainer( A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based on the provided parameters. """ + from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder + if ( cfg.torch_compile and cfg.fsdp_config