diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py
index e2ad1f579..d6d19607f 100644
--- a/src/axolotl/core/builders/causal.py
+++ b/src/axolotl/core/builders/causal.py
@@ -149,6 +149,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlTrainer
def build(self, total_num_steps):
+ from axolotl.core.training_args import (
+ AxolotlPRMConfig,
+ AxolotlRewardConfig,
+ AxolotlTrainingArguments,
+ )
+
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
@@ -317,12 +323,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
- if self.cfg.kd_ce_alpha is not None:
- training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
- if self.cfg.kd_alpha is not None:
- training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
- if self.cfg.kd_temperature is not None:
- training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
+
+ if self.cfg.plugins:
+ plugin_manager = PluginManager.get_instance()
+ plugin_training_args = plugin_manager.get_training_args(self.cfg)
+ if plugin_training_args:
+ training_arguments_kwargs.update(plugin_training_args)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -403,7 +409,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer
def build_collator(
- self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
+ self,
+ training_args, # type: "AxolotlTrainingArguments"
+ is_eval=False,
+ **kwargs,
):
if training_args.pretraining:
if (
diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py
index 14dbfa715..e7591131e 100644
--- a/src/axolotl/core/builders/rl.py
+++ b/src/axolotl/core/builders/rl.py
@@ -12,11 +12,6 @@ from axolotl.core.trainers import (
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
-from axolotl.core.training_args import (
- AxolotlCPOConfig,
- AxolotlKTOConfig,
- AxolotlORPOConfig,
-)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.logging import get_logger
@@ -79,6 +74,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Returns training_args and trainer_kwargs
"""
+ from axolotl.core.training_args import (
+ AxolotlCPOConfig,
+ AxolotlKTOConfig,
+ AxolotlORPOConfig,
+ )
+
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
@@ -165,6 +166,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
+
+ if self.cfg.plugins:
+ plugin_manager = PluginManager.get_instance()
+ plugin_training_args = plugin_manager.get_training_args(self.cfg)
+ if plugin_training_args:
+ training_args_kwargs.update(plugin_training_args)
+
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,
diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py
index 03cad93e7..d5be9fc62 100644
--- a/src/axolotl/core/training_args.py
+++ b/src/axolotl/core/training_args.py
@@ -2,224 +2,17 @@
extra axolotl specific training args
"""
-from dataclasses import dataclass, field
-from typing import Optional
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Optional, Type
-from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
+from axolotl.integrations.config import merge_training_args
-@dataclass
-class AxolotlTrainingMixins:
- """
- Mixin class for the Axolotl training args.
- """
-
- # pylint: disable=duplicate-code
- model_type: Optional[str] = field(
- default=None, metadata={"help": "HF model configuration model_type."}
- )
- lr_quadratic_warmup: bool = field(
- default=False,
- metadata={"help": "Use quadratic warmup for cosine scheduling."},
- )
- pretraining: bool = field(
- default=False,
- metadata={
- "help": "Indicates to trainer whether we are doing continued pretraining."
- },
- )
- sample_packing: bool = field(
- default=False,
- metadata={"help": "Use sample packing for efficient training."},
- )
- sample_packing_sequentially: bool = field(
- default=False,
- metadata={
- "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
- },
- )
- multipack_real_batches: bool = field(
- default=False,
- metadata={"help": "Use real batches for efficient training."},
- )
- eval_sample_packing: Optional[bool] = field(
- default=None,
- metadata={"help": "Use sample packing for efficient evals."},
- )
- sample_packing_efficiency: float = field(
- default=1.0,
- metadata={"help": "Sample packing efficiency for calculating batch length."},
- )
- sample_packing_bin_size: int = field(
- default=200,
- metadata={
- "help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
- },
- )
- sample_packing_group_size: int = field(
- default=100000,
- metadata={
- "help": "The number of samples to group together for packing. Increase for better packing."
- },
- )
- max_seq_length: int = field(
- default=2048,
- metadata={"help": "The maximum sequence length the model can handle"},
- )
- relora_steps: Optional[int] = field(
- default=None,
- metadata={"help": "how often to reset for ReLoRA"},
- )
- relora_warmup_steps: Optional[int] = field(
- default=None,
- metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
- )
- relora_anneal_steps: Optional[int] = field(
- default=None,
- metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
- )
- relora_prune_ratio: Optional[float] = field(
- default=0.9,
- metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
- )
- bench_split: Optional[str] = field(
- default="eval", metadata={"help": "The benchmark split to run on"}
- )
- bench_dataset: Optional[str] = field(
- default="pharaouk/dharma-1/dharma_1_mini.json",
- metadata={
- "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
- },
- )
- do_bench_eval: Optional[bool] = field(
- default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
- )
- do_causal_lm_eval: Optional[bool] = field(
- default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
- )
- max_bench_samples: Optional[int] = field(
- default=None,
- metadata={
- "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
- },
- )
- bench_source_max_len: int = field(
- default=2048, metadata={"help": "Maximum source sequence length for bench."}
- )
- dataloader_prefetch_factor: Optional[int] = field(
- default=None,
- metadata={"help": "prefetch_factor argument to the dataloader"},
- )
- cosine_min_lr_ratio: Optional[float] = field(
- default=None,
- metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
- )
- cosine_constant_lr_ratio: Optional[float] = field(
- default=None,
- metadata={
- "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
- },
- )
- loraplus_lr_ratio: Optional[float] = field(
- default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
- )
- loraplus_lr_embedding: Optional[float] = field(
- default=1e-6,
- metadata={"help": "loraplus learning rate for lora embedding layers."},
- )
- embedding_lr_scale: Optional[float] = field(
- default=None,
- metadata={"help": "Scale the learning rate for the embedding layers."},
- )
- lr_groups: Optional[list[dict]] = field(
- default=None,
- metadata={"help": "Specify learning rate groups for with different LRs."},
- )
- embedding_lr: Optional[float] = field(
- default=None,
- metadata={"help": "absolute learning rate for the embedding layers."},
- )
- qlora: bool = field(
- default=False,
- metadata={"help": "whether this is a qlora training"},
- )
- orpo_alpha: Optional[float] = field(
- default=None,
- )
- lisa_n_layers: Optional[int] = field(
- default=None,
- metadata={"help": "the number of activate layers in LISA"},
- )
- lisa_step_interval: Optional[int] = field(
- default=None,
- metadata={"help": "how often to switch layers in LISA"},
- )
- lisa_layers_attribute: Optional[str] = field(
- default=None,
- metadata={"help": "path under the model to access the layers"},
- )
- curriculum_sampling: Optional[bool] = field(
- default=None,
- metadata={"help": "whether to use sequential sampling for curriculum learning"},
- )
- alternate_lr_scheduler_type: Optional[str] = field(
- default=None,
- metadata={
- "help": "workaround to pass an alternate lr scheduler to the HF trainer"
- },
- )
- chat_template: Optional[str] = field(
- default=None,
- metadata={"help": "Chat template converting chat messages to text"},
- )
-
- kd_ce_alpha: Optional[float] = field(
- default=None,
- metadata={
- "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
- },
- )
-
- kd_alpha: Optional[float] = field(
- default=1.0,
- metadata={"help": "The alpha scaling parameter for KD loss"},
- )
-
- kd_temperature: Optional[float] = field(
- default=1.0,
- metadata={
- "help": "the temperature parameter for KL divergence loss when using KD"
- },
- )
-
- adam_beta3: Optional[float] = field(
- default=None,
- metadata={
- "help": "The beta3 hyperparameter used in some optimizers such as CAME"
- },
- )
- adam_epsilon2: Optional[float] = field(
- default=None,
- metadata={
- "help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
- },
- )
-
- # multi-modal section
-
- image_size: int | tuple[int, int] | None = field(
- default=None,
- metadata={"help": "The size of the image to resize to"},
- )
-
- image_resize_algorithm: Resampling | None = field(
- default=None,
- metadata={"help": "The algorithm to use for image resizing"},
- )
-
- # end of multi-modal section
+AxolotlTrainingMixins: Type = merge_training_args()
@dataclass
diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py
new file mode 100644
index 000000000..fbd387e5f
--- /dev/null
+++ b/src/axolotl/core/training_args_base.py
@@ -0,0 +1,222 @@
+"""
+Base Axolotl Training Mixins shared across various trainer configs
+"""
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+from PIL.Image import Resampling
+
+from axolotl.utils.schemas.enums import RingAttnFunc
+
+
+@dataclass
+class AxolotlTrainingMixins:
+ """
+ Mixin class for the Axolotl training args.
+ """
+
+ # pylint: disable=duplicate-code
+ model_type: Optional[str] = field(
+ default=None, metadata={"help": "HF model configuration model_type."}
+ )
+ lr_quadratic_warmup: bool = field(
+ default=False,
+ metadata={"help": "Use quadratic warmup for cosine scheduling."},
+ )
+ pretraining: bool = field(
+ default=False,
+ metadata={
+ "help": "Indicates to trainer whether we are doing continued pretraining."
+ },
+ )
+ sample_packing: bool = field(
+ default=False,
+ metadata={"help": "Use sample packing for efficient training."},
+ )
+ sample_packing_sequentially: bool = field(
+ default=False,
+ metadata={
+ "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
+ },
+ )
+ multipack_real_batches: bool = field(
+ default=False,
+ metadata={"help": "Use real batches for efficient training."},
+ )
+ eval_sample_packing: Optional[bool] = field(
+ default=None,
+ metadata={"help": "Use sample packing for efficient evals."},
+ )
+ sample_packing_efficiency: float = field(
+ default=1.0,
+ metadata={"help": "Sample packing efficiency for calculating batch length."},
+ )
+ sample_packing_bin_size: int = field(
+ default=200,
+ metadata={
+ "help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
+ },
+ )
+ sample_packing_group_size: int = field(
+ default=100000,
+ metadata={
+ "help": "The number of samples to group together for packing. Increase for better packing."
+ },
+ )
+ max_seq_length: int = field(
+ default=2048,
+ metadata={"help": "The maximum sequence length the model can handle"},
+ )
+ relora_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "how often to reset for ReLoRA"},
+ )
+ relora_warmup_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
+ )
+ relora_anneal_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
+ )
+ relora_prune_ratio: Optional[float] = field(
+ default=0.9,
+ metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
+ )
+ bench_split: Optional[str] = field(
+ default="eval", metadata={"help": "The benchmark split to run on"}
+ )
+ bench_dataset: Optional[str] = field(
+ default="pharaouk/dharma-1/dharma_1_mini.json",
+ metadata={
+ "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
+ },
+ )
+ do_bench_eval: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
+ )
+ do_causal_lm_eval: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
+ )
+ max_bench_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
+ },
+ )
+ bench_source_max_len: int = field(
+ default=2048, metadata={"help": "Maximum source sequence length for bench."}
+ )
+ dataloader_prefetch_factor: Optional[int] = field(
+ default=None,
+ metadata={"help": "prefetch_factor argument to the dataloader"},
+ )
+ cosine_min_lr_ratio: Optional[float] = field(
+ default=None,
+ metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
+ )
+ cosine_constant_lr_ratio: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
+ },
+ )
+ loraplus_lr_ratio: Optional[float] = field(
+ default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
+ )
+ loraplus_lr_embedding: Optional[float] = field(
+ default=1e-6,
+ metadata={"help": "loraplus learning rate for lora embedding layers."},
+ )
+ embedding_lr_scale: Optional[float] = field(
+ default=None,
+ metadata={"help": "Scale the learning rate for the embedding layers."},
+ )
+ lr_groups: Optional[list[dict]] = field(
+ default=None,
+ metadata={"help": "Specify learning rate groups for with different LRs."},
+ )
+ embedding_lr: Optional[float] = field(
+ default=None,
+ metadata={"help": "absolute learning rate for the embedding layers."},
+ )
+ qlora: bool = field(
+ default=False,
+ metadata={"help": "whether this is a qlora training"},
+ )
+ orpo_alpha: Optional[float] = field(
+ default=None,
+ )
+ lisa_n_layers: Optional[int] = field(
+ default=None,
+ metadata={"help": "the number of activate layers in LISA"},
+ )
+ lisa_step_interval: Optional[int] = field(
+ default=None,
+ metadata={"help": "how often to switch layers in LISA"},
+ )
+ lisa_layers_attribute: Optional[str] = field(
+ default=None,
+ metadata={"help": "path under the model to access the layers"},
+ )
+ curriculum_sampling: Optional[bool] = field(
+ default=None,
+ metadata={"help": "whether to use sequential sampling for curriculum learning"},
+ )
+ alternate_lr_scheduler_type: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "workaround to pass an alternate lr scheduler to the HF trainer"
+ },
+ )
+ chat_template: Optional[str] = field(
+ default=None,
+ metadata={"help": "Chat template converting chat messages to text"},
+ )
+
+ # kd_ce_alpha: Optional[float] = field(
+ # default=None,
+ # metadata={
+ # "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
+ # },
+ # )
+ #
+ # kd_alpha: Optional[float] = field(
+ # default=1.0,
+ # metadata={"help": "The alpha scaling parameter for KD loss"},
+ # )
+ #
+ # kd_temperature: Optional[float] = field(
+ # default=1.0,
+ # metadata={
+ # "help": "the temperature parameter for KL divergence loss when using KD"
+ # },
+ # )
+
+ adam_beta3: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "The beta3 hyperparameter used in some optimizers such as CAME"
+ },
+ )
+ adam_epsilon2: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
+ },
+ )
+
+ # multi-modal section
+
+ image_size: int | tuple[int, int] | None = field(
+ default=None,
+ metadata={"help": "The size of the image to resize to"},
+ )
+
+ image_resize_algorithm: Resampling | None = field(
+ default=None,
+ metadata={"help": "The algorithm to use for image resizing"},
+ )
+
+ # end of multi-modal section
diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py
index f89dc5049..001436152 100644
--- a/src/axolotl/integrations/base.py
+++ b/src/axolotl/integrations/base.py
@@ -22,6 +22,8 @@ from __future__ import annotations
import collections
import importlib
+import logging
+import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
@@ -83,6 +85,11 @@ class BasePlugin:
def get_input_args(self) -> str | None:
"""Returns a pydantic model for the plugin's input arguments."""
+ def get_training_args_mixin(self) -> str | None:
+ """
+ Returns a dataclass model for the plugin's training arguments.
+ """
+
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -158,6 +165,32 @@ class BasePlugin:
trainer: The trainer object for training.
"""
+
+ def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
+ """
+ Returns custom training arguments to set on TrainingArgs.
+
+ Args:
+ cfg: The global axolotl configuration.
+
+ Returns:
+ object: dict containing the training arguments.
+ """
+
+ def get_collator_cls_and_kwargs(
+ self, cfg: DictDefault, is_eval: bool=False
+ ): # pylint: disable=unused-argument):
+ """
+ Returns a custom class for the collator.
+
+ Args:
+ cfg: The global axolotl configuration.
+ is_eval: Whether this is an eval split.
+
+ Returns:
+ class: The class for the collator.
+ """
+
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -167,84 +200,7 @@ class BasePlugin:
trainer: The trainer object for training.
Returns:
-<<<<<<< HEAD
The created optimizer.
-=======
- None
- """
-
- def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
- """
- Performs actions before LoRA weights are loaded.
-
- Args:
- cfg (dict): The configuration for the plugin.
- model (object): The loaded model.
-
- Returns:
- None
- """
-
- def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
- """
- Performs actions after LoRA weights are loaded.
-
- Args:
- cfg (dict): The configuration for the plugin.
- model (object): The loaded model.
-
- Returns:
- None
- """
-
- def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
- """
- Returns a custom class for the trainer.
-
- Args:
- cfg (dict): The global axolotl configuration.
-
- Returns:
- class: The class for the trainer.
- """
-
- def get_collator_cls_and_kwargs(
- self, cfg, is_eval=False
- ): # pylint: disable=unused-argument):
- """
- Returns a custom class for the collator.
-
- Args:
- cfg (dict): The global axolotl configuration.
- is_eval (bool): Whether this is an eval split.
-
- Returns:
- class: The class for the collator.
- """
-
- def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
- """
- Performs actions after the trainer is created.
-
- Args:
- cfg (dict): The configuration for the plugin.
- trainer (object): The trainer object for training.
-
- Returns:
- None
- """
-
- def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
- """
- Creates and returns an optimizer for training.
-
- Args:
- cfg (dict): The configuration for the plugin.
- trainer (object): The trainer object for training.
-
- Returns:
- object: The created optimizer.
->>>>>>> f8df1563d (collator cls for plugins)
"""
# pylint: disable=unused-argument
@@ -355,7 +311,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
-class PluginManager:
+class PluginManager: # pylint: disable=too-many-public-methods
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
@@ -414,8 +370,11 @@ class PluginManager:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
LOG.info(f"Plugin loaded successfully: {plugin_name}")
- except ImportError:
+ except ImportError as exc:
LOG.error(f"Failed to load plugin: {plugin_name}")
+ # print stacktrace
+ traceback.print_exc()
+ print(f"Error: {exc}")
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
@@ -430,6 +389,21 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
+ def get_training_args_mixin(self):
+ """
+ Returns a list of dataclasses for all registered plugins' training args mixins'
+
+ Returns:
+ list[str]: A list of dataclsses
+ """
+ training_args = []
+ for plugin in self.plugins.values():
+ training_args_from_plugin = plugin.get_training_args_mixin()
+ print(f"Training args from plugin: {plugin.__class__.__name__}")
+ if training_args_from_plugin is not None:
+ training_args.append(training_args_from_plugin)
+ return training_args
+
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -519,6 +493,24 @@ class PluginManager:
return trainer_cls
return None
+ def get_training_args(self, cfg):
+ """
+ Calls the get_training_args method of all registered plugins and returns the combined training arguments.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+
+ Returns:
+ object: The training arguments
+ """
+ training_args_kwargs = {}
+ for plugin in self.plugins.values():
+ training_args = plugin.get_training_args(cfg)
+ if training_args is not None:
+ training_args_kwargs.update(training_args)
+
+ return training_args_kwargs
+
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
"""
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
@@ -531,9 +523,7 @@ class PluginManager:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
- collator = plugin.get_collator_cls_and_kwargs(
- cfg, is_eval=is_eval
- )
+ collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
if collator is not None:
collator_cls, collator_kwargs = collator
return collator_cls, collator_kwargs
diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py
index b443f228e..f5fc07e9e 100644
--- a/src/axolotl/integrations/config.py
+++ b/src/axolotl/integrations/config.py
@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
This was moved here to prevent circular imports.
"""
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Type
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
@@ -61,3 +61,43 @@ def merge_input_args():
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
+
+
+def merge_training_args() -> Type:
+ """
+ Merges training arguments from registered plugins with the base TrainingArguments.
+
+ This function retrieves the training arguments from registered plugins using the PluginManager.
+ It then dynamically creates new classes, AxolotlTrainingMixins,
+ that inherit from the base configurations and include the training arguments from the plugins.
+
+ Returns:
+ tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
+ """
+ # pylint: disable=duplicate-code
+ from axolotl.core.training_args_base import (
+ AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
+ )
+ from axolotl.integrations.base import PluginManager
+
+ plugin_manager = PluginManager.get_instance()
+ training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
+ mixin_classes = []
+ dynamic_input = ""
+ for plugin_args in training_args_mixins:
+ plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
+ dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
+ mixin_classes.append(plugin_cls)
+ if dynamic_input:
+ dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
+
+ namespace: Dict[Any, Any] = {}
+ local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
+ exec( # pylint: disable=exec-used # nosec B102
+ dynamic_input, {**globals(), **local_vars}, namespace
+ )
+ AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
+ "AxolotlTrainingMixins"
+ ]
+ return AxolotlTrainingMixins
+ return AxolotlTrainingMixinsBase
diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py
index e648bcd25..b53c62f8a 100644
--- a/src/axolotl/integrations/kd/__init__.py
+++ b/src/axolotl/integrations/kd/__init__.py
@@ -15,7 +15,12 @@
"""
Plugin init to add KD support to Axolotl.
"""
+from typing import Any
+
+from transformers import Trainer
+
from axolotl.integrations.base import BasePlugin
+from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
@@ -28,6 +33,9 @@ class KDPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
+ def get_training_args_mixin(self):
+ return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
+
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
@@ -35,6 +43,14 @@ class KDPlugin(BasePlugin):
return AxolotlKDTrainer
return None
+ def get_training_args(self, cfg):
+ return {
+ "kd_ce_alpha": cfg.kd_ce_alpha,
+ "kd_alpha": cfg.kd_alpha,
+ "kd_temperature": cfg.kd_temperature,
+ "kd_beta": cfg.kd_beta,
+ }
+
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
if not cfg.kd_trainer:
return None, None
@@ -66,3 +82,24 @@ class KDPlugin(BasePlugin):
from .kernels.models import apply_kernel
apply_kernel(cfg.model_config_type)
+
+ def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
+ """
+ Adds temp scheduler callback to the Trainer instance.
+
+ Args:
+ cfg (Any): Configuration object containing the sparse recipe.
+ trainer (Trainer): Huggingface Trainer instance.
+
+ Returns:
+ list: List containing the configured callback instances.
+ """
+ if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
+ callback = KDTemperatureSchedulerCallback(
+ cfg.kd_temperature,
+ cfg.kd_temperature_min,
+ trainer,
+ )
+ return [callback]
+
+ return []
diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py
index 5c97e7bdd..8b6d6b6f5 100644
--- a/src/axolotl/integrations/kd/args.py
+++ b/src/axolotl/integrations/kd/args.py
@@ -15,14 +15,19 @@
"""
Plugin args for KD support.
"""
+from dataclasses import dataclass
from enum import Enum
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
class InferenceServerType(str, Enum):
- vllm = "vllm"
- sglang = "sglang"
+ """
+ Online inferences server types to handle different request args
+ """
+
+ vllm = "vllm" # pylint: disable=invalid-name
+ sglang = "sglang" # pylint: disable=invalid-name
class KDArgs(BaseModel):
@@ -36,9 +41,29 @@ class KDArgs(BaseModel):
)
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
+ kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
# TODO online kd
kd_online_server_base_url: str | None = None
kd_online_topk: int | None = None
- kd_online_server: InferenceServerType | None = "vllm"
+ kd_online_server: InferenceServerType | None = Field(
+ default_factory=lambda: InferenceServerType.vllm
+ )
kd_online_timeout: int | None = 120
+ kd_temperature_min: float | None = (
+ None # kd temperature scheduling during online kd
+ )
+
+
+@dataclass
+class KDTrainingArgsMixin:
+ """
+ Additional args for KD training.
+ """
+
+ kd_ce_alpha: float | None = (
+ None # loss coefficient for cross-entropy loss during KD
+ )
+ kd_alpha: float | None = None # loss coefficient for KD loss
+ kd_temperature: float | None = None # temperature for sampling during KD
+ kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
diff --git a/src/axolotl/integrations/kd/callbacks.py b/src/axolotl/integrations/kd/callbacks.py
new file mode 100644
index 000000000..b8a806f69
--- /dev/null
+++ b/src/axolotl/integrations/kd/callbacks.py
@@ -0,0 +1,36 @@
+"""
+Transformers trainer callbacks to schedule the KD temperature during training
+"""
+
+import math
+
+from transformers.trainer_callback import CallbackHandler
+
+
+class KDTemperatureSchedulerCallback(CallbackHandler):
+ """
+ KD temperature scheduler callback for the trainer.
+ """
+
+ def __init__(self, temperature_start, temperature_min, trainer):
+ self.temperature_start = temperature_start
+ self.temperature_min = temperature_min
+ self.temperature = temperature_start
+
+ self.trainer = trainer
+
+ def on_step_end(
+ self, args, state, control, **kwargs
+ ): # pylint: disable=unused-argument
+ # cosine decay temperature over the max steps
+
+ progress = state.global_step / state.max_steps
+ # Cosine decay factor: 0.5 * (1 + cos(pi * progress))
+ # This factor goes from 1 (at progress=0) to 0 (at progress=1)
+ decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
+ self.temperature = self.temperature_start - (
+ (self.temperature_start - self.temperature_min) * decay_factor
+ )
+
+ if hasattr(self.trainer.data_collator, "kd_temperature"):
+ self.trainer.data_collator.kd_temperature = self.temperature
diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py
index 39f6d16fb..6a6a4e2ee 100644
--- a/src/axolotl/integrations/kd/collator_online_teacher.py
+++ b/src/axolotl/integrations/kd/collator_online_teacher.py
@@ -2,12 +2,14 @@
Packed data loader for online teacher training supporting vllm and sglang.
"""
+import hashlib
+import hmac
import logging
from typing import Any, Dict, List, Optional
-import pandas as pd
import requests
import torch
+from orjson import orjson
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data.utils import retry_on_request_exceptions
@@ -15,6 +17,31 @@ from axolotl.utils.data.utils import retry_on_request_exceptions
LOG = logging.getLogger(__name__)
+def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
+ """
+ Create HMAC-SHA hash from a list of integers
+
+ Args:
+ int_list: List of integers
+ key: Secret key (string or bytes)
+ hash_func: Hash function (default: sha256)
+
+ Returns:
+ HMAC digest as hex string
+ """
+ # Convert key to bytes if it's a string
+ if isinstance(key, str):
+ key = key.encode("utf-8")
+
+ # Convert list of ints to bytes
+ # Method 1: Convert each int to bytes and concatenate
+ data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
+
+ # Create HMAC
+ h = hmac.new(key, data, hash_func)
+ return h.hexdigest()
+
+
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
"""
Collator for online teacher training.
@@ -30,6 +57,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
kd_temperature: Optional[float] = 1.0,
kd_online_server: Optional[str] = "vllm",
kd_online_timeout: Optional[int] = 120,
+ kd_cache_dir: Optional[str] = None,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
@@ -49,6 +77,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
self.kd_online_server = kd_online_server
self.http_session = requests.Session()
self.kd_online_timeout = kd_online_timeout
+ self.kd_cache_dir = kd_cache_dir
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
"""
@@ -109,7 +138,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
return final_logprobs_tensor.tolist()
- except Exception as e:
+ except Exception as e: # pylint: disable=broad-exception-caught
LOG.error(
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
exc_info=True,
@@ -142,11 +171,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
# Initialize with empty lists, so if API call fails, these are returned.
- ret_logprobs_data = {
- "target_token_ids": [],
- "target_logprobs": [],
- "target_mask": [],
- }
+ ret_data_target_token_ids: List[List[List[int]]] = []
+ ret_data_target_logprobs: List[List[List[float]]] = []
+ ret_data_target_mask: List[List[List[int]]] = []
try:
response = self.http_session.post(
@@ -162,7 +189,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
- return ret_logprobs_data
+ return {
+ "target_token_ids": ret_data_target_token_ids,
+ "target_logprobs": ret_data_target_logprobs,
+ "target_mask": ret_data_target_mask,
+ }
for sequence_data, seq_input_ids, seq_labels in zip(
api_data, batch_input_ids, labels
@@ -185,7 +216,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
- for i, input_id, label in zip(
+ for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
@@ -254,9 +285,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
- ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
- ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
- ret_logprobs_data["target_mask"].append(current_target_mask)
+ ret_data_target_token_ids.append(current_target_token_ids)
+ ret_data_target_logprobs.append(current_target_logprobs)
+ ret_data_target_mask.append(current_target_mask)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
@@ -269,7 +300,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
)
raise e
- return ret_logprobs_data
+ return {
+ "target_token_ids": ret_data_target_token_ids,
+ "target_logprobs": ret_data_target_logprobs,
+ "target_mask": ret_data_target_mask,
+ }
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_vllm(
@@ -296,18 +331,19 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
}
# Initialize with empty lists, so if API call fails, these are returned.
- ret_logprobs_data = {
- "target_token_ids": [],
- "target_logprobs": [],
- "target_mask": [],
- }
+ ret_data_target_token_ids: List[List[List[int]]] = []
+ ret_data_target_logprobs: List[List[List[float]]] = []
+ ret_data_target_mask: List[List[List[int]]] = []
try:
response = self.http_session.post(
- api_endpoint, json=payload, timeout=self.kd_online_timeout
+ api_endpoint,
+ json=payload,
+ timeout=self.kd_online_timeout,
+ # json_decoder=orjson.loads,
)
response.raise_for_status()
- api_data: dict = response.json()
+ api_data: dict = orjson.loads(response.content)
choices: list[dict] = api_data["choices"]
# Ensure api_data is a list, and its length matches batch_input_ids
@@ -317,7 +353,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
- return ret_logprobs_data
+ return {
+ "target_token_ids": ret_data_target_token_ids,
+ "target_logprobs": ret_data_target_logprobs,
+ "target_mask": ret_data_target_mask,
+ }
for sequence_data, seq_input_ids, seq_labels in zip(
choices, batch_input_ids, labels
@@ -330,29 +370,10 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_mask = []
# Ensure input_top_logprobs is a list
- input_top_logprobs: Optional[list[None | list[tuple]]] = (
+ input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
sequence_data.pop("prompt_logprobs", [])
)
- """
- vllm api data for prompt logprobs looks like:
- "prompt_logprobs": [
- null, # first token is always null
- { # second token logprobs
- "8948": { # token ID
- "logprob": -2.3841830625315197e-06,
- "rank": 1,
- "decoded_token": "system"
- },
- "1849": { # token ID
- "logprob": -13.187501907348633,
- "rank": 2,
- "decoded_token": "Ġsystem"
- },
- ... # rest of the top-k tokens/logprobs
- },
- ... # more tokens
- }
- """
+
if not isinstance(input_top_logprobs, list):
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
@@ -369,11 +390,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
continue
- elif (
+ if (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
- pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]
+ pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (
isinstance(pos_top_logprobs_data, dict)
@@ -396,9 +417,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
- pos_token_ids = pos_top_logprobs_data.keys()
+ pos_token_ids_str = list(pos_top_logprobs_data.keys())
pos_logprobs_dict = pos_top_logprobs_data.values()
- pos_token_ids = [int(token_id) for token_id in pos_token_ids]
+ pos_token_ids = [
+ int(token_id) for token_id in pos_token_ids_str
+ ]
pos_logprobs_raw = [
float(logprob.get("logprob", -float("inf")))
for logprob in pos_logprobs_dict
@@ -446,17 +469,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
current_target_token_ids.append(list(range(self.kd_online_topk)))
current_target_mask.append([0] * self.kd_online_topk)
- ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
- ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
- ret_logprobs_data["target_mask"].append(current_target_mask)
+ ret_data_target_token_ids.append(current_target_token_ids)
+ ret_data_target_logprobs.append(current_target_logprobs)
+ ret_data_target_mask.append(current_target_mask)
# TODO save and load targets to disk for caching for next epoch
- # generate a hash over seq_input_ids and convert it to an int
- # hash_input_ids: int = hash(tuple(seq_input_ids))
- # with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
- # pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
- # with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
- # pd.DataFrame(current_target_token_ids).to_parquet(f, index=False)
+ # generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
+ # if self.kd_cache_dir:
+ # hash_input_ids = hmac_sha_from_int_list(
+ # seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
+ # )
+ # with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
+ # pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
@@ -469,7 +493,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
)
raise e
- return ret_logprobs_data
+ return {
+ "target_token_ids": ret_data_target_token_ids,
+ "target_logprobs": ret_data_target_logprobs,
+ "target_mask": ret_data_target_mask,
+ }
def __call__(
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py
index f30904d5a..0050ffe33 100644
--- a/src/axolotl/integrations/kd/kernels/liger.py
+++ b/src/axolotl/integrations/kd/kernels/liger.py
@@ -20,6 +20,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
+ beta: float = 0.0,
) -> torch.Tensor:
"""
Compute Top-K KL divergence loss for a chunk.
@@ -28,6 +29,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
+ beta: Controls the type of KL divergence.
+ 0.0 for Forward KL (P_teacher || P_student).
+ 1.0 for Reverse KL (P_student || P_teacher).
+ 0.5 for Symmetric KL (average of Forward and Reverse).
Returns:
Sum of KL divergence losses for the chunk.
"""
@@ -59,6 +64,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# Teacher probabilities P(y|x_teacher) from logprobs
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
teacher_probs_valid = target_logprobs_valid.exp()
+ # Student probabilities P_student from log P_student
+ student_probs_topk_valid = student_logprobs_topk_valid.exp()
+
+ kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
@@ -66,9 +75,17 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
# Here, target_logprobs_valid are log_softmax_teacher.
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
- kd_loss_per_token = teacher_probs_valid * (
- target_logprobs_valid - student_logprobs_topk_valid
- )
+ if beta < 1.0: # Contribution from Forward KL
+ fwd_kl_per_token = teacher_probs_valid * (
+ target_logprobs_valid - student_logprobs_topk_valid
+ )
+ kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token
+ if beta > 0.0: # Contribution from Reverse KL
+ rev_kl_per_token = student_probs_topk_valid * (
+ student_logprobs_topk_valid - target_logprobs_valid
+ )
+ kd_loss_per_token += beta * rev_kl_per_token
+
kd_loss = kd_loss_per_token.sum()
return kd_loss
@@ -91,6 +108,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
weight_soft_loss: float = 0.5,
compute_ce_loss: bool = True,
temperature: float = 1.0,
+ beta: float = 0.0,
):
# Compute student logits for the chunk from hidden states and LM head
# student_input_chunk: [chunk_size, hidden_dim]
@@ -125,10 +143,9 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
target_token_ids_chunk,
target_logprobs_chunk,
target_mask_chunk,
+ beta=beta,
)
- loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss
-
return soft_loss, ce_loss
@classmethod
@@ -146,6 +163,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
+ beta: float = 0.0,
compiled: bool = False,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
@@ -192,6 +210,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
+ beta=beta,
)
def accumulate_chunk_grads(
@@ -288,13 +307,13 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
- ctx.hyperparams_count = 7 # Corresponds to number of hyperparams after main tensors in fwd signature
+ ctx.hyperparams_count = 8 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
- # since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulatedsum
- # we still need to scale the kd_loss by the temp
- kd_loss_acc = kd_loss_acc * (temperature ** 2)
+ # since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum
+ # we still need to scale the kd_loss by the temp^2
+ kd_loss_acc = kd_loss_acc * (temperature**2)
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
return final_loss
@@ -373,6 +392,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
temperature: float = 1.0, # This is the kd_temperature
+ beta: float = 1.0,
ignore_index: int = -100,
compiled: bool = True,
chunk_size: int = 1024,
@@ -387,6 +407,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.temperature = temperature
+ self.beta = beta
self.ignore_index = ignore_index
self.compiled = compiled
self.chunk_size = chunk_size
@@ -424,6 +445,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
self.weight_soft_loss,
self.ignore_index,
self.temperature,
+ self.beta,
self.compiled,
self.chunk_size,
self.compute_ce_loss,
diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py
index f346ad21a..5a7c286bc 100644
--- a/src/axolotl/integrations/kd/kernels/models.py
+++ b/src/axolotl/integrations/kd/kernels/models.py
@@ -75,8 +75,8 @@ def kldiv_forward_llama_like(
target_mask,
true_labels=labels,
)
- num_items_in_batch = kwargs.pop("num_items_in_batch", None)
- if num_items_in_batch is not None:
+ num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
+ if num_items_in_batch is not None and num_items_in_batch > 0:
loss = loss / num_items_in_batch
return CausalLMOutputWithPast(
diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py
index bb7daa092..131e1695d 100644
--- a/src/axolotl/integrations/kd/trainer.py
+++ b/src/axolotl/integrations/kd/trainer.py
@@ -33,6 +33,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss
self.args.kd_temperature,
+ self.args.kd_beta,
compute_ce_loss=bool(self.args.kd_ce_alpha),
)
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 866a9c454..59d9ec595 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -1,10 +1,13 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
+from __future__ import annotations
+
import importlib
import inspect
import os
import signal
import sys
+import typing
import weakref
from contextlib import ExitStack
from pathlib import Path
@@ -25,7 +28,6 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
-from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
@@ -45,6 +47,9 @@ try:
except ImportError:
BetterTransformer = None
+if typing.TYPE_CHECKING:
+ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
+
LOG = get_logger(__name__)
@@ -472,7 +477,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
- HFRLTrainerBuilder | HFCausalTrainerBuilder,
+ "HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index bf496d2c5..09bfb5576 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -36,7 +36,7 @@ _CHAT_TEMPLATES = {
"deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}",
"jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n',
"qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
- "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}",
+ "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- else %}\n {{- '\\n\\n' }}\n {%- endif %}\n{%- endif %}",
"exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}",
"metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}",
"pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message["content"] + eos_token }}\n{%- endif %}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}',
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index 67f590a37..275b1f414 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -16,7 +16,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
-from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
@@ -629,6 +628,8 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
+ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
+
if (
cfg.torch_compile
and cfg.fsdp_config