From ccc94da8ad3e0e99e97357220a720fe095e645b4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Tue, 17 Jun 2025 12:09:13 -0400 Subject: [PATCH] KD fix w/ online distillation (#2700) [skip ci] * kd fixes * fix collator setup * fix input args * better handling to drop string fields for kd with raw dataset * kd trainer has kd temp as part of the init * drop top_k before softmax * simplfy and remove zscore * WIP chunked KD loss with autograd wrapper * more fixes and liger-type chunked loss * collator cls for plugins * remove debugging * additional plugin collator kwargs, don't scale up kd loss by t^2 * don't need temp arg to distill method * online kd wip * add close to comment block * suport sampling params/max new tokens * handle when no custom collator is used in plugins * logsumexp trick: * fix check * shift off the first empty token * fix length of padding * use max not min * temp scale kd loss at end * support for dynamic plugin training args mixins and symmetric kl * chore: lint * fix trainer callback base class * Fix decay * accept compressed responses for smaller wire payload * post-rebase lint * more KD updates * increase hyperparams_count for gradients for added normalize_topk * fix to remove attention_mask * rename vars for consistency * fix rebase issues * default to dropping last batch in multipack batch sampler * improve handling of train len * init collator_cls_and_kwargs * explicit drop_last=False when checking for multipack completeness * use separate v2 loader for kd * fix kd tests to use subprocess so it picks up kd training args * default value for kd_beta arg * use updated dataset for ci * longer timeout for e2e --- .github/workflows/tests.yml | 4 +- deepspeed_configs/zero2_torch_compile.json | 31 + src/axolotl/core/builders/causal.py | 63 +- src/axolotl/core/builders/rl.py | 17 +- src/axolotl/core/trainers/base.py | 10 +- src/axolotl/core/training_args.py | 237 +------- src/axolotl/core/training_args_base.py | 224 +++++++ src/axolotl/integrations/base.py | 88 ++- src/axolotl/integrations/config.py | 42 +- src/axolotl/integrations/kd/__init__.py | 71 +++ src/axolotl/integrations/kd/args.py | 56 +- src/axolotl/integrations/kd/callbacks.py | 36 ++ src/axolotl/integrations/kd/chat_template.py | 142 ++++- src/axolotl/integrations/kd/collator.py | 24 +- .../kd/collator_online_teacher.py | 561 ++++++++++++++++++ .../integrations/kd/kernels/__init__.py | 8 + src/axolotl/integrations/kd/kernels/liger.py | 485 +++++++++++++++ src/axolotl/integrations/kd/kernels/models.py | 98 +++ .../kd/topk_logprob/forward_kl.py | 216 +++---- src/axolotl/integrations/kd/trainer.py | 74 +-- src/axolotl/integrations/kd/utils.py | 100 ++++ src/axolotl/prompt_strategies/__init__.py | 5 +- src/axolotl/train.py | 9 +- src/axolotl/utils/__init__.py | 7 + src/axolotl/utils/chat_templates.py | 2 +- src/axolotl/utils/collators/batching.py | 4 +- src/axolotl/utils/data/utils.py | 1 + src/axolotl/utils/samplers/multipack.py | 18 +- src/axolotl/utils/trainer.py | 6 +- tests/e2e/integrations/test_kd.py | 65 +- tests/test_packed_batch_sampler.py | 1 + 31 files changed, 2178 insertions(+), 527 deletions(-) create mode 100644 deepspeed_configs/zero2_torch_compile.json create mode 100644 src/axolotl/core/training_args_base.py create mode 100644 src/axolotl/integrations/kd/callbacks.py create mode 100644 src/axolotl/integrations/kd/collator_online_teacher.py create mode 100644 src/axolotl/integrations/kd/kernels/liger.py create mode 100644 src/axolotl/integrations/kd/kernels/models.py create mode 100644 src/axolotl/integrations/kd/utils.py diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 11fe13713..bb865e98d 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -188,7 +188,7 @@ jobs: if: ${{ ! contains(github.event.commits[0].message, '[skip e2e]') && github.repository_owner == 'axolotl-ai-cloud' }} # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 90 + timeout-minutes: 120 needs: [pre-commit, pytest, pytest-sdist] strategy: @@ -238,7 +238,7 @@ jobs: if: github.repository_owner == 'axolotl-ai-cloud' # this job needs to be run on self-hosted GPU runners... runs-on: [self-hosted, modal] - timeout-minutes: 90 + timeout-minutes: 120 # Only run the remainder of the matrix if the first e2e check passed; # this is to save on wasted compute costs for known failures that get caught in the first run needs: [pre-commit, pytest, docker-e2e-tests-1st] diff --git a/deepspeed_configs/zero2_torch_compile.json b/deepspeed_configs/zero2_torch_compile.json new file mode 100644 index 000000000..c3bcf98cf --- /dev/null +++ b/deepspeed_configs/zero2_torch_compile.json @@ -0,0 +1,31 @@ +{ + "compile": { + "disable": false, + "backend": "inductor" + }, + "zero_optimization": { + "stage": 2, + "offload_optimizer": { + "device": "cpu" + }, + "contiguous_gradients": true, + "overlap_comm": true + }, + "bf16": { + "enabled": "auto" + }, + "fp16": { + "enabled": "auto", + "auto_cast": false, + "loss_scale": 0, + "initial_scale_power": 32, + "loss_scale_window": 1000, + "hysteresis": 2, + "min_loss_scale": 1 + }, + "gradient_accumulation_steps": "auto", + "gradient_clipping": "auto", + "train_batch_size": "auto", + "train_micro_batch_size_per_gpu": "auto", + "wall_clock_breakdown": false +} diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 8ff565dbb..6ed298d9f 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -21,11 +21,6 @@ from axolotl.core.trainers import ( AxolotlTrainer, ReLoRATrainer, ) -from axolotl.core.training_args import ( - AxolotlPRMConfig, - AxolotlRewardConfig, - AxolotlTrainingArguments, -) from axolotl.integrations.base import PluginManager from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.relora import ReLoRACallback @@ -130,6 +125,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return callbacks def _get_trainer_cls(self): + """ + Gets the trainer class for the given configuration. + """ if self.cfg.plugins: plugin_manager = PluginManager.get_instance() trainer_cls = plugin_manager.get_trainer_cls(self.cfg) @@ -146,6 +144,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlTrainer def build(self, total_num_steps): + from axolotl.core.training_args import ( + AxolotlPRMConfig, + AxolotlRewardConfig, + AxolotlTrainingArguments, + ) + training_arguments_kwargs, trainer_kwargs = self._set_base_training_args( total_num_steps ) @@ -314,20 +318,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["image_resize_algorithm"] = ( self.cfg.image_resize_algorithm ) - if self.cfg.kd_ce_alpha is not None: - training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha - if self.cfg.kd_alpha is not None: - training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha - if self.cfg.kd_temperature is not None: - training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature - if self.cfg.kd_zscore_base_temp is not None: - training_arguments_kwargs["kd_zscore_base_temp"] = ( - self.cfg.kd_zscore_base_temp - ) - if self.cfg.kd_top_k_before_softmax is not None: - training_arguments_kwargs["kd_top_k_before_softmax"] = ( - self.cfg.kd_top_k_before_softmax - ) + + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + plugin_training_args = plugin_manager.get_training_args(self.cfg) + if plugin_training_args: + training_arguments_kwargs.update(plugin_training_args) if self.cfg.reward_model: training_args_cls = AxolotlRewardConfig @@ -408,7 +404,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return trainer def build_collator( - self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs + self, + training_args, # type: "AxolotlTrainingArguments" # type: ignore + is_eval=False, + **kwargs, ): if training_args.pretraining: if ( @@ -437,7 +436,19 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] ] collator_args = [self.tokenizer] - if self.cfg.reward_model: + + collator_cls_and_kwargs = None + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs( + self.cfg, is_eval=is_eval + ) + + if collator_cls_and_kwargs: + collator = collator_cls_and_kwargs[0] + if kwargs and isinstance(kwargs, dict): + kwargs.update(collator_cls_and_kwargs[1]) + elif self.cfg.reward_model: collator = RewardDataCollatorWithPadding elif use_batch_sampler_collator: # Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention, @@ -468,16 +479,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) kwargs.pop("padding", None) - elif self.cfg.kd_trainer: - from axolotl.integrations.kd.collator import ( - DataCollatorForKD, - KDBatchSamplerDataCollatorForSeq2Seq, - ) - - if self.cfg.sample_packing: - collator = KDBatchSamplerDataCollatorForSeq2Seq - else: - collator = DataCollatorForKD else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 47ace7451..c5f01dd41 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -12,11 +12,6 @@ from axolotl.core.trainers import ( from axolotl.core.trainers.dpo import DPOStrategy from axolotl.core.trainers.dpo.args import AxolotlDPOConfig from axolotl.core.trainers.grpo import GRPOStrategy -from axolotl.core.training_args import ( - AxolotlCPOConfig, - AxolotlKTOConfig, - AxolotlORPOConfig, -) from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import ensure_dtype from axolotl.utils.callbacks.qat import QATCallback @@ -83,6 +78,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase): """ Returns training_args and trainer_kwargs """ + from axolotl.core.training_args import ( + AxolotlCPOConfig, + AxolotlKTOConfig, + AxolotlORPOConfig, + ) + training_args_kwargs, trainer_kwargs = self._set_base_training_args( total_num_steps=total_num_steps ) @@ -150,6 +151,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if blocklist_key in training_args_kwargs: del training_args_kwargs[blocklist_key] + if self.cfg.plugins: + plugin_manager = PluginManager.get_instance() + plugin_training_args = plugin_manager.get_training_args(self.cfg) + if plugin_training_args: + training_args_kwargs.update(plugin_training_args) + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg logging_first_step=True, **training_args_kwargs, diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index 25ffb4cbf..fbae253d6 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -34,6 +34,7 @@ from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, ) +from axolotl.utils import get_not_null from axolotl.utils.logging import get_logger from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths @@ -104,7 +105,7 @@ class AxolotlTrainer( ) batch_max_len = train_batch_size * self.args.max_seq_length - return MultipackBatchSampler( + sampler = MultipackBatchSampler( base_sampler, lengths=get_dataset_lengths(dataset), packing_efficiency_estimate=self.args.sample_packing_efficiency, @@ -117,6 +118,9 @@ class AxolotlTrainer( num_processes=self.args.dataset_num_proc, ) + len(sampler) + return sampler + def _get_train_sampler( self, train_dataset: Optional[Dataset] = None ) -> Optional[Sampler]: @@ -224,7 +228,9 @@ class AxolotlTrainer( } if not isinstance(dataset, torch.utils.data.IterableDataset): - dataloader_params["drop_last"] = self.args.dataloader_drop_last + dataloader_params["drop_last"] = get_not_null( + self.args.dataloader_drop_last, True + ) if sampler_fn is not None: sampler = sampler_fn(dataset) if isinstance(sampler, BatchSampler): diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 2b53c6798..d5be9fc62 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -2,242 +2,17 @@ extra axolotl specific training args """ -from dataclasses import dataclass, field -from typing import Optional +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Optional, Type -from PIL.Image import Resampling from transformers import TrainingArguments from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig +from axolotl.integrations.config import merge_training_args -@dataclass -class AxolotlTrainingMixins: - """ - Mixin class for the Axolotl training args. - """ - - # pylint: disable=duplicate-code - model_type: Optional[str] = field( - default=None, metadata={"help": "HF model configuration model_type."} - ) - lr_quadratic_warmup: bool = field( - default=False, - metadata={"help": "Use quadratic warmup for cosine scheduling."}, - ) - pretraining: bool = field( - default=False, - metadata={ - "help": "Indicates to trainer whether we are doing continued pretraining." - }, - ) - sample_packing: bool = field( - default=False, - metadata={"help": "Use sample packing for efficient training."}, - ) - sample_packing_sequentially: bool = field( - default=False, - metadata={ - "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." - }, - ) - multipack_real_batches: bool = field( - default=False, - metadata={"help": "Use real batches for efficient training."}, - ) - eval_sample_packing: Optional[bool] = field( - default=None, - metadata={"help": "Use sample packing for efficient evals."}, - ) - sample_packing_efficiency: float = field( - default=1.0, - metadata={"help": "Sample packing efficiency for calculating batch length."}, - ) - sample_packing_bin_size: int = field( - default=200, - metadata={ - "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." - }, - ) - sample_packing_group_size: int = field( - default=100000, - metadata={ - "help": "The number of samples to group together for packing. Increase for better packing." - }, - ) - max_seq_length: int = field( - default=2048, - metadata={"help": "The maximum sequence length the model can handle"}, - ) - dataset_num_proc: int | None = field( - default=None, - metadata={"help": "The number of processes to use for data processing"}, - ) - relora_steps: Optional[int] = field( - default=None, - metadata={"help": "how often to reset for ReLoRA"}, - ) - relora_warmup_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_anneal_steps: Optional[int] = field( - default=None, - metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, - ) - relora_prune_ratio: Optional[float] = field( - default=0.9, - metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, - ) - bench_split: Optional[str] = field( - default="eval", metadata={"help": "The benchmark split to run on"} - ) - bench_dataset: Optional[str] = field( - default="pharaouk/dharma-1/dharma_1_mini.json", - metadata={ - "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" - }, - ) - do_bench_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Benchmark evaluation."} - ) - do_causal_lm_eval: Optional[bool] = field( - default=False, metadata={"help": "Whether to run the Causal LM evaluation."} - ) - max_bench_samples: Optional[int] = field( - default=None, - metadata={ - "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." - }, - ) - bench_source_max_len: int = field( - default=2048, metadata={"help": "Maximum source sequence length for bench."} - ) - dataloader_prefetch_factor: Optional[int] = field( - default=None, - metadata={"help": "prefetch_factor argument to the dataloader"}, - ) - cosine_min_lr_ratio: Optional[float] = field( - default=None, - metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, - ) - cosine_constant_lr_ratio: Optional[float] = field( - default=None, - metadata={ - "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" - }, - ) - loraplus_lr_ratio: Optional[float] = field( - default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} - ) - loraplus_lr_embedding: Optional[float] = field( - default=1e-6, - metadata={"help": "loraplus learning rate for lora embedding layers."}, - ) - embedding_lr_scale: Optional[float] = field( - default=None, - metadata={"help": "Scale the learning rate for the embedding layers."}, - ) - lr_groups: Optional[list[dict]] = field( - default=None, - metadata={"help": "Specify learning rate groups for with different LRs."}, - ) - embedding_lr: Optional[float] = field( - default=None, - metadata={"help": "absolute learning rate for the embedding layers."}, - ) - qlora: bool = field( - default=False, - metadata={"help": "whether this is a qlora training"}, - ) - orpo_alpha: Optional[float] = field( - default=None, - ) - lisa_n_layers: Optional[int] = field( - default=None, - metadata={"help": "the number of activate layers in LISA"}, - ) - lisa_step_interval: Optional[int] = field( - default=None, - metadata={"help": "how often to switch layers in LISA"}, - ) - lisa_layers_attribute: Optional[str] = field( - default=None, - metadata={"help": "path under the model to access the layers"}, - ) - curriculum_sampling: Optional[bool] = field( - default=None, - metadata={"help": "whether to use sequential sampling for curriculum learning"}, - ) - alternate_lr_scheduler_type: Optional[str] = field( - default=None, - metadata={ - "help": "workaround to pass an alternate lr scheduler to the HF trainer" - }, - ) - chat_template: Optional[str] = field( - default=None, - metadata={"help": "Chat template converting chat messages to text"}, - ) - - kd_ce_alpha: Optional[float] = field( - default=None, - metadata={ - "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" - }, - ) - - kd_alpha: Optional[float] = field( - default=1.0, - metadata={"help": "The alpha scaling parameter for KD loss"}, - ) - - kd_temperature: Optional[float] = field( - default=1.0, - metadata={ - "help": "the temperature parameter for KL divergence loss when using KD" - }, - ) - - kd_zscore_base_temp: Optional[float] = field( - default=None, - metadata={ - "help": "the base temperature parameter for KL divergence with z-score when using KD" - }, - ) - - kd_top_k_before_softmax: Optional[bool] = field( - default=None, - metadata={ - "help": "Whether to apply top_k_before_softmax to the logits when using KD" - }, - ) - - adam_beta3: Optional[float] = field( - default=None, - metadata={ - "help": "The beta3 hyperparameter used in some optimizers such as CAME" - }, - ) - adam_epsilon2: Optional[float] = field( - default=None, - metadata={ - "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" - }, - ) - - # multi-modal section - - image_size: int | tuple[int, int] | None = field( - default=None, - metadata={"help": "The size of the image to resize to"}, - ) - - image_resize_algorithm: Resampling | None = field( - default=None, - metadata={"help": "The algorithm to use for image resizing"}, - ) - - # end of multi-modal section +AxolotlTrainingMixins: Type = merge_training_args() @dataclass diff --git a/src/axolotl/core/training_args_base.py b/src/axolotl/core/training_args_base.py new file mode 100644 index 000000000..8fcaff632 --- /dev/null +++ b/src/axolotl/core/training_args_base.py @@ -0,0 +1,224 @@ +""" +Base Axolotl Training Mixins shared across various trainer configs +""" + +from dataclasses import dataclass, field +from typing import Optional + +from PIL.Image import Resampling + + +@dataclass +class AxolotlTrainingMixins: + """ + Mixin class for the Axolotl training args. + """ + + # pylint: disable=duplicate-code + model_type: Optional[str] = field( + default=None, metadata={"help": "HF model configuration model_type."} + ) + lr_quadratic_warmup: bool = field( + default=False, + metadata={"help": "Use quadratic warmup for cosine scheduling."}, + ) + pretraining: bool = field( + default=False, + metadata={ + "help": "Indicates to trainer whether we are doing continued pretraining." + }, + ) + sample_packing: bool = field( + default=False, + metadata={"help": "Use sample packing for efficient training."}, + ) + sample_packing_sequentially: bool = field( + default=False, + metadata={ + "help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing." + }, + ) + multipack_real_batches: bool = field( + default=False, + metadata={"help": "Use real batches for efficient training."}, + ) + eval_sample_packing: Optional[bool] = field( + default=None, + metadata={"help": "Use sample packing for efficient evals."}, + ) + sample_packing_efficiency: float = field( + default=1.0, + metadata={"help": "Sample packing efficiency for calculating batch length."}, + ) + sample_packing_bin_size: int = field( + default=200, + metadata={ + "help": "The max number of samples that packed sample can contain after packing. Increase for better packing." + }, + ) + sample_packing_group_size: int = field( + default=100000, + metadata={ + "help": "The number of samples to group together for packing. Increase for better packing." + }, + ) + max_seq_length: int = field( + default=2048, + metadata={"help": "The maximum sequence length the model can handle"}, + ) + dataset_num_proc: int | None = field( + default=None, + metadata={"help": "The number of processes to use for data processing"}, + ) + relora_steps: Optional[int] = field( + default=None, + metadata={"help": "how often to reset for ReLoRA"}, + ) + relora_warmup_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_anneal_steps: Optional[int] = field( + default=None, + metadata={"help": "how many warmup steps to take after reset for ReLoRA"}, + ) + relora_prune_ratio: Optional[float] = field( + default=0.9, + metadata={"help": "prune ratio for magnitude pruning of the optimizer"}, + ) + bench_split: Optional[str] = field( + default="eval", metadata={"help": "The benchmark split to run on"} + ) + bench_dataset: Optional[str] = field( + default="pharaouk/dharma-1/dharma_1_mini.json", + metadata={ + "help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file" + }, + ) + do_bench_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Benchmark evaluation."} + ) + do_causal_lm_eval: Optional[bool] = field( + default=False, metadata={"help": "Whether to run the Causal LM evaluation."} + ) + max_bench_samples: Optional[int] = field( + default=None, + metadata={ + "help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset." + }, + ) + bench_source_max_len: int = field( + default=2048, metadata={"help": "Maximum source sequence length for bench."} + ) + dataloader_prefetch_factor: Optional[int] = field( + default=None, + metadata={"help": "prefetch_factor argument to the dataloader"}, + ) + cosine_min_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"}, + ) + cosine_constant_lr_ratio: Optional[float] = field( + default=None, + metadata={ + "help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps" + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."} + ) + loraplus_lr_embedding: Optional[float] = field( + default=1e-6, + metadata={"help": "loraplus learning rate for lora embedding layers."}, + ) + embedding_lr_scale: Optional[float] = field( + default=None, + metadata={"help": "Scale the learning rate for the embedding layers."}, + ) + lr_groups: Optional[list[dict]] = field( + default=None, + metadata={"help": "Specify learning rate groups for with different LRs."}, + ) + embedding_lr: Optional[float] = field( + default=None, + metadata={"help": "absolute learning rate for the embedding layers."}, + ) + qlora: bool = field( + default=False, + metadata={"help": "whether this is a qlora training"}, + ) + orpo_alpha: Optional[float] = field( + default=None, + ) + lisa_n_layers: Optional[int] = field( + default=None, + metadata={"help": "the number of activate layers in LISA"}, + ) + lisa_step_interval: Optional[int] = field( + default=None, + metadata={"help": "how often to switch layers in LISA"}, + ) + lisa_layers_attribute: Optional[str] = field( + default=None, + metadata={"help": "path under the model to access the layers"}, + ) + curriculum_sampling: Optional[bool] = field( + default=None, + metadata={"help": "whether to use sequential sampling for curriculum learning"}, + ) + alternate_lr_scheduler_type: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate lr scheduler to the HF trainer" + }, + ) + chat_template: Optional[str] = field( + default=None, + metadata={"help": "Chat template converting chat messages to text"}, + ) + + # kd_ce_alpha: Optional[float] = field( + # default=None, + # metadata={ + # "help": "The alpha scaling parameter for SFT cross entropy loss when using KD" + # }, + # ) + # + # kd_alpha: Optional[float] = field( + # default=1.0, + # metadata={"help": "The alpha scaling parameter for KD loss"}, + # ) + # + # kd_temperature: Optional[float] = field( + # default=1.0, + # metadata={ + # "help": "the temperature parameter for KL divergence loss when using KD" + # }, + # ) + + adam_beta3: Optional[float] = field( + default=None, + metadata={ + "help": "The beta3 hyperparameter used in some optimizers such as CAME" + }, + ) + adam_epsilon2: Optional[float] = field( + default=None, + metadata={ + "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" + }, + ) + + # multi-modal section + + image_size: int | tuple[int, int] | None = field( + default=None, + metadata={"help": "The size of the image to resize to"}, + ) + + image_resize_algorithm: Resampling | None = field( + default=None, + metadata={"help": "The algorithm to use for image resizing"}, + ) + + # end of multi-modal section diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index 0edc9fdea..9162bc745 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -22,6 +22,7 @@ from __future__ import annotations import collections import importlib +import traceback from typing import TYPE_CHECKING, Callable, OrderedDict, Union from peft import PeftModel @@ -83,6 +84,11 @@ class BasePlugin: def get_input_args(self) -> str | None: """Returns a pydantic model for the plugin's input arguments.""" + def get_training_args_mixin(self) -> str | None: + """ + Returns a dataclass model for the plugin's training arguments. + """ + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -158,6 +164,31 @@ class BasePlugin: trainer: The trainer object for training. """ + def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument): + """ + Returns custom training arguments to set on TrainingArgs. + + Args: + cfg: The global axolotl configuration. + + Returns: + object: dict containing the training arguments. + """ + + def get_collator_cls_and_kwargs( + self, cfg: DictDefault, is_eval: bool = False + ): # pylint: disable=unused-argument): + """ + Returns a custom class for the collator. + + Args: + cfg: The global axolotl configuration. + is_eval: Whether this is an eval split. + + Returns: + class: The class for the collator. + """ + # pylint: disable=unused-argument def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None: """Creates and returns an optimizer for training. @@ -278,7 +309,7 @@ def load_plugin(plugin_name: str) -> BasePlugin: return plugin -class PluginManager: +class PluginManager: # pylint: disable=too-many-public-methods """The `PluginManager` class is responsible for loading and managing plugins. It should be a singleton so it can be accessed from anywhere in the codebase. @@ -337,8 +368,11 @@ class PluginManager: plugin = load_plugin(plugin_name) self.plugins[plugin_name] = plugin LOG.info(f"Plugin loaded successfully: {plugin_name}") - except ImportError: + except ImportError as exc: LOG.error(f"Failed to load plugin: {plugin_name}") + # print stacktrace + traceback.print_exc() + print(f"Error: {exc}") def get_input_args(self) -> list[str]: """Returns a list of Pydantic classes for all registered plugins' input arguments.' @@ -353,6 +387,20 @@ class PluginManager: input_args.append(input_args_from_plugin) return input_args + def get_training_args_mixin(self): + """ + Returns a list of dataclasses for all registered plugins' training args mixins' + + Returns: + list[str]: A list of dataclsses + """ + training_args = [] + for plugin in self.plugins.values(): + training_args_from_plugin = plugin.get_training_args_mixin() + if training_args_from_plugin is not None: + training_args.append(training_args_from_plugin) + return training_args + def load_datasets( self, cfg: DictDefault, preprocess: bool = False ) -> Union["TrainDatasetMeta", None]: @@ -442,6 +490,42 @@ class PluginManager: return trainer_cls return None + def get_training_args(self, cfg): + """ + Calls the get_training_args method of all registered plugins and returns the combined training arguments. + + Parameters: + cfg (dict): The configuration for the plugins. + + Returns: + object: The training arguments + """ + training_args_kwargs = {} + for plugin in self.plugins.values(): + training_args = plugin.get_training_args(cfg) + if training_args is not None: + training_args_kwargs.update(training_args) + + return training_args_kwargs + + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): + """ + Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class. + + Parameters: + cfg (dict): The configuration for the plugins. + is_eval (bool): Whether this is an eval split. + + Returns: + object: The collator class, or None if none was found. + """ + for plugin in self.plugins.values(): + collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval) + if collator is not None: + collator_cls, collator_kwargs = collator + return collator_cls, collator_kwargs + return None + def post_trainer_create(self, cfg: DictDefault, trainer: Trainer): """Calls the `post_trainer_create` method of all registered plugins. diff --git a/src/axolotl/integrations/config.py b/src/axolotl/integrations/config.py index b443f228e..f5fc07e9e 100644 --- a/src/axolotl/integrations/config.py +++ b/src/axolotl/integrations/config.py @@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio This was moved here to prevent circular imports. """ -from typing import Any, Dict, List +from typing import Any, Dict, List, Type from axolotl.utils.schemas.config import ( AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase, @@ -61,3 +61,43 @@ def merge_input_args(): ] return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase + + +def merge_training_args() -> Type: + """ + Merges training arguments from registered plugins with the base TrainingArguments. + + This function retrieves the training arguments from registered plugins using the PluginManager. + It then dynamically creates new classes, AxolotlTrainingMixins, + that inherit from the base configurations and include the training arguments from the plugins. + + Returns: + tuple: A tuple containing the newly created classes, AxolotlTrainingMixins. + """ + # pylint: disable=duplicate-code + from axolotl.core.training_args_base import ( + AxolotlTrainingMixins as AxolotlTrainingMixinsBase, + ) + from axolotl.integrations.base import PluginManager + + plugin_manager = PluginManager.get_instance() + training_args_mixins: List[str] = plugin_manager.get_training_args_mixin() + mixin_classes = [] + dynamic_input = "" + for plugin_args in training_args_mixins: + plugin_module, plugin_cls = plugin_args.rsplit(".", 1) + dynamic_input += f"from {plugin_module} import {plugin_cls}\n" + mixin_classes.append(plugin_cls) + if dynamic_input: + dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n" + + namespace: Dict[Any, Any] = {} + local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase} + exec( # pylint: disable=exec-used # nosec B102 + dynamic_input, {**globals(), **local_vars}, namespace + ) + AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name + "AxolotlTrainingMixins" + ] + return AxolotlTrainingMixins + return AxolotlTrainingMixinsBase diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index 8a6e3eda1..4c8535a0a 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -15,7 +15,12 @@ """ Plugin init to add KD support to Axolotl. """ +from typing import Any + +from transformers import Trainer + from axolotl.integrations.base import BasePlugin +from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback from .args import KDArgs # pylint: disable=unused-import. # noqa: F401 @@ -28,9 +33,75 @@ class KDPlugin(BasePlugin): def get_input_args(self): return "axolotl.integrations.kd.KDArgs" + def get_training_args_mixin(self): + return "axolotl.integrations.kd.args.KDTrainingArgsMixin" + def get_trainer_cls(self, cfg): if cfg.kd_trainer: from .trainer import AxolotlKDTrainer return AxolotlKDTrainer return None + + def get_training_args(self, cfg): + return { + "kd_ce_alpha": cfg.kd_ce_alpha, + "kd_alpha": cfg.kd_alpha, + "kd_temperature": cfg.kd_temperature, + "kd_beta": cfg.kd_beta, + "kd_normalize_topk": cfg.kd_normalize_topk, + } + + def get_collator_cls_and_kwargs(self, cfg, is_eval=False): + if not cfg.kd_trainer: + return None, None + + from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq + + use_batch_sampler_collator = False + if is_eval is False and cfg.sample_packing: + use_batch_sampler_collator = True + if cfg.eval_sample_packing and is_eval: + use_batch_sampler_collator = True + + if cfg.kd_online_server_base_url: + from .collator_online_teacher import OnlineTeacherCollator + + return OnlineTeacherCollator, { + "kd_online_server_base_url": cfg.kd_online_server_base_url, + "kd_online_topk": cfg.kd_online_topk, + "kd_temperature": cfg.kd_temperature, + "kd_online_server": cfg.kd_online_server, + "kd_online_timeout": cfg.kd_online_timeout, + "kd_normalize_topk": cfg.kd_normalize_topk, + } + + if use_batch_sampler_collator: + return KDBatchSamplerDataCollatorForSeq2Seq, {} + return DataCollatorForKD, {} + + def pre_model_load(self, cfg): + from .kernels.models import apply_kernel + + apply_kernel(cfg.model_config_type) + + def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list: + """ + Adds temp scheduler callback to the Trainer instance. + + Args: + cfg (Any): Configuration object containing the sparse recipe. + trainer (Trainer): Huggingface Trainer instance. + + Returns: + list: List containing the configured callback instances. + """ + if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url: + callback = KDTemperatureSchedulerCallback( + cfg.kd_temperature, + cfg.kd_temperature_min, + trainer, + ) + return [callback] + + return [] diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py index 2fbba2c6a..758bc8917 100644 --- a/src/axolotl/integrations/kd/args.py +++ b/src/axolotl/integrations/kd/args.py @@ -15,9 +15,19 @@ """ Plugin args for KD support. """ -from typing import Optional +from dataclasses import dataclass +from enum import Enum -from pydantic import BaseModel +from pydantic import BaseModel, Field + + +class InferenceServerType(str, Enum): + """ + Online inferences server types to handle different request args + """ + + vllm = "vllm" # pylint: disable=invalid-name + sglang = "sglang" # pylint: disable=invalid-name class KDArgs(BaseModel): @@ -25,13 +35,41 @@ class KDArgs(BaseModel): Input args for knowledge distillation. """ - kd_trainer: Optional[bool] = None # whether to use KD trainer - kd_ce_alpha: Optional[float] = ( + kd_trainer: float | None = None # whether to use KD trainer + kd_ce_alpha: float | None = ( None # loss coefficient for cross-entropy loss during KD ) - kd_alpha: Optional[float] = None # loss coefficient for KD loss - kd_temperature: Optional[float] = None # temperature for sampling during KD - kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling - kd_top_k_before_softmax: Optional[bool] = ( - None # whether to sample top k before softmax during KD + kd_alpha: float | None = None # loss coefficient for KD loss + kd_temperature: float | None = None # temperature for sampling during KD + kd_beta: float | None = 0.0 # beta coefficient for ratio of fwd and reverse KL + kd_normalize_topk: bool | None = ( + None # whether to normalize student logits during KD + ) + + # TODO online kd + kd_online_server_base_url: str | None = None + kd_online_topk: int | None = None + kd_online_server: InferenceServerType | None = Field( + default_factory=lambda: InferenceServerType.vllm + ) + kd_online_timeout: int | None = 120 + kd_temperature_min: float | None = ( + None # kd temperature scheduling during online kd + ) + + +@dataclass +class KDTrainingArgsMixin: + """ + Additional args for KD training. + """ + + kd_ce_alpha: float | None = ( + None # loss coefficient for cross-entropy loss during KD + ) + kd_alpha: float | None = None # loss coefficient for KD loss + kd_temperature: float | None = None # temperature for sampling during KD + kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL + kd_normalize_topk: float | None = ( + None # whether to normalize student logits during KD ) diff --git a/src/axolotl/integrations/kd/callbacks.py b/src/axolotl/integrations/kd/callbacks.py new file mode 100644 index 000000000..911c3d517 --- /dev/null +++ b/src/axolotl/integrations/kd/callbacks.py @@ -0,0 +1,36 @@ +""" +Transformers trainer callbacks to schedule the KD temperature during training +""" + +import math + +from transformers.trainer_callback import TrainerCallback + + +class KDTemperatureSchedulerCallback(TrainerCallback): + """ + KD temperature scheduler callback for the trainer. + """ + + def __init__(self, temperature_start, temperature_min, trainer): + self.temperature_start = temperature_start + self.temperature_min = temperature_min + self.temperature = temperature_start + + self.trainer = trainer + + def on_step_end( + self, args, state, control, **kwargs + ): # pylint: disable=unused-argument + # cosine decay temperature over the max steps + + progress = state.global_step / state.max_steps + # Cosine decay factor: 0.5 * (1 + cos(pi * progress)) + # This factor goes from 1 (at progress=0) to 0 (at progress=1) + decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress)) + self.temperature = self.temperature_start - ( + (self.temperature_start - self.temperature_min) * (1.0 - decay_factor) + ) + + if hasattr(self.trainer.data_collator, "kd_temperature"): + self.trainer.data_collator.kd_temperature = self.temperature diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 7c99a9c3d..f99dfe458 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -15,12 +15,15 @@ """ Chat template prompt strategy loader with KD support """ +import logging from typing import Any, Dict import torch from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader +LOG = logging.getLogger(__name__) + class ChatTemplateStrategyWithKD(ChatTemplateStrategy): """ @@ -101,10 +104,8 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # fill with -inf for padding_len tokens for top_k tokens # extend target_logprobs with a padding_len x top_k 2D list filled with -inf - # for causal models, if we start the range at 1, then we don't need to shift in the trainer - # otherwise, we need to shift in the trainer - shift = 0 - for _ in range(shift, input_padding_len): + # we shift for causal models in the trainer, so start the range from 0 + for _ in range(0, input_padding_len): target_logprobs.append([-float("inf")] * top_k) target_token_ids.append(list(range(top_k))) target_mask.append([0] * top_k) @@ -143,6 +144,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): # # Convert from log to probability teacher_probs_t1 = position_logprobs_tensor.exp() + # normalize probabilities to sum to 1 in case they aren't already + teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True) + if teacher_probs_t1_sum > 1e-9: + teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum if self.kd_temperature != self.gen_temperature: # Exponentiate by factor (T1 / T2) exponent = self.gen_temperature / self.kd_temperature @@ -162,12 +167,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): target_logprobs.append(position_logprobs_scaled) target_token_ids.append(position_token_ids) - if shift == 1: - # since we started at index 1 for causal, we need one more padding token - target_logprobs.append([-float("inf")] * top_k) - target_token_ids.append(list(range(top_k))) - target_mask.append([0] * top_k) - # Update sample with transformed logprobs sample["target_logprobs"] = target_logprobs sample["target_token_ids"] = target_token_ids @@ -184,6 +183,117 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): return tokenized_prompt +class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD): + """ + Strat for datasets with complete structured KD logprob data + """ + + def transform_logprobs(self, sample): + """ + Transform logprobs to target format for KD training + """ + # pylint: disable=duplicate-code + + logprobs = sample.pop(self.logprobs_field) + target_seq_len = len(logprobs) + input_seq_len = len(sample["input_ids"]) + input_padding_len = input_seq_len - target_seq_len + # get non-zero top-k (prune None logprobs from vllm data step) + top_k_vals = [ + len(logprobs[i]) + for i in range(len(logprobs)) + if logprobs[i] is not None and len(logprobs[i]) + ] + max_top_k = max(set(top_k_vals), key=top_k_vals.count) + min_top_k = min(set(top_k_vals), key=top_k_vals.count) + top_k = min(max_top_k, min_top_k) + if top_k == 0: + raise ValueError("No non-zero top-k logprobs found.") + + target_logprobs = [] + target_token_ids = [] + target_mask = [] + + if input_padding_len < 0: + # logprobs is longer than target_seq_len, + # so we need to slice from the left/beginning of logprobs + logprobs = logprobs[:-input_seq_len] + input_padding_len = 0 + # target_seq_len = input_seq_len + + # truncate the second dimension of the logprobs to top_k + logprobs = [row[:top_k] for row in logprobs] + + # fill with -inf for padding_len tokens for top_k tokens + # extend target_logprobs with a padding_len x top_k 2D list filled with -inf + + # we shift for causal models in the trainer, so start the range from 0 + for _ in range(0, input_padding_len): + target_logprobs.append([-float("inf")] * top_k) + target_token_ids.append(list(range(top_k))) + target_mask.append([0] * top_k) + + for position in range(input_padding_len, input_seq_len): + if sample["labels"][position] == -100: + target_mask.append([0] * top_k) + else: + target_mask.append([1] * top_k) + + for token_pos_logprobs, pos_target_token_ids in zip( + logprobs, sample["target_token_ids"] + ): + # Convert to a tensor for easier manipulation + position_logprobs_tensor = torch.tensor( + token_pos_logprobs, dtype=torch.float + ) + + # Now we have distribution at T1 in log form, i.e. log p_{T1}(k). + # Next, re-scale to T2 = self.kd_temperature via exponent-based trick + # p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z + # + # Convert from log to probability + teacher_probs_t1 = position_logprobs_tensor.exp() + # normalize probabilities to sum to 1 in case they aren't already + teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True) + if teacher_probs_t1_sum > 1e-9: + teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum + if self.kd_temperature != self.gen_temperature: + # Exponentiate by factor (T1 / T2) + exponent = self.gen_temperature / self.kd_temperature + teacher_probs_t2 = teacher_probs_t1**exponent + else: + teacher_probs_t2 = teacher_probs_t1 + # Re-normalize + teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum( + dim=0, keepdim=True + ) + # Convert back to log + position_logprobs_tensor = torch.log(teacher_probs_t2) + + # Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor + position_logprobs_scaled = position_logprobs_tensor.tolist() + + target_logprobs.append(position_logprobs_scaled) + target_token_ids.append(pos_target_token_ids) + + # Update sample with transformed logprobs + sample["target_logprobs"] = target_logprobs + sample["target_token_ids"] = target_token_ids + sample["target_mask"] = target_mask + + return sample + + def _tokenize_single_prompt(self, prompt): + logprobs = prompt.pop(self.logprobs_field) + target_token_ids = prompt.pop("target_token_ids") + tokenized_prompt = super()._tokenize_single_prompt(prompt) + tokenized_prompt[self.logprobs_field] = logprobs + tokenized_prompt["target_token_ids"] = target_token_ids + tokenized_prompt = self.transform_logprobs(tokenized_prompt) + + return tokenized_prompt + + class KDStrategyLoader(StrategyLoader): """ Load ChatTemplateStrategy with KD support using StrategyLoader. @@ -204,4 +314,14 @@ class KDStrategyLoader(StrategyLoader): return strategy_params -load = KDStrategyLoader() +class KDStrategyLoaderV2(KDStrategyLoader): + """ + Load KD chat template datasets with pre-tokenized logprob data + """ + + def _get_strategy_cls(self, cfg): # pylint: disable=unused-argument + return ChatTemplateStrategyWithKDv2 + + +load_legacy = KDStrategyLoader() +load = KDStrategyLoaderV2() diff --git a/src/axolotl/integrations/kd/collator.py b/src/axolotl/integrations/kd/collator.py index de63869c7..0cc745b78 100644 --- a/src/axolotl/integrations/kd/collator.py +++ b/src/axolotl/integrations/kd/collator.py @@ -47,11 +47,16 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): position_pad_token_id: int = 0 return_tensors: str = "pt" + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + def __call__(self, features, return_tensors=None): if return_tensors is None: return_tensors = self.return_tensors padding_side = self.tokenizer.padding_side + max_len = 0 # Pad labels and position_ids first for feature_name, pad_token_id in [ @@ -102,7 +107,9 @@ class DataCollatorForKD(DataCollatorForSeq2Seq): target_mask_list.append(f.pop("target_mask")) # Determine max lengths - max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list) + max_teacher_seq_len = max_len or max( + len(seq) for seq in target_logprobs_list + ) max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq) padded_target_logprobs = [] @@ -209,7 +216,9 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): # We want to produce a single "merged" feature dict for each sub-batch. out_features = [{} for _ in features] - for i, sub_features in enumerate(features): + for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks + features + ): # sub_features is a list of dicts, each dict = one sequence’s features # We'll merge them into out_features[i]. # @@ -243,10 +252,17 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD): # For example, input_ids or labels are often arrays. arrays = [] for feat in sub_features: - if field_name in feat: + if field_name in feat and isinstance( + feat[field_name], (list, torch.Tensor) + ): + if isinstance( + feat[field_name][0], (dict, str) + ): # pylint: disable=too-many-nested-blocks + continue arr = np.array(feat[field_name]) arrays.append(arr) - out_features[i][field_name] = np.concatenate(arrays) + if arrays: + out_features[i][field_name] = np.concatenate(arrays) # 3) Now call the parent collator, which will do: # - padding of labels/position_ids diff --git a/src/axolotl/integrations/kd/collator_online_teacher.py b/src/axolotl/integrations/kd/collator_online_teacher.py new file mode 100644 index 000000000..584ace481 --- /dev/null +++ b/src/axolotl/integrations/kd/collator_online_teacher.py @@ -0,0 +1,561 @@ +""" +Packed data loader for online teacher training supporting vllm and sglang. +""" + +import hashlib +import hmac +import logging +from typing import Any, Dict, List, Optional + +import requests +import torch +from orjson import orjson + +from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq +from axolotl.integrations.kd.utils import normalize_logprobs +from axolotl.utils.data.utils import retry_on_request_exceptions + +LOG = logging.getLogger(__name__) + + +def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256): + """ + Create HMAC-SHA hash from a list of integers + + Args: + int_list: List of integers + key: Secret key (string or bytes) + hash_func: Hash function (default: sha256) + + Returns: + HMAC digest as hex string + """ + # Convert key to bytes if it's a string + if isinstance(key, str): + key = key.encode("utf-8") + + # Convert list of ints to bytes + # Method 1: Convert each int to bytes and concatenate + data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list) + + # Create HMAC + h = hmac.new(key, data, hash_func) + return h.hexdigest() + + +class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq): + """ + Collator for online teacher training. + """ + + DEFAULT_LABEL_PAD_TOKEN_ID: int = -100 + + def __init__( + self, + *args: Any, + kd_online_server_base_url: Optional[str] = None, + kd_online_topk: Optional[int] = None, + kd_temperature: Optional[float] = 1.0, + kd_online_server: Optional[str] = "vllm", + kd_online_timeout: Optional[int] = 120, + kd_cache_dir: Optional[str] = None, + kd_normalize_topk: Optional[bool] = True, + **kwargs: Any, + ): + super().__init__(*args, **kwargs) + + if kd_online_server_base_url is None: + raise ValueError( + "kd_online_server_base_url must be provided for OnlineTeacherDataloader" + ) + if kd_online_topk is None or kd_online_topk <= 0: + raise ValueError( + "kd_online_topk must be a positive integer for OnlineTeacherDataloader" + ) + + self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/") + self.kd_online_topk = kd_online_topk + self.kd_temperature = kd_temperature + self.kd_online_server = kd_online_server + self.http_session = requests.Session() + self.kd_online_timeout = kd_online_timeout + self.kd_cache_dir = kd_cache_dir + self.kd_normalize_topk = kd_normalize_topk + + def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]: + """ + Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. + """ + if not raw_logprobs or self.kd_online_topk == 0: + return ( + [-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else [] + ) + + raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32) + return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist() + + @retry_on_request_exceptions(max_retries=10, delay=5) + def fetch_online_logprobs_sglang( + self, batch_input_ids: List[List[int]], labels: List[List[int]] + ): + """ + Fetches logprobs from an online teacher served by sglang for a batch of input_ids. + Assumes API returns token IDs as strings in logprob dictionary keys. + """ + api_endpoint = f"{self.kd_online_server_base_url}/generate" + + payload = { + "input_ids": batch_input_ids, + "return_logprob": True, + "top_logprobs_num": self.kd_online_topk, + "logprob_start_len": 0, + "return_text_in_logprobs": True, + "echo": True, + "sampling_params": { + "max_new_tokens": 0, + "temperature": self.kd_temperature, + "skip_special_tokens": False, + }, + } + + # Initialize with empty lists, so if API call fails, these are returned. + ret_data_target_token_ids: List[List[List[int]]] = [] + ret_data_target_logprobs: List[List[List[float]]] = [] + ret_data_target_mask: List[List[List[int]]] = [] + + try: + response = self.http_session.post( + api_endpoint, json=payload, timeout=self.kd_online_timeout + ) + response.raise_for_status() + api_data: list[dict] = response.json() + + # Ensure api_data is a list, and its length matches batch_input_ids + if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids): + LOG.error( + f"API response format error. Expected a list of {len(batch_input_ids)} " + f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." + ) + # Return empty data; items processed later will get default empty KD fields + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + for sequence_data, seq_input_ids, seq_labels in zip( + api_data, batch_input_ids, labels + ): + current_target_logprobs = [] + current_target_token_ids = [] + current_target_mask = [] + + meta_info = sequence_data.pop("meta_info", {}) + # Ensure input_top_logprobs is a list + input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop( + "input_top_logprobs", [] + ) + if not isinstance(input_top_logprobs, list): + LOG.warning( + f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." + ) + input_top_logprobs = [] # Treat as empty + + # basic check that the logprob data len matches the input len, so no need to handle padding + assert len(seq_input_ids) == len(input_top_logprobs) + + for i, _, label in zip( + range(len(seq_input_ids)), seq_input_ids, seq_labels + ): + if i < len(input_top_logprobs) and input_top_logprobs[i] is None: + # this is always the case for the first token. + # there is never logprob data for the first token since that's a true input + # so we replace the None value with padding data + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + elif ( + i < len(input_top_logprobs) + and input_top_logprobs[i] is not None + ): + pos_top_logprobs_data = input_top_logprobs[i] + # Ensure pos_top_logprobs_data is a list of lists as expected + if not ( + isinstance(pos_top_logprobs_data, list) + and all( + isinstance(item, list) for item in pos_top_logprobs_data + ) + and len(pos_top_logprobs_data) > 0 + and len(pos_top_logprobs_data[0]) == 3 + ): # [logprob, token_id, token_str] + LOG.warning( + f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position." + ) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + continue + + # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids + pos_logprobs_raw, pos_token_ids, _ = [ + list(row) for row in zip(*pos_top_logprobs_data) + ] + + # Ensure correct length (top_k) + if len(pos_logprobs_raw) < self.kd_online_topk: + pad_len = self.kd_online_topk - len(pos_logprobs_raw) + pos_logprobs_raw.extend([-float("inf")] * pad_len) + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + + # truncate to top_k in case the response was longer + current_target_token_ids.append( + pos_token_ids[: self.kd_online_topk] + ) + + if self.kd_normalize_topk: + normalized_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) + current_target_logprobs.append( + normalized_logprobs_for_position + ) + else: + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) + + # Mask depends on the corresponding label for the student + if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: + current_target_mask.append([0] * self.kd_online_topk) + else: + current_target_mask.append([1] * self.kd_online_topk) + else: + # Pad if no logprobs for this position (either due to length mismatch or None entry) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append([0] * self.kd_online_topk) + current_target_mask.append([0] * self.kd_online_topk) + + ret_data_target_token_ids.append(current_target_token_ids) + ret_data_target_logprobs.append(current_target_logprobs) + ret_data_target_mask.append(current_target_mask) + + except requests.exceptions.RequestException as e: + LOG.error(f"Error fetching logprobs from online teacher: {e}") + raise e + # ret_logprobs_data will be returned with empty lists, handled by the caller. + except Exception as e: # Catch other potential errors during processing + LOG.error( + f"Unexpected error processing API response in fetch_online_logprobs: {e}", + exc_info=True, + ) + raise e + + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + @retry_on_request_exceptions(max_retries=10, delay=5) + def fetch_online_logprobs_vllm( + self, batch_input_ids: List[List[int]], labels: List[List[int]] + ): + """ + Fetches logprobs from an online teacher served by vllm for a batch of input_ids. + Assumes API returns token IDs as strings in logprob dictionary keys. + """ + api_endpoint = f"{self.kd_online_server_base_url}/v1/completions" + + payload = { + "prompt": batch_input_ids, + "echo": True, + "logprobs": True, + "prompt_logprobs": self.kd_online_topk, + "top_logprobs": self.kd_online_topk, + "max_new_tokens": 0, + "skip_special_tokens": False, + "temperature": self.kd_temperature, + "sampling_params": { + "max_tokens": 0, + }, + } + + # Initialize with empty lists, so if API call fails, these are returned. + ret_data_target_token_ids: List[List[List[int]]] = [] + ret_data_target_logprobs: List[List[List[float]]] = [] + ret_data_target_mask: List[List[List[int]]] = [] + + try: + headers = {"Accept-Encoding": "deflate, gzip, br, zstd"} + response = self.http_session.post( + api_endpoint, + json=payload, + headers=headers, + timeout=self.kd_online_timeout, + ) + response.raise_for_status() + api_data: dict = orjson.loads(response.content) + choices: list[dict] = api_data["choices"] + + # Ensure api_data is a list, and its length matches batch_input_ids + if not isinstance(choices, list) or len(choices) != len(batch_input_ids): + LOG.error( + f"API response format error. Expected a list of {len(batch_input_ids)} " + f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}." + ) + # Return empty data; items processed later will get default empty KD fields + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + for sequence_data, seq_input_ids, seq_labels in zip( + choices, batch_input_ids, labels + ): + # seq_input_ids: List[int] + # seq_labels: List[int] + + current_target_logprobs = [] + current_target_token_ids = [] + current_target_mask = [] + + # Ensure input_top_logprobs is a list + input_top_logprobs: Optional[list[None | dict[str, dict]]] = ( + sequence_data.pop("prompt_logprobs", []) + ) + + if not isinstance(input_top_logprobs, list): + LOG.warning( + f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence." + ) + input_top_logprobs = [] # Treat as empty + + # basic check that the logprob data len matches the input len, so no need to handle padding + assert len(seq_input_ids) == len(input_top_logprobs) + + seq_len = len(seq_input_ids) + + for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels): + if i < len(input_top_logprobs) and input_top_logprobs[i] is None: + # this is always the case for the first token. + # there is never logprob data for the first token since that's a true input + continue + if ( + i < len(input_top_logprobs) + and input_top_logprobs[i] is not None + ): + pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment] + # Ensure pos_top_logprobs_data is a list of lists as expected + if not ( + isinstance(pos_top_logprobs_data, dict) + and all( + isinstance(item, dict) + for item in pos_top_logprobs_data.values() + ) + and len(pos_top_logprobs_data.keys()) > 0 + ): # [logprob, token_id, token_str] + LOG.warning( + f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position." + ) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) + current_target_mask.append([0] * self.kd_online_topk) + continue + + # pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids + pos_token_ids_str = list(pos_top_logprobs_data.keys()) + pos_logprobs_dict = pos_top_logprobs_data.values() + pos_token_ids = [ + int(token_id) for token_id in pos_token_ids_str + ] + pos_logprobs_raw = [ + float(logprob.get("logprob", -float("inf"))) + for logprob in pos_logprobs_dict + ] + + # Ensure correct length (top_k) + if len(pos_logprobs_raw) < self.kd_online_topk: + pad_len = self.kd_online_topk - len(pos_logprobs_raw) + LOG.warning( + f"Padding position {i} with {pad_len} top-k tokens and logprobs." + ) + pos_logprobs_raw.extend([-float("inf")] * pad_len) + pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id + + # truncate to top_k in case the response was longer + current_target_token_ids.append( + pos_token_ids[: self.kd_online_topk] + ) + + if self.kd_normalize_topk: + normalized_logprobs_for_position = self._normalize_logprobs( + pos_logprobs_raw[: self.kd_online_topk] + ) + current_target_logprobs.append( + normalized_logprobs_for_position + ) + else: + current_target_logprobs.append( + pos_logprobs_raw[: self.kd_online_topk] + ) + + # Mask depends on the corresponding label for the student + if label == self.DEFAULT_LABEL_PAD_TOKEN_ID: + current_target_mask.append([0] * self.kd_online_topk) + else: + current_target_mask.append([1] * self.kd_online_topk) + else: + # Pad if no logprobs for this position (either due to length mismatch or None entry) + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append( + list(range(self.kd_online_topk)) + ) + current_target_mask.append([0] * self.kd_online_topk) + for i in range(max(0, seq_len - len(current_target_logprobs))): + current_target_logprobs.append( + [-float("inf")] * self.kd_online_topk + ) + current_target_token_ids.append(list(range(self.kd_online_topk))) + current_target_mask.append([0] * self.kd_online_topk) + + ret_data_target_token_ids.append(current_target_token_ids) + ret_data_target_logprobs.append(current_target_logprobs) + ret_data_target_mask.append(current_target_mask) + + # TODO save and load targets to disk for caching for next epoch + # generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int + # if self.kd_cache_dir: + # hash_input_ids = hmac_sha_from_int_list( + # seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}" + # ) + # with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f: + # pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False) + + except requests.exceptions.RequestException as e: + LOG.error(f"Error fetching logprobs from online teacher: {e}") + raise e + # ret_logprobs_data will be returned with empty lists, handled by the caller. + except Exception as e: # Catch other potential errors during processing + LOG.error( + f"Unexpected error processing API response in fetch_online_logprobs: {e}", + exc_info=True, + ) + raise e + + return { + "target_token_ids": ret_data_target_token_ids, + "target_logprobs": ret_data_target_logprobs, + "target_mask": ret_data_target_mask, + } + + def __call__( + self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None + ) -> Dict[str, Any]: + if not features: + return super().__call__(features, return_tensors=return_tensors) + + for ( + sub_batch_features + ) in features: # sub_batch_features is List[Dict[str, Any]] + if not sub_batch_features: + continue + + input_ids_for_api_call: List[List[int]] = [] + labels_for_api_call: List[List[int]] = [] + # Store references to the original item dictionaries to update them in-place + items_for_api_call: List[Dict[str, Any]] = [] + + for item_dict in sub_batch_features: + if not isinstance(item_dict, dict): + LOG.warning( + f"Skipping non-dict item in sub_batch_features: {item_dict}" + ) + continue + + current_input_ids = item_dict.get("input_ids") + current_labels = item_dict.get("labels") + + if current_input_ids is not None and current_labels is not None: + # Ensure input_ids and labels are lists of ints for JSON serialization + input_ids_list = ( + current_input_ids.tolist() + if hasattr(current_input_ids, "tolist") + else list(current_input_ids) + ) + labels_list = ( + current_labels.tolist() + if hasattr(current_labels, "tolist") + else list(current_labels) + ) + + input_ids_for_api_call.append(input_ids_list) + labels_for_api_call.append(labels_list) + items_for_api_call.append(item_dict) + else: + # This item will not get teacher logprobs from the API. + # Initialize KD fields to empty lists so downstream collators handle them uniformly. + item_dict.setdefault("target_token_ids", []) + item_dict.setdefault("target_logprobs", []) + item_dict.setdefault("target_mask", []) + + # print(items_for_api_call) + if items_for_api_call: # Only call API if there's something to process + if self.kd_online_server == "sglang": + api_responses_for_sub_batch = self.fetch_online_logprobs_sglang( + input_ids_for_api_call, labels_for_api_call + ) + else: + api_responses_for_sub_batch = self.fetch_online_logprobs_vllm( + input_ids_for_api_call, labels_for_api_call + ) + + # api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask" + # Each value is a list, corresponding to items_for_api_call + for i, item_to_update in enumerate(items_for_api_call): + # TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly. + if api_responses_for_sub_batch and i < len( + api_responses_for_sub_batch["target_token_ids"] + ): # Check bounds + assert len( + api_responses_for_sub_batch["target_token_ids"][i] + ) == len(item_to_update["input_ids"]) + assert len( + api_responses_for_sub_batch["target_logprobs"][i] + ) == len(item_to_update["input_ids"]) + assert len( + api_responses_for_sub_batch["target_mask"][i] + ) == len(item_to_update["labels"]) + item_to_update["target_token_ids"] = ( + api_responses_for_sub_batch["target_token_ids"][i] + ) + item_to_update["target_logprobs"] = api_responses_for_sub_batch[ + "target_logprobs" + ][i] + item_to_update["target_mask"] = api_responses_for_sub_batch[ + "target_mask" + ][i] + else: + # API call failed for this item, or response was shorter than expected. + # Ensure KD fields are initialized as empty lists. + LOG.warning( + f" (index {i}), or API response was too short. " + f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}" + ) + item_to_update.setdefault("target_token_ids", []) + item_to_update.setdefault("target_logprobs", []) + item_to_update.setdefault("target_mask", []) + + return super().__call__(features, return_tensors=return_tensors) diff --git a/src/axolotl/integrations/kd/kernels/__init__.py b/src/axolotl/integrations/kd/kernels/__init__.py index e69de29bb..3f1144a45 100644 --- a/src/axolotl/integrations/kd/kernels/__init__.py +++ b/src/axolotl/integrations/kd/kernels/__init__.py @@ -0,0 +1,8 @@ +""" +Liger Chunked loss optimizations module +""" + +from .liger import LigerFusedLinearKLTopKLogprobLoss +from .models import apply_kernel + +__all__ = ["LigerFusedLinearKLTopKLogprobLoss", "apply_kernel"] diff --git a/src/axolotl/integrations/kd/kernels/liger.py b/src/axolotl/integrations/kd/kernels/liger.py new file mode 100644 index 000000000..6356643c2 --- /dev/null +++ b/src/axolotl/integrations/kd/kernels/liger.py @@ -0,0 +1,485 @@ +""" +Liger Kernels for Chunked Top-K Log-Prob Distillation +""" + +import torch +import torch.nn.functional as F +from liger_kernel.chunked_loss.fused_linear_distillation import ( + LigerFusedLinearDistillationBase, +) + +from axolotl.integrations.kd.utils import normalize_logprobs + + +class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase): + """ + Chunked kl-div loss for top-k logprobs + """ + + @staticmethod + def distillation_loss_fn( + student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled + target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k] + target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs + target_mask_chunk: torch.Tensor, # [chunk_size, top_k] + beta: float = 0.0, + normalize_topk: bool = True, + ) -> torch.Tensor: + """ + Compute Top-K KL divergence loss for a chunk. + Args: + student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V). + target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K). + target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K). + target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K). + beta: Controls the type of KL divergence. + 0.0 for Forward KL (P_teacher || P_student). + 1.0 for Reverse KL (P_student || P_teacher). + 0.5 for Symmetric KL (average of Forward and Reverse). + normalize_topk: Whether to normalize the log probabilities + Returns: + Sum of KL divergence losses for the chunk. + """ + topk = target_token_ids_chunk.shape[-1] + student_logits_temp_scaled = ( # [chunk_size, vocab_size] + student_logits_temp_scaled.float() + ) + target_logprobs_chunk = target_logprobs_chunk.float() + + # Gather student logits for the top-k teacher token IDs + # target_token_ids_chunk: [chunk_size, top_k] + # student_logits_topk_temp_scaled: [chunk_size, top_k] + student_logits_topk_temp_scaled = torch.gather( + student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk + ) + + # Student log-probabilities for the gathered top-k tokens + student_lse = torch.logsumexp( + student_logits_temp_scaled, dim=-1, keepdim=True + ) # [chunk_size, 1] + student_logprobs_topk_temp_scaled = ( + student_logits_topk_temp_scaled - student_lse + ) + + # we have the top-k student logprobs, normalize them + if normalize_topk: + student_logprobs_topk_temp_scaled = normalize_logprobs( + student_logprobs_topk_temp_scaled, topk + ) + + valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k] + + student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask] + teacher_logprobs_valid = target_logprobs_chunk[valid_mask] + + # Teacher probabilities P(y|x_teacher) from logprobs + # target_logprobs_valid are already normalized (log(softmax(teacher_logits/T))) + teacher_probs_valid = teacher_logprobs_valid.exp() + # Student probabilities P_student from log P_student + student_probs_topk_valid = student_logprobs_topk_valid.exp() + + # kd_loss_per_token = torch.zeros_like(target_logprobs_valid) + + # KL divergence: sum(P_teacher * (log P_teacher - log P_student)) + # = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student) + # The distillation loss is often formulated as -sum(P_teacher * log P_student) + # or as sum(P_teacher * (log_softmax_teacher - log_softmax_student)) + # Here, target_logprobs_valid are log_softmax_teacher. + # student_logprobs_topk_valid are log_softmax_student (for the selected K indices). + if beta == 0.0: # Contribution from Forward KL + fwd_kl_per_token = teacher_probs_valid * ( + teacher_logprobs_valid - student_logprobs_topk_valid + ) + kd_loss = fwd_kl_per_token.sum() + elif beta == 1.0: # Contribution from Reverse KL + rev_kl_per_token = student_probs_topk_valid * ( + student_logprobs_topk_valid - teacher_logprobs_valid + ) + kd_loss = rev_kl_per_token.sum() + else: + # JSD - Jensen-Shannon Divergence / Symmetric + mean_probs = ( + 1 - beta + ) * student_probs_topk_valid + beta * teacher_probs_valid + log_mean_probs = mean_probs.log() + student_kl = F.kl_div( + log_mean_probs, + student_logprobs_topk_valid, + reduction="sum", + log_target=True, + ) + teacher_kl = F.kl_div( + log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True + ) + jsd_loss = beta * teacher_kl + (1 - beta) * student_kl + kd_loss = jsd_loss + + return kd_loss + + @staticmethod + def _compute_loss_kl_topk( + student_input_chunk: torch.Tensor, + student_weight: torch.Tensor, + # Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value + # or through `partial`. Let's make them explicit here for clarity. + target_token_ids_chunk: torch.Tensor, + target_logprobs_chunk: torch.Tensor, + target_mask_chunk: torch.Tensor, + target_chunk: torch.Tensor, # For hard loss (true labels) + student_bias: torch.Tensor = None, # This will be one of the grad targets + # Other params passed via `partial` from `forward` + distillation_loss_fn=None, + ignore_index: int = -100, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + compute_ce_loss: bool = True, + temperature: float = 1.0, + beta: float = 0.0, + normalize_topk: bool = True, + ): + # Compute student logits for the chunk from hidden states and LM head + # student_input_chunk: [chunk_size, hidden_dim] + # student_lm_head_weight: [vocab_size, hidden_dim] + # student_logits_chunk: [chunk_size, vocab_size] + student_logits_chunk = F.linear( + student_input_chunk, student_weight, student_bias + ) + + ce_loss = torch.tensor( + 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype + ) + if compute_ce_loss and weight_hard_loss > 0.0: + ce_loss = F.cross_entropy( + student_logits_chunk.view(-1, student_logits_chunk.shape[-1]), + target_chunk.view(-1), + reduction="sum", + ignore_index=ignore_index, + ) + + soft_loss = torch.tensor( + 0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype + ) + if weight_soft_loss > 0.0: + student_logits_chunk_temp_scaled = student_logits_chunk / temperature + + # Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max() + # No explicit padding here; user must ensure vocab alignment or pre-pad student_weight. + + soft_loss = distillation_loss_fn( + student_logits_chunk_temp_scaled, + target_token_ids_chunk, + target_logprobs_chunk, + target_mask_chunk, + beta=beta, + normalize_topk=normalize_topk, + ) + + return soft_loss, ce_loss + + @classmethod + def forward( + cls, + ctx, + student_input: torch.Tensor, # [batch_size, seq_len, dim] + student_lm_head_weight: torch.Tensor, # [dim, vocab_size] + target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k] + target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k] + target_mask: torch.Tensor, # [batch_size, seq_len, top_k] + true_labels: torch.Tensor, # [batch_size, seq_len] + student_lm_head_bias: torch.Tensor = None, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + ignore_index: int = -100, + temperature: float = 1.0, + beta: float = 0.0, + compiled: bool = False, + chunk_size: int = 1024, + compute_ce_loss: bool = True, + normalize_topk: bool = True, + ): + CHUNK_SIZE = chunk_size # pylint: disable=invalid-name + grad_weight_acc = torch.zeros_like(student_lm_head_weight) + grad_inputs_list = [] + grad_bias_acc = ( + torch.zeros_like(student_lm_head_bias) + if student_lm_head_bias is not None + else None + ) + kd_loss_acc = torch.zeros( + (), device=student_input.device, dtype=student_input.dtype + ) + ce_loss_acc = torch.zeros( + (), device=student_input.device, dtype=student_input.dtype + ) + + # This function will be what torch.func.grad_and_value differentiates. + # It takes student_input_chunk, student_weight (full), student_bias (full) as primals. + # Other necessary data (target_*, etc.) are passed as non-differentiable arguments. + def loss_fn_for_grad( + _student_input_chunk, + _student_lm_head_weight, # full weight + _student_lm_head_bias, # full bias + # Fixed arguments for a given chunk, not differentiated: + _target_token_ids_chunk, + _target_logprobs_chunk, + _target_mask_chunk, + _true_labels_chunk, + ): + return cls._compute_loss_kl_topk( + student_input_chunk=_student_input_chunk, + student_weight=_student_lm_head_weight, + target_token_ids_chunk=_target_token_ids_chunk, + target_logprobs_chunk=_target_logprobs_chunk, + target_mask_chunk=_target_mask_chunk, + target_chunk=_true_labels_chunk, + student_bias=_student_lm_head_bias, + distillation_loss_fn=cls.distillation_loss_fn, + ignore_index=ignore_index, + weight_hard_loss=weight_hard_loss, + weight_soft_loss=weight_soft_loss, + compute_ce_loss=compute_ce_loss, + temperature=temperature, + beta=beta, + normalize_topk=normalize_topk, + ) + + def accumulate_chunk_grads( + student_input_chunk_ac, + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ): + # student_weight and student_bias are closed over from the outer scope (full tensors) + if student_lm_head_bias is not None: + ( + (chunk_grad_input, chunk_grad_weight, chunk_grad_bias), + (chunk_kd_loss, chunk_ce_loss), + ) = torch.func.grad_and_value( + loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True + )( + student_input_chunk_ac, + student_lm_head_weight, + student_lm_head_bias, # primals + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ) # non-primals + grad_bias_acc.add_(chunk_grad_bias) + else: + argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight + ( + (chunk_grad_input, chunk_grad_weight), # No grad for bias + (chunk_kd_loss, chunk_ce_loss), + ) = torch.func.grad_and_value( + loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True + )( + student_input_chunk_ac, + student_lm_head_weight, + None, # Pass None for student_bias primal + target_token_ids_chunk_ac, + target_logprobs_chunk_ac, + target_mask_chunk_ac, + true_labels_chunk_ac, + ) + + grad_weight_acc.add_(chunk_grad_weight) + kd_loss_acc.add_(chunk_kd_loss) + ce_loss_acc.add_(chunk_ce_loss) + + return chunk_grad_input + + if compiled: + accumulate_chunk_grads_compiled = torch.compile( + accumulate_chunk_grads, dynamic=True, backend="inductor" + ) # dynamic=True often helpful + else: + accumulate_chunk_grads_compiled = accumulate_chunk_grads + + # Use the same chunking logic as LigerFusedLinearDistillationBase.forward + B, N, D = student_input.shape # pylint: disable=invalid-name + K = target_token_ids.shape[-1] # pylint: disable=invalid-name + + student_input_flat = student_input.reshape(-1, student_input.shape[-1]) + target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1]) + target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1]) + target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1]) + # pad and shift for cross entropy loss + true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index) + true_labels_flat = true_labels[:, 1:].contiguous().view(-1) + + num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE) + + _student_input_chunks = torch.chunk( + student_input_flat, chunks=num_chunks, dim=0 + ) + _target_token_ids_chunks = torch.chunk( + target_token_ids_flat, chunks=num_chunks, dim=0 + ) + _target_logprobs_chunks = torch.chunk( + target_logprobs_flat, chunks=num_chunks, dim=0 + ) + _target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0) + _true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0) + + for i in range(num_chunks): + grad_input_chunk = accumulate_chunk_grads_compiled( + _student_input_chunks[i], + _target_token_ids_chunks[i], + _target_logprobs_chunks[i], + _target_mask_chunks[i], + _true_labels_chunks[i], + ) + grad_inputs_list.append(grad_input_chunk) + + grad_inputs_combined = torch.cat(grad_inputs_list, dim=0) + ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc) + + # For matching None returns in backward for non-tensor/non-grad_requiring inputs + ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature + ctx.bias_was_none = student_lm_head_bias is None + ctx.orig_dims = (B, N, D, K) + + # since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum + # we still need to scale the kd_loss by the temp^2 + kd_loss_acc = kd_loss_acc * (temperature**2) + final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc + + return final_loss + + @staticmethod + def backward(ctx, grad_output): + grad_input_flat, grad_weight, grad_bias_maybe = ( + ctx.saved_tensors + ) # grad_input_flat is (B*N, D) + + # Scale gradients by grad_output if it's not 1.0 + if not torch.equal( + grad_output, + torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype), + ): + grad_input_flat = grad_input_flat * grad_output + grad_weight = grad_weight * grad_output + if grad_bias_maybe is not None: + grad_bias_maybe = grad_bias_maybe * grad_output + + # Reshape grad_input_flat to match original student_input shape (B, N, D) + # ctx.orig_dims stores (B, N, D, K) + # We need the first three dimensions for student_input's shape. + # Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors + if ( + ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0 + and grad_input_flat.numel() == 0 + ): + # If original input was empty, gradient should also be empty with correct shape + grad_input_reshaped = torch.zeros( + ctx.orig_dims[0], + ctx.orig_dims[1], + ctx.orig_dims[2], + dtype=grad_input_flat.dtype, + device=grad_input_flat.device, + ) + elif grad_input_flat.numel() == 0 and not ( + ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0 + ): + # This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad) + # but as a safeguard: + grad_input_reshaped = torch.zeros( + ctx.orig_dims[0], + ctx.orig_dims[1], + ctx.orig_dims[2], + dtype=grad_input_flat.dtype, + device=grad_input_flat.device, + ) + else: + grad_input_reshaped = grad_input_flat.view( + ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2] + ) + + nones_for_hyperparams = [None] * ctx.hyperparams_count + grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None + + return ( + grad_input_reshaped, # Gradient for student_input (reshaped) + grad_weight, # Gradient for student_lm_head_weight + None, # Gradient for target_token_ids + None, # Gradient for target_logprobs + None, # Gradient for target_mask + None, # Gradient for true_labels + grad_bias_return, # Gradient for student_lm_head_bias + *nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss + ) + + +class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module): + """ + wrapper for chunked top-k logprob kl-d + """ + + def __init__( + self, + weight_hard_loss: float = 0.5, + weight_soft_loss: float = 0.5, + temperature: float = 1.0, # This is the kd_temperature + beta: float = 1.0, + ignore_index: int = -100, + compiled: bool = True, + chunk_size: int = 1024, + compute_ce_loss: bool = True, + normalize_topk: bool = True, + ): + super().__init__() + if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0): + raise ValueError("Loss weights must be between 0.0 and 1.0.") + if temperature <= 0: + raise ValueError("Temperature must be positive.") + + self.weight_hard_loss = weight_hard_loss + self.weight_soft_loss = weight_soft_loss + self.temperature = temperature + self.beta = beta + self.ignore_index = ignore_index + self.compiled = compiled + self.chunk_size = chunk_size + self.compute_ce_loss = compute_ce_loss + self.normalize_topk = normalize_topk + + if not self.compute_ce_loss and self.weight_hard_loss > 0.0: + print( + f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero." + ) + # self.weight_hard_loss = 0.0 # Or let user manage this + if self.weight_soft_loss == 0.0: + print( + "Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed." + ) + + def forward( + self, + lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head + student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head + target_token_ids: torch.Tensor, + target_logprobs: torch.Tensor, + target_mask: torch.Tensor, + true_labels: torch.Tensor, + student_bias: torch.Tensor = None, + ) -> torch.Tensor: + return LigerFusedLinearKLTopKLogprobFunction.apply( + student_hidden_states, + lm_head_weight, + target_token_ids, + target_logprobs, + target_mask, + true_labels, + student_bias, + self.weight_hard_loss, + self.weight_soft_loss, + self.ignore_index, + self.temperature, + self.beta, + self.compiled, + self.chunk_size, + self.compute_ce_loss, + self.normalize_topk, + ) diff --git a/src/axolotl/integrations/kd/kernels/models.py b/src/axolotl/integrations/kd/kernels/models.py new file mode 100644 index 000000000..bfd752964 --- /dev/null +++ b/src/axolotl/integrations/kd/kernels/models.py @@ -0,0 +1,98 @@ +""" +model patcher for chunked top-k kl-div +""" + +from types import MethodType +from typing import Optional, Union, Unpack + +import torch +from transformers import Cache +from transformers.modeling_flash_attention_utils import FlashAttentionKwargs +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.utils import LossKwargs + + +class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs): + """ + placeholder kwargs for hf model classes + """ + + +def kldiv_forward_llama_like( + self, + input_ids: Optional[torch.LongTensor] = None, + target_logprobs: Optional[torch.Tensor] = None, + target_token_ids: Optional[torch.LongTensor] = None, + target_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Cache] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument + **kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc] +) -> CausalLMOutputWithPast: + # pylint: disable=duplicate-code + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + # TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100 + # self.loss_function should be LigerFusedLinearKLTopKLogprobLoss + + loss = self.loss_function( + self.lm_head.weight, + hidden_states, + target_token_ids, + target_logprobs, + target_mask, + true_labels=labels, + ) + num_items_in_batch = kwargs.pop("num_items_in_batch", -1) + if num_items_in_batch is not None and num_items_in_batch > 0: + loss = loss / num_items_in_batch + + return CausalLMOutputWithPast( + loss=loss, + logits=None, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + +def apply_kernel(model_type): + # Dynamically import the module and attention class + module_path = f"transformers.models.{model_type}.modeling_{model_type}" + model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")]) + module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"]) + model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM") + model_cls.forward = MethodType(kldiv_forward_llama_like, model_cls) diff --git a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py index 3c9515091..74184455f 100644 --- a/src/axolotl/integrations/kd/topk_logprob/forward_kl.py +++ b/src/axolotl/integrations/kd/topk_logprob/forward_kl.py @@ -16,40 +16,7 @@ loss for top_k KL divergence """ import torch - - -def zscore_standardize( - logits: torch.Tensor, - mask: torch.Tensor = None, - base_temperature: float = 1.0, - eps: float = 1e-9, -): - """ - Z-score standardize along the last dimension of `logits`. - i.e., for each [B, seq_len] row, across K entries: - z = (logits - mean) / std, - then scale by 1 / base_temperature if desired. - - mask can be broadcastable or None. If None, we standardize all elements. - """ - if mask is None: - # shape: [B, seq_len, K] - # Mean and std over dim=-1 - mean = logits.mean(dim=-1, keepdim=True) - var = logits.var(dim=-1, unbiased=False, keepdim=True) - else: - # If you have to exclude some tokens, multiply by mask, etc. - float_mask = mask.to(logits.dtype) - count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0) - mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count - var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count - - std = torch.sqrt(var.clamp_min(eps)) - z = (logits - mean) / std - - # Scale by 1 / base_temperature - z = z / base_temperature - return z +from torch import nn @torch.jit.script @@ -60,7 +27,6 @@ def loss( target_mask: torch.Tensor, num_items_in_batch: int = -1, # Use -1 to indicate "None" kd_temperature: float = 1.0, - top_k_before_softmax: int = 0, ) -> torch.Tensor: """ A KD loss function that is TorchScript-friendly. @@ -77,8 +43,6 @@ def loss( num_items_in_batch (int, optional): The number of items in the batch. kd_temperature (float, optional): The temperature for KD. Default: 1.0 - top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits - Default: 0 """ target_logprobs = target_logprobs.float() @@ -88,46 +52,24 @@ def loss( # student_logits shape: [B, student_seq_len, vocab_size] teacher_seq_len = target_token_ids.shape[1] - if top_k_before_softmax: - # Slice student logits to match teacher-provided sequence length - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, teacher_seq_len, vocab_size] + # Slice student logits to match teacher-provided sequence length + student_logits_for_kd = ( + student_logits[:, :teacher_seq_len, :] / kd_temperature + ) # [B, teacher_seq_len, vocab_size] - # Gather student logits for teacher's top-K tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] + # keep in full precision for numerical stability of loss + student_logits_for_kd = student_logits_for_kd.float() - student_logits_topk = student_logits_topk.float() + # Gather student logits for teacher's top-K tokens + student_logits_topk = torch.gather( + student_logits_for_kd, dim=-1, index=target_token_ids + ) # [B, teacher_seq_len, K] - # Apply KD temperature to student’s logits - if kd_temperature != 1.0: - student_logits_topk = student_logits_topk / kd_temperature + # Compute logsumexp across full vocabulary + student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) - # Convert student top-k logits to logprobs - student_logprobs_topk = student_logits_topk - torch.logsumexp( - student_logits_topk, dim=-1, keepdim=True - ) # [B, teacher_seq_len, K] - else: - # Slice student logits to match teacher-provided sequence length - student_logits_for_kd = ( - student_logits[:, :teacher_seq_len, :] / kd_temperature - ) # [B, teacher_seq_len, vocab_size] - - # keep in full precision for numerical stability of loss - student_logits_for_kd = student_logits_for_kd.float() - - # Gather student logits for teacher's top-K tokens - student_logits_topk = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, teacher_seq_len, K] - - # Compute logsumexp across full vocabulary - student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True) - - # Convert just the top-k logits to logprobs - student_logprobs_topk = student_logits_topk - student_lse + # Convert just the top-k logits to logprobs + student_logprobs_topk = student_logits_topk - student_lse # Convert teacher_mask to boolean for indexing # In TorchScript, .bool() is sometimes unsupported, so we do: @@ -144,10 +86,6 @@ def loss( kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk) kd_loss = kd_loss_per_token.sum() - # Multiply by T^2 (classical KD scaling) - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - # Normalize by number of items (if provided) or by valid tokens if num_items_in_batch > 0: kd_loss = kd_loss / float(num_items_in_batch) @@ -158,80 +96,74 @@ def loss( return kd_loss -def topk_kd_loss_with_zscore( - student_logits: torch.Tensor, # [B, seq_len, vocab_size] - target_token_ids: torch.Tensor, # [B, seq_len, K] - target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space - target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len] - kd_temperature: float = 1.0, # classic KD temperature - zscore_base_temp: float = 1.0, # from the paper - num_items_in_batch: int = -1, -): +class ChunkedTopKKDLoss(nn.Module): """ - A variant of top_k KL divergence with Z-score scaling - from "Logit Standardization in Knowledge Distillation". + A wrapper that chunks (splits) the student and teacher outputs along the time dimension + to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies. + + Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs. """ - target_logprobs = target_logprobs.float() + def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0): + super().__init__() + self.num_output_chunks = num_output_chunks + self.kd_temperature = kd_temperature - B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name - # 1) Gather the student's top-k logits to match teacher - student_logits_for_kd = student_logits[ - :, :teacher_seq_len, : - ] # [B, seq_len, vocab] - student_topk_logits = torch.gather( - student_logits_for_kd, dim=-1, index=target_token_ids - ) # [B, seq_len, K] + def forward( + self, + student_logits: torch.Tensor, # [B, seq_len, vocab_size] + target_token_ids: torch.Tensor, # [B, seq_len, K] + target_logprobs: torch.Tensor, # [B, seq_len, K] + target_mask: torch.Tensor, # [B, seq_len, K] + num_items_in_batch: int = -1, # optional batch size for normalization + ) -> torch.Tensor: - student_topk_logits = student_topk_logits.float() + # 1. Split along the "token" dimension (dim=1). + student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1) + token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1) + logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1) + mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1) - # 2) If you want to keep the "classical" T scaling, apply it first - if kd_temperature != 1.0: - student_topk_logits = student_topk_logits / kd_temperature + # We'll accumulate a global "sum of losses" and "sum of valid tokens" + # so that our final average is consistent with the entire sequence/batch. + total_loss = 0.0 + total_valid_tokens = 0 - # 3) Convert teacher logprobs -> treat them as “logits” for z-score - # (They differ by +some_constant from real logits, but in z-score - # that constant is subtracted out anyway.) - teacher_logits_for_zscore = target_logprobs # rename variable for clarity + # 2. Loop over each chunk and compute a chunk-specific loss. + for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip( + student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks + ): + # We pass num_items_in_batch=-1 so that the kd_loss + # will average over *this chunk's* valid tokens only. + chunk_loss = loss( + student_logits=st_chunk, + target_token_ids=tid_chunk, + target_logprobs=lp_chunk, + target_mask=msk_chunk, + num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens + kd_temperature=self.kd_temperature, + ) - # 4) Z-score teacher and student - # If target_mask is 2D, expand to 3D for the K dimension - if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len): - target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K) + # kd_loss returns an average over the chunk's valid tokens. + # We want a global average in the end, so we need to re‐weight + # by the number of valid tokens in this chunk and keep track of the total. + chunk_valid_mask = msk_chunk.to(torch.bool) + chunk_valid_count = chunk_valid_mask.sum() # scalar tensor - teacher_z = zscore_standardize( - teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp - ) - student_z = zscore_standardize( - student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp - ) + # Re-scale "chunk average" back to "chunk sum" + chunk_loss_sum = chunk_loss * chunk_valid_count - # 5) Convert to log-probs for KL - teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True) - student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True) + total_loss += chunk_loss_sum + total_valid_tokens += chunk_valid_count - # 6) Restrict to valid tokens if needed - valid_mask = target_mask.bool() # shape [B, seq_len, K] - teacher_probs_z = teacher_logprobs_z.exp() - teacher_probs_z = teacher_probs_z[valid_mask] - teacher_logprobs_z = teacher_logprobs_z[valid_mask] - student_logprobs_z = student_logprobs_z[valid_mask] + # 3. Normalize *once* at the end. + if num_items_in_batch > 0: + # If the user gave us a manual denominator (e.g. total items in batch), + # we divide by it. Typically used if each item is of different length. + final_loss = total_loss / float(num_items_in_batch) + else: + # Otherwise, divide by total valid tokens across all chunks. + # to get the same result as a non-chunked approach. + final_loss = total_loss / float(total_valid_tokens) - # 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] ) - kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z) - kd_loss = kd_loss_per_token.sum() - - # 8) If using classical KD scaling by T^2 - if kd_temperature != 1.0: - kd_loss = kd_loss * (kd_temperature**2) - - # Optionally scale by zscore_base_temp**2 if you want (paper might differ). - # kd_loss = kd_loss * (zscore_base_temp**2) - - # 9) Normalize - if num_items_in_batch is not None and num_items_in_batch > 0: - kd_loss = kd_loss / float(num_items_in_batch) - else: - kd_loss = kd_loss / float(kd_loss_per_token.size(0)) - - return kd_loss + return final_loss diff --git a/src/axolotl/integrations/kd/trainer.py b/src/axolotl/integrations/kd/trainer.py index f99f2ca28..7ec43333a 100644 --- a/src/axolotl/integrations/kd/trainer.py +++ b/src/axolotl/integrations/kd/trainer.py @@ -18,8 +18,7 @@ KD trainer from axolotl.core.trainers.base import AxolotlTrainer -from .topk_logprob.forward_kl import loss as topk_kd_loss -from .topk_logprob.forward_kl import topk_kd_loss_with_zscore +from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss class AxolotlKDTrainer(AxolotlTrainer): @@ -27,6 +26,18 @@ class AxolotlKDTrainer(AxolotlTrainer): Custom trainer subclass for Knowledge Distillation (KD) """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.model_accepts_loss_kwargs = True + self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss( + self.args.kd_ce_alpha, # hard label loss + self.args.kd_alpha, # kd loss + self.args.kd_temperature, + self.args.kd_beta or 0.0, + compute_ce_loss=bool(self.args.kd_ce_alpha), + normalize_topk=self.args.kd_normalize_topk, + ) + def _set_signature_columns_if_needed(self): super()._set_signature_columns_if_needed() columns_to_add = [] @@ -52,12 +63,12 @@ class AxolotlKDTrainer(AxolotlTrainer): Subclass and override for custom behavior. """ - - target_logprobs = inputs.pop("target_logprobs") - target_token_ids = inputs.pop("target_token_ids") - target_mask = inputs.pop("target_mask") - - seq_len = target_token_ids.shape[1] + if ( + self.args.sample_packing + and hasattr(inputs, "attention_mask") + and hasattr(inputs, "position_ids") + ): + del inputs["attention_mask"] if self.model_accepts_loss_kwargs: loss_kwargs = {} @@ -65,49 +76,4 @@ class AxolotlKDTrainer(AxolotlTrainer): loss_kwargs["num_items_in_batch"] = num_items_in_batch inputs = {**inputs, **loss_kwargs} outputs = model(**inputs) - - # FIXME: account for tokenizer.padding_side - student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous() - - shift_logits = student_logits.contiguous() - target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous() - target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous() - target_mask_for_loss = target_mask[..., 1:, :].contiguous() - - if self.args.kd_zscore_base_temp: - loss_kd = topk_kd_loss_with_zscore( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - kd_temperature=self.args.kd_temperature, - zscore_base_temp=self.args.kd_zscore_base_temp, - num_items_in_batch=num_items_in_batch, - ) - else: - loss_kd = topk_kd_loss( - shift_logits, - target_token_ids_for_loss, - target_logprobs_for_loss, - target_mask_for_loss, - num_items_in_batch=num_items_in_batch, - kd_temperature=self.args.kd_temperature, - top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0, - ) - - if self.args.kd_ce_alpha > 0: - kd_alpha = self.args.kd_alpha - loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd - else: - loss = loss_kd - # Save past state if it exists - # TODO: this needs to be fixed and made cleaner later. - if self.args.past_index >= 0: - self._past = outputs[ # pylint: disable=attribute-defined-outside-init - self.args.past_index - ] - - if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs: - loss *= self.accelerator.num_processes - - return (loss, outputs) if return_outputs else loss + return outputs[0] diff --git a/src/axolotl/integrations/kd/utils.py b/src/axolotl/integrations/kd/utils.py new file mode 100644 index 000000000..ba60694a5 --- /dev/null +++ b/src/axolotl/integrations/kd/utils.py @@ -0,0 +1,100 @@ +"""Helper KD utils""" + +import math +from typing import List, Union + +import numpy as np +import torch +from torch import FloatTensor, Tensor + + +def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor: + """ + Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs. + """ + # Ensure raw_logprobs matches kd_online_topk length for tensor operations + # This should ideally be handled by the caller ensuring correct padding/truncation first + if logprobs.shape[-1] != topk: + # pad last dimension of logprobs to match topk length with -inf + padding_len = topk - logprobs.shape[-1] + padding_tensor = torch.full( + ( + *logprobs.shape[:-1], + padding_len, + ), # Takes all dimensions of logprobs except the last, then appends padding_needed + float("-inf"), + dtype=logprobs.dtype, + device=logprobs.device, + ) + logprobs = torch.cat((logprobs, padding_tensor), dim=-1) + + # Convert logprobs at T_online to probabilities + # use log sum exp trick to avoid underflow + position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True) + teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse) + + # Normalize probabilities (sum to 1) + # This is important if the top-k from server aren't a full distribution + teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True) + teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum + + final_logprobs_tensor = torch.log(teacher_probs_t_online) + + return final_logprobs_tensor + + +def strided_chunk_views( + tensor: Union[np.ndarray, torch.Tensor], + chunks: int, + dim: int = 0, + stride: int = 1, + chunk_size: int | None = None, +) -> List[Union[np.ndarray, torch.Tensor]]: + """ + Split a tensor into chunks along a dimension with striding, prioritizing views over copies. + + Args: + tensor: Input tensor (numpy array or torch tensor) + chunks: Number of chunks to create + dim: Dimension along which to chunk (default: 0) + stride: Stride between chunk starting positions (default: 1) + chunk_size: Size of each chunk. If None, calculated automatically (default: None) + + Returns: + List of tensor chunks (views when possible, copies when necessary) + """ + + # Get the size of the specified dimension + dim_size = tensor.shape[dim] + + # Calculate chunk size if not provided + if chunk_size is None: + chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division + + chunks_list = [] + + for i in range(chunks): + start_idx = i * stride + end_idx = min(start_idx + chunk_size, dim_size) + + # Break if we've gone beyond the tensor + if start_idx >= dim_size: + break + + # Create slice objects for all dimensions + slices = [slice(None)] * tensor.ndim + slices[dim] = slice(start_idx, end_idx) + + chunk = tensor[tuple(slices)] + chunks_list.append(chunk) + + return chunks_list + + +def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1): + dim_size = input_tensor.shape[dim] + stride = math.ceil(dim_size / chunks) + + return strided_chunk_views( + input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap + ) diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py index 3cdbbb6f3..cf936481e 100644 --- a/src/axolotl/prompt_strategies/__init__.py +++ b/src/axolotl/prompt_strategies/__init__.py @@ -17,7 +17,10 @@ def load(strategy, tokenizer, cfg, ds_cfg, processor=None): return messages_load(tokenizer, cfg, ds_cfg, processor=processor) load_fn = "load" package = "axolotl.prompt_strategies" - if strategy.split(".")[-1].startswith("load_"): + if ( + strategy.split(".")[-1].startswith("load_") + or strategy.split(".")[-1] == "load" + ): load_fn = strategy.split(".")[-1] strategy = ".".join(strategy.split(".")[:-1]) elif len(strategy.split(".")) > 1: diff --git a/src/axolotl/train.py b/src/axolotl/train.py index 13ac8ec0d..fa7d56913 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -1,10 +1,13 @@ """Prepare and train a model on a dataset. Can also infer from a model or merge lora""" +from __future__ import annotations + import importlib import inspect import os import signal import sys +import typing import weakref from contextlib import ExitStack from pathlib import Path @@ -25,7 +28,6 @@ from axolotl.common.datasets import TrainDatasetMeta from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module fix_untrained_tokens, ) -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.integrations.base import PluginManager from axolotl.loaders import ( ModelLoader, @@ -45,6 +47,9 @@ try: except ImportError: BetterTransformer = None +if typing.TYPE_CHECKING: + from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder + LOG = get_logger(__name__) @@ -472,7 +477,7 @@ def handle_untrained_tokens_fix( def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[ - HFRLTrainerBuilder | HFCausalTrainerBuilder, + "HFRLTrainerBuilder" | "HFCausalTrainerBuilder", PeftModel | PreTrainedModel, PreTrainedTokenizer, PeftConfig | None, diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py index 3d0ba7c9c..e669413f8 100644 --- a/src/axolotl/utils/__init__.py +++ b/src/axolotl/utils/__init__.py @@ -52,3 +52,10 @@ def patch_optimized_env(): if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None: os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" set_pytorch_cuda_alloc_conf() + + +def get_not_null(value, default=None): + """ + return the value if it's not None, otherwise return the default value + """ + return value if value is not None else default diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index bf496d2c5..09bfb5576 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -36,7 +36,7 @@ _CHAT_TEMPLATES = { "deepseek_v3": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% set ns = namespace(is_first=false, is_tool=false, is_output_first=true, system_prompt='', is_first_sp=true) %}{%- for message in messages %}{%- if message['role'] == 'system' %}{%- if ns.is_first_sp %}{% set ns.system_prompt = ns.system_prompt + message['content'] %}{% set ns.is_first_sp = false %}{%- else %}{% set ns.system_prompt = ns.system_prompt + '\\n\\n' + message['content'] %}{%- endif %}{%- endif %}{%- endfor %}{{ bos_token }}{{ ns.system_prompt }}{%- for message in messages %}{%- if message['role'] == 'user' %}{%- set ns.is_tool = false -%}{{'<|User|>' + message['content']}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' in message %}{%- set ns.is_tool = false -%}{%- for tool in message['tool_calls'] %}{%- if not ns.is_first %}{%- if message['content'] is none %}{{'<|Assistant|><|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- else %}{{'<|Assistant|>' + message['content'] + '<|tool▁calls▁begin|><|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- set ns.is_first = true -%}{%- else %}{{'\\n' + '<|tool▁call▁begin|>' + tool['type'] + '<|tool▁sep|>' + tool['function']['name'] + '\\n' + '```json' + '\\n' + tool['function']['arguments'] + '\\n' + '```' + '<|tool▁call▁end|>'}}{%- endif %}{%- endfor %}{{'<|tool▁calls▁end|><|end▁of▁sentence|>'}}{%- endif %}{%- if message['role'] == 'assistant' and 'tool_calls' not in message %}{%- if ns.is_tool %}{{'<|tool▁outputs▁end|>' + message['content'] + '<|end▁of▁sentence|>'}}{%- set ns.is_tool = false -%}{%- else %}{% set content = message['content'] %}{% if '' in content %}{% set content = content.split('')[-1] %}{% endif %}{{'<|Assistant|>' + content + '<|end▁of▁sentence|>'}}{%- endif %}{%- endif %}{%- if message['role'] == 'tool' %}{%- set ns.is_tool = true -%}{%- if ns.is_output_first %}{{'<|tool▁outputs▁begin|><|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- set ns.is_output_first = false %}{%- else %}{{'<|tool▁output▁begin|>' + message['content'] + '<|tool▁output▁end|>'}}{%- endif %}{%- endif %}{%- endfor -%}{% if ns.is_tool %}{{'<|tool▁outputs▁end|>'}}{% endif %}{% if add_generation_prompt and not ns.is_tool %}{{'<|Assistant|>'}}{% endif %}", "jamba": '{# Variables #}\n{% set ns = namespace(message_count=0, is_last_checked_defined=False) %}\n{##}\n{% set bom_str = bom_str or "<|bom|>" %}\n{% set eom_str = eom_str or "<|eom|>" %}\n{% set default_system_message = "" %}\n{##}\n{% set documents_prefix = "" %}\n{% set documents_suffix = "" %}\n{% set tool_definitions_prefix = "" %}\n{% set tool_definitions_suffix = "" %}\n{% set active_modes_prefix = "" %}\n{% set active_modes_suffix = "" %}\n{##}\n{% set tool_calls_prefix = "" %}\n{% set tool_calls_suffix = "" %}\n{% set citations_prefix = "" %}\n{% set citations_suffix = "" %}\n{##}\n{% if add_generation_prompt is not defined %}\n {% set add_generation_prompt = True %}\n{% endif %}\n{% set role_to_predict = role_to_predict or "assistant" %}\n{% if messages|length > 0 and messages[0].role == "system" %}\n {% set system_message = messages[0].content %}\n {% set loop_messages = messages[1:] %}\n{% else %}\n {% set system_message = default_system_message %}\n {% set loop_messages = messages %}\n{% endif %}\n{##}\n{##}\n{# Macros #}\n{% macro handle_tool_definitions(tools) %}\n {{- tool_definitions_prefix -}}\n {{- "\\n# Tools" -}}\n {{- "\\n\\n## Functions" -}}\n {% for tool in tools %}\n {% set _ = is_param_set(tool, field="type") %}\n {% set is_tool_type_set = ns.is_last_checked_defined %}\n {% if is_tool_type_set %}\n {% if tool.type == "function" %}\n {% set tool = tool.function %}\n {% else %}\n {{ raise_exception("Currently, the only supported tool type is `function`") }}\n {% endif %}\n {% endif %}\n {{- "\\n\\n" + (tool|tojson(indent=2)) -}}\n {% endfor %}\n {{- "\\n" + tool_definitions_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_first_system_message(system_message, tools) %}\n {{- bom_str + handle_role("system") -}}\n {% set _ = is_param_set(system_message) %}\n {% set is_system_message_set = ns.is_last_checked_defined %}\n {% if is_system_message_set %}\n {{- system_message -}}\n {% endif %}\n {% set _ = is_param_set(tools, is_list=True) %}\n {% set is_tools_set = ns.is_last_checked_defined %}\n {% if is_tools_set %}\n {% if system_message %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- handle_tool_definitions(tools) -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_tool_calls(tool_calls) %}\n {{- tool_calls_prefix + "[\\n" -}}\n {% for tool_call in tool_calls %}\n {% set _ = is_param_set(tool_call, field="function") %}\n {% set is_tool_call_function_set = ns.is_last_checked_defined %}\n {% if is_tool_call_function_set %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {% set arguments = tool_call.arguments %}\n {% if arguments is not string %}\n {%- set arguments = arguments|tojson -%}\n {%- endif %}\n {{ "{\\"name\\": \\"" + tool_call.name + "\\", \\"arguments\\": " + arguments + "}" -}}\n {% if not loop.last %}\n {{- "," }}\n {% endif %}\n {% endfor %}\n {{- "\\n]" + tool_calls_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_documents(documents) %}\n {{- documents_prefix -}}\n {{- "\\n# Documents" -}}\n {{- "\\n\\nYou can use the following documents for reference:" -}}\n {% for doc in documents %}\n {{- "\\n\\n## Document ID: " + loop.index0|string -}}\n {% set _ = is_param_set(doc, field="title") %}\n {% set is_doc_title_set = ns.is_last_checked_defined %}\n {% if is_doc_title_set %}\n {{- "\\nTitle: " + doc.title -}}\n {% endif %}\n {% for key, value in doc.items() %}\n {% if key not in ["title", "text"] %}\n {{- "\\n" + key|title + ": " + value|string -}}\n {% endif %}\n {% endfor %}\n {{- "\\nText: " + doc.text -}}\n {% endfor %}\n {{- "\\n" + documents_suffix -}}\n{% endmacro %}\n{##}\n{% macro handle_knobs(knobs) %}\n {{- active_modes_prefix -}}\n {{- "\\n# Active Modes" -}}\n {{ "\\n\\nThe following modes configure the format or style of your responses. You should adhere to all currently" -}}\n {{ " active modes simultaneously." -}}\n {% if knobs.citation_mode == "fast" %}\n {{- "\\n\\n## Citation Mode" -}}\n {{- "\\n\\nProvide a list of references only for the documents you base your response on. Format your response" -}}\n {{ " with the original answer followed by a citation section. Use this template:" -}}\n {{ " `{answer}" + citations_prefix + "DOCUMENT_IDS" + citations_suffix + "`, where DOCUMENT_IDS are the relevant document numbers" -}}\n {{ " (e.g. [2, 5, 9]), or [] if the answer cannot be supported by the provided documents." -}}\n {% endif %}\n {% if knobs.response_format == "json_object" %}\n {{- "\\n\\n## JSON Mode" -}}\n {{ "\\n\\nProvide your response in JSON format. Adhere strictly to any schema given by the user." -}}\n {{ " If an appropriate JSON format exists, use it without modification." -}}\n {% endif %}\n {{- "\\n" + active_modes_suffix -}}\n{% endmacro %}\n{##}\n{% macro get_last_user_index(messages) %}\n {% set ns.last_user_index = 0 %}\n {% for message in messages %}\n {% if message.role == \'user\' %}\n {% set ns.last_user_index = loop.index0 %}\n {% endif %}\n {% endfor %}\n {{- ns.last_user_index -}}\n{% endmacro %}\n{##}\n{% macro handle_last_system_message(documents, knobs, use_documents, use_knobs) %}\n {{- bom_str + handle_role("system") -}}\n {% set macros_to_call = [] %}\n {% set params_for_macros = [] %}\n {% if use_documents %}\n {% set macros_to_call = macros_to_call + [handle_documents] %}\n {% set params_for_macros = params_for_macros + [[documents]] %}\n {% endif %}\n {% if use_knobs %}\n {% set macros_to_call = macros_to_call + [handle_knobs] %}\n {% set params_for_macros = params_for_macros + [[knobs]] %}\n {% endif %}\n {% for i in range(macros_to_call|length) %}\n {% if i > 0 %}\n {{- "\\n\\n" -}}\n {% endif %}\n {{- macros_to_call[i](*params_for_macros[i]) -}}\n {% endfor %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endmacro %}\n{##}\n{% macro handle_role(role, add_space=True) %}\n {{- "<|" + role + "|>" -}}\n {% if add_space %}\n {{- " " -}}\n {% endif %}\n{% endmacro %}\n{##}\n{% macro is_param_set(param, field=none, is_list=False) %}\n {% if field is not none %}\n {% if field in param %}\n {% set param = param[field] %}\n {% else %}\n {% set param = none %}\n {% endif %}\n {% endif %}\n {% set is_defined = param is defined and param is not none %}\n {% if is_list %}\n {% set ns.is_last_checked_defined = is_defined and param|length > 0 %}\n {% else %}\n {% set ns.is_last_checked_defined = is_defined %}\n {% endif %}\n{% endmacro %}\n{##}\n{##}\n{# Template #}\n{{- "<|startoftext|>" -}}\n{% set _ = is_param_set(system_message) %}\n{% set is_system_message_set = ns.is_last_checked_defined %}\n{% set _ = is_param_set(tools, is_list=True) %}\n{% set is_tools_set = ns.is_last_checked_defined %}\n{% set has_system_message = (is_system_message_set or is_tools_set) %}\n{% if has_system_message %}\n {{- handle_first_system_message(system_message, tools) -}}\n{% endif %}\n{% set last_user_index = get_last_user_index(loop_messages)|int %}\n{% for message in loop_messages %}\n {% if loop.index0 == last_user_index %}\n {% set _ = is_param_set(documents, is_list=True) %}\n {% set use_documents = ns.is_last_checked_defined %}\n {% set _ = is_param_set(knobs) %}\n {% set use_knobs = ns.is_last_checked_defined and knobs.is_set %}\n {% set add_last_system_message = use_documents or use_knobs %}\n {% if add_last_system_message %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- handle_last_system_message(documents, knobs, use_documents, use_knobs) -}}\n {% endif %}\n {% endif %}\n {% set role = message.role %}\n {% set _ = is_param_set(message, field="name") %}\n {% set is_message_name_set = ns.is_last_checked_defined %}\n {% if is_message_name_set %}\n {% set message_prefix = handle_role(role) + "(" + message.name + ")" %}\n {% else %}\n {% set message_prefix = handle_role(role) %}\n {% endif %}\n {% set content = (message.content or "") %}\n {% if content is not string %}\n {% set content = content|tojson %}\n {% endif %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + message_prefix + content -}}\n {% set _ = is_param_set(message, field="tool_calls", is_list=True) %}\n {% set is_tool_calls_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_tool_calls_set %}\n {{- handle_tool_calls(message.tool_calls) -}}\n {% endif %}\n {% set _ = is_param_set(message, field="citations", is_list=True) %}\n {% set is_citations_set = ns.is_last_checked_defined %}\n {% if role == "assistant" and is_citations_set %}\n {{- citations_prefix + message.citations|map(attribute="document_id")|list|string + citations_suffix -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% endfor %}\n{% if add_generation_prompt %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n {{- bom_str + handle_role(role_to_predict, add_space=False) -}}\n {% set _ = is_param_set(generation_preamble) %}\n {% set is_generation_preamble_set = ns.is_last_checked_defined %}\n {% if is_generation_preamble_set and generation_preamble.strip() != "" %}\n {{- " " + generation_preamble -}}\n {% endif %}\n {% set ns.message_count = ns.message_count + 1 %}\n{% else %}\n {% if ns.message_count > 0 %}\n {{- eom_str -}}\n {% endif %}\n{% endif %}\n', "qwen_25": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0]['role'] == 'system' %}\n {{- messages[0]['content'] }}\n {%- else %}\n {{- 'You are Qwen, created by Alibaba Cloud. You are a helpful assistant.' }}\n {%- endif %}\n {{- \"\\n\\n# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0]['role'] == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0]['content'] + '<|im_end|>\\n' }}\n {%- else %}\n {{- '<|im_start|>system\\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) or (message.role == \"assistant\" and not message.tool_calls) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {{- '<|im_start|>' + message.role }}\n {%- if message.content %}\n {{- '\\n' + message.content }}\n {%- endif %}\n {%- for tool_call in message.tool_calls %}\n {%- if tool_call.function is defined %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {{- tool_call.arguments | tojson }}\n {{- '}\\n' }}\n {%- endfor %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if (loop.index0 == 0) or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n{%- endif %}\n", - "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}", + "qwen3": "{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + message.content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set content = message.content %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is defined and message.reasoning_content is not none %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in message.content %}\n {%- set content = message.content.split('')[-1].lstrip('\\n') %}\n {%- set reasoning_content = message.content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- message.content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- else %}\n {{- '\\n\\n' }}\n {%- endif %}\n{%- endif %}", "exaone": "{% for message in messages %}{% if loop.first and message['role'] != 'system' %}{{ '[|system|][|endofturn|]\n' }}{% endif %}{{ '[|' + message['role'] + '|]' + message['content'] }}{% if message['role'] == 'user' %}{{ '\n' }}{% else %}{{ '[|endofturn|]\n' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '[|assistant|]' }}{% endif %}", "metharme": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'Enter RP mode. You shall reply to the user while staying in character. Your responses must be detailed, creative, immersive, and drive the scenario forward.' %}{% endif %}{{ '<|system|>' + system_message }}{% for message in loop_messages %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|user|>' + content.strip() }}{% elif message['role'] == 'assistant' %}{{ '<|model|>' + content.strip() }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|model|>' }}{% else %}{{ eos_token }}{% endif %}", "pixtral": '{%- if messages[0]["role"] == "system" %}\n {%- set system_message = messages[0]["content"] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message[\'role\'] == \'user\') != (loop.index0 % 2 == 0) %}\n {{- raise_exception(\'After the optional system message, conversation roles must alternate user/assistant/user/assistant/...\') }}\n {%- endif %}\n {%- if message["role"] == "user" %}\n {%- if loop.last and system_message is defined %}\n {{- "[INST]" + system_message + "\n\n" }}\n {%- else %}\n {{- "[INST]" }}\n {%- endif %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n {%- endif %}\n {%- endfor %}\n {%- else %}\n {{- message["content"] }}\n {%- endif %}\n {{- "[/INST]" }}\n {%- elif message["role"] == "assistant" %}\n {%- if message["content"] is not string %}\n {%- for chunk in message["content"] %}\n {%- if chunk["type"] == "text" %}\n {{- chunk["text"] }}\n {%- elif chunk["type"] == "image" %}\n {{- "[IMG]" }}\n {%- else %}\n {{- raise_exception("Unrecognized content type!") }}\n{%- endif %}\n{%- endfor %}\n{{- eos_token }}\n{%- else %}\n{{- message["content"] + eos_token }}\n{%- endif %}\n {%- else %}\n {{- raise_exception("Only user and assistant roles are supported, with the exception of an initial optional system message!") }}\n {%- endif %}\n{%- endfor %}', diff --git a/src/axolotl/utils/collators/batching.py b/src/axolotl/utils/collators/batching.py index d8414d117..a28f360be 100644 --- a/src/axolotl/utils/collators/batching.py +++ b/src/axolotl/utils/collators/batching.py @@ -1,7 +1,7 @@ """Data collators for axolotl to pad labels and position_ids for packed sequences""" from dataclasses import dataclass -from typing import Any +from typing import Any, List import numpy as np from transformers import PreTrainedTokenizerBase @@ -163,7 +163,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): def __call__(self, features, return_tensors=None): if not isinstance(features[0], list): - features = [features] + features: List[List[dict]] = [features] out_features = [{} for _ in features] for i, features_ in enumerate(features): for feature in features_[0].keys(): diff --git a/src/axolotl/utils/data/utils.py b/src/axolotl/utils/data/utils.py index 0ffaa932f..4f7f6f8dd 100644 --- a/src/axolotl/utils/data/utils.py +++ b/src/axolotl/utils/data/utils.py @@ -51,6 +51,7 @@ def retry_on_request_exceptions( except ( requests.exceptions.ReadTimeout, requests.exceptions.ConnectionError, + requests.exceptions.HTTPError, huggingface_hub.errors.HfHubHTTPError, ) as exc: if attempt < max_retries - 1: diff --git a/src/axolotl/utils/samplers/multipack.py b/src/axolotl/utils/samplers/multipack.py index eabfc2d84..7fb5e1b41 100644 --- a/src/axolotl/utils/samplers/multipack.py +++ b/src/axolotl/utils/samplers/multipack.py @@ -259,7 +259,7 @@ class MultipackBatchSampler(BatchSampler): batch_max_len: int, # Maximum sequence length (bin capacity) lengths: np.ndarray, # Sequence lengths packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate - drop_last: bool = False, # Whether to drop final batches (might be incomplete) + drop_last: bool = True, # Whether to drop final batches (might be incomplete) num_count_samples: int = 8, # Number of times to estimate batch count sequential: bool = False, # Whether to use sequential packing group_size: int = 100_000, # Size of groups for parallel packing @@ -446,10 +446,18 @@ class MultipackBatchSampler(BatchSampler): if self._len_across_ranks is None: # Sample multiple times to get stable estimate - len_batches = min( # pylint: disable=consider-using-generator - [len(self._batches) for _ in range(self.num_count_samples)] - ) + _sampled_lens = [] + for _ in range(self.num_count_samples): + self._batches = None # Reset cached batches + _sampled_lens.append(len(self.generate_batches(set_stats=False))) + len_batches = min(_sampled_lens) + # Gather minimum across all ranks - self._len_across_ranks = self.gather_len_batches(len_batches) + if self._len_across_ranks is None: + self._len_across_ranks = self.gather_len_batches(len_batches) + else: + self._len_across_ranks = min( + self._len_across_ranks, self.gather_len_batches(len_batches) + ) return self._len_across_ranks diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index ec5360fa3..33ddadf78 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,6 @@ from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler, SequentialSampler from transformers.utils import is_torch_bf16_gpu_available -from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2 from axolotl.utils.distributed import reduce_and_broadcast from axolotl.utils.environment import check_cuda_p2p_ib_support @@ -483,6 +482,9 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree ) ) + if cfg.dataloader_drop_last: + # drop the last batch for each epoch + total_num_steps -= int(math.ceil(cfg.num_epochs)) def calc_sample_packing_eff_est(estimates: List[float]): LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}") @@ -630,6 +632,8 @@ def setup_trainer( A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based on the provided parameters. """ + from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder + if ( cfg.torch_compile and cfg.fsdp_config diff --git a/tests/e2e/integrations/test_kd.py b/tests/e2e/integrations/test_kd.py index 2bd1fbf3d..212450e89 100644 --- a/tests/e2e/integrations/test_kd.py +++ b/tests/e2e/integrations/test_kd.py @@ -5,10 +5,9 @@ e2e tests for kd trainer support in Axolotl from pathlib import Path import pytest +import yaml +from accelerate.test_utils import execute_subprocess_async, get_torch_dist_unique_port -from axolotl.common.datasets import load_datasets -from axolotl.train import train -from axolotl.utils.config import normalize_config, prepare_plugins, validate_config from axolotl.utils.dict import DictDefault from tests.e2e.utils import check_tensorboard, require_torch_2_5_1 @@ -17,8 +16,8 @@ from tests.e2e.utils import check_tensorboard, require_torch_2_5_1 @pytest.fixture(name="kd_min_cfg") def min_cfg(temp_dir): return { - "base_model": "osllmai-community/Llama-3.2-1B", - "tokenizer_config": "axolotl-ai-co/Llama-3.3-70B-Instruct-tokenizer", + "base_model": "Qwen/Qwen3-0.6B", + "tokenizer_config": "winglian/qwen3-14b-math", "plugins": [ "axolotl.integrations.kd.KDPlugin", "axolotl.integrations.liger.LigerPlugin", @@ -31,20 +30,22 @@ def min_cfg(temp_dir): "kd_ce_alpha": 0.1, "kd_alpha": 0.9, "kd_temperature": 1.0, + "kd_beta": 0.0, + "kd_normalize_topk": True, "dataloader_prefetch_factor": 8, "dataloader_num_workers": 4, "dataloader_pin_memory": True, "datasets": [ { - "path": "axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample", - "type": "axolotl.integrations.kd.chat_template", - "field_messages": "messages_combined", + "path": "winglian/OpenThoughts-114k-math-correct-qwen3-14b-math-prepared-topk128-normalized", + "type": "chat_template", "split": "train", - "logprobs_field": "llm_text_generation_vllm_logprobs", - "temperature": 1.0, - "preprocess_shards": 2, + "split_thinking": True, + "eot_tokens": ["<|im_end|>"], + "data_files": ["train/batch-000000.parquet"], }, ], + "skip_prepare_dataset": True, "val_set_size": 0.0, "sequence_len": 2048, "sample_packing": True, @@ -80,17 +81,29 @@ class TestKnowledgeDistillation: def test_llama_kd(self, temp_dir, kd_min_cfg): cfg = DictDefault(kd_min_cfg) # pylint: disable=duplicate-code - cfg = validate_config(cfg) - prepare_plugins(cfg) - normalize_config(cfg) - dataset_meta = load_datasets(cfg=cfg) + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) + + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) - train(cfg=cfg, dataset_meta=dataset_meta) assert (Path(temp_dir) / "model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.4, "Train Loss (%s) is too high" ) + @pytest.mark.skip(reason="Chunked KD loss doesn't support PEFT/LoRA") @pytest.mark.parametrize( "load_in_8bit", [True, False], @@ -110,12 +123,22 @@ class TestKnowledgeDistillation: | kd_min_cfg ) # pylint: disable=duplicate-code - cfg = validate_config(cfg) - prepare_plugins(cfg) - normalize_config(cfg) - dataset_meta = load_datasets(cfg=cfg) + # write cfg to yaml file + Path(temp_dir).mkdir(parents=True, exist_ok=True) + with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout: + fout.write(yaml.dump(cfg.to_dict(), Dumper=yaml.Dumper)) - train(cfg=cfg, dataset_meta=dataset_meta) + execute_subprocess_async( + [ + "axolotl", + "train", + str(Path(temp_dir) / "config.yaml"), + "--num-processes", + "1", + "--main-process-port", + f"{get_torch_dist_unique_port()}", + ] + ) assert (Path(temp_dir) / "adapter_model.safetensors").exists() check_tensorboard( temp_dir + "/runs", "train/loss", 1.2, "Train Loss (%s) is too high" diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index 2b03c62f8..d91f63d94 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -81,6 +81,7 @@ class TestBatchedSamplerPacking: group_size=100000, bin_size=200, sequential=sequential, + drop_last=False, ) loader = DataLoader(