diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 11fe13713..bb865e98d 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -188,7 +188,7 @@ jobs:
if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }}
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
- timeout-minutes: 90
+ timeout-minutes: 120
needs: [pre-commit, pytest, pytest-sdist]
strategy:
@@ -238,7 +238,7 @@ jobs:
if: github.repository_owner == 'axolotl-ai-cloud'
# this job needs to be run on self-hosted GPU runners...
runs-on: [self-hosted, modal]
- timeout-minutes: 90
+ timeout-minutes: 120
# Only run the remainder of the matrix if the first e2e check passed;
# this is to save on wasted compute costs for known failures that get caught in the first run
needs: [pre-commit, pytest, docker-e2e-tests-1st]
diff --git a/deepspeed_configs/zero2_torch_compile.json b/deepspeed_configs/zero2_torch_compile.json
new file mode 100644
index 000000000..c3bcf98cf
--- /dev/null
+++ b/deepspeed_configs/zero2_torch_compile.json
@@ -0,0 +1,31 @@
+{
+ "compile": {
+ "disable": false,
+ "backend": "inductor"
+ },
+ "zero_optimization": {
+ "stage": 2,
+ "offload_optimizer": {
+ "device": "cpu"
+ },
+ "contiguous_gradients": true,
+ "overlap_comm": true
+ },
+ "bf16": {
+ "enabled": "auto"
+ },
+ "fp16": {
+ "enabled": "auto",
+ "auto_cast": false,
+ "loss_scale": 0,
+ "initial_scale_power": 32,
+ "loss_scale_window": 1000,
+ "hysteresis": 2,
+ "min_loss_scale": 1
+ },
+ "gradient_accumulation_steps": "auto",
+ "gradient_clipping": "auto",
+ "train_batch_size": "auto",
+ "train_micro_batch_size_per_gpu": "auto",
+ "wall_clock_breakdown": false
+}
diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py
index 8ff565dbb..6ed298d9f 100644
--- a/src/axolotl/core/builders/causal.py
+++ b/src/axolotl/core/builders/causal.py
@@ -21,11 +21,6 @@ from axolotl.core.trainers import (
AxolotlTrainer,
ReLoRATrainer,
)
-from axolotl.core.training_args import (
- AxolotlPRMConfig,
- AxolotlRewardConfig,
- AxolotlTrainingArguments,
-)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
@@ -130,6 +125,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def _get_trainer_cls(self):
+ """
+ Gets the trainer class for the given configuration.
+ """
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
@@ -146,6 +144,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlTrainer
def build(self, total_num_steps):
+ from axolotl.core.training_args import (
+ AxolotlPRMConfig,
+ AxolotlRewardConfig,
+ AxolotlTrainingArguments,
+ )
+
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
@@ -314,20 +318,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
- if self.cfg.kd_ce_alpha is not None:
- training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
- if self.cfg.kd_alpha is not None:
- training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
- if self.cfg.kd_temperature is not None:
- training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
- if self.cfg.kd_zscore_base_temp is not None:
- training_arguments_kwargs["kd_zscore_base_temp"] = (
- self.cfg.kd_zscore_base_temp
- )
- if self.cfg.kd_top_k_before_softmax is not None:
- training_arguments_kwargs["kd_top_k_before_softmax"] = (
- self.cfg.kd_top_k_before_softmax
- )
+
+ if self.cfg.plugins:
+ plugin_manager = PluginManager.get_instance()
+ plugin_training_args = plugin_manager.get_training_args(self.cfg)
+ if plugin_training_args:
+ training_arguments_kwargs.update(plugin_training_args)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -408,7 +404,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer
def build_collator(
- self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
+ self,
+ training_args, # type: "AxolotlTrainingArguments" # type: ignore
+ is_eval=False,
+ **kwargs,
):
if training_args.pretraining:
if (
@@ -437,7 +436,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
collator_args = [self.tokenizer]
- if self.cfg.reward_model:
+
+ collator_cls_and_kwargs = None
+ if self.cfg.plugins:
+ plugin_manager = PluginManager.get_instance()
+ collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
+ self.cfg, is_eval=is_eval
+ )
+
+ if collator_cls_and_kwargs:
+ collator = collator_cls_and_kwargs[0]
+ if kwargs and isinstance(kwargs, dict):
+ kwargs.update(collator_cls_and_kwargs[1])
+ elif self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
@@ -468,16 +479,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
- elif self.cfg.kd_trainer:
- from axolotl.integrations.kd.collator import (
- DataCollatorForKD,
- KDBatchSamplerDataCollatorForSeq2Seq,
- )
-
- if self.cfg.sample_packing:
- collator = KDBatchSamplerDataCollatorForSeq2Seq
- else:
- collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq
diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py
index 47ace7451..c5f01dd41 100644
--- a/src/axolotl/core/builders/rl.py
+++ b/src/axolotl/core/builders/rl.py
@@ -12,11 +12,6 @@ from axolotl.core.trainers import (
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
-from axolotl.core.training_args import (
- AxolotlCPOConfig,
- AxolotlKTOConfig,
- AxolotlORPOConfig,
-)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
@@ -83,6 +78,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Returns training_args and trainer_kwargs
"""
+ from axolotl.core.training_args import (
+ AxolotlCPOConfig,
+ AxolotlKTOConfig,
+ AxolotlORPOConfig,
+ )
+
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
@@ -150,6 +151,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
+ if self.cfg.plugins:
+ plugin_manager = PluginManager.get_instance()
+ plugin_training_args = plugin_manager.get_training_args(self.cfg)
+ if plugin_training_args:
+ training_args_kwargs.update(plugin_training_args)
+
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,
diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py
index 25ffb4cbf..fbae253d6 100644
--- a/src/axolotl/core/trainers/base.py
+++ b/src/axolotl/core/trainers/base.py
@@ -34,6 +34,7 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
+from axolotl.utils import get_not_null
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -104,7 +105,7 @@ class AxolotlTrainer(
)
batch_max_len = train_batch_size * self.args.max_seq_length
- return MultipackBatchSampler(
+ sampler = MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
@@ -117,6 +118,9 @@ class AxolotlTrainer(
num_processes=self.args.dataset_num_proc,
)
+ len(sampler)
+ return sampler
+
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
@@ -224,7 +228,9 @@ class AxolotlTrainer(
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
- dataloader_params["drop_last"] = self.args.dataloader_drop_last
+ dataloader_params["drop_last"] = get_not_null(
+ self.args.dataloader_drop_last, True
+ )
if sampler_fn is not None:
sampler = sampler_fn(dataset)
if isinstance(sampler, BatchSampler):
diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py
index 2b53c6798..d5be9fc62 100644
--- a/src/axolotl/core/training_args.py
+++ b/src/axolotl/core/training_args.py
@@ -2,242 +2,17 @@
extra axolotl specific training args
"""
-from dataclasses import dataclass, field
-from typing import Optional
+from __future__ import annotations
+
+from dataclasses import dataclass, field
+from typing import Optional, Type
-from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
+from axolotl.integrations.config import merge_training_args
-@dataclass
-class AxolotlTrainingMixins:
- """
- Mixin class for the Axolotl training args.
- """
-
- # pylint: disable=duplicate-code
- model_type: Optional[str] = field(
- default=None, metadata={"help": "HF model configuration model_type."}
- )
- lr_quadratic_warmup: bool = field(
- default=False,
- metadata={"help": "Use quadratic warmup for cosine scheduling."},
- )
- pretraining: bool = field(
- default=False,
- metadata={
- "help": "Indicates to trainer whether we are doing continued pretraining."
- },
- )
- sample_packing: bool = field(
- default=False,
- metadata={"help": "Use sample packing for efficient training."},
- )
- sample_packing_sequentially: bool = field(
- default=False,
- metadata={
- "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
- },
- )
- multipack_real_batches: bool = field(
- default=False,
- metadata={"help": "Use real batches for efficient training."},
- )
- eval_sample_packing: Optional[bool] = field(
- default=None,
- metadata={"help": "Use sample packing for efficient evals."},
- )
- sample_packing_efficiency: float = field(
- default=1.0,
- metadata={"help": "Sample packing efficiency for calculating batch length."},
- )
- sample_packing_bin_size: int = field(
- default=200,
- metadata={
- "help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
- },
- )
- sample_packing_group_size: int = field(
- default=100000,
- metadata={
- "help": "The number of samples to group together for packing. Increase for better packing."
- },
- )
- max_seq_length: int = field(
- default=2048,
- metadata={"help": "The maximum sequence length the model can handle"},
- )
- dataset_num_proc: int | None = field(
- default=None,
- metadata={"help": "The number of processes to use for data processing"},
- )
- relora_steps: Optional[int] = field(
- default=None,
- metadata={"help": "how often to reset for ReLoRA"},
- )
- relora_warmup_steps: Optional[int] = field(
- default=None,
- metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
- )
- relora_anneal_steps: Optional[int] = field(
- default=None,
- metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
- )
- relora_prune_ratio: Optional[float] = field(
- default=0.9,
- metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
- )
- bench_split: Optional[str] = field(
- default="eval", metadata={"help": "The benchmark split to run on"}
- )
- bench_dataset: Optional[str] = field(
- default="pharaouk/dharma-1/dharma_1_mini.json",
- metadata={
- "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
- },
- )
- do_bench_eval: Optional[bool] = field(
- default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
- )
- do_causal_lm_eval: Optional[bool] = field(
- default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
- )
- max_bench_samples: Optional[int] = field(
- default=None,
- metadata={
- "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
- },
- )
- bench_source_max_len: int = field(
- default=2048, metadata={"help": "Maximum source sequence length for bench."}
- )
- dataloader_prefetch_factor: Optional[int] = field(
- default=None,
- metadata={"help": "prefetch_factor argument to the dataloader"},
- )
- cosine_min_lr_ratio: Optional[float] = field(
- default=None,
- metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
- )
- cosine_constant_lr_ratio: Optional[float] = field(
- default=None,
- metadata={
- "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
- },
- )
- loraplus_lr_ratio: Optional[float] = field(
- default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
- )
- loraplus_lr_embedding: Optional[float] = field(
- default=1e-6,
- metadata={"help": "loraplus learning rate for lora embedding layers."},
- )
- embedding_lr_scale: Optional[float] = field(
- default=None,
- metadata={"help": "Scale the learning rate for the embedding layers."},
- )
- lr_groups: Optional[list[dict]] = field(
- default=None,
- metadata={"help": "Specify learning rate groups for with different LRs."},
- )
- embedding_lr: Optional[float] = field(
- default=None,
- metadata={"help": "absolute learning rate for the embedding layers."},
- )
- qlora: bool = field(
- default=False,
- metadata={"help": "whether this is a qlora training"},
- )
- orpo_alpha: Optional[float] = field(
- default=None,
- )
- lisa_n_layers: Optional[int] = field(
- default=None,
- metadata={"help": "the number of activate layers in LISA"},
- )
- lisa_step_interval: Optional[int] = field(
- default=None,
- metadata={"help": "how often to switch layers in LISA"},
- )
- lisa_layers_attribute: Optional[str] = field(
- default=None,
- metadata={"help": "path under the model to access the layers"},
- )
- curriculum_sampling: Optional[bool] = field(
- default=None,
- metadata={"help": "whether to use sequential sampling for curriculum learning"},
- )
- alternate_lr_scheduler_type: Optional[str] = field(
- default=None,
- metadata={
- "help": "workaround to pass an alternate lr scheduler to the HF trainer"
- },
- )
- chat_template: Optional[str] = field(
- default=None,
- metadata={"help": "Chat template converting chat messages to text"},
- )
-
- kd_ce_alpha: Optional[float] = field(
- default=None,
- metadata={
- "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
- },
- )
-
- kd_alpha: Optional[float] = field(
- default=1.0,
- metadata={"help": "The alpha scaling parameter for KD loss"},
- )
-
- kd_temperature: Optional[float] = field(
- default=1.0,
- metadata={
- "help": "the temperature parameter for KL divergence loss when using KD"
- },
- )
-
- kd_zscore_base_temp: Optional[float] = field(
- default=None,
- metadata={
- "help": "the base temperature parameter for KL divergence with z-score when using KD"
- },
- )
-
- kd_top_k_before_softmax: Optional[bool] = field(
- default=None,
- metadata={
- "help": "Whether to apply top_k_before_softmax to the logits when using KD"
- },
- )
-
- adam_beta3: Optional[float] = field(
- default=None,
- metadata={
- "help": "The beta3 hyperparameter used in some optimizers such as CAME"
- },
- )
- adam_epsilon2: Optional[float] = field(
- default=None,
- metadata={
- "help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
- },
- )
-
- # multi-modal section
-
- image_size: int | tuple[int, int] | None = field(
- default=None,
- metadata={"help": "The size of the image to resize to"},
- )
-
- image_resize_algorithm: Resampling | None = field(
- default=None,
- metadata={"help": "The algorithm to use for image resizing"},
- )
-
- # end of multi-modal section
+AxolotlTrainingMixins: Type = merge_training_args()
@dataclass
diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py
new file mode 100644
index 000000000..8fcaff632
--- /dev/null
+++ b/src/axolotl/core/training_args_base.py
@@ -0,0 +1,224 @@
+"""
+Base Axolotl Training Mixins shared across various trainer configs
+"""
+
+from dataclasses import dataclass, field
+from typing import Optional
+
+from PIL.Image import Resampling
+
+
+@dataclass
+class AxolotlTrainingMixins:
+ """
+ Mixin class for the Axolotl training args.
+ """
+
+ # pylint: disable=duplicate-code
+ model_type: Optional[str] = field(
+ default=None, metadata={"help": "HF model configuration model_type."}
+ )
+ lr_quadratic_warmup: bool = field(
+ default=False,
+ metadata={"help": "Use quadratic warmup for cosine scheduling."},
+ )
+ pretraining: bool = field(
+ default=False,
+ metadata={
+ "help": "Indicates to trainer whether we are doing continued pretraining."
+ },
+ )
+ sample_packing: bool = field(
+ default=False,
+ metadata={"help": "Use sample packing for efficient training."},
+ )
+ sample_packing_sequentially: bool = field(
+ default=False,
+ metadata={
+ "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
+ },
+ )
+ multipack_real_batches: bool = field(
+ default=False,
+ metadata={"help": "Use real batches for efficient training."},
+ )
+ eval_sample_packing: Optional[bool] = field(
+ default=None,
+ metadata={"help": "Use sample packing for efficient evals."},
+ )
+ sample_packing_efficiency: float = field(
+ default=1.0,
+ metadata={"help": "Sample packing efficiency for calculating batch length."},
+ )
+ sample_packing_bin_size: int = field(
+ default=200,
+ metadata={
+ "help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
+ },
+ )
+ sample_packing_group_size: int = field(
+ default=100000,
+ metadata={
+ "help": "The number of samples to group together for packing. Increase for better packing."
+ },
+ )
+ max_seq_length: int = field(
+ default=2048,
+ metadata={"help": "The maximum sequence length the model can handle"},
+ )
+ dataset_num_proc: int | None = field(
+ default=None,
+ metadata={"help": "The number of processes to use for data processing"},
+ )
+ relora_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "how often to reset for ReLoRA"},
+ )
+ relora_warmup_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
+ )
+ relora_anneal_steps: Optional[int] = field(
+ default=None,
+ metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
+ )
+ relora_prune_ratio: Optional[float] = field(
+ default=0.9,
+ metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
+ )
+ bench_split: Optional[str] = field(
+ default="eval", metadata={"help": "The benchmark split to run on"}
+ )
+ bench_dataset: Optional[str] = field(
+ default="pharaouk/dharma-1/dharma_1_mini.json",
+ metadata={
+ "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
+ },
+ )
+ do_bench_eval: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
+ )
+ do_causal_lm_eval: Optional[bool] = field(
+ default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
+ )
+ max_bench_samples: Optional[int] = field(
+ default=None,
+ metadata={
+ "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
+ },
+ )
+ bench_source_max_len: int = field(
+ default=2048, metadata={"help": "Maximum source sequence length for bench."}
+ )
+ dataloader_prefetch_factor: Optional[int] = field(
+ default=None,
+ metadata={"help": "prefetch_factor argument to the dataloader"},
+ )
+ cosine_min_lr_ratio: Optional[float] = field(
+ default=None,
+ metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
+ )
+ cosine_constant_lr_ratio: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
+ },
+ )
+ loraplus_lr_ratio: Optional[float] = field(
+ default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
+ )
+ loraplus_lr_embedding: Optional[float] = field(
+ default=1e-6,
+ metadata={"help": "loraplus learning rate for lora embedding layers."},
+ )
+ embedding_lr_scale: Optional[float] = field(
+ default=None,
+ metadata={"help": "Scale the learning rate for the embedding layers."},
+ )
+ lr_groups: Optional[list[dict]] = field(
+ default=None,
+ metadata={"help": "Specify learning rate groups for with different LRs."},
+ )
+ embedding_lr: Optional[float] = field(
+ default=None,
+ metadata={"help": "absolute learning rate for the embedding layers."},
+ )
+ qlora: bool = field(
+ default=False,
+ metadata={"help": "whether this is a qlora training"},
+ )
+ orpo_alpha: Optional[float] = field(
+ default=None,
+ )
+ lisa_n_layers: Optional[int] = field(
+ default=None,
+ metadata={"help": "the number of activate layers in LISA"},
+ )
+ lisa_step_interval: Optional[int] = field(
+ default=None,
+ metadata={"help": "how often to switch layers in LISA"},
+ )
+ lisa_layers_attribute: Optional[str] = field(
+ default=None,
+ metadata={"help": "path under the model to access the layers"},
+ )
+ curriculum_sampling: Optional[bool] = field(
+ default=None,
+ metadata={"help": "whether to use sequential sampling for curriculum learning"},
+ )
+ alternate_lr_scheduler_type: Optional[str] = field(
+ default=None,
+ metadata={
+ "help": "workaround to pass an alternate lr scheduler to the HF trainer"
+ },
+ )
+ chat_template: Optional[str] = field(
+ default=None,
+ metadata={"help": "Chat template converting chat messages to text"},
+ )
+
+ # kd_ce_alpha: Optional[float] = field(
+ # default=None,
+ # metadata={
+ # "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
+ # },
+ # )
+ #
+ # kd_alpha: Optional[float] = field(
+ # default=1.0,
+ # metadata={"help": "The alpha scaling parameter for KD loss"},
+ # )
+ #
+ # kd_temperature: Optional[float] = field(
+ # default=1.0,
+ # metadata={
+ # "help": "the temperature parameter for KL divergence loss when using KD"
+ # },
+ # )
+
+ adam_beta3: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "The beta3 hyperparameter used in some optimizers such as CAME"
+ },
+ )
+ adam_epsilon2: Optional[float] = field(
+ default=None,
+ metadata={
+ "help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
+ },
+ )
+
+ # multi-modal section
+
+ image_size: int | tuple[int, int] | None = field(
+ default=None,
+ metadata={"help": "The size of the image to resize to"},
+ )
+
+ image_resize_algorithm: Resampling | None = field(
+ default=None,
+ metadata={"help": "The algorithm to use for image resizing"},
+ )
+
+ # end of multi-modal section
diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py
index 0edc9fdea..9162bc745 100644
--- a/src/axolotl/integrations/base.py
+++ b/src/axolotl/integrations/base.py
@@ -22,6 +22,7 @@ from __future__ import annotations
import collections
import importlib
+import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
@@ -83,6 +84,11 @@ class BasePlugin:
def get_input_args(self) -> str | None:
"""Returns a pydantic model for the plugin's input arguments."""
+ def get_training_args_mixin(self) -> str | None:
+ """
+ Returns a dataclass model for the plugin's training arguments.
+ """
+
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -158,6 +164,31 @@ class BasePlugin:
trainer: The trainer object for training.
"""
+ def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
+ """
+ Returns custom training arguments to set on TrainingArgs.
+
+ Args:
+ cfg: The global axolotl configuration.
+
+ Returns:
+ object: dict containing the training arguments.
+ """
+
+ def get_collator_cls_and_kwargs(
+ self, cfg: DictDefault, is_eval: bool = False
+ ): # pylint: disable=unused-argument):
+ """
+ Returns a custom class for the collator.
+
+ Args:
+ cfg: The global axolotl configuration.
+ is_eval: Whether this is an eval split.
+
+ Returns:
+ class: The class for the collator.
+ """
+
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -278,7 +309,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
-class PluginManager:
+class PluginManager: # pylint: disable=too-many-public-methods
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
@@ -337,8 +368,11 @@ class PluginManager:
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
LOG.info(f"Plugin loaded successfully: {plugin_name}")
- except ImportError:
+ except ImportError as exc:
LOG.error(f"Failed to load plugin: {plugin_name}")
+ # print stacktrace
+ traceback.print_exc()
+ print(f"Error: {exc}")
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
@@ -353,6 +387,20 @@ class PluginManager:
input_args.append(input_args_from_plugin)
return input_args
+ def get_training_args_mixin(self):
+ """
+ Returns a list of dataclasses for all registered plugins' training args mixins'
+
+ Returns:
+ list[str]: A list of dataclsses
+ """
+ training_args = []
+ for plugin in self.plugins.values():
+ training_args_from_plugin = plugin.get_training_args_mixin()
+ if training_args_from_plugin is not None:
+ training_args.append(training_args_from_plugin)
+ return training_args
+
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -442,6 +490,42 @@ class PluginManager:
return trainer_cls
return None
+ def get_training_args(self, cfg):
+ """
+ Calls the get_training_args method of all registered plugins and returns the combined training arguments.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+
+ Returns:
+ object: The training arguments
+ """
+ training_args_kwargs = {}
+ for plugin in self.plugins.values():
+ training_args = plugin.get_training_args(cfg)
+ if training_args is not None:
+ training_args_kwargs.update(training_args)
+
+ return training_args_kwargs
+
+ def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
+ """
+ Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
+
+ Parameters:
+ cfg (dict): The configuration for the plugins.
+ is_eval (bool): Whether this is an eval split.
+
+ Returns:
+ object: The collator class, or None if none was found.
+ """
+ for plugin in self.plugins.values():
+ collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
+ if collator is not None:
+ collator_cls, collator_kwargs = collator
+ return collator_cls, collator_kwargs
+ return None
+
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.
diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py
index b443f228e..f5fc07e9e 100644
--- a/src/axolotl/integrations/config.py
+++ b/src/axolotl/integrations/config.py
@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
This was moved here to prevent circular imports.
"""
-from typing import Any, Dict, List
+from typing import Any, Dict, List, Type
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
@@ -61,3 +61,43 @@ def merge_input_args():
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
+
+
+def merge_training_args() -> Type:
+ """
+ Merges training arguments from registered plugins with the base TrainingArguments.
+
+ This function retrieves the training arguments from registered plugins using the PluginManager.
+ It then dynamically creates new classes, AxolotlTrainingMixins,
+ that inherit from the base configurations and include the training arguments from the plugins.
+
+ Returns:
+ tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
+ """
+ # pylint: disable=duplicate-code
+ from axolotl.core.training_args_base import (
+ AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
+ )
+ from axolotl.integrations.base import PluginManager
+
+ plugin_manager = PluginManager.get_instance()
+ training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
+ mixin_classes = []
+ dynamic_input = ""
+ for plugin_args in training_args_mixins:
+ plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
+ dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
+ mixin_classes.append(plugin_cls)
+ if dynamic_input:
+ dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
+
+ namespace: Dict[Any, Any] = {}
+ local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
+ exec( # pylint: disable=exec-used # nosec B102
+ dynamic_input, {**globals(), **local_vars}, namespace
+ )
+ AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
+ "AxolotlTrainingMixins"
+ ]
+ return AxolotlTrainingMixins
+ return AxolotlTrainingMixinsBase
diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py
index 8a6e3eda1..4c8535a0a 100644
--- a/src/axolotl/integrations/kd/__init__.py
+++ b/src/axolotl/integrations/kd/__init__.py
@@ -15,7 +15,12 @@
"""
Plugin init to add KD support to Axolotl.
"""
+from typing import Any
+
+from transformers import Trainer
+
from axolotl.integrations.base import BasePlugin
+from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
@@ -28,9 +33,75 @@ class KDPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
+ def get_training_args_mixin(self):
+ return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
+
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None
+
+ def get_training_args(self, cfg):
+ return {
+ "kd_ce_alpha": cfg.kd_ce_alpha,
+ "kd_alpha": cfg.kd_alpha,
+ "kd_temperature": cfg.kd_temperature,
+ "kd_beta": cfg.kd_beta,
+ "kd_normalize_topk": cfg.kd_normalize_topk,
+ }
+
+ def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
+ if not cfg.kd_trainer:
+ return None, None
+
+ from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
+
+ use_batch_sampler_collator = False
+ if is_eval is False and cfg.sample_packing:
+ use_batch_sampler_collator = True
+ if cfg.eval_sample_packing and is_eval:
+ use_batch_sampler_collator = True
+
+ if cfg.kd_online_server_base_url:
+ from .collator_online_teacher import OnlineTeacherCollator
+
+ return OnlineTeacherCollator, {
+ "kd_online_server_base_url": cfg.kd_online_server_base_url,
+ "kd_online_topk": cfg.kd_online_topk,
+ "kd_temperature": cfg.kd_temperature,
+ "kd_online_server": cfg.kd_online_server,
+ "kd_online_timeout": cfg.kd_online_timeout,
+ "kd_normalize_topk": cfg.kd_normalize_topk,
+ }
+
+ if use_batch_sampler_collator:
+ return KDBatchSamplerDataCollatorForSeq2Seq, {}
+ return DataCollatorForKD, {}
+
+ def pre_model_load(self, cfg):
+ from .kernels.models import apply_kernel
+
+ apply_kernel(cfg.model_config_type)
+
+ def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
+ """
+ Adds temp scheduler callback to the Trainer instance.
+
+ Args:
+ cfg (Any): Configuration object containing the sparse recipe.
+ trainer (Trainer): Huggingface Trainer instance.
+
+ Returns:
+ list: List containing the configured callback instances.
+ """
+ if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
+ callback = KDTemperatureSchedulerCallback(
+ cfg.kd_temperature,
+ cfg.kd_temperature_min,
+ trainer,
+ )
+ return [callback]
+
+ return []
diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py
index 2fbba2c6a..758bc8917 100644
--- a/src/axolotl/integrations/kd/args.py
+++ b/src/axolotl/integrations/kd/args.py
@@ -15,9 +15,19 @@
"""
Plugin args for KD support.
"""
-from typing import Optional
+from dataclasses import dataclass
+from enum import Enum
-from pydantic import BaseModel
+from pydantic import BaseModel, Field
+
+
+class InferenceServerType(str, Enum):
+ """
+ Online inferences server types to handle different request args
+ """
+
+ vllm = "vllm" # pylint: disable=invalid-name
+ sglang = "sglang" # pylint: disable=invalid-name
class KDArgs(BaseModel):
@@ -25,13 +35,41 @@ class KDArgs(BaseModel):
Input args for knowledge distillation.
"""
- kd_trainer: Optional[bool] = None # whether to use KD trainer
- kd_ce_alpha: Optional[float] = (
+ kd_trainer: float | None = None # whether to use KD trainer
+ kd_ce_alpha: float | None = (
None # loss coefficient for cross-entropy loss during KD
)
- kd_alpha: Optional[float] = None # loss coefficient for KD loss
- kd_temperature: Optional[float] = None # temperature for sampling during KD
- kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
- kd_top_k_before_softmax: Optional[bool] = (
- None # whether to sample top k before softmax during KD
+ kd_alpha: float | None = None # loss coefficient for KD loss
+ kd_temperature: float | None = None # temperature for sampling during KD
+ kd_beta: float | None = 0.0 # beta coefficient for ratio of fwd and reverse KL
+ kd_normalize_topk: bool | None = (
+ None # whether to normalize student logits during KD
+ )
+
+ # TODO online kd
+ kd_online_server_base_url: str | None = None
+ kd_online_topk: int | None = None
+ kd_online_server: InferenceServerType | None = Field(
+ default_factory=lambda: InferenceServerType.vllm
+ )
+ kd_online_timeout: int | None = 120
+ kd_temperature_min: float | None = (
+ None # kd temperature scheduling during online kd
+ )
+
+
+@dataclass
+class KDTrainingArgsMixin:
+ """
+ Additional args for KD training.
+ """
+
+ kd_ce_alpha: float | None = (
+ None # loss coefficient for cross-entropy loss during KD
+ )
+ kd_alpha: float | None = None # loss coefficient for KD loss
+ kd_temperature: float | None = None # temperature for sampling during KD
+ kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
+ kd_normalize_topk: float | None = (
+ None # whether to normalize student logits during KD
)
diff --git a/src/axolotl/integrations/kd/callbacks.py b/src/axolotl/integrations/kd/callbacks.py
new file mode 100644
index 000000000..911c3d517
--- /dev/null
+++ b/src/axolotl/integrations/kd/callbacks.py
@@ -0,0 +1,36 @@
+"""
+Transformers trainer callbacks to schedule the KD temperature during training
+"""
+
+import math
+
+from transformers.trainer_callback import TrainerCallback
+
+
+class KDTemperatureSchedulerCallback(TrainerCallback):
+ """
+ KD temperature scheduler callback for the trainer.
+ """
+
+ def __init__(self, temperature_start, temperature_min, trainer):
+ self.temperature_start = temperature_start
+ self.temperature_min = temperature_min
+ self.temperature = temperature_start
+
+ self.trainer = trainer
+
+ def on_step_end(
+ self, args, state, control, **kwargs
+ ): # pylint: disable=unused-argument
+ # cosine decay temperature over the max steps
+
+ progress = state.global_step / state.max_steps
+ # Cosine decay factor: 0.5 * (1 + cos(pi * progress))
+ # This factor goes from 1 (at progress=0) to 0 (at progress=1)
+ decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
+ self.temperature = self.temperature_start - (
+ (self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
+ )
+
+ if hasattr(self.trainer.data_collator, "kd_temperature"):
+ self.trainer.data_collator.kd_temperature = self.temperature
diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py
index 7c99a9c3d..f99dfe458 100644
--- a/src/axolotl/integrations/kd/chat_template.py
+++ b/src/axolotl/integrations/kd/chat_template.py
@@ -15,12 +15,15 @@
"""
Chat template prompt strategy loader with KD support
"""
+import logging
from typing import Any, Dict
import torch
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
+LOG = logging.getLogger(__name__)
+
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
@@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
- # for causal models, if we start the range at 1, then we don't need to shift in the trainer
- # otherwise, we need to shift in the trainer
- shift = 0
- for _ in range(shift, input_padding_len):
+ # we shift for causal models in the trainer, so start the range from 0
+ for _ in range(0, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
@@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
+ # normalize probabilities to sum to 1 in case they aren't already
+ teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
+ if teacher_probs_t1_sum > 1e-9:
+ teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
@@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
- if shift == 1:
- # since we started at index 1 for causal, we need one more padding token
- target_logprobs.append([-float("inf")] * top_k)
- target_token_ids.append(list(range(top_k)))
- target_mask.append([0] * top_k)
-
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
@@ -184,6 +183,117 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
return tokenized_prompt
+class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
+ """
+ Strat for datasets with complete structured KD logprob data
+ """
+
+ def transform_logprobs(self, sample):
+ """
+ Transform logprobs to target format for KD training
+ """
+ # pylint: disable=duplicate-code
+
+ logprobs = sample.pop(self.logprobs_field)
+ target_seq_len = len(logprobs)
+ input_seq_len = len(sample["input_ids"])
+ input_padding_len = input_seq_len - target_seq_len
+ # get non-zero top-k (prune None logprobs from vllm data step)
+ top_k_vals = [
+ len(logprobs[i])
+ for i in range(len(logprobs))
+ if logprobs[i] is not None and len(logprobs[i])
+ ]
+ max_top_k = max(set(top_k_vals), key=top_k_vals.count)
+ min_top_k = min(set(top_k_vals), key=top_k_vals.count)
+ top_k = min(max_top_k, min_top_k)
+ if top_k == 0:
+ raise ValueError("No non-zero top-k logprobs found.")
+
+ target_logprobs = []
+ target_token_ids = []
+ target_mask = []
+
+ if input_padding_len < 0:
+ # logprobs is longer than target_seq_len,
+ # so we need to slice from the left/beginning of logprobs
+ logprobs = logprobs[:-input_seq_len]
+ input_padding_len = 0
+ # target_seq_len = input_seq_len
+
+ # truncate the second dimension of the logprobs to top_k
+ logprobs = [row[:top_k] for row in logprobs]
+
+ # fill with -inf for padding_len tokens for top_k tokens
+ # extend target_logprobs with a padding_len x top_k 2D list filled with -inf
+
+ # we shift for causal models in the trainer, so start the range from 0
+ for _ in range(0, input_padding_len):
+ target_logprobs.append([-float("inf")] * top_k)
+ target_token_ids.append(list(range(top_k)))
+ target_mask.append([0] * top_k)
+
+ for position in range(input_padding_len, input_seq_len):
+ if sample["labels"][position] == -100:
+ target_mask.append([0] * top_k)
+ else:
+ target_mask.append([1] * top_k)
+
+ for token_pos_logprobs, pos_target_token_ids in zip(
+ logprobs, sample["target_token_ids"]
+ ):
+ # Convert to a tensor for easier manipulation
+ position_logprobs_tensor = torch.tensor(
+ token_pos_logprobs, dtype=torch.float
+ )
+
+ # Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
+ # Next, re-scale to T2 = self.kd_temperature via exponent-based trick
+ # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
+ #
+ # Convert from log to probability
+ teacher_probs_t1 = position_logprobs_tensor.exp()
+ # normalize probabilities to sum to 1 in case they aren't already
+ teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
+ if teacher_probs_t1_sum > 1e-9:
+ teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
+ if self.kd_temperature != self.gen_temperature:
+ # Exponentiate by factor (T1 / T2)
+ exponent = self.gen_temperature / self.kd_temperature
+ teacher_probs_t2 = teacher_probs_t1**exponent
+ else:
+ teacher_probs_t2 = teacher_probs_t1
+ # Re-normalize
+ teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
+ dim=0, keepdim=True
+ )
+ # Convert back to log
+ position_logprobs_tensor = torch.log(teacher_probs_t2)
+
+ # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
+ position_logprobs_scaled = position_logprobs_tensor.tolist()
+
+ target_logprobs.append(position_logprobs_scaled)
+ target_token_ids.append(pos_target_token_ids)
+
+ # Update sample with transformed logprobs
+ sample["target_logprobs"] = target_logprobs
+ sample["target_token_ids"] = target_token_ids
+ sample["target_mask"] = target_mask
+
+ return sample
+
+ def _tokenize_single_prompt(self, prompt):
+ logprobs = prompt.pop(self.logprobs_field)
+ target_token_ids = prompt.pop("target_token_ids")
+ tokenized_prompt = super()._tokenize_single_prompt(prompt)
+ tokenized_prompt[self.logprobs_field] = logprobs
+ tokenized_prompt["target_token_ids"] = target_token_ids
+ tokenized_prompt = self.transform_logprobs(tokenized_prompt)
+
+ return tokenized_prompt
+
+
class KDStrategyLoader(StrategyLoader):
"""
Load ChatTemplateStrategy with KD support using StrategyLoader.
@@ -204,4 +314,14 @@ class KDStrategyLoader(StrategyLoader):
return strategy_params
-load = KDStrategyLoader()
+class KDStrategyLoaderV2(KDStrategyLoader):
+ """
+ Load KD chat template datasets with pre-tokenized logprob data
+ """
+
+ def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument
+ return ChatTemplateStrategyWithKDv2
+
+
+load_legacy = KDStrategyLoader()
+load = KDStrategyLoaderV2()
diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py
index de63869c7..0cc745b78 100644
--- a/src/axolotl/integrations/kd/collator.py
+++ b/src/axolotl/integrations/kd/collator.py
@@ -47,11 +47,16 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
position_pad_token_id: int = 0
return_tensors: str = "pt"
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
+
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
padding_side = self.tokenizer.padding_side
+ max_len = 0
# Pad labels and position_ids first
for feature_name, pad_token_id in [
@@ -102,7 +107,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
target_mask_list.append(f.pop("target_mask"))
# Determine max lengths
- max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
+ max_teacher_seq_len = max_len or max(
+ len(seq) for seq in target_logprobs_list
+ )
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
padded_target_logprobs = []
@@ -209,7 +216,9 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features]
- for i, sub_features in enumerate(features):
+ for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
+ features
+ ):
# sub_features is a list of dicts, each dict = one sequence’s features
# We'll merge them into out_features[i].
#
@@ -243,10 +252,17 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# For example, input_ids or labels are often arrays.
arrays = []
for feat in sub_features:
- if field_name in feat:
+ if field_name in feat and isinstance(
+ feat[field_name], (list, torch.Tensor)
+ ):
+ if isinstance(
+ feat[field_name][0], (dict, str)
+ ): # pylint: disable=too-many-nested-blocks
+ continue
arr = np.array(feat[field_name])
arrays.append(arr)
- out_features[i][field_name] = np.concatenate(arrays)
+ if arrays:
+ out_features[i][field_name] = np.concatenate(arrays)
# 3) Now call the parent collator, which will do:
# - padding of labels/position_ids
diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py
new file mode 100644
index 000000000..584ace481
--- /dev/null
+++ b/src/axolotl/integrations/kd/collator_online_teacher.py
@@ -0,0 +1,561 @@
+"""
+Packed data loader for online teacher training supporting vllm and sglang.
+"""
+
+import hashlib
+import hmac
+import logging
+from typing import Any, Dict, List, Optional
+
+import requests
+import torch
+from orjson import orjson
+
+from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
+from axolotl.integrations.kd.utils import normalize_logprobs
+from axolotl.utils.data.utils import retry_on_request_exceptions
+
+LOG = logging.getLogger(__name__)
+
+
+def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
+ """
+ Create HMAC-SHA hash from a list of integers
+
+ Args:
+ int_list: List of integers
+ key: Secret key (string or bytes)
+ hash_func: Hash function (default: sha256)
+
+ Returns:
+ HMAC digest as hex string
+ """
+ # Convert key to bytes if it's a string
+ if isinstance(key, str):
+ key = key.encode("utf-8")
+
+ # Convert list of ints to bytes
+ # Method 1: Convert each int to bytes and concatenate
+ data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
+
+ # Create HMAC
+ h = hmac.new(key, data, hash_func)
+ return h.hexdigest()
+
+
+class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
+ """
+ Collator for online teacher training.
+ """
+
+ DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
+
+ def __init__(
+ self,
+ *args: Any,
+ kd_online_server_base_url: Optional[str] = None,
+ kd_online_topk: Optional[int] = None,
+ kd_temperature: Optional[float] = 1.0,
+ kd_online_server: Optional[str] = "vllm",
+ kd_online_timeout: Optional[int] = 120,
+ kd_cache_dir: Optional[str] = None,
+ kd_normalize_topk: Optional[bool] = True,
+ **kwargs: Any,
+ ):
+ super().__init__(*args, **kwargs)
+
+ if kd_online_server_base_url is None:
+ raise ValueError(
+ "kd_online_server_base_url must be provided for OnlineTeacherDataloader"
+ )
+ if kd_online_topk is None or kd_online_topk <= 0:
+ raise ValueError(
+ "kd_online_topk must be a positive integer for OnlineTeacherDataloader"
+ )
+
+ self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
+ self.kd_online_topk = kd_online_topk
+ self.kd_temperature = kd_temperature
+ self.kd_online_server = kd_online_server
+ self.http_session = requests.Session()
+ self.kd_online_timeout = kd_online_timeout
+ self.kd_cache_dir = kd_cache_dir
+ self.kd_normalize_topk = kd_normalize_topk
+
+ def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
+ """
+ Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
+ """
+ if not raw_logprobs or self.kd_online_topk == 0:
+ return (
+ [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
+ )
+
+ raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
+ return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
+
+ @retry_on_request_exceptions(max_retries=10, delay=5)
+ def fetch_online_logprobs_sglang(
+ self, batch_input_ids: List[List[int]], labels: List[List[int]]
+ ):
+ """
+ Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
+ Assumes API returns token IDs as strings in logprob dictionary keys.
+ """
+ api_endpoint = f"{self.kd_online_server_base_url}/generate"
+
+ payload = {
+ "input_ids": batch_input_ids,
+ "return_logprob": True,
+ "top_logprobs_num": self.kd_online_topk,
+ "logprob_start_len": 0,
+ "return_text_in_logprobs": True,
+ "echo": True,
+ "sampling_params": {
+ "max_new_tokens": 0,
+ "temperature": self.kd_temperature,
+ "skip_special_tokens": False,
+ },
+ }
+
+ # Initialize with empty lists, so if API call fails, these are returned.
+ ret_data_target_token_ids: List[List[List[int]]] = []
+ ret_data_target_logprobs: List[List[List[float]]] = []
+ ret_data_target_mask: List[List[List[int]]] = []
+
+ try:
+ response = self.http_session.post(
+ api_endpoint, json=payload, timeout=self.kd_online_timeout
+ )
+ response.raise_for_status()
+ api_data: list[dict] = response.json()
+
+ # Ensure api_data is a list, and its length matches batch_input_ids
+ if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):
+ LOG.error(
+ f"API response format error. Expected a list of {len(batch_input_ids)} "
+ f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
+ )
+ # Return empty data; items processed later will get default empty KD fields
+ return {
+ "target_token_ids": ret_data_target_token_ids,
+ "target_logprobs": ret_data_target_logprobs,
+ "target_mask": ret_data_target_mask,
+ }
+
+ for sequence_data, seq_input_ids, seq_labels in zip(
+ api_data, batch_input_ids, labels
+ ):
+ current_target_logprobs = []
+ current_target_token_ids = []
+ current_target_mask = []
+
+ meta_info = sequence_data.pop("meta_info", {})
+ # Ensure input_top_logprobs is a list
+ input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
+ "input_top_logprobs", []
+ )
+ if not isinstance(input_top_logprobs, list):
+ LOG.warning(
+ f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
+ )
+ input_top_logprobs = [] # Treat as empty
+
+ # basic check that the logprob data len matches the input len, so no need to handle padding
+ assert len(seq_input_ids) == len(input_top_logprobs)
+
+ for i, _, label in zip(
+ range(len(seq_input_ids)), seq_input_ids, seq_labels
+ ):
+ if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
+ # this is always the case for the first token.
+ # there is never logprob data for the first token since that's a true input
+ # so we replace the None value with padding data
+ current_target_logprobs.append(
+ [-float("inf")] * self.kd_online_topk
+ )
+ current_target_token_ids.append([0] * self.kd_online_topk)
+ current_target_mask.append([0] * self.kd_online_topk)
+ elif (
+ i < len(input_top_logprobs)
+ and input_top_logprobs[i] is not None
+ ):
+ pos_top_logprobs_data = input_top_logprobs[i]
+ # Ensure pos_top_logprobs_data is a list of lists as expected
+ if not (
+ isinstance(pos_top_logprobs_data, list)
+ and all(
+ isinstance(item, list) for item in pos_top_logprobs_data
+ )
+ and len(pos_top_logprobs_data) > 0
+ and len(pos_top_logprobs_data[0]) == 3
+ ): # [logprob, token_id, token_str]
+ LOG.warning(
+ f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
+ )
+ current_target_logprobs.append(
+ [-float("inf")] * self.kd_online_topk
+ )
+ current_target_token_ids.append([0] * self.kd_online_topk)
+ current_target_mask.append([0] * self.kd_online_topk)
+ continue
+
+ # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
+ pos_logprobs_raw, pos_token_ids, _ = [
+ list(row) for row in zip(*pos_top_logprobs_data)
+ ]
+
+ # Ensure correct length (top_k)
+ if len(pos_logprobs_raw) < self.kd_online_topk:
+ pad_len = self.kd_online_topk - len(pos_logprobs_raw)
+ pos_logprobs_raw.extend([-float("inf")] * pad_len)
+ pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
+
+ # truncate to top_k in case the response was longer
+ current_target_token_ids.append(
+ pos_token_ids[: self.kd_online_topk]
+ )
+
+ if self.kd_normalize_topk:
+ normalized_logprobs_for_position = self._normalize_logprobs(
+ pos_logprobs_raw[: self.kd_online_topk]
+ )
+ current_target_logprobs.append(
+ normalized_logprobs_for_position
+ )
+ else:
+ current_target_logprobs.append(
+ pos_logprobs_raw[: self.kd_online_topk]
+ )
+
+ # Mask depends on the corresponding label for the student
+ if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
+ current_target_mask.append([0] * self.kd_online_topk)
+ else:
+ current_target_mask.append([1] * self.kd_online_topk)
+ else:
+ # Pad if no logprobs for this position (either due to length mismatch or None entry)
+ current_target_logprobs.append(
+ [-float("inf")] * self.kd_online_topk
+ )
+ current_target_token_ids.append([0] * self.kd_online_topk)
+ current_target_mask.append([0] * self.kd_online_topk)
+
+ ret_data_target_token_ids.append(current_target_token_ids)
+ ret_data_target_logprobs.append(current_target_logprobs)
+ ret_data_target_mask.append(current_target_mask)
+
+ except requests.exceptions.RequestException as e:
+ LOG.error(f"Error fetching logprobs from online teacher: {e}")
+ raise e
+ # ret_logprobs_data will be returned with empty lists, handled by the caller.
+ except Exception as e: # Catch other potential errors during processing
+ LOG.error(
+ f"Unexpected error processing API response in fetch_online_logprobs: {e}",
+ exc_info=True,
+ )
+ raise e
+
+ return {
+ "target_token_ids": ret_data_target_token_ids,
+ "target_logprobs": ret_data_target_logprobs,
+ "target_mask": ret_data_target_mask,
+ }
+
+ @retry_on_request_exceptions(max_retries=10, delay=5)
+ def fetch_online_logprobs_vllm(
+ self, batch_input_ids: List[List[int]], labels: List[List[int]]
+ ):
+ """
+ Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
+ Assumes API returns token IDs as strings in logprob dictionary keys.
+ """
+ api_endpoint = f"{self.kd_online_server_base_url}/v1/completions"
+
+ payload = {
+ "prompt": batch_input_ids,
+ "echo": True,
+ "logprobs": True,
+ "prompt_logprobs": self.kd_online_topk,
+ "top_logprobs": self.kd_online_topk,
+ "max_new_tokens": 0,
+ "skip_special_tokens": False,
+ "temperature": self.kd_temperature,
+ "sampling_params": {
+ "max_tokens": 0,
+ },
+ }
+
+ # Initialize with empty lists, so if API call fails, these are returned.
+ ret_data_target_token_ids: List[List[List[int]]] = []
+ ret_data_target_logprobs: List[List[List[float]]] = []
+ ret_data_target_mask: List[List[List[int]]] = []
+
+ try:
+ headers = {"Accept-Encoding": "deflate, gzip, br, zstd"}
+ response = self.http_session.post(
+ api_endpoint,
+ json=payload,
+ headers=headers,
+ timeout=self.kd_online_timeout,
+ )
+ response.raise_for_status()
+ api_data: dict = orjson.loads(response.content)
+ choices: list[dict] = api_data["choices"]
+
+ # Ensure api_data is a list, and its length matches batch_input_ids
+ if not isinstance(choices, list) or len(choices) != len(batch_input_ids):
+ LOG.error(
+ f"API response format error. Expected a list of {len(batch_input_ids)} "
+ f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
+ )
+ # Return empty data; items processed later will get default empty KD fields
+ return {
+ "target_token_ids": ret_data_target_token_ids,
+ "target_logprobs": ret_data_target_logprobs,
+ "target_mask": ret_data_target_mask,
+ }
+
+ for sequence_data, seq_input_ids, seq_labels in zip(
+ choices, batch_input_ids, labels
+ ):
+ # seq_input_ids: List[int]
+ # seq_labels: List[int]
+
+ current_target_logprobs = []
+ current_target_token_ids = []
+ current_target_mask = []
+
+ # Ensure input_top_logprobs is a list
+ input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
+ sequence_data.pop("prompt_logprobs", [])
+ )
+
+ if not isinstance(input_top_logprobs, list):
+ LOG.warning(
+ f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
+ )
+ input_top_logprobs = [] # Treat as empty
+
+ # basic check that the logprob data len matches the input len, so no need to handle padding
+ assert len(seq_input_ids) == len(input_top_logprobs)
+
+ seq_len = len(seq_input_ids)
+
+ for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
+ if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
+ # this is always the case for the first token.
+ # there is never logprob data for the first token since that's a true input
+ continue
+ if (
+ i < len(input_top_logprobs)
+ and input_top_logprobs[i] is not None
+ ):
+ pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
+ # Ensure pos_top_logprobs_data is a list of lists as expected
+ if not (
+ isinstance(pos_top_logprobs_data, dict)
+ and all(
+ isinstance(item, dict)
+ for item in pos_top_logprobs_data.values()
+ )
+ and len(pos_top_logprobs_data.keys()) > 0
+ ): # [logprob, token_id, token_str]
+ LOG.warning(
+ f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
+ )
+ current_target_logprobs.append(
+ [-float("inf")] * self.kd_online_topk
+ )
+ current_target_token_ids.append(
+ list(range(self.kd_online_topk))
+ )
+ current_target_mask.append([0] * self.kd_online_topk)
+ continue
+
+ # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
+ pos_token_ids_str = list(pos_top_logprobs_data.keys())
+ pos_logprobs_dict = pos_top_logprobs_data.values()
+ pos_token_ids = [
+ int(token_id) for token_id in pos_token_ids_str
+ ]
+ pos_logprobs_raw = [
+ float(logprob.get("logprob", -float("inf")))
+ for logprob in pos_logprobs_dict
+ ]
+
+ # Ensure correct length (top_k)
+ if len(pos_logprobs_raw) < self.kd_online_topk:
+ pad_len = self.kd_online_topk - len(pos_logprobs_raw)
+ LOG.warning(
+ f"Padding position {i} with {pad_len} top-k tokens and logprobs."
+ )
+ pos_logprobs_raw.extend([-float("inf")] * pad_len)
+ pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
+
+ # truncate to top_k in case the response was longer
+ current_target_token_ids.append(
+ pos_token_ids[: self.kd_online_topk]
+ )
+
+ if self.kd_normalize_topk:
+ normalized_logprobs_for_position = self._normalize_logprobs(
+ pos_logprobs_raw[: self.kd_online_topk]
+ )
+ current_target_logprobs.append(
+ normalized_logprobs_for_position
+ )
+ else:
+ current_target_logprobs.append(
+ pos_logprobs_raw[: self.kd_online_topk]
+ )
+
+ # Mask depends on the corresponding label for the student
+ if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
+ current_target_mask.append([0] * self.kd_online_topk)
+ else:
+ current_target_mask.append([1] * self.kd_online_topk)
+ else:
+ # Pad if no logprobs for this position (either due to length mismatch or None entry)
+ current_target_logprobs.append(
+ [-float("inf")] * self.kd_online_topk
+ )
+ current_target_token_ids.append(
+ list(range(self.kd_online_topk))
+ )
+ current_target_mask.append([0] * self.kd_online_topk)
+ for i in range(max(0, seq_len - len(current_target_logprobs))):
+ current_target_logprobs.append(
+ [-float("inf")] * self.kd_online_topk
+ )
+ current_target_token_ids.append(list(range(self.kd_online_topk)))
+ current_target_mask.append([0] * self.kd_online_topk)
+
+ ret_data_target_token_ids.append(current_target_token_ids)
+ ret_data_target_logprobs.append(current_target_logprobs)
+ ret_data_target_mask.append(current_target_mask)
+
+ # TODO save and load targets to disk for caching for next epoch
+ # generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
+ # if self.kd_cache_dir:
+ # hash_input_ids = hmac_sha_from_int_list(
+ # seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
+ # )
+ # with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
+ # pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
+
+ except requests.exceptions.RequestException as e:
+ LOG.error(f"Error fetching logprobs from online teacher: {e}")
+ raise e
+ # ret_logprobs_data will be returned with empty lists, handled by the caller.
+ except Exception as e: # Catch other potential errors during processing
+ LOG.error(
+ f"Unexpected error processing API response in fetch_online_logprobs: {e}",
+ exc_info=True,
+ )
+ raise e
+
+ return {
+ "target_token_ids": ret_data_target_token_ids,
+ "target_logprobs": ret_data_target_logprobs,
+ "target_mask": ret_data_target_mask,
+ }
+
+ def __call__(
+ self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
+ ) -> Dict[str, Any]:
+ if not features:
+ return super().__call__(features, return_tensors=return_tensors)
+
+ for (
+ sub_batch_features
+ ) in features: # sub_batch_features is List[Dict[str, Any]]
+ if not sub_batch_features:
+ continue
+
+ input_ids_for_api_call: List[List[int]] = []
+ labels_for_api_call: List[List[int]] = []
+ # Store references to the original item dictionaries to update them in-place
+ items_for_api_call: List[Dict[str, Any]] = []
+
+ for item_dict in sub_batch_features:
+ if not isinstance(item_dict, dict):
+ LOG.warning(
+ f"Skipping non-dict item in sub_batch_features: {item_dict}"
+ )
+ continue
+
+ current_input_ids = item_dict.get("input_ids")
+ current_labels = item_dict.get("labels")
+
+ if current_input_ids is not None and current_labels is not None:
+ # Ensure input_ids and labels are lists of ints for JSON serialization
+ input_ids_list = (
+ current_input_ids.tolist()
+ if hasattr(current_input_ids, "tolist")
+ else list(current_input_ids)
+ )
+ labels_list = (
+ current_labels.tolist()
+ if hasattr(current_labels, "tolist")
+ else list(current_labels)
+ )
+
+ input_ids_for_api_call.append(input_ids_list)
+ labels_for_api_call.append(labels_list)
+ items_for_api_call.append(item_dict)
+ else:
+ # This item will not get teacher logprobs from the API.
+ # Initialize KD fields to empty lists so downstream collators handle them uniformly.
+ item_dict.setdefault("target_token_ids", [])
+ item_dict.setdefault("target_logprobs", [])
+ item_dict.setdefault("target_mask", [])
+
+ # print(items_for_api_call)
+ if items_for_api_call: # Only call API if there's something to process
+ if self.kd_online_server == "sglang":
+ api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
+ input_ids_for_api_call, labels_for_api_call
+ )
+ else:
+ api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
+ input_ids_for_api_call, labels_for_api_call
+ )
+
+ # api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
+ # Each value is a list, corresponding to items_for_api_call
+ for i, item_to_update in enumerate(items_for_api_call):
+ # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
+ if api_responses_for_sub_batch and i < len(
+ api_responses_for_sub_batch["target_token_ids"]
+ ): # Check bounds
+ assert len(
+ api_responses_for_sub_batch["target_token_ids"][i]
+ ) == len(item_to_update["input_ids"])
+ assert len(
+ api_responses_for_sub_batch["target_logprobs"][i]
+ ) == len(item_to_update["input_ids"])
+ assert len(
+ api_responses_for_sub_batch["target_mask"][i]
+ ) == len(item_to_update["labels"])
+ item_to_update["target_token_ids"] = (
+ api_responses_for_sub_batch["target_token_ids"][i]
+ )
+ item_to_update["target_logprobs"] = api_responses_for_sub_batch[
+ "target_logprobs"
+ ][i]
+ item_to_update["target_mask"] = api_responses_for_sub_batch[
+ "target_mask"
+ ][i]
+ else:
+ # API call failed for this item, or response was shorter than expected.
+ # Ensure KD fields are initialized as empty lists.
+ LOG.warning(
+ f" (index {i}), or API response was too short. "
+ f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
+ )
+ item_to_update.setdefault("target_token_ids", [])
+ item_to_update.setdefault("target_logprobs", [])
+ item_to_update.setdefault("target_mask", [])
+
+ return super().__call__(features, return_tensors=return_tensors)
diff --git a/src/axolotl/integrations/kd/kernels/__init__.py b/src/axolotl/integrations/kd/kernels/__init__.py
index e69de29bb..3f1144a45 100644
--- a/src/axolotl/integrations/kd/kernels/__init__.py
+++ b/src/axolotl/integrations/kd/kernels/__init__.py
@@ -0,0 +1,8 @@
+"""
+Liger Chunked loss optimizations module
+"""
+
+from .liger import LigerFusedLinearKLTopKLogprobLoss
+from .models import apply_kernel
+
+__all__ = ["LigerFusedLinearKLTopKLogprobLoss", "apply_kernel"]
diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py
new file mode 100644
index 000000000..6356643c2
--- /dev/null
+++ b/src/axolotl/integrations/kd/kernels/liger.py
@@ -0,0 +1,485 @@
+"""
+Liger Kernels for Chunked Top-K Log-Prob Distillation
+"""
+
+import torch
+import torch.nn.functional as F
+from liger_kernel.chunked_loss.fused_linear_distillation import (
+ LigerFusedLinearDistillationBase,
+)
+
+from axolotl.integrations.kd.utils import normalize_logprobs
+
+
+class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
+ """
+ Chunked kl-div loss for top-k logprobs
+ """
+
+ @staticmethod
+ def distillation_loss_fn(
+ student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
+ target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
+ target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
+ target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
+ beta: float = 0.0,
+ normalize_topk: bool = True,
+ ) -> torch.Tensor:
+ """
+ Compute Top-K KL divergence loss for a chunk.
+ Args:
+ student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
+ target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
+ target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
+ target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
+ beta: Controls the type of KL divergence.
+ 0.0 for Forward KL (P_teacher || P_student).
+ 1.0 for Reverse KL (P_student || P_teacher).
+ 0.5 for Symmetric KL (average of Forward and Reverse).
+ normalize_topk: Whether to normalize the log probabilities
+ Returns:
+ Sum of KL divergence losses for the chunk.
+ """
+ topk = target_token_ids_chunk.shape[-1]
+ student_logits_temp_scaled = ( # [chunk_size, vocab_size]
+ student_logits_temp_scaled.float()
+ )
+ target_logprobs_chunk = target_logprobs_chunk.float()
+
+ # Gather student logits for the top-k teacher token IDs
+ # target_token_ids_chunk: [chunk_size, top_k]
+ # student_logits_topk_temp_scaled: [chunk_size, top_k]
+ student_logits_topk_temp_scaled = torch.gather(
+ student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
+ )
+
+ # Student log-probabilities for the gathered top-k tokens
+ student_lse = torch.logsumexp(
+ student_logits_temp_scaled, dim=-1, keepdim=True
+ ) # [chunk_size, 1]
+ student_logprobs_topk_temp_scaled = (
+ student_logits_topk_temp_scaled - student_lse
+ )
+
+ # we have the top-k student logprobs, normalize them
+ if normalize_topk:
+ student_logprobs_topk_temp_scaled = normalize_logprobs(
+ student_logprobs_topk_temp_scaled, topk
+ )
+
+ valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
+
+ student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
+ teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
+
+ # Teacher probabilities P(y|x_teacher) from logprobs
+ # target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
+ teacher_probs_valid = teacher_logprobs_valid.exp()
+ # Student probabilities P_student from log P_student
+ student_probs_topk_valid = student_logprobs_topk_valid.exp()
+
+ # kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
+
+ # KL divergence: sum(P_teacher * (log P_teacher - log P_student))
+ # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
+ # The distillation loss is often formulated as -sum(P_teacher * log P_student)
+ # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
+ # Here, target_logprobs_valid are log_softmax_teacher.
+ # student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
+ if beta == 0.0: # Contribution from Forward KL
+ fwd_kl_per_token = teacher_probs_valid * (
+ teacher_logprobs_valid - student_logprobs_topk_valid
+ )
+ kd_loss = fwd_kl_per_token.sum()
+ elif beta == 1.0: # Contribution from Reverse KL
+ rev_kl_per_token = student_probs_topk_valid * (
+ student_logprobs_topk_valid - teacher_logprobs_valid
+ )
+ kd_loss = rev_kl_per_token.sum()
+ else:
+ # JSD - Jensen-Shannon Divergence / Symmetric
+ mean_probs = (
+ 1 - beta
+ ) * student_probs_topk_valid + beta * teacher_probs_valid
+ log_mean_probs = mean_probs.log()
+ student_kl = F.kl_div(
+ log_mean_probs,
+ student_logprobs_topk_valid,
+ reduction="sum",
+ log_target=True,
+ )
+ teacher_kl = F.kl_div(
+ log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
+ )
+ jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
+ kd_loss = jsd_loss
+
+ return kd_loss
+
+ @staticmethod
+ def _compute_loss_kl_topk(
+ student_input_chunk: torch.Tensor,
+ student_weight: torch.Tensor,
+ # Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
+ # or through `partial`. Let's make them explicit here for clarity.
+ target_token_ids_chunk: torch.Tensor,
+ target_logprobs_chunk: torch.Tensor,
+ target_mask_chunk: torch.Tensor,
+ target_chunk: torch.Tensor, # For hard loss (true labels)
+ student_bias: torch.Tensor = None, # This will be one of the grad targets
+ # Other params passed via `partial` from `forward`
+ distillation_loss_fn=None,
+ ignore_index: int = -100,
+ weight_hard_loss: float = 0.5,
+ weight_soft_loss: float = 0.5,
+ compute_ce_loss: bool = True,
+ temperature: float = 1.0,
+ beta: float = 0.0,
+ normalize_topk: bool = True,
+ ):
+ # Compute student logits for the chunk from hidden states and LM head
+ # student_input_chunk: [chunk_size, hidden_dim]
+ # student_lm_head_weight: [vocab_size, hidden_dim]
+ # student_logits_chunk: [chunk_size, vocab_size]
+ student_logits_chunk = F.linear(
+ student_input_chunk, student_weight, student_bias
+ )
+
+ ce_loss = torch.tensor(
+ 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
+ )
+ if compute_ce_loss and weight_hard_loss > 0.0:
+ ce_loss = F.cross_entropy(
+ student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
+ target_chunk.view(-1),
+ reduction="sum",
+ ignore_index=ignore_index,
+ )
+
+ soft_loss = torch.tensor(
+ 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
+ )
+ if weight_soft_loss > 0.0:
+ student_logits_chunk_temp_scaled = student_logits_chunk / temperature
+
+ # Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
+ # No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
+
+ soft_loss = distillation_loss_fn(
+ student_logits_chunk_temp_scaled,
+ target_token_ids_chunk,
+ target_logprobs_chunk,
+ target_mask_chunk,
+ beta=beta,
+ normalize_topk=normalize_topk,
+ )
+
+ return soft_loss, ce_loss
+
+ @classmethod
+ def forward(
+ cls,
+ ctx,
+ student_input: torch.Tensor, # [batch_size, seq_len, dim]
+ student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
+ target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
+ target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
+ target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
+ true_labels: torch.Tensor, # [batch_size, seq_len]
+ student_lm_head_bias: torch.Tensor = None,
+ weight_hard_loss: float = 0.5,
+ weight_soft_loss: float = 0.5,
+ ignore_index: int = -100,
+ temperature: float = 1.0,
+ beta: float = 0.0,
+ compiled: bool = False,
+ chunk_size: int = 1024,
+ compute_ce_loss: bool = True,
+ normalize_topk: bool = True,
+ ):
+ CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
+ grad_weight_acc = torch.zeros_like(student_lm_head_weight)
+ grad_inputs_list = []
+ grad_bias_acc = (
+ torch.zeros_like(student_lm_head_bias)
+ if student_lm_head_bias is not None
+ else None
+ )
+ kd_loss_acc = torch.zeros(
+ (), device=student_input.device, dtype=student_input.dtype
+ )
+ ce_loss_acc = torch.zeros(
+ (), device=student_input.device, dtype=student_input.dtype
+ )
+
+ # This function will be what torch.func.grad_and_value differentiates.
+ # It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
+ # Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
+ def loss_fn_for_grad(
+ _student_input_chunk,
+ _student_lm_head_weight, # full weight
+ _student_lm_head_bias, # full bias
+ # Fixed arguments for a given chunk, not differentiated:
+ _target_token_ids_chunk,
+ _target_logprobs_chunk,
+ _target_mask_chunk,
+ _true_labels_chunk,
+ ):
+ return cls._compute_loss_kl_topk(
+ student_input_chunk=_student_input_chunk,
+ student_weight=_student_lm_head_weight,
+ target_token_ids_chunk=_target_token_ids_chunk,
+ target_logprobs_chunk=_target_logprobs_chunk,
+ target_mask_chunk=_target_mask_chunk,
+ target_chunk=_true_labels_chunk,
+ student_bias=_student_lm_head_bias,
+ distillation_loss_fn=cls.distillation_loss_fn,
+ ignore_index=ignore_index,
+ weight_hard_loss=weight_hard_loss,
+ weight_soft_loss=weight_soft_loss,
+ compute_ce_loss=compute_ce_loss,
+ temperature=temperature,
+ beta=beta,
+ normalize_topk=normalize_topk,
+ )
+
+ def accumulate_chunk_grads(
+ student_input_chunk_ac,
+ target_token_ids_chunk_ac,
+ target_logprobs_chunk_ac,
+ target_mask_chunk_ac,
+ true_labels_chunk_ac,
+ ):
+ # student_weight and student_bias are closed over from the outer scope (full tensors)
+ if student_lm_head_bias is not None:
+ (
+ (chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
+ (chunk_kd_loss, chunk_ce_loss),
+ ) = torch.func.grad_and_value(
+ loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
+ )(
+ student_input_chunk_ac,
+ student_lm_head_weight,
+ student_lm_head_bias, # primals
+ target_token_ids_chunk_ac,
+ target_logprobs_chunk_ac,
+ target_mask_chunk_ac,
+ true_labels_chunk_ac,
+ ) # non-primals
+ grad_bias_acc.add_(chunk_grad_bias)
+ else:
+ argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
+ (
+ (chunk_grad_input, chunk_grad_weight), # No grad for bias
+ (chunk_kd_loss, chunk_ce_loss),
+ ) = torch.func.grad_and_value(
+ loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
+ )(
+ student_input_chunk_ac,
+ student_lm_head_weight,
+ None, # Pass None for student_bias primal
+ target_token_ids_chunk_ac,
+ target_logprobs_chunk_ac,
+ target_mask_chunk_ac,
+ true_labels_chunk_ac,
+ )
+
+ grad_weight_acc.add_(chunk_grad_weight)
+ kd_loss_acc.add_(chunk_kd_loss)
+ ce_loss_acc.add_(chunk_ce_loss)
+
+ return chunk_grad_input
+
+ if compiled:
+ accumulate_chunk_grads_compiled = torch.compile(
+ accumulate_chunk_grads, dynamic=True, backend="inductor"
+ ) # dynamic=True often helpful
+ else:
+ accumulate_chunk_grads_compiled = accumulate_chunk_grads
+
+ # Use the same chunking logic as LigerFusedLinearDistillationBase.forward
+ B, N, D = student_input.shape # pylint: disable=invalid-name
+ K = target_token_ids.shape[-1] # pylint: disable=invalid-name
+
+ student_input_flat = student_input.reshape(-1, student_input.shape[-1])
+ target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
+ target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
+ target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
+ # pad and shift for cross entropy loss
+ true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
+ true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
+
+ num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
+
+ _student_input_chunks = torch.chunk(
+ student_input_flat, chunks=num_chunks, dim=0
+ )
+ _target_token_ids_chunks = torch.chunk(
+ target_token_ids_flat, chunks=num_chunks, dim=0
+ )
+ _target_logprobs_chunks = torch.chunk(
+ target_logprobs_flat, chunks=num_chunks, dim=0
+ )
+ _target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
+ _true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
+
+ for i in range(num_chunks):
+ grad_input_chunk = accumulate_chunk_grads_compiled(
+ _student_input_chunks[i],
+ _target_token_ids_chunks[i],
+ _target_logprobs_chunks[i],
+ _target_mask_chunks[i],
+ _true_labels_chunks[i],
+ )
+ grad_inputs_list.append(grad_input_chunk)
+
+ grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
+ ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
+
+ # For matching None returns in backward for non-tensor/non-grad_requiring inputs
+ ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
+ ctx.bias_was_none = student_lm_head_bias is None
+ ctx.orig_dims = (B, N, D, K)
+
+ # since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
+ # we still need to scale the kd_loss by the temp^2
+ kd_loss_acc = kd_loss_acc * (temperature**2)
+ final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
+
+ return final_loss
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input_flat, grad_weight, grad_bias_maybe = (
+ ctx.saved_tensors
+ ) # grad_input_flat is (B*N, D)
+
+ # Scale gradients by grad_output if it's not 1.0
+ if not torch.equal(
+ grad_output,
+ torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
+ ):
+ grad_input_flat = grad_input_flat * grad_output
+ grad_weight = grad_weight * grad_output
+ if grad_bias_maybe is not None:
+ grad_bias_maybe = grad_bias_maybe * grad_output
+
+ # Reshape grad_input_flat to match original student_input shape (B, N, D)
+ # ctx.orig_dims stores (B, N, D, K)
+ # We need the first three dimensions for student_input's shape.
+ # Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
+ if (
+ ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
+ and grad_input_flat.numel() == 0
+ ):
+ # If original input was empty, gradient should also be empty with correct shape
+ grad_input_reshaped = torch.zeros(
+ ctx.orig_dims[0],
+ ctx.orig_dims[1],
+ ctx.orig_dims[2],
+ dtype=grad_input_flat.dtype,
+ device=grad_input_flat.device,
+ )
+ elif grad_input_flat.numel() == 0 and not (
+ ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
+ ):
+ # This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
+ # but as a safeguard:
+ grad_input_reshaped = torch.zeros(
+ ctx.orig_dims[0],
+ ctx.orig_dims[1],
+ ctx.orig_dims[2],
+ dtype=grad_input_flat.dtype,
+ device=grad_input_flat.device,
+ )
+ else:
+ grad_input_reshaped = grad_input_flat.view(
+ ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
+ )
+
+ nones_for_hyperparams = [None] * ctx.hyperparams_count
+ grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
+
+ return (
+ grad_input_reshaped, # Gradient for student_input (reshaped)
+ grad_weight, # Gradient for student_lm_head_weight
+ None, # Gradient for target_token_ids
+ None, # Gradient for target_logprobs
+ None, # Gradient for target_mask
+ None, # Gradient for true_labels
+ grad_bias_return, # Gradient for student_lm_head_bias
+ *nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
+ )
+
+
+class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
+ """
+ wrapper for chunked top-k logprob kl-d
+ """
+
+ def __init__(
+ self,
+ weight_hard_loss: float = 0.5,
+ weight_soft_loss: float = 0.5,
+ temperature: float = 1.0, # This is the kd_temperature
+ beta: float = 1.0,
+ ignore_index: int = -100,
+ compiled: bool = True,
+ chunk_size: int = 1024,
+ compute_ce_loss: bool = True,
+ normalize_topk: bool = True,
+ ):
+ super().__init__()
+ if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
+ raise ValueError("Loss weights must be between 0.0 and 1.0.")
+ if temperature <= 0:
+ raise ValueError("Temperature must be positive.")
+
+ self.weight_hard_loss = weight_hard_loss
+ self.weight_soft_loss = weight_soft_loss
+ self.temperature = temperature
+ self.beta = beta
+ self.ignore_index = ignore_index
+ self.compiled = compiled
+ self.chunk_size = chunk_size
+ self.compute_ce_loss = compute_ce_loss
+ self.normalize_topk = normalize_topk
+
+ if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
+ print(
+ f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
+ )
+ # self.weight_hard_loss = 0.0 # Or let user manage this
+ if self.weight_soft_loss == 0.0:
+ print(
+ "Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
+ )
+
+ def forward(
+ self,
+ lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
+ student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
+ target_token_ids: torch.Tensor,
+ target_logprobs: torch.Tensor,
+ target_mask: torch.Tensor,
+ true_labels: torch.Tensor,
+ student_bias: torch.Tensor = None,
+ ) -> torch.Tensor:
+ return LigerFusedLinearKLTopKLogprobFunction.apply(
+ student_hidden_states,
+ lm_head_weight,
+ target_token_ids,
+ target_logprobs,
+ target_mask,
+ true_labels,
+ student_bias,
+ self.weight_hard_loss,
+ self.weight_soft_loss,
+ self.ignore_index,
+ self.temperature,
+ self.beta,
+ self.compiled,
+ self.chunk_size,
+ self.compute_ce_loss,
+ self.normalize_topk,
+ )
diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py
new file mode 100644
index 000000000..bfd752964
--- /dev/null
+++ b/src/axolotl/integrations/kd/kernels/models.py
@@ -0,0 +1,98 @@
+"""
+model patcher for chunked top-k kl-div
+"""
+
+from types import MethodType
+from typing import Optional, Union, Unpack
+
+import torch
+from transformers import Cache
+from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
+from transformers.modeling_outputs import CausalLMOutputWithPast
+from transformers.utils import LossKwargs
+
+
+class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
+ """
+ placeholder kwargs for hf model classes
+ """
+
+
+def kldiv_forward_llama_like(
+ self,
+ input_ids: Optional[torch.LongTensor] = None,
+ target_logprobs: Optional[torch.Tensor] = None,
+ target_token_ids: Optional[torch.LongTensor] = None,
+ target_mask: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[Cache] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
+ **kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
+) -> CausalLMOutputWithPast:
+ # pylint: disable=duplicate-code
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.config.output_hidden_states
+ )
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ cache_position=cache_position,
+ **kwargs,
+ )
+
+ hidden_states = outputs.last_hidden_state
+
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
+ # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
+ # self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
+
+ loss = self.loss_function(
+ self.lm_head.weight,
+ hidden_states,
+ target_token_ids,
+ target_logprobs,
+ target_mask,
+ true_labels=labels,
+ )
+ num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
+ if num_items_in_batch is not None and num_items_in_batch > 0:
+ loss = loss / num_items_in_batch
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=None,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+
+def apply_kernel(model_type):
+ # Dynamically import the module and attention class
+ module_path = f"transformers.models.{model_type}.modeling_{model_type}"
+ model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
+ module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
+ model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
+ model_cls.forward = MethodType(kldiv_forward_llama_like, model_cls)
diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py
index 3c9515091..74184455f 100644
--- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py
+++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py
@@ -16,40 +16,7 @@
loss for top_k KL divergence
"""
import torch
-
-
-def zscore_standardize(
- logits: torch.Tensor,
- mask: torch.Tensor = None,
- base_temperature: float = 1.0,
- eps: float = 1e-9,
-):
- """
- Z-score standardize along the last dimension of `logits`.
- i.e., for each [B, seq_len] row, across K entries:
- z = (logits - mean) / std,
- then scale by 1 / base_temperature if desired.
-
- mask can be broadcastable or None. If None, we standardize all elements.
- """
- if mask is None:
- # shape: [B, seq_len, K]
- # Mean and std over dim=-1
- mean = logits.mean(dim=-1, keepdim=True)
- var = logits.var(dim=-1, unbiased=False, keepdim=True)
- else:
- # If you have to exclude some tokens, multiply by mask, etc.
- float_mask = mask.to(logits.dtype)
- count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
- mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
- var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
-
- std = torch.sqrt(var.clamp_min(eps))
- z = (logits - mean) / std
-
- # Scale by 1 / base_temperature
- z = z / base_temperature
- return z
+from torch import nn
@torch.jit.script
@@ -60,7 +27,6 @@ def loss(
target_mask: torch.Tensor,
num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0,
- top_k_before_softmax: int = 0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
@@ -77,8 +43,6 @@ def loss(
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
- top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
- Default: 0
"""
target_logprobs = target_logprobs.float()
@@ -88,46 +52,24 @@ def loss(
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1]
- if top_k_before_softmax:
- # Slice student logits to match teacher-provided sequence length
- student_logits_for_kd = student_logits[
- :, :teacher_seq_len, :
- ] # [B, teacher_seq_len, vocab_size]
+ # Slice student logits to match teacher-provided sequence length
+ student_logits_for_kd = (
+ student_logits[:, :teacher_seq_len, :] / kd_temperature
+ ) # [B, teacher_seq_len, vocab_size]
- # Gather student logits for teacher's top-K tokens
- student_logits_topk = torch.gather(
- student_logits_for_kd, dim=-1, index=target_token_ids
- ) # [B, teacher_seq_len, K]
+ # keep in full precision for numerical stability of loss
+ student_logits_for_kd = student_logits_for_kd.float()
- student_logits_topk = student_logits_topk.float()
+ # Gather student logits for teacher's top-K tokens
+ student_logits_topk = torch.gather(
+ student_logits_for_kd, dim=-1, index=target_token_ids
+ ) # [B, teacher_seq_len, K]
- # Apply KD temperature to student’s logits
- if kd_temperature != 1.0:
- student_logits_topk = student_logits_topk / kd_temperature
+ # Compute logsumexp across full vocabulary
+ student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
- # Convert student top-k logits to logprobs
- student_logprobs_topk = student_logits_topk - torch.logsumexp(
- student_logits_topk, dim=-1, keepdim=True
- ) # [B, teacher_seq_len, K]
- else:
- # Slice student logits to match teacher-provided sequence length
- student_logits_for_kd = (
- student_logits[:, :teacher_seq_len, :] / kd_temperature
- ) # [B, teacher_seq_len, vocab_size]
-
- # keep in full precision for numerical stability of loss
- student_logits_for_kd = student_logits_for_kd.float()
-
- # Gather student logits for teacher's top-K tokens
- student_logits_topk = torch.gather(
- student_logits_for_kd, dim=-1, index=target_token_ids
- ) # [B, teacher_seq_len, K]
-
- # Compute logsumexp across full vocabulary
- student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
-
- # Convert just the top-k logits to logprobs
- student_logprobs_topk = student_logits_topk - student_lse
+ # Convert just the top-k logits to logprobs
+ student_logprobs_topk = student_logits_topk - student_lse
# Convert teacher_mask to boolean for indexing
# In TorchScript, .bool() is sometimes unsupported, so we do:
@@ -144,10 +86,6 @@ def loss(
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
- # Multiply by T^2 (classical KD scaling)
- if kd_temperature != 1.0:
- kd_loss = kd_loss * (kd_temperature**2)
-
# Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
@@ -158,80 +96,74 @@ def loss(
return kd_loss
-def topk_kd_loss_with_zscore(
- student_logits: torch.Tensor, # [B, seq_len, vocab_size]
- target_token_ids: torch.Tensor, # [B, seq_len, K]
- target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
- target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
- kd_temperature: float = 1.0, # classic KD temperature
- zscore_base_temp: float = 1.0, # from the paper
- num_items_in_batch: int = -1,
-):
+class ChunkedTopKKDLoss(nn.Module):
"""
- A variant of top_k KL divergence with Z-score scaling
- from "Logit Standardization in Knowledge Distillation".
+ A wrapper that chunks (splits) the student and teacher outputs along the time dimension
+ to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
+
+ Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
"""
- target_logprobs = target_logprobs.float()
+ def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
+ super().__init__()
+ self.num_output_chunks = num_output_chunks
+ self.kd_temperature = kd_temperature
- B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
- # 1) Gather the student's top-k logits to match teacher
- student_logits_for_kd = student_logits[
- :, :teacher_seq_len, :
- ] # [B, seq_len, vocab]
- student_topk_logits = torch.gather(
- student_logits_for_kd, dim=-1, index=target_token_ids
- ) # [B, seq_len, K]
+ def forward(
+ self,
+ student_logits: torch.Tensor, # [B, seq_len, vocab_size]
+ target_token_ids: torch.Tensor, # [B, seq_len, K]
+ target_logprobs: torch.Tensor, # [B, seq_len, K]
+ target_mask: torch.Tensor, # [B, seq_len, K]
+ num_items_in_batch: int = -1, # optional batch size for normalization
+ ) -> torch.Tensor:
- student_topk_logits = student_topk_logits.float()
+ # 1. Split along the "token" dimension (dim=1).
+ student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
+ token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
+ logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
+ mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
- # 2) If you want to keep the "classical" T scaling, apply it first
- if kd_temperature != 1.0:
- student_topk_logits = student_topk_logits / kd_temperature
+ # We'll accumulate a global "sum of losses" and "sum of valid tokens"
+ # so that our final average is consistent with the entire sequence/batch.
+ total_loss = 0.0
+ total_valid_tokens = 0
- # 3) Convert teacher logprobs -> treat them as “logits” for z-score
- # (They differ by +some_constant from real logits, but in z-score
- # that constant is subtracted out anyway.)
- teacher_logits_for_zscore = target_logprobs # rename variable for clarity
+ # 2. Loop over each chunk and compute a chunk-specific loss.
+ for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
+ student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
+ ):
+ # We pass num_items_in_batch=-1 so that the kd_loss
+ # will average over *this chunk's* valid tokens only.
+ chunk_loss = loss(
+ student_logits=st_chunk,
+ target_token_ids=tid_chunk,
+ target_logprobs=lp_chunk,
+ target_mask=msk_chunk,
+ num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
+ kd_temperature=self.kd_temperature,
+ )
- # 4) Z-score teacher and student
- # If target_mask is 2D, expand to 3D for the K dimension
- if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
- target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
+ # kd_loss returns an average over the chunk's valid tokens.
+ # We want a global average in the end, so we need to re‐weight
+ # by the number of valid tokens in this chunk and keep track of the total.
+ chunk_valid_mask = msk_chunk.to(torch.bool)
+ chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
- teacher_z = zscore_standardize(
- teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
- )
- student_z = zscore_standardize(
- student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
- )
+ # Re-scale "chunk average" back to "chunk sum"
+ chunk_loss_sum = chunk_loss * chunk_valid_count
- # 5) Convert to log-probs for KL
- teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
- student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
+ total_loss += chunk_loss_sum
+ total_valid_tokens += chunk_valid_count
- # 6) Restrict to valid tokens if needed
- valid_mask = target_mask.bool() # shape [B, seq_len, K]
- teacher_probs_z = teacher_logprobs_z.exp()
- teacher_probs_z = teacher_probs_z[valid_mask]
- teacher_logprobs_z = teacher_logprobs_z[valid_mask]
- student_logprobs_z = student_logprobs_z[valid_mask]
+ # 3. Normalize *once* at the end.
+ if num_items_in_batch > 0:
+ # If the user gave us a manual denominator (e.g. total items in batch),
+ # we divide by it. Typically used if each item is of different length.
+ final_loss = total_loss / float(num_items_in_batch)
+ else:
+ # Otherwise, divide by total valid tokens across all chunks.
+ # to get the same result as a non-chunked approach.
+ final_loss = total_loss / float(total_valid_tokens)
- # 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
- kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
- kd_loss = kd_loss_per_token.sum()
-
- # 8) If using classical KD scaling by T^2
- if kd_temperature != 1.0:
- kd_loss = kd_loss * (kd_temperature**2)
-
- # Optionally scale by zscore_base_temp**2 if you want (paper might differ).
- # kd_loss = kd_loss * (zscore_base_temp**2)
-
- # 9) Normalize
- if num_items_in_batch is not None and num_items_in_batch > 0:
- kd_loss = kd_loss / float(num_items_in_batch)
- else:
- kd_loss = kd_loss / float(kd_loss_per_token.size(0))
-
- return kd_loss
+ return final_loss
diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py
index f99f2ca28..7ec43333a 100644
--- a/src/axolotl/integrations/kd/trainer.py
+++ b/src/axolotl/integrations/kd/trainer.py
@@ -18,8 +18,7 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer
-from .topk_logprob.forward_kl import loss as topk_kd_loss
-from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
+from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
class AxolotlKDTrainer(AxolotlTrainer):
@@ -27,6 +26,18 @@ class AxolotlKDTrainer(AxolotlTrainer):
Custom trainer subclass for Knowledge Distillation (KD)
"""
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.model_accepts_loss_kwargs = True
+ self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
+ self.args.kd_ce_alpha, # hard label loss
+ self.args.kd_alpha, # kd loss
+ self.args.kd_temperature,
+ self.args.kd_beta or 0.0,
+ compute_ce_loss=bool(self.args.kd_ce_alpha),
+ normalize_topk=self.args.kd_normalize_topk,
+ )
+
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
@@ -52,12 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior.
"""
-
- target_logprobs = inputs.pop("target_logprobs")
- target_token_ids = inputs.pop("target_token_ids")
- target_mask = inputs.pop("target_mask")
-
- seq_len = target_token_ids.shape[1]
+ if (
+ self.args.sample_packing
+ and hasattr(inputs, "attention_mask")
+ and hasattr(inputs, "position_ids")
+ ):
+ del inputs["attention_mask"]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
@@ -65,49 +76,4 @@ class AxolotlKDTrainer(AxolotlTrainer):
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
-
- # FIXME: account for tokenizer.padding_side
- student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
-
- shift_logits = student_logits.contiguous()
- target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
- target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
- target_mask_for_loss = target_mask[..., 1:, :].contiguous()
-
- if self.args.kd_zscore_base_temp:
- loss_kd = topk_kd_loss_with_zscore(
- shift_logits,
- target_token_ids_for_loss,
- target_logprobs_for_loss,
- target_mask_for_loss,
- kd_temperature=self.args.kd_temperature,
- zscore_base_temp=self.args.kd_zscore_base_temp,
- num_items_in_batch=num_items_in_batch,
- )
- else:
- loss_kd = topk_kd_loss(
- shift_logits,
- target_token_ids_for_loss,
- target_logprobs_for_loss,
- target_mask_for_loss,
- num_items_in_batch=num_items_in_batch,
- kd_temperature=self.args.kd_temperature,
- top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
- )
-
- if self.args.kd_ce_alpha > 0:
- kd_alpha = self.args.kd_alpha
- loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
- else:
- loss = loss_kd
- # Save past state if it exists
- # TODO: this needs to be fixed and made cleaner later.
- if self.args.past_index >= 0:
- self._past = outputs[ # pylint: disable=attribute-defined-outside-init
- self.args.past_index
- ]
-
- if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
- loss *= self.accelerator.num_processes
-
- return (loss, outputs) if return_outputs else loss
+ return outputs[0]
diff --git a/src/axolotl/integrations/kd/utils.py b/src/axolotl/integrations/kd/utils.py
new file mode 100644
index 000000000..ba60694a5
--- /dev/null
+++ b/src/axolotl/integrations/kd/utils.py
@@ -0,0 +1,100 @@
+"""Helper KD utils"""
+
+import math
+from typing import List, Union
+
+import numpy as np
+import torch
+from torch import FloatTensor, Tensor
+
+
+def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
+ """
+ Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
+ """
+ # Ensure raw_logprobs matches kd_online_topk length for tensor operations
+ # This should ideally be handled by the caller ensuring correct padding/truncation first
+ if logprobs.shape[-1] != topk:
+ # pad last dimension of logprobs to match topk length with -inf
+ padding_len = topk - logprobs.shape[-1]
+ padding_tensor = torch.full(
+ (
+ *logprobs.shape[:-1],
+ padding_len,
+ ), # Takes all dimensions of logprobs except the last, then appends padding_needed
+ float("-inf"),
+ dtype=logprobs.dtype,
+ device=logprobs.device,
+ )
+ logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
+
+ # Convert logprobs at T_online to probabilities
+ # use log sum exp trick to avoid underflow
+ position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
+ teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
+
+ # Normalize probabilities (sum to 1)
+ # This is important if the top-k from server aren't a full distribution
+ teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
+ teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
+
+ final_logprobs_tensor = torch.log(teacher_probs_t_online)
+
+ return final_logprobs_tensor
+
+
+def strided_chunk_views(
+ tensor: Union[np.ndarray, torch.Tensor],
+ chunks: int,
+ dim: int = 0,
+ stride: int = 1,
+ chunk_size: int | None = None,
+) -> List[Union[np.ndarray, torch.Tensor]]:
+ """
+ Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
+
+ Args:
+ tensor: Input tensor (numpy array or torch tensor)
+ chunks: Number of chunks to create
+ dim: Dimension along which to chunk (default: 0)
+ stride: Stride between chunk starting positions (default: 1)
+ chunk_size: Size of each chunk. If None, calculated automatically (default: None)
+
+ Returns:
+ List of tensor chunks (views when possible, copies when necessary)
+ """
+
+ # Get the size of the specified dimension
+ dim_size = tensor.shape[dim]
+
+ # Calculate chunk size if not provided
+ if chunk_size is None:
+ chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
+
+ chunks_list = []
+
+ for i in range(chunks):
+ start_idx = i * stride
+ end_idx = min(start_idx + chunk_size, dim_size)
+
+ # Break if we've gone beyond the tensor
+ if start_idx >= dim_size:
+ break
+
+ # Create slice objects for all dimensions
+ slices = [slice(None)] * tensor.ndim
+ slices[dim] = slice(start_idx, end_idx)
+
+ chunk = tensor[tuple(slices)]
+ chunks_list.append(chunk)
+
+ return chunks_list
+
+
+def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
+ dim_size = input_tensor.shape[dim]
+ stride = math.ceil(dim_size / chunks)
+
+ return strided_chunk_views(
+ input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
+ )
diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py
index 3cdbbb6f3..cf936481e 100644
--- a/src/axolotl/prompt_strategies/__init__.py
+++ b/src/axolotl/prompt_strategies/__init__.py
@@ -17,7 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None):
return messages_load(tokenizer, cfg, ds_cfg, processor=processor)
load_fn = "load"
package = "axolotl.prompt_strategies"
- if strategy.split(".")[-1].startswith("load_"):
+ if (
+ strategy.split(".")[-1].startswith("load_")
+ or strategy.split(".")[-1] == "load"
+ ):
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
elif len(strategy.split(".")) > 1:
diff --git a/src/axolotl/train.py b/src/axolotl/train.py
index 13ac8ec0d..fa7d56913 100644
--- a/src/axolotl/train.py
+++ b/src/axolotl/train.py
@@ -1,10 +1,13 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
+from __future__ import annotations
+
import importlib
import inspect
import os
import signal
import sys
+import typing
import weakref
from contextlib import ExitStack
from pathlib import Path
@@ -25,7 +28,6 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
-from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
@@ -45,6 +47,9 @@ try:
except ImportError:
BetterTransformer = None
+if typing.TYPE_CHECKING:
+ from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
+
LOG = get_logger(__name__)
@@ -472,7 +477,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
- HFRLTrainerBuilder | HFCausalTrainerBuilder,
+ "HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,
diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py
index 3d0ba7c9c..e669413f8 100644
--- a/src/axolotl/utils/__init__.py
+++ b/src/axolotl/utils/__init__.py
@@ -52,3 +52,10 @@ def patch_optimized_env():
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()
+
+
+def get_not_null(value, default=None):
+ """
+ return the value if it's not None, otherwise return the default value
+ """
+ return value if value is not None else default
diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py
index bf496d2c5..09bfb5576 100644
--- a/src/axolotl/utils/chat_templates.py
+++ b/src/axolotl/utils/chat_templates.py
@@ -36,7 +36,7 @@ _CHAT_TEMPLATES = {
"deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}",
"jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n',
"qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n",
- "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}",
+ "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- else %}\n {{- '\\n\\n' }}\n {%- endif %}\n{%- endif %}",
"exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}",
"metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}",
"pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message["content"] + eos_token }}\n{%- endif %}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}',
diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py
index d8414d117..a28f360be 100644
--- a/src/axolotl/utils/collators/batching.py
+++ b/src/axolotl/utils/collators/batching.py
@@ -1,7 +1,7 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
from dataclasses import dataclass
-from typing import Any
+from typing import Any, List
import numpy as np
from transformers import PreTrainedTokenizerBase
@@ -163,7 +163,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
- features = [features]
+ features: List[List[dict]] = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py
index 0ffaa932f..4f7f6f8dd 100644
--- a/src/axolotl/utils/data/utils.py
+++ b/src/axolotl/utils/data/utils.py
@@ -51,6 +51,7 @@ def retry_on_request_exceptions(
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
+ requests.exceptions.HTTPError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:
diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py
index eabfc2d84..7fb5e1b41 100644
--- a/src/axolotl/utils/samplers/multipack.py
+++ b/src/axolotl/utils/samplers/multipack.py
@@ -259,7 +259,7 @@ class MultipackBatchSampler(BatchSampler):
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
- drop_last: bool = False, # Whether to drop final batches (might be incomplete)
+ drop_last: bool = True, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 8, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
@@ -446,10 +446,18 @@ class MultipackBatchSampler(BatchSampler):
if self._len_across_ranks is None:
# Sample multiple times to get stable estimate
- len_batches = min( # pylint: disable=consider-using-generator
- [len(self._batches) for _ in range(self.num_count_samples)]
- )
+ _sampled_lens = []
+ for _ in range(self.num_count_samples):
+ self._batches = None # Reset cached batches
+ _sampled_lens.append(len(self.generate_batches(set_stats=False)))
+ len_batches = min(_sampled_lens)
+
# Gather minimum across all ranks
- self._len_across_ranks = self.gather_len_batches(len_batches)
+ if self._len_across_ranks is None:
+ self._len_across_ranks = self.gather_len_batches(len_batches)
+ else:
+ self._len_across_ranks = min(
+ self._len_across_ranks, self.gather_len_batches(len_batches)
+ )
return self._len_across_ranks
diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py
index ec5360fa3..33ddadf78 100644
--- a/src/axolotl/utils/trainer.py
+++ b/src/axolotl/utils/trainer.py
@@ -16,7 +16,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
-from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
@@ -483,6 +482,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
)
)
+ if cfg.dataloader_drop_last:
+ # drop the last batch for each epoch
+ total_num_steps -= int(math.ceil(cfg.num_epochs))
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -630,6 +632,8 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
+ from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
+
if (
cfg.torch_compile
and cfg.fsdp_config
diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py
index 2bd1fbf3d..212450e89 100644
--- a/tests/e2e/integrations/test_kd.py
+++ b/tests/e2e/integrations/test_kd.py
@@ -5,10 +5,9 @@ e2e tests for kd trainer support in Axolotl
from pathlib import Path
import pytest
+import yaml
+from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port
-from axolotl.common.datasets import load_datasets
-from axolotl.train import train
-from axolotl.utils.config import normalize_config, prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@@ -17,8 +16,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1
@pytest.fixture(name="kd_min_cfg")
def min_cfg(temp_dir):
return {
- "base_model": "osllmai-community/Llama-3.2-1B",
- "tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer",
+ "base_model": "Qwen/Qwen3-0.6B",
+ "tokenizer_config": "winglian/qwen3-14b-math",
"plugins": [
"axolotl.integrations.kd.KDPlugin",
"axolotl.integrations.liger.LigerPlugin",
@@ -31,20 +30,22 @@ def min_cfg(temp_dir):
"kd_ce_alpha": 0.1,
"kd_alpha": 0.9,
"kd_temperature": 1.0,
+ "kd_beta": 0.0,
+ "kd_normalize_topk": True,
"dataloader_prefetch_factor": 8,
"dataloader_num_workers": 4,
"dataloader_pin_memory": True,
"datasets": [
{
- "path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample",
- "type": "axolotl.integrations.kd.chat_template",
- "field_messages": "messages_combined",
+ "path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized",
+ "type": "chat_template",
"split": "train",
- "logprobs_field": "llm_text_generation_vllm_logprobs",
- "temperature": 1.0,
- "preprocess_shards": 2,
+ "split_thinking": True,
+ "eot_tokens": ["<|im_end|>"],
+ "data_files": ["train/batch-000000.parquet"],
},
],
+ "skip_prepare_dataset": True,
"val_set_size": 0.0,
"sequence_len": 2048,
"sample_packing": True,
@@ -80,17 +81,29 @@ class TestKnowledgeDistillation:
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
- cfg = validate_config(cfg)
- prepare_plugins(cfg)
- normalize_config(cfg)
- dataset_meta = load_datasets(cfg=cfg)
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
+
+ execute_subprocess_async(
+ [
+ "axolotl",
+ "train",
+ str(Path(temp_dir) / "config.yaml"),
+ "--num-processes",
+ "1",
+ "--main-process-port",
+ f"{get_torch_dist_unique_port()}",
+ ]
+ )
- train(cfg=cfg, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high"
)
+ @pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA")
@pytest.mark.parametrize(
"load_in_8bit",
[True, False],
@@ -110,12 +123,22 @@ class TestKnowledgeDistillation:
| kd_min_cfg
)
# pylint: disable=duplicate-code
- cfg = validate_config(cfg)
- prepare_plugins(cfg)
- normalize_config(cfg)
- dataset_meta = load_datasets(cfg=cfg)
+ # write cfg to yaml file
+ Path(temp_dir).mkdir(parents=True, exist_ok=True)
+ with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
+ fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper))
- train(cfg=cfg, dataset_meta=dataset_meta)
+ execute_subprocess_async(
+ [
+ "axolotl",
+ "train",
+ str(Path(temp_dir) / "config.yaml"),
+ "--num-processes",
+ "1",
+ "--main-process-port",
+ f"{get_torch_dist_unique_port()}",
+ ]
+ )
assert (Path(temp_dir) / "adapter_model.safetensors").exists()
check_tensorboard(
temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high"
diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py
index 2b03c62f8..d91f63d94 100644
--- a/tests/test_packed_batch_sampler.py
+++ b/tests/test_packed_batch_sampler.py
@@ -81,6 +81,7 @@ class TestBatchedSamplerPacking:
group_size=100000,
bin_size=200,
sequential=sequential,
+ drop_last=False,
)
loader = DataLoader(