Compare commits

..

2 Commits

Author SHA1 Message Date
Wing Lian
a6056e35de enable torch compile on the optimizer step
make optimizer compile independent of torch compile on the model
2025-06-10 00:07:49 -07:00
Wing Lian
09c685fd2c fix worker_init_fn signature handling (#2769) 2025-06-08 23:14:10 -07:00
31 changed files with 547 additions and 2130 deletions

View File

@@ -1,31 +0,0 @@
{
"compile": {
"disable": false,
"backend": "inductor"
},
"zero_optimization": {
"stage": 2,
"offload_optimizer": {
"device": "cpu"
},
"contiguous_gradients": true,
"overlap_comm": true
},
"bf16": {
"enabled": "auto"
},
"fp16": {
"enabled": "auto",
"auto_cast": false,
"loss_scale": 0,
"initial_scale_power": 32,
"loss_scale_window": 1000,
"hysteresis": 2,
"min_loss_scale": 1
},
"gradient_accumulation_steps": "auto",
"gradient_clipping": "auto",
"train_batch_size": "auto",
"train_micro_batch_size_per_gpu": "auto",
"wall_clock_breakdown": false
}

View File

@@ -422,6 +422,9 @@ class TrainerBuilderBase(abc.ABC):
if self.cfg.torch_compile_mode:
training_args_kwargs["torch_compile_mode"] = self.cfg.torch_compile_mode
if self.cfg.compile_optimizer:
training_args_kwargs["compile_optimizer"] = True
def _configure_gradient_checkpointing(self, training_args_kwargs: dict):
if self.cfg.gradient_checkpointing:
training_args_kwargs["gradient_checkpointing"] = (

View File

@@ -21,6 +21,11 @@ from axolotl.core.trainers import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
from axolotl.integrations.base import PluginManager
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback
@@ -125,9 +130,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return callbacks
def _get_trainer_cls(self):
"""
Gets the trainer class for the given configuration.
"""
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
@@ -144,12 +146,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlTrainer
def build(self, total_num_steps):
from axolotl.core.training_args import (
AxolotlPRMConfig,
AxolotlRewardConfig,
AxolotlTrainingArguments,
)
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps
)
@@ -318,12 +314,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["image_resize_algorithm"] = (
self.cfg.image_resize_algorithm
)
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_arguments_kwargs.update(plugin_training_args)
if self.cfg.kd_ce_alpha is not None:
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
if self.cfg.kd_alpha is not None:
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
if self.cfg.kd_temperature is not None:
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
if self.cfg.kd_zscore_base_temp is not None:
training_arguments_kwargs["kd_zscore_base_temp"] = (
self.cfg.kd_zscore_base_temp
)
if self.cfg.kd_top_k_before_softmax is not None:
training_arguments_kwargs["kd_top_k_before_softmax"] = (
self.cfg.kd_top_k_before_softmax
)
if self.cfg.reward_model:
training_args_cls = AxolotlRewardConfig
@@ -404,10 +408,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return trainer
def build_collator(
self,
training_args, # type: "AxolotlTrainingArguments" # type: ignore
is_eval=False,
**kwargs,
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if training_args.pretraining:
if (
@@ -436,18 +437,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
collator_args = [self.tokenizer]
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
collator_cls_and_kwargs = plugin_manager.get_collator_cls_and_kwargs(
self.cfg, is_eval=is_eval
)
if collator_cls_and_kwargs:
collator = collator_cls_and_kwargs[0]
if kwargs and isinstance(kwargs, dict):
kwargs.update(collator_cls_and_kwargs[1])
elif self.cfg.reward_model:
if self.cfg.reward_model:
collator = RewardDataCollatorWithPadding
elif use_batch_sampler_collator:
# Use V2BatchSamplerDataCollatorForSeq2Seq for flex attention,
@@ -478,6 +468,16 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import (
DataCollatorForKD,
KDBatchSamplerDataCollatorForSeq2Seq,
)
if self.cfg.sample_packing:
collator = KDBatchSamplerDataCollatorForSeq2Seq
else:
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq

View File

@@ -12,6 +12,11 @@ from axolotl.core.trainers import (
from axolotl.core.trainers.dpo import DPOStrategy
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.logging import get_logger
@@ -74,12 +79,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
"""
Returns training_args and trainer_kwargs
"""
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlKTOConfig,
AxolotlORPOConfig,
)
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
total_num_steps=total_num_steps
)
@@ -166,13 +165,6 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if blocklist_key in training_args_kwargs:
del training_args_kwargs[blocklist_key]
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
plugin_training_args = plugin_manager.get_training_args(self.cfg)
if plugin_training_args:
training_args_kwargs.update(plugin_training_args)
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
logging_first_step=True,
**training_args_kwargs,

View File

@@ -33,7 +33,6 @@ from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
)
from axolotl.utils import get_not_null
from axolotl.utils.logging import get_logger
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@@ -102,7 +101,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
)
batch_max_len = train_batch_size * self.args.max_seq_length
sampler = MultipackBatchSampler(
return MultipackBatchSampler(
base_sampler,
lengths=get_dataset_lengths(dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
@@ -114,9 +113,6 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
drop_last=True,
)
len(sampler)
return sampler
def _get_train_sampler(
self, train_dataset: Optional[Dataset] = None
) -> Optional[Sampler]:
@@ -224,9 +220,7 @@ class AxolotlTrainer(SchedulerMixin, OptimizerMixin, RngLoaderMixin, Trainer):
}
if not isinstance(dataset, torch.utils.data.IterableDataset):
dataloader_params["drop_last"] = get_not_null(
self.args.dataloader_drop_last, True
)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
if sampler_fn is not None:
sampler = sampler_fn(dataset)
if isinstance(sampler, BatchSampler):

View File

@@ -3,6 +3,7 @@
# pylint: disable=too-many-lines,duplicate-code,protected-access,no-member
import warnings
from functools import partial
from typing import Any
import datasets
@@ -58,6 +59,42 @@ class AxolotlGRPOTrainer(
_tag_names = ["trl", "grpo", "axolotl"]
def get_train_dataloader(self):
if self.train_dataset is None:
raise ValueError("Trainer: training requires a train_dataset.")
train_dataset = self.train_dataset
data_collator = self.data_collator
if isinstance(train_dataset, datasets.Dataset):
train_dataset = self._remove_unused_columns(
train_dataset, description="training"
)
else:
data_collator = self._get_collator_with_removed_columns(
data_collator, description="training"
)
dataloader_params = {
"batch_size": self._train_batch_size
* self.args.steps_per_generation, # < this is the change
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
"persistent_workers": self.args.dataloader_persistent_workers,
}
if not isinstance(train_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_train_sampler()
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = partial(
seed_worker,
num_workers=self.args.dataloader_num_workers,
rank=self.args.process_index,
)
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
return self.accelerator.prepare(DataLoader(train_dataset, **dataloader_params))
class AxolotlGRPOSequenceParallelTrainer(AxolotlGRPOTrainer):
"""Extend the base GRPOTrainer for sequence parallelism handling"""

View File

@@ -1,5 +1,6 @@
"""Module for Axolotl trainer optimizer mixin"""
import torch
from peft.optimizers import create_loraplus_optimizer
from torch import nn
from transformers.trainer import Trainer
@@ -185,12 +186,12 @@ class OptimizerMixin(Trainer):
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
LOG.info(f"skipped {module}: {skipped / 2 ** 20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
LOG.info(f"skipped: {skipped / 2 ** 20}M params")
if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
@@ -199,6 +200,11 @@ class OptimizerMixin(Trainer):
return self.optimizer
def create_optimizer_and_scheduler(self, num_training_steps: int):
super().create_optimizer_and_scheduler(num_training_steps)
if self.args.compile_optimizer:
self.optimizer.step = torch.compile(self.optimizer.step)
class OptimizerInitMixin:
"""

View File

@@ -2,17 +2,242 @@
extra axolotl specific training args
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Optional, Type
from typing import Optional
from PIL.Image import Resampling
from transformers import TrainingArguments
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
from axolotl.integrations.config import merge_training_args
AxolotlTrainingMixins: Type = merge_training_args()
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
compile_optimizer: Optional[bool] = field(
default=None,
metadata={"help": "Whether to compile the optimizer for faster training."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
kd_ce_alpha: Optional[float] = field(
default=None,
metadata={
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
},
)
kd_alpha: Optional[float] = field(
default=1.0,
metadata={"help": "The alpha scaling parameter for KD loss"},
)
kd_temperature: Optional[float] = field(
default=1.0,
metadata={
"help": "the temperature parameter for KL divergence loss when using KD"
},
)
kd_zscore_base_temp: Optional[float] = field(
default=None,
metadata={
"help": "the base temperature parameter for KL divergence with z-score when using KD"
},
)
kd_top_k_before_softmax: Optional[bool] = field(
default=None,
metadata={
"help": "Whether to apply top_k_before_softmax to the logits when using KD"
},
)
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section
@dataclass

View File

@@ -1,220 +0,0 @@
"""
Base Axolotl Training Mixins shared across various trainer configs
"""
from dataclasses import dataclass, field
from typing import Optional
from PIL.Image import Resampling
@dataclass
class AxolotlTrainingMixins:
"""
Mixin class for the Axolotl training args.
"""
# pylint: disable=duplicate-code
model_type: Optional[str] = field(
default=None, metadata={"help": "HF model configuration model_type."}
)
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
pretraining: bool = field(
default=False,
metadata={
"help": "Indicates to trainer whether we are doing continued pretraining."
},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_sequentially: bool = field(
default=False,
metadata={
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
sample_packing_bin_size: int = field(
default=200,
metadata={
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
},
)
sample_packing_group_size: int = field(
default=100000,
metadata={
"help": "The number of samples to group together for packing. Increase for better packing."
},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
relora_steps: Optional[int] = field(
default=None,
metadata={"help": "how often to reset for ReLoRA"},
)
relora_warmup_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_prune_ratio: Optional[float] = field(
default=0.9,
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
bench_dataset: Optional[str] = field(
default="pharaouk/dharma-1/dharma_1_mini.json",
metadata={
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
},
)
do_bench_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
)
do_causal_lm_eval: Optional[bool] = field(
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
)
max_bench_samples: Optional[int] = field(
default=None,
metadata={
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
},
)
bench_source_max_len: int = field(
default=2048, metadata={"help": "Maximum source sequence length for bench."}
)
dataloader_prefetch_factor: Optional[int] = field(
default=None,
metadata={"help": "prefetch_factor argument to the dataloader"},
)
cosine_min_lr_ratio: Optional[float] = field(
default=None,
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
)
cosine_constant_lr_ratio: Optional[float] = field(
default=None,
metadata={
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
},
)
loraplus_lr_ratio: Optional[float] = field(
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
)
loraplus_lr_embedding: Optional[float] = field(
default=1e-6,
metadata={"help": "loraplus learning rate for lora embedding layers."},
)
embedding_lr_scale: Optional[float] = field(
default=None,
metadata={"help": "Scale the learning rate for the embedding layers."},
)
lr_groups: Optional[list[dict]] = field(
default=None,
metadata={"help": "Specify learning rate groups for with different LRs."},
)
embedding_lr: Optional[float] = field(
default=None,
metadata={"help": "absolute learning rate for the embedding layers."},
)
qlora: bool = field(
default=False,
metadata={"help": "whether this is a qlora training"},
)
orpo_alpha: Optional[float] = field(
default=None,
)
lisa_n_layers: Optional[int] = field(
default=None,
metadata={"help": "the number of activate layers in LISA"},
)
lisa_step_interval: Optional[int] = field(
default=None,
metadata={"help": "how often to switch layers in LISA"},
)
lisa_layers_attribute: Optional[str] = field(
default=None,
metadata={"help": "path under the model to access the layers"},
)
curriculum_sampling: Optional[bool] = field(
default=None,
metadata={"help": "whether to use sequential sampling for curriculum learning"},
)
alternate_lr_scheduler_type: Optional[str] = field(
default=None,
metadata={
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
chat_template: Optional[str] = field(
default=None,
metadata={"help": "Chat template converting chat messages to text"},
)
# kd_ce_alpha: Optional[float] = field(
# default=None,
# metadata={
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
# },
# )
#
# kd_alpha: Optional[float] = field(
# default=1.0,
# metadata={"help": "The alpha scaling parameter for KD loss"},
# )
#
# kd_temperature: Optional[float] = field(
# default=1.0,
# metadata={
# "help": "the temperature parameter for KL divergence loss when using KD"
# },
# )
adam_beta3: Optional[float] = field(
default=None,
metadata={
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
},
)
adam_epsilon2: Optional[float] = field(
default=None,
metadata={
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
},
)
# multi-modal section
image_size: int | tuple[int, int] | None = field(
default=None,
metadata={"help": "The size of the image to resize to"},
)
image_resize_algorithm: Resampling | None = field(
default=None,
metadata={"help": "The algorithm to use for image resizing"},
)
# end of multi-modal section

View File

@@ -22,7 +22,6 @@ from __future__ import annotations
import collections
import importlib
import traceback
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
from peft import PeftModel
@@ -84,11 +83,6 @@ class BasePlugin:
def get_input_args(self) -> str | None:
"""Returns a pydantic model for the plugin's input arguments."""
def get_training_args_mixin(self) -> str | None:
"""
Returns a dataclass model for the plugin's training arguments.
"""
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -164,31 +158,6 @@ class BasePlugin:
trainer: The trainer object for training.
"""
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
"""
Returns custom training arguments to set on TrainingArgs.
Args:
cfg: The global axolotl configuration.
Returns:
object: dict containing the training arguments.
"""
def get_collator_cls_and_kwargs(
self, cfg: DictDefault, is_eval: bool = False
): # pylint: disable=unused-argument):
"""
Returns a custom class for the collator.
Args:
cfg: The global axolotl configuration.
is_eval: Whether this is an eval split.
Returns:
class: The class for the collator.
"""
# pylint: disable=unused-argument
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
"""Creates and returns an optimizer for training.
@@ -309,7 +278,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
return plugin
class PluginManager: # pylint: disable=too-many-public-methods
class PluginManager:
"""The `PluginManager` class is responsible for loading and managing plugins. It
should be a singleton so it can be accessed from anywhere in the codebase.
@@ -368,11 +337,8 @@ class PluginManager: # pylint: disable=too-many-public-methods
plugin = load_plugin(plugin_name)
self.plugins[plugin_name] = plugin
LOG.info(f"Plugin loaded successfully: {plugin_name}")
except ImportError as exc:
except ImportError:
LOG.error(f"Failed to load plugin: {plugin_name}")
# print stacktrace
traceback.print_exc()
print(f"Error: {exc}")
def get_input_args(self) -> list[str]:
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
@@ -387,20 +353,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
input_args.append(input_args_from_plugin)
return input_args
def get_training_args_mixin(self):
"""
Returns a list of dataclasses for all registered plugins' training args mixins'
Returns:
list[str]: A list of dataclsses
"""
training_args = []
for plugin in self.plugins.values():
training_args_from_plugin = plugin.get_training_args_mixin()
if training_args_from_plugin is not None:
training_args.append(training_args_from_plugin)
return training_args
def load_datasets(
self, cfg: DictDefault, preprocess: bool = False
) -> Union["TrainDatasetMeta", None]:
@@ -490,42 +442,6 @@ class PluginManager: # pylint: disable=too-many-public-methods
return trainer_cls
return None
def get_training_args(self, cfg):
"""
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
Parameters:
cfg (dict): The configuration for the plugins.
Returns:
object: The training arguments
"""
training_args_kwargs = {}
for plugin in self.plugins.values():
training_args = plugin.get_training_args(cfg)
if training_args is not None:
training_args_kwargs.update(training_args)
return training_args_kwargs
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
"""
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
Parameters:
cfg (dict): The configuration for the plugins.
is_eval (bool): Whether this is an eval split.
Returns:
object: The collator class, or None if none was found.
"""
for plugin in self.plugins.values():
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
if collator is not None:
collator_cls, collator_kwargs = collator
return collator_cls, collator_kwargs
return None
def post_trainer_create(self, cfg: DictDefault, trainer: Trainer):
"""Calls the `post_trainer_create` method of all registered plugins.

View File

@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
This was moved here to prevent circular imports.
"""
from typing import Any, Dict, List, Type
from typing import Any, Dict, List
from axolotl.utils.schemas.config import (
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
@@ -61,43 +61,3 @@ def merge_input_args():
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
def merge_training_args() -> Type:
"""
Merges training arguments from registered plugins with the base TrainingArguments.
This function retrieves the training arguments from registered plugins using the PluginManager.
It then dynamically creates new classes, AxolotlTrainingMixins,
that inherit from the base configurations and include the training arguments from the plugins.
Returns:
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
"""
# pylint: disable=duplicate-code
from axolotl.core.training_args_base import (
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
)
from axolotl.integrations.base import PluginManager
plugin_manager = PluginManager.get_instance()
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
mixin_classes = []
dynamic_input = ""
for plugin_args in training_args_mixins:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
mixin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
namespace: Dict[Any, Any] = {}
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
exec( # pylint: disable=exec-used # nosec B102
dynamic_input, {**globals(), **local_vars}, namespace
)
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
"AxolotlTrainingMixins"
]
return AxolotlTrainingMixins
return AxolotlTrainingMixinsBase

View File

@@ -21,32 +21,3 @@ datasets:
```
An example dataset can be found at [`axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample`](https://huggingface.co/datasets/axolotl-ai-co/evolkit-logprobs-pipeline-75k-v2-sample)
## Online KD (sglang)
```bash
export UV_TORCH_BACKEND=cu124
uv venv sglang --python 3.11
source sglang/bin/activate
uv pip install --upgrade pip
uv pip install setuptools
uv pip install torch~=2.5.1 --index-url https://download.pytorch.org/whl/cu124
uv pip install sgl-kernel --force-reinstall --no-deps
uv pip install "sglang[all]>=0.4.2.post4" --find-links https://flashinfer.ai/whl/cu124/torch2.5/flashinfer/
```
## Online KD (vllm)
```bash
VLLM_USE_V1=0 vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --max-num-seqs
256 --gpu_memory_utilization 0.2 --enable-chunked-prefill
```
```bash
vllm serve open-r1/OlympicCoder-32B --max-model-len 16400 --port 8888 --max-logprobs 128 --return-tokens-as-token-ids --tensor-parallel-size 8 --no-enable-prefix-caching --gpu-memory-utilization 0.3 --max-num-batched-tokens 131072 --host 0.0.0.0
```
```bash
python -m sglang.launch_server --model-path open-r1/OlympicCoder-32B --tensor-parallel-size 8 --port 8080 --host 0.0.0.0 --max-running-requests 256 --context-length 16400 --mem-fraction-static 0.2 --schedule-conservativeness 0.3 --chunked-prefill-size 131072 --schedule-policy fcfs --skip-tokenizer-init
```

View File

@@ -15,12 +15,7 @@
"""
Plugin init to add KD support to Axolotl.
"""
from typing import Any
from transformers import Trainer
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
@@ -33,75 +28,9 @@ class KDPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
def get_training_args_mixin(self):
return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None
def get_training_args(self, cfg):
return {
"kd_ce_alpha": cfg.kd_ce_alpha,
"kd_alpha": cfg.kd_alpha,
"kd_temperature": cfg.kd_temperature,
"kd_beta": cfg.kd_beta,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
if not cfg.kd_trainer:
return None, None
from .collator import DataCollatorForKD, KDBatchSamplerDataCollatorForSeq2Seq
use_batch_sampler_collator = False
if is_eval is False and cfg.sample_packing:
use_batch_sampler_collator = True
if cfg.eval_sample_packing and is_eval:
use_batch_sampler_collator = True
if cfg.kd_online_server_base_url:
from .collator_online_teacher import OnlineTeacherCollator
return OnlineTeacherCollator, {
"kd_online_server_base_url": cfg.kd_online_server_base_url,
"kd_online_topk": cfg.kd_online_topk,
"kd_temperature": cfg.kd_temperature,
"kd_online_server": cfg.kd_online_server,
"kd_online_timeout": cfg.kd_online_timeout,
"kd_normalize_topk": cfg.kd_normalize_topk,
}
if use_batch_sampler_collator:
return KDBatchSamplerDataCollatorForSeq2Seq, {}
return DataCollatorForKD, {}
def pre_model_load(self, cfg):
from .kernels.models import apply_kernel
apply_kernel(cfg.model_config_type)
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
"""
Adds temp scheduler callback to the Trainer instance.
Args:
cfg (Any): Configuration object containing the sparse recipe.
trainer (Trainer): Huggingface Trainer instance.
Returns:
list: List containing the configured callback instances.
"""
if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
callback = KDTemperatureSchedulerCallback(
cfg.kd_temperature,
cfg.kd_temperature_min,
trainer,
)
return [callback]
return []

View File

@@ -15,19 +15,9 @@
"""
Plugin args for KD support.
"""
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from pydantic import BaseModel, Field
class InferenceServerType(str, Enum):
"""
Online inferences server types to handle different request args
"""
vllm = "vllm" # pylint: disable=invalid-name
sglang = "sglang" # pylint: disable=invalid-name
from pydantic import BaseModel
class KDArgs(BaseModel):
@@ -35,41 +25,13 @@ class KDArgs(BaseModel):
Input args for knowledge distillation.
"""
kd_trainer: float | None = None # whether to use KD trainer
kd_ce_alpha: float | None = (
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[float] = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: bool | None = (
None # whether to normalize student logits during KD
)
# TODO online kd
kd_online_server_base_url: str | None = None
kd_online_topk: int | None = None
kd_online_server: InferenceServerType | None = Field(
default_factory=lambda: InferenceServerType.vllm
)
kd_online_timeout: int | None = 120
kd_temperature_min: float | None = (
None # kd temperature scheduling during online kd
)
@dataclass
class KDTrainingArgsMixin:
"""
Additional args for KD training.
"""
kd_ce_alpha: float | None = (
None # loss coefficient for cross-entropy loss during KD
)
kd_alpha: float | None = None # loss coefficient for KD loss
kd_temperature: float | None = None # temperature for sampling during KD
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
kd_normalize_topk: float | None = (
None # whether to normalize student logits during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
kd_zscore_base_temp: Optional[float] = None # base temperature for zscore scaling
kd_top_k_before_softmax: Optional[bool] = (
None # whether to sample top k before softmax during KD
)

View File

@@ -1,36 +0,0 @@
"""
Transformers trainer callbacks to schedule the KD temperature during training
"""
import math
from transformers.trainer_callback import TrainerCallback
class KDTemperatureSchedulerCallback(TrainerCallback):
"""
KD temperature scheduler callback for the trainer.
"""
def __init__(self, temperature_start, temperature_min, trainer):
self.temperature_start = temperature_start
self.temperature_min = temperature_min
self.temperature = temperature_start
self.trainer = trainer
def on_step_end(
self, args, state, control, **kwargs
): # pylint: disable=unused-argument
# cosine decay temperature over the max steps
progress = state.global_step / state.max_steps
# Cosine decay factor: 0.5 * (1 + cos(pi * progress))
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
self.temperature = self.temperature_start - (
(self.temperature_start - self.temperature_min) * (1.0 - decay_factor)
)
if hasattr(self.trainer.data_collator, "kd_temperature"):
self.trainer.data_collator.kd_temperature = self.temperature

View File

@@ -15,15 +15,12 @@
"""
Chat template prompt strategy loader with KD support
"""
import logging
from typing import Any, Dict
import torch
from axolotl.prompt_strategies.chat_template import ChatTemplateStrategy, StrategyLoader
LOG = logging.getLogger(__name__)
class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
"""
@@ -104,8 +101,10 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# we shift for causal models in the trainer, so start the range from 0
for _ in range(0, input_padding_len):
# for causal models, if we start the range at 1, then we don't need to shift in the trainer
# otherwise, we need to shift in the trainer
shift = 0
for _ in range(shift, input_padding_len):
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
@@ -144,10 +143,6 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
@@ -167,115 +162,12 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(position_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
sample["target_mask"] = target_mask
return sample
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
"""
Strat for datasets with complete structured KD logprob data
"""
def transform_logprobs(self, sample):
"""
Transform logprobs to target format for KD training
"""
# pylint: disable=duplicate-code
logprobs = sample.pop(self.logprobs_field)
target_seq_len = len(logprobs)
input_seq_len = len(sample["input_ids"])
input_padding_len = input_seq_len - target_seq_len
# get non-zero top-k (prune None logprobs from vllm data step)
top_k_vals = [
len(logprobs[i])
for i in range(len(logprobs))
if logprobs[i] is not None and len(logprobs[i])
]
max_top_k = max(set(top_k_vals), key=top_k_vals.count)
min_top_k = min(set(top_k_vals), key=top_k_vals.count)
top_k = min(max_top_k, min_top_k)
if top_k == 0:
raise ValueError("No non-zero top-k logprobs found.")
target_logprobs = []
target_token_ids = []
target_mask = []
if input_padding_len < 0:
# logprobs is longer than target_seq_len,
# so we need to slice from the left/beginning of logprobs
logprobs = logprobs[:-input_seq_len]
input_padding_len = 0
# target_seq_len = input_seq_len
# truncate the second dimension of the logprobs to top_k
logprobs = [row[:top_k] for row in logprobs]
# fill with -inf for padding_len tokens for top_k tokens
# extend target_logprobs with a padding_len x top_k 2D list filled with -inf
# we shift for causal models in the trainer, so start the range from 0
for _ in range(0, input_padding_len):
if shift == 1:
# since we started at index 1 for causal, we need one more padding token
target_logprobs.append([-float("inf")] * top_k)
target_token_ids.append(list(range(top_k)))
target_mask.append([0] * top_k)
for position in range(input_padding_len, input_seq_len):
if sample["labels"][position] == -100:
target_mask.append([0] * top_k)
else:
target_mask.append([1] * top_k)
for token_pos_logprobs, pos_target_token_ids in zip(
logprobs, sample["target_token_ids"]
):
# Convert to a tensor for easier manipulation
position_logprobs_tensor = torch.tensor(
token_pos_logprobs, dtype=torch.float
)
# Now we have distribution at T1 in log form, i.e. log p_{T1}(k).
# Next, re-scale to T2 = self.kd_temperature via exponent-based trick
# p_{T2}(k) = [p_{T1}(k)]^(T1 / T2) / Z
#
# Convert from log to probability
teacher_probs_t1 = position_logprobs_tensor.exp()
# normalize probabilities to sum to 1 in case they aren't already
teacher_probs_t1_sum = teacher_probs_t1.sum(dim=0, keepdim=True)
if teacher_probs_t1_sum > 1e-9:
teacher_probs_t1 = teacher_probs_t1 / teacher_probs_t1_sum
if self.kd_temperature != self.gen_temperature:
# Exponentiate by factor (T1 / T2)
exponent = self.gen_temperature / self.kd_temperature
teacher_probs_t2 = teacher_probs_t1**exponent
else:
teacher_probs_t2 = teacher_probs_t1
# Re-normalize
teacher_probs_t2 = teacher_probs_t2 / teacher_probs_t2.sum(
dim=0, keepdim=True
)
# Convert back to log
position_logprobs_tensor = torch.log(teacher_probs_t2)
# Now we have log p_{teacher, T2}(k) stored in position_logprobs_tensor
position_logprobs_scaled = position_logprobs_tensor.tolist()
target_logprobs.append(position_logprobs_scaled)
target_token_ids.append(pos_target_token_ids)
# Update sample with transformed logprobs
sample["target_logprobs"] = target_logprobs
sample["target_token_ids"] = target_token_ids
@@ -285,10 +177,8 @@ class ChatTemplateStrategyWithKDv2(ChatTemplateStrategyWithKD):
def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field)
target_token_ids = prompt.pop("target_token_ids")
tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt["target_token_ids"] = target_token_ids
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
return tokenized_prompt
@@ -300,7 +190,7 @@ class KDStrategyLoader(StrategyLoader):
"""
def _get_strategy_cls(self):
return ChatTemplateStrategyWithKDv2
return ChatTemplateStrategyWithKD
def _get_strategy_params(self, cfg, ds_cfg: Dict[str, Any]):
strategy_params = super()._get_strategy_params(cfg, ds_cfg)

View File

@@ -47,16 +47,11 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
def __call__(self, features, return_tensors=None):
if return_tensors is None:
return_tensors = self.return_tensors
padding_side = self.tokenizer.padding_side
max_len = 0
# Pad labels and position_ids first
for feature_name, pad_token_id in [
@@ -107,9 +102,7 @@ class DataCollatorForKD(DataCollatorForSeq2Seq):
target_mask_list.append(f.pop("target_mask"))
# Determine max lengths
max_teacher_seq_len = max_len or max(
len(seq) for seq in target_logprobs_list
)
max_teacher_seq_len = max(len(seq) for seq in target_logprobs_list)
max_k = max(len(seq_k) for seq in target_logprobs_list for seq_k in seq)
padded_target_logprobs = []
@@ -216,9 +209,7 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# We want to produce a single "merged" feature dict for each sub-batch.
out_features = [{} for _ in features]
for i, sub_features in enumerate( # pylint: disable=too-many-nested-blocks
features
):
for i, sub_features in enumerate(features):
# sub_features is a list of dicts, each dict = one sequences features
# We'll merge them into out_features[i].
#
@@ -252,17 +243,10 @@ class KDBatchSamplerDataCollatorForSeq2Seq(DataCollatorForKD):
# For example, input_ids or labels are often arrays.
arrays = []
for feat in sub_features:
if field_name in feat and isinstance(
feat[field_name], (list, torch.Tensor)
):
if isinstance(
feat[field_name][0], (dict, str)
): # pylint: disable=too-many-nested-blocks
continue
if field_name in feat:
arr = np.array(feat[field_name])
arrays.append(arr)
if arrays:
out_features[i][field_name] = np.concatenate(arrays)
out_features[i][field_name] = np.concatenate(arrays)
# 3) Now call the parent collator, which will do:
# - padding of labels/position_ids

View File

@@ -1,561 +0,0 @@
"""
Packed data loader for online teacher training supporting vllm and sglang.
"""
import hashlib
import hmac
import logging
from typing import Any, Dict, List, Optional
import requests
import torch
from orjson import orjson
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
from axolotl.integrations.kd.utils import normalize_logprobs
from axolotl.utils.data.utils import retry_on_request_exceptions
LOG = logging.getLogger(__name__)
def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
"""
Create HMAC-SHA hash from a list of integers
Args:
int_list: List of integers
key: Secret key (string or bytes)
hash_func: Hash function (default: sha256)
Returns:
HMAC digest as hex string
"""
# Convert key to bytes if it's a string
if isinstance(key, str):
key = key.encode("utf-8")
# Convert list of ints to bytes
# Method 1: Convert each int to bytes and concatenate
data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
# Create HMAC
h = hmac.new(key, data, hash_func)
return h.hexdigest()
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
"""
Collator for online teacher training.
"""
DEFAULT_LABEL_PAD_TOKEN_ID: int = -100
def __init__(
self,
*args: Any,
kd_online_server_base_url: Optional[str] = None,
kd_online_topk: Optional[int] = None,
kd_temperature: Optional[float] = 1.0,
kd_online_server: Optional[str] = "vllm",
kd_online_timeout: Optional[int] = 120,
kd_cache_dir: Optional[str] = None,
kd_normalize_topk: Optional[bool] = True,
**kwargs: Any,
):
super().__init__(*args, **kwargs)
if kd_online_server_base_url is None:
raise ValueError(
"kd_online_server_base_url must be provided for OnlineTeacherDataloader"
)
if kd_online_topk is None or kd_online_topk <= 0:
raise ValueError(
"kd_online_topk must be a positive integer for OnlineTeacherDataloader"
)
self.kd_online_server_base_url = kd_online_server_base_url.rstrip("/")
self.kd_online_topk = kd_online_topk
self.kd_temperature = kd_temperature
self.kd_online_server = kd_online_server
self.http_session = requests.Session()
self.kd_online_timeout = kd_online_timeout
self.kd_cache_dir = kd_cache_dir
self.kd_normalize_topk = kd_normalize_topk
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
if not raw_logprobs or self.kd_online_topk == 0:
return (
[-float("inf")] * self.kd_online_topk if self.kd_online_topk > 0 else []
)
raw_logprobs_tensor = torch.tensor(raw_logprobs, dtype=torch.float32)
return normalize_logprobs(raw_logprobs_tensor, self.kd_online_topk).tolist()
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_sglang(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by sglang for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/generate"
payload = {
"input_ids": batch_input_ids,
"return_logprob": True,
"top_logprobs_num": self.kd_online_topk,
"logprob_start_len": 0,
"return_text_in_logprobs": True,
"echo": True,
"sampling_params": {
"max_new_tokens": 0,
"temperature": self.kd_temperature,
"skip_special_tokens": False,
},
}
# Initialize with empty lists, so if API call fails, these are returned.
ret_data_target_token_ids: List[List[List[int]]] = []
ret_data_target_logprobs: List[List[List[float]]] = []
ret_data_target_mask: List[List[List[int]]] = []
try:
response = self.http_session.post(
api_endpoint, json=payload, timeout=self.kd_online_timeout
)
response.raise_for_status()
api_data: list[dict] = response.json()
# Ensure api_data is a list, and its length matches batch_input_ids
if not isinstance(api_data, list) or len(api_data) != len(batch_input_ids):
LOG.error(
f"API response format error. Expected a list of {len(batch_input_ids)} "
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
for sequence_data, seq_input_ids, seq_labels in zip(
api_data, batch_input_ids, labels
):
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
meta_info = sequence_data.pop("meta_info", {})
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | list[tuple]]] = meta_info.pop(
"input_top_logprobs", []
)
if not isinstance(input_top_logprobs, list):
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
for i, _, label in zip(
range(len(seq_input_ids)), seq_input_ids, seq_labels
):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
# so we replace the None value with padding data
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
elif (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data = input_top_logprobs[i]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (
isinstance(pos_top_logprobs_data, list)
and all(
isinstance(item, list) for item in pos_top_logprobs_data
)
and len(pos_top_logprobs_data) > 0
and len(pos_top_logprobs_data[0]) == 3
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_logprobs_raw, pos_token_ids, _ = [
list(row) for row in zip(*pos_top_logprobs_data)
]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append([0] * self.kd_online_topk)
current_target_mask.append([0] * self.kd_online_topk)
ret_data_target_token_ids.append(current_target_token_ids)
ret_data_target_logprobs.append(current_target_logprobs)
ret_data_target_mask.append(current_target_mask)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
@retry_on_request_exceptions(max_retries=10, delay=5)
def fetch_online_logprobs_vllm(
self, batch_input_ids: List[List[int]], labels: List[List[int]]
):
"""
Fetches logprobs from an online teacher served by vllm for a batch of input_ids.
Assumes API returns token IDs as strings in logprob dictionary keys.
"""
api_endpoint = f"{self.kd_online_server_base_url}/v1/completions"
payload = {
"prompt": batch_input_ids,
"echo": True,
"logprobs": True,
"prompt_logprobs": self.kd_online_topk,
"top_logprobs": self.kd_online_topk,
"max_new_tokens": 0,
"skip_special_tokens": False,
"temperature": self.kd_temperature,
"sampling_params": {
"max_tokens": 0,
},
}
# Initialize with empty lists, so if API call fails, these are returned.
ret_data_target_token_ids: List[List[List[int]]] = []
ret_data_target_logprobs: List[List[List[float]]] = []
ret_data_target_mask: List[List[List[int]]] = []
try:
headers = {"Accept-Encoding": "deflate, gzip, br, zstd"}
response = self.http_session.post(
api_endpoint,
json=payload,
headers=headers,
timeout=self.kd_online_timeout,
)
response.raise_for_status()
api_data: dict = orjson.loads(response.content)
choices: list[dict] = api_data["choices"]
# Ensure api_data is a list, and its length matches batch_input_ids
if not isinstance(choices, list) or len(choices) != len(batch_input_ids):
LOG.error(
f"API response format error. Expected a list of {len(batch_input_ids)} "
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
)
# Return empty data; items processed later will get default empty KD fields
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
for sequence_data, seq_input_ids, seq_labels in zip(
choices, batch_input_ids, labels
):
# seq_input_ids: List[int]
# seq_labels: List[int]
current_target_logprobs = []
current_target_token_ids = []
current_target_mask = []
# Ensure input_top_logprobs is a list
input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
sequence_data.pop("prompt_logprobs", [])
)
if not isinstance(input_top_logprobs, list):
LOG.warning(
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
)
input_top_logprobs = [] # Treat as empty
# basic check that the logprob data len matches the input len, so no need to handle padding
assert len(seq_input_ids) == len(input_top_logprobs)
seq_len = len(seq_input_ids)
for i, _, label in zip(range(seq_len), seq_input_ids, seq_labels):
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
# this is always the case for the first token.
# there is never logprob data for the first token since that's a true input
continue
if (
i < len(input_top_logprobs)
and input_top_logprobs[i] is not None
):
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
# Ensure pos_top_logprobs_data is a list of lists as expected
if not (
isinstance(pos_top_logprobs_data, dict)
and all(
isinstance(item, dict)
for item in pos_top_logprobs_data.values()
)
and len(pos_top_logprobs_data.keys()) > 0
): # [logprob, token_id, token_str]
LOG.warning(
f"Malformed pos_top_logprobs_data: {pos_top_logprobs_data}. Padding this position."
)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
continue
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
pos_token_ids_str = list(pos_top_logprobs_data.keys())
pos_logprobs_dict = pos_top_logprobs_data.values()
pos_token_ids = [
int(token_id) for token_id in pos_token_ids_str
]
pos_logprobs_raw = [
float(logprob.get("logprob", -float("inf")))
for logprob in pos_logprobs_dict
]
# Ensure correct length (top_k)
if len(pos_logprobs_raw) < self.kd_online_topk:
pad_len = self.kd_online_topk - len(pos_logprobs_raw)
LOG.warning(
f"Padding position {i} with {pad_len} top-k tokens and logprobs."
)
pos_logprobs_raw.extend([-float("inf")] * pad_len)
pos_token_ids.extend([0] * pad_len) # Pad with 0 token_id
# truncate to top_k in case the response was longer
current_target_token_ids.append(
pos_token_ids[: self.kd_online_topk]
)
if self.kd_normalize_topk:
normalized_logprobs_for_position = self._normalize_logprobs(
pos_logprobs_raw[: self.kd_online_topk]
)
current_target_logprobs.append(
normalized_logprobs_for_position
)
else:
current_target_logprobs.append(
pos_logprobs_raw[: self.kd_online_topk]
)
# Mask depends on the corresponding label for the student
if label == self.DEFAULT_LABEL_PAD_TOKEN_ID:
current_target_mask.append([0] * self.kd_online_topk)
else:
current_target_mask.append([1] * self.kd_online_topk)
else:
# Pad if no logprobs for this position (either due to length mismatch or None entry)
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(
list(range(self.kd_online_topk))
)
current_target_mask.append([0] * self.kd_online_topk)
for i in range(max(0, seq_len - len(current_target_logprobs))):
current_target_logprobs.append(
[-float("inf")] * self.kd_online_topk
)
current_target_token_ids.append(list(range(self.kd_online_topk)))
current_target_mask.append([0] * self.kd_online_topk)
ret_data_target_token_ids.append(current_target_token_ids)
ret_data_target_logprobs.append(current_target_logprobs)
ret_data_target_mask.append(current_target_mask)
# TODO save and load targets to disk for caching for next epoch
# generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
# if self.kd_cache_dir:
# hash_input_ids = hmac_sha_from_int_list(
# seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
# )
# with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
# pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
except requests.exceptions.RequestException as e:
LOG.error(f"Error fetching logprobs from online teacher: {e}")
raise e
# ret_logprobs_data will be returned with empty lists, handled by the caller.
except Exception as e: # Catch other potential errors during processing
LOG.error(
f"Unexpected error processing API response in fetch_online_logprobs: {e}",
exc_info=True,
)
raise e
return {
"target_token_ids": ret_data_target_token_ids,
"target_logprobs": ret_data_target_logprobs,
"target_mask": ret_data_target_mask,
}
def __call__(
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
) -> Dict[str, Any]:
if not features:
return super().__call__(features, return_tensors=return_tensors)
for (
sub_batch_features
) in features: # sub_batch_features is List[Dict[str, Any]]
if not sub_batch_features:
continue
input_ids_for_api_call: List[List[int]] = []
labels_for_api_call: List[List[int]] = []
# Store references to the original item dictionaries to update them in-place
items_for_api_call: List[Dict[str, Any]] = []
for item_dict in sub_batch_features:
if not isinstance(item_dict, dict):
LOG.warning(
f"Skipping non-dict item in sub_batch_features: {item_dict}"
)
continue
current_input_ids = item_dict.get("input_ids")
current_labels = item_dict.get("labels")
if current_input_ids is not None and current_labels is not None:
# Ensure input_ids and labels are lists of ints for JSON serialization
input_ids_list = (
current_input_ids.tolist()
if hasattr(current_input_ids, "tolist")
else list(current_input_ids)
)
labels_list = (
current_labels.tolist()
if hasattr(current_labels, "tolist")
else list(current_labels)
)
input_ids_for_api_call.append(input_ids_list)
labels_for_api_call.append(labels_list)
items_for_api_call.append(item_dict)
else:
# This item will not get teacher logprobs from the API.
# Initialize KD fields to empty lists so downstream collators handle them uniformly.
item_dict.setdefault("target_token_ids", [])
item_dict.setdefault("target_logprobs", [])
item_dict.setdefault("target_mask", [])
# print(items_for_api_call)
if items_for_api_call: # Only call API if there's something to process
if self.kd_online_server == "sglang":
api_responses_for_sub_batch = self.fetch_online_logprobs_sglang(
input_ids_for_api_call, labels_for_api_call
)
else:
api_responses_for_sub_batch = self.fetch_online_logprobs_vllm(
input_ids_for_api_call, labels_for_api_call
)
# api_responses_for_sub_batch has keys: "target_token_ids", "target_logprobs", "target_mask"
# Each value is a list, corresponding to items_for_api_call
for i, item_to_update in enumerate(items_for_api_call):
# TODO make sure to figure out which input in sub_batch_features to update the batch in the original `features` object so the super class can handle it properly.
if api_responses_for_sub_batch and i < len(
api_responses_for_sub_batch["target_token_ids"]
): # Check bounds
assert len(
api_responses_for_sub_batch["target_token_ids"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_logprobs"][i]
) == len(item_to_update["input_ids"])
assert len(
api_responses_for_sub_batch["target_mask"][i]
) == len(item_to_update["labels"])
item_to_update["target_token_ids"] = (
api_responses_for_sub_batch["target_token_ids"][i]
)
item_to_update["target_logprobs"] = api_responses_for_sub_batch[
"target_logprobs"
][i]
item_to_update["target_mask"] = api_responses_for_sub_batch[
"target_mask"
][i]
else:
# API call failed for this item, or response was shorter than expected.
# Ensure KD fields are initialized as empty lists.
LOG.warning(
f" (index {i}), or API response was too short. "
f"API response keys: {list(api_responses_for_sub_batch.keys()) if api_responses_for_sub_batch else 'None'}"
)
item_to_update.setdefault("target_token_ids", [])
item_to_update.setdefault("target_logprobs", [])
item_to_update.setdefault("target_mask", [])
return super().__call__(features, return_tensors=return_tensors)

View File

@@ -1,485 +0,0 @@
"""
Liger Kernels for Chunked Top-K Log-Prob Distillation
"""
import torch
import torch.nn.functional as F
from liger_kernel.chunked_loss.fused_linear_distillation import (
LigerFusedLinearDistillationBase,
)
from axolotl.integrations.kd.utils import normalize_logprobs
class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
"""
Chunked kl-div loss for top-k logprobs
"""
@staticmethod
def distillation_loss_fn(
student_logits_temp_scaled: torch.Tensor, # [chunk_size, vocab_size], already temp-scaled
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
beta: float = 0.0,
normalize_topk: bool = True,
) -> torch.Tensor:
"""
Compute Top-K KL divergence loss for a chunk.
Args:
student_logits_temp_scaled: Student logits, scaled by temperature. Shape: (N, V).
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
beta: Controls the type of KL divergence.
0.0 for Forward KL (P_teacher || P_student).
1.0 for Reverse KL (P_student || P_teacher).
0.5 for Symmetric KL (average of Forward and Reverse).
normalize_topk: Whether to normalize the log probabilities
Returns:
Sum of KL divergence losses for the chunk.
"""
topk = target_token_ids_chunk.shape[-1]
student_logits_temp_scaled = ( # [chunk_size, vocab_size]
student_logits_temp_scaled.float()
)
target_logprobs_chunk = target_logprobs_chunk.float()
# Gather student logits for the top-k teacher token IDs
# target_token_ids_chunk: [chunk_size, top_k]
# student_logits_topk_temp_scaled: [chunk_size, top_k]
student_logits_topk_temp_scaled = torch.gather(
student_logits_temp_scaled, dim=-1, index=target_token_ids_chunk
)
# Student log-probabilities for the gathered top-k tokens
student_lse = torch.logsumexp(
student_logits_temp_scaled, dim=-1, keepdim=True
) # [chunk_size, 1]
student_logprobs_topk_temp_scaled = (
student_logits_topk_temp_scaled - student_lse
)
# we have the top-k student logprobs, normalize them
if normalize_topk:
student_logprobs_topk_temp_scaled = normalize_logprobs(
student_logprobs_topk_temp_scaled, topk
)
valid_mask = target_mask_chunk.to(torch.bool) # [chunk_size, top_k]
student_logprobs_topk_valid = student_logprobs_topk_temp_scaled[valid_mask]
teacher_logprobs_valid = target_logprobs_chunk[valid_mask]
# Teacher probabilities P(y|x_teacher) from logprobs
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
teacher_probs_valid = teacher_logprobs_valid.exp()
# Student probabilities P_student from log P_student
student_probs_topk_valid = student_logprobs_topk_valid.exp()
# kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
# The distillation loss is often formulated as -sum(P_teacher * log P_student)
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
# Here, target_logprobs_valid are log_softmax_teacher.
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
if beta == 0.0: # Contribution from Forward KL
fwd_kl_per_token = teacher_probs_valid * (
teacher_logprobs_valid - student_logprobs_topk_valid
)
kd_loss = fwd_kl_per_token.sum()
elif beta == 1.0: # Contribution from Reverse KL
rev_kl_per_token = student_probs_topk_valid * (
student_logprobs_topk_valid - teacher_logprobs_valid
)
kd_loss = rev_kl_per_token.sum()
else:
# JSD - Jensen-Shannon Divergence / Symmetric
mean_probs = (
1 - beta
) * student_probs_topk_valid + beta * teacher_probs_valid
log_mean_probs = mean_probs.log()
student_kl = F.kl_div(
log_mean_probs,
student_logprobs_topk_valid,
reduction="sum",
log_target=True,
)
teacher_kl = F.kl_div(
log_mean_probs, teacher_logprobs_valid, reduction="sum", log_target=True
)
jsd_loss = beta * teacher_kl + (1 - beta) * student_kl
kd_loss = jsd_loss
return kd_loss
@staticmethod
def _compute_loss_kl_topk(
student_input_chunk: torch.Tensor,
student_weight: torch.Tensor,
# Args for student_bias, target_token_ids_chunk etc. are passed to the lambda wrapped by grad_and_value
# or through `partial`. Let's make them explicit here for clarity.
target_token_ids_chunk: torch.Tensor,
target_logprobs_chunk: torch.Tensor,
target_mask_chunk: torch.Tensor,
target_chunk: torch.Tensor, # For hard loss (true labels)
student_bias: torch.Tensor = None, # This will be one of the grad targets
# Other params passed via `partial` from `forward`
distillation_loss_fn=None,
ignore_index: int = -100,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
compute_ce_loss: bool = True,
temperature: float = 1.0,
beta: float = 0.0,
normalize_topk: bool = True,
):
# Compute student logits for the chunk from hidden states and LM head
# student_input_chunk: [chunk_size, hidden_dim]
# student_lm_head_weight: [vocab_size, hidden_dim]
# student_logits_chunk: [chunk_size, vocab_size]
student_logits_chunk = F.linear(
student_input_chunk, student_weight, student_bias
)
ce_loss = torch.tensor(
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
)
if compute_ce_loss and weight_hard_loss > 0.0:
ce_loss = F.cross_entropy(
student_logits_chunk.view(-1, student_logits_chunk.shape[-1]),
target_chunk.view(-1),
reduction="sum",
ignore_index=ignore_index,
)
soft_loss = torch.tensor(
0.0, device=student_logits_chunk.device, dtype=student_logits_chunk.dtype
)
if weight_soft_loss > 0.0:
student_logits_chunk_temp_scaled = student_logits_chunk / temperature
# Assuming student_weight.shape[0] (vocab_size) is adequate for target_token_ids_chunk.max()
# No explicit padding here; user must ensure vocab alignment or pre-pad student_weight.
soft_loss = distillation_loss_fn(
student_logits_chunk_temp_scaled,
target_token_ids_chunk,
target_logprobs_chunk,
target_mask_chunk,
beta=beta,
normalize_topk=normalize_topk,
)
return soft_loss, ce_loss
@classmethod
def forward(
cls,
ctx,
student_input: torch.Tensor, # [batch_size, seq_len, dim]
student_lm_head_weight: torch.Tensor, # [dim, vocab_size]
target_token_ids: torch.Tensor, # [batch_size, seq_len, top_k]
target_logprobs: torch.Tensor, # [batch_size, seq_len, top_k]
target_mask: torch.Tensor, # [batch_size, seq_len, top_k]
true_labels: torch.Tensor, # [batch_size, seq_len]
student_lm_head_bias: torch.Tensor = None,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
ignore_index: int = -100,
temperature: float = 1.0,
beta: float = 0.0,
compiled: bool = False,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
CHUNK_SIZE = chunk_size # pylint: disable=invalid-name
grad_weight_acc = torch.zeros_like(student_lm_head_weight)
grad_inputs_list = []
grad_bias_acc = (
torch.zeros_like(student_lm_head_bias)
if student_lm_head_bias is not None
else None
)
kd_loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
ce_loss_acc = torch.zeros(
(), device=student_input.device, dtype=student_input.dtype
)
# This function will be what torch.func.grad_and_value differentiates.
# It takes student_input_chunk, student_weight (full), student_bias (full) as primals.
# Other necessary data (target_*, etc.) are passed as non-differentiable arguments.
def loss_fn_for_grad(
_student_input_chunk,
_student_lm_head_weight, # full weight
_student_lm_head_bias, # full bias
# Fixed arguments for a given chunk, not differentiated:
_target_token_ids_chunk,
_target_logprobs_chunk,
_target_mask_chunk,
_true_labels_chunk,
):
return cls._compute_loss_kl_topk(
student_input_chunk=_student_input_chunk,
student_weight=_student_lm_head_weight,
target_token_ids_chunk=_target_token_ids_chunk,
target_logprobs_chunk=_target_logprobs_chunk,
target_mask_chunk=_target_mask_chunk,
target_chunk=_true_labels_chunk,
student_bias=_student_lm_head_bias,
distillation_loss_fn=cls.distillation_loss_fn,
ignore_index=ignore_index,
weight_hard_loss=weight_hard_loss,
weight_soft_loss=weight_soft_loss,
compute_ce_loss=compute_ce_loss,
temperature=temperature,
beta=beta,
normalize_topk=normalize_topk,
)
def accumulate_chunk_grads(
student_input_chunk_ac,
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
):
# student_weight and student_bias are closed over from the outer scope (full tensors)
if student_lm_head_bias is not None:
(
(chunk_grad_input, chunk_grad_weight, chunk_grad_bias),
(chunk_kd_loss, chunk_ce_loss),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=(0, 1, 2), has_aux=True
)(
student_input_chunk_ac,
student_lm_head_weight,
student_lm_head_bias, # primals
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
) # non-primals
grad_bias_acc.add_(chunk_grad_bias)
else:
argnums_for_grad = (0, 1) # Differentiate wrt input_chunk, weight
(
(chunk_grad_input, chunk_grad_weight), # No grad for bias
(chunk_kd_loss, chunk_ce_loss),
) = torch.func.grad_and_value(
loss_fn_for_grad, argnums=argnums_for_grad, has_aux=True
)(
student_input_chunk_ac,
student_lm_head_weight,
None, # Pass None for student_bias primal
target_token_ids_chunk_ac,
target_logprobs_chunk_ac,
target_mask_chunk_ac,
true_labels_chunk_ac,
)
grad_weight_acc.add_(chunk_grad_weight)
kd_loss_acc.add_(chunk_kd_loss)
ce_loss_acc.add_(chunk_ce_loss)
return chunk_grad_input
if compiled:
accumulate_chunk_grads_compiled = torch.compile(
accumulate_chunk_grads, dynamic=True, backend="inductor"
) # dynamic=True often helpful
else:
accumulate_chunk_grads_compiled = accumulate_chunk_grads
# Use the same chunking logic as LigerFusedLinearDistillationBase.forward
B, N, D = student_input.shape # pylint: disable=invalid-name
K = target_token_ids.shape[-1] # pylint: disable=invalid-name
student_input_flat = student_input.reshape(-1, student_input.shape[-1])
target_token_ids_flat = target_token_ids.reshape(-1, target_token_ids.shape[-1])
target_logprobs_flat = target_logprobs.reshape(-1, target_logprobs.shape[-1])
target_mask_flat = target_mask.reshape(-1, target_mask.shape[-1])
# pad and shift for cross entropy loss
true_labels = torch.nn.functional.pad(true_labels, (0, 1), value=ignore_index)
true_labels_flat = true_labels[:, 1:].contiguous().view(-1)
num_chunks = max(1, student_input_flat.shape[0] // CHUNK_SIZE)
_student_input_chunks = torch.chunk(
student_input_flat, chunks=num_chunks, dim=0
)
_target_token_ids_chunks = torch.chunk(
target_token_ids_flat, chunks=num_chunks, dim=0
)
_target_logprobs_chunks = torch.chunk(
target_logprobs_flat, chunks=num_chunks, dim=0
)
_target_mask_chunks = torch.chunk(target_mask_flat, chunks=num_chunks, dim=0)
_true_labels_chunks = torch.chunk(true_labels_flat, chunks=num_chunks, dim=0)
for i in range(num_chunks):
grad_input_chunk = accumulate_chunk_grads_compiled(
_student_input_chunks[i],
_target_token_ids_chunks[i],
_target_logprobs_chunks[i],
_target_mask_chunks[i],
_true_labels_chunks[i],
)
grad_inputs_list.append(grad_input_chunk)
grad_inputs_combined = torch.cat(grad_inputs_list, dim=0)
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
ctx.hyperparams_count = 9 # Corresponds to number of hyperparams after main tensors in fwd signature
ctx.bias_was_none = student_lm_head_bias is None
ctx.orig_dims = (B, N, D, K)
# since this is packed, there is simply a single batch, so batchmean reduction of kl-div is simply the accumulated sum
# we still need to scale the kd_loss by the temp^2
kd_loss_acc = kd_loss_acc * (temperature**2)
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
return final_loss
@staticmethod
def backward(ctx, grad_output):
grad_input_flat, grad_weight, grad_bias_maybe = (
ctx.saved_tensors
) # grad_input_flat is (B*N, D)
# Scale gradients by grad_output if it's not 1.0
if not torch.equal(
grad_output,
torch.tensor(1.0, device=grad_output.device, dtype=grad_output.dtype),
):
grad_input_flat = grad_input_flat * grad_output
grad_weight = grad_weight * grad_output
if grad_bias_maybe is not None:
grad_bias_maybe = grad_bias_maybe * grad_output
# Reshape grad_input_flat to match original student_input shape (B, N, D)
# ctx.orig_dims stores (B, N, D, K)
# We need the first three dimensions for student_input's shape.
# Ensure that orig_dims are not (0,0,0,K) for empty inputs leading to view errors
if (
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
and grad_input_flat.numel() == 0
):
# If original input was empty, gradient should also be empty with correct shape
grad_input_reshaped = torch.zeros(
ctx.orig_dims[0],
ctx.orig_dims[1],
ctx.orig_dims[2],
dtype=grad_input_flat.dtype,
device=grad_input_flat.device,
)
elif grad_input_flat.numel() == 0 and not (
ctx.orig_dims[0] * ctx.orig_dims[1] * ctx.orig_dims[2] == 0
):
# This case should ideally not happen if forward path is correct (non-empty input -> non-empty flat grad)
# but as a safeguard:
grad_input_reshaped = torch.zeros(
ctx.orig_dims[0],
ctx.orig_dims[1],
ctx.orig_dims[2],
dtype=grad_input_flat.dtype,
device=grad_input_flat.device,
)
else:
grad_input_reshaped = grad_input_flat.view(
ctx.orig_dims[0], ctx.orig_dims[1], ctx.orig_dims[2]
)
nones_for_hyperparams = [None] * ctx.hyperparams_count
grad_bias_return = grad_bias_maybe if not ctx.bias_was_none else None
return (
grad_input_reshaped, # Gradient for student_input (reshaped)
grad_weight, # Gradient for student_lm_head_weight
None, # Gradient for target_token_ids
None, # Gradient for target_logprobs
None, # Gradient for target_mask
None, # Gradient for true_labels
grad_bias_return, # Gradient for student_lm_head_bias
*nones_for_hyperparams, # Grads for weight_hard_loss, ..., compute_ce_loss
)
class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
"""
wrapper for chunked top-k logprob kl-d
"""
def __init__(
self,
weight_hard_loss: float = 0.5,
weight_soft_loss: float = 0.5,
temperature: float = 1.0, # This is the kd_temperature
beta: float = 1.0,
ignore_index: int = -100,
compiled: bool = True,
chunk_size: int = 1024,
compute_ce_loss: bool = True,
normalize_topk: bool = True,
):
super().__init__()
if not (0.0 <= weight_hard_loss <= 1.0 and 0.0 <= weight_soft_loss <= 1.0):
raise ValueError("Loss weights must be between 0.0 and 1.0.")
if temperature <= 0:
raise ValueError("Temperature must be positive.")
self.weight_hard_loss = weight_hard_loss
self.weight_soft_loss = weight_soft_loss
self.temperature = temperature
self.beta = beta
self.ignore_index = ignore_index
self.compiled = compiled
self.chunk_size = chunk_size
self.compute_ce_loss = compute_ce_loss
self.normalize_topk = normalize_topk
if not self.compute_ce_loss and self.weight_hard_loss > 0.0:
print(
f"Warning: compute_ce_loss is False, but weight_hard_loss ({self.weight_hard_loss}) > 0. Hard loss will effectively be zero."
)
# self.weight_hard_loss = 0.0 # Or let user manage this
if self.weight_soft_loss == 0.0:
print(
"Warning: weight_soft_loss is 0.0. Soft (KD) loss will not be computed."
)
def forward(
self,
lm_head_weight: torch.Tensor, # Weights of the linear layer in the LM head
student_hidden_states: torch.Tensor, # student_hidden_states before the lm_head
target_token_ids: torch.Tensor,
target_logprobs: torch.Tensor,
target_mask: torch.Tensor,
true_labels: torch.Tensor,
student_bias: torch.Tensor = None,
) -> torch.Tensor:
return LigerFusedLinearKLTopKLogprobFunction.apply(
student_hidden_states,
lm_head_weight,
target_token_ids,
target_logprobs,
target_mask,
true_labels,
student_bias,
self.weight_hard_loss,
self.weight_soft_loss,
self.ignore_index,
self.temperature,
self.beta,
self.compiled,
self.chunk_size,
self.compute_ce_loss,
self.normalize_topk,
)

View File

@@ -1,97 +0,0 @@
"""
model patcher for chunked top-k kl-div
"""
from typing import Optional, Union, Unpack
import torch
from transformers import Cache
from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.utils import LossKwargs
class KwargsForCausalLM(FlashAttentionKwargs, LossKwargs):
"""
placeholder kwargs for hf model classes
"""
def kldiv_forward_llama_like(
self,
input_ids: Optional[torch.LongTensor] = None,
target_logprobs: Optional[torch.Tensor] = None,
target_token_ids: Optional[torch.LongTensor] = None,
target_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
cache_position: Optional[torch.LongTensor] = None,
logits_to_keep: Union[int, torch.Tensor] = 0, # pylint: disable=unused-argument
**kwargs: Unpack[KwargsForCausalLM], # type: ignore[misc]
) -> CausalLMOutputWithPast:
# pylint: disable=duplicate-code
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
cache_position=cache_position,
**kwargs,
)
hidden_states = outputs.last_hidden_state
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
# TODO, we can optimize this further by filtering hidden_states on sequence dimension using labels != -100
# self.loss_function should be LigerFusedLinearKLTopKLogprobLoss
loss = self.loss_function(
self.lm_head.weight,
hidden_states,
target_token_ids,
target_logprobs,
target_mask,
true_labels=labels,
)
num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
if num_items_in_batch is not None and num_items_in_batch > 0:
loss = loss / num_items_in_batch
return CausalLMOutputWithPast(
loss=loss,
logits=None,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def apply_kernel(model_type):
# Dynamically import the module and attention class
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
model_cls_prefix = "".join([part.capitalize() for part in model_type.split("_")])
module = __import__(module_path, fromlist=[f"{model_cls_prefix}ForCausalLM"])
model_cls = getattr(module, f"{model_cls_prefix}ForCausalLM")
model_cls.forward = kldiv_forward_llama_like

View File

@@ -16,7 +16,40 @@
loss for top_k KL divergence
"""
import torch
from torch import nn
def zscore_standardize(
logits: torch.Tensor,
mask: torch.Tensor = None,
base_temperature: float = 1.0,
eps: float = 1e-9,
):
"""
Z-score standardize along the last dimension of `logits`.
i.e., for each [B, seq_len] row, across K entries:
z = (logits - mean) / std,
then scale by 1 / base_temperature if desired.
mask can be broadcastable or None. If None, we standardize all elements.
"""
if mask is None:
# shape: [B, seq_len, K]
# Mean and std over dim=-1
mean = logits.mean(dim=-1, keepdim=True)
var = logits.var(dim=-1, unbiased=False, keepdim=True)
else:
# If you have to exclude some tokens, multiply by mask, etc.
float_mask = mask.to(logits.dtype)
count = float_mask.sum(dim=-1, keepdim=True).clamp_min(1.0)
mean = (logits * float_mask).sum(dim=-1, keepdim=True) / count
var = (float_mask * (logits - mean) ** 2).sum(dim=-1, keepdim=True) / count
std = torch.sqrt(var.clamp_min(eps))
z = (logits - mean) / std
# Scale by 1 / base_temperature
z = z / base_temperature
return z
@torch.jit.script
@@ -27,6 +60,7 @@ def loss(
target_mask: torch.Tensor,
num_items_in_batch: int = -1, # Use -1 to indicate "None"
kd_temperature: float = 1.0,
top_k_before_softmax: int = 0,
) -> torch.Tensor:
"""
A KD loss function that is TorchScript-friendly.
@@ -43,6 +77,8 @@ def loss(
num_items_in_batch (int, optional): The number of items in the batch.
kd_temperature (float, optional): The temperature for KD.
Default: 1.0
top_k_before_softmax (int, optional): Flag of whether to apply softmax before gathering student top-k logits
Default: 0
"""
target_logprobs = target_logprobs.float()
@@ -52,24 +88,46 @@ def loss(
# student_logits shape: [B, student_seq_len, vocab_size]
teacher_seq_len = target_token_ids.shape[1]
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
if top_k_before_softmax:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
student_logits_topk = student_logits_topk.float()
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Apply KD temperature to students logits
if kd_temperature != 1.0:
student_logits_topk = student_logits_topk / kd_temperature
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert student top-k logits to logprobs
student_logprobs_topk = student_logits_topk - torch.logsumexp(
student_logits_topk, dim=-1, keepdim=True
) # [B, teacher_seq_len, K]
else:
# Slice student logits to match teacher-provided sequence length
student_logits_for_kd = (
student_logits[:, :teacher_seq_len, :] / kd_temperature
) # [B, teacher_seq_len, vocab_size]
# keep in full precision for numerical stability of loss
student_logits_for_kd = student_logits_for_kd.float()
# Gather student logits for teacher's top-K tokens
student_logits_topk = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, teacher_seq_len, K]
# Compute logsumexp across full vocabulary
student_lse = torch.logsumexp(student_logits_for_kd, dim=-1, keepdim=True)
# Convert just the top-k logits to logprobs
student_logprobs_topk = student_logits_topk - student_lse
# Convert teacher_mask to boolean for indexing
# In TorchScript, .bool() is sometimes unsupported, so we do:
@@ -86,6 +144,10 @@ def loss(
kd_loss_per_token = teacher_probs * (target_logprobs - student_logprobs_topk)
kd_loss = kd_loss_per_token.sum()
# Multiply by T^2 (classical KD scaling)
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Normalize by number of items (if provided) or by valid tokens
if num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
@@ -96,74 +158,80 @@ def loss(
return kd_loss
class ChunkedTopKKDLoss(nn.Module):
def topk_kd_loss_with_zscore(
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K], sums to 1.0 in prob space
target_mask: torch.Tensor, # [B, seq_len, K] or [B, seq_len]
kd_temperature: float = 1.0, # classic KD temperature
zscore_base_temp: float = 1.0, # from the paper
num_items_in_batch: int = -1,
):
"""
A wrapper that chunks (splits) the student and teacher outputs along the time dimension
to reduce peak memory usage when upcasting from bf16 to fp32, especially for large vocabularies.
Usage is analogous to ForwardKLWithChunkedOutputLoss but adapted to top-K teacher logprobs.
A variant of top_k KL divergence with Z-score scaling
from "Logit Standardization in Knowledge Distillation".
"""
def __init__(self, num_output_chunks: int = 8, kd_temperature: float = 1.0):
super().__init__()
self.num_output_chunks = num_output_chunks
self.kd_temperature = kd_temperature
target_logprobs = target_logprobs.float()
def forward(
self,
student_logits: torch.Tensor, # [B, seq_len, vocab_size]
target_token_ids: torch.Tensor, # [B, seq_len, K]
target_logprobs: torch.Tensor, # [B, seq_len, K]
target_mask: torch.Tensor, # [B, seq_len, K]
num_items_in_batch: int = -1, # optional batch size for normalization
) -> torch.Tensor:
B, teacher_seq_len, K = target_logprobs.shape # pylint: disable=invalid-name
# 1) Gather the student's top-k logits to match teacher
student_logits_for_kd = student_logits[
:, :teacher_seq_len, :
] # [B, seq_len, vocab]
student_topk_logits = torch.gather(
student_logits_for_kd, dim=-1, index=target_token_ids
) # [B, seq_len, K]
# 1. Split along the "token" dimension (dim=1).
student_logits_chunks = student_logits.chunk(self.num_output_chunks, dim=1)
token_ids_chunks = target_token_ids.chunk(self.num_output_chunks, dim=1)
logprobs_chunks = target_logprobs.chunk(self.num_output_chunks, dim=1)
mask_chunks = target_mask.chunk(self.num_output_chunks, dim=1)
student_topk_logits = student_topk_logits.float()
# We'll accumulate a global "sum of losses" and "sum of valid tokens"
# so that our final average is consistent with the entire sequence/batch.
total_loss = 0.0
total_valid_tokens = 0
# 2) If you want to keep the "classical" T scaling, apply it first
if kd_temperature != 1.0:
student_topk_logits = student_topk_logits / kd_temperature
# 2. Loop over each chunk and compute a chunk-specific loss.
for st_chunk, tid_chunk, lp_chunk, msk_chunk in zip(
student_logits_chunks, token_ids_chunks, logprobs_chunks, mask_chunks
):
# We pass num_items_in_batch=-1 so that the kd_loss
# will average over *this chunk's* valid tokens only.
chunk_loss = loss(
student_logits=st_chunk,
target_token_ids=tid_chunk,
target_logprobs=lp_chunk,
target_mask=msk_chunk,
num_items_in_batch=-1, # ensure per-chunk averaging by valid tokens
kd_temperature=self.kd_temperature,
)
# 3) Convert teacher logprobs -> treat them as “logits” for z-score
# (They differ by +some_constant from real logits, but in z-score
# that constant is subtracted out anyway.)
teacher_logits_for_zscore = target_logprobs # rename variable for clarity
# kd_loss returns an average over the chunk's valid tokens.
# We want a global average in the end, so we need to reweight
# by the number of valid tokens in this chunk and keep track of the total.
chunk_valid_mask = msk_chunk.to(torch.bool)
chunk_valid_count = chunk_valid_mask.sum() # scalar tensor
# 4) Z-score teacher and student
# If target_mask is 2D, expand to 3D for the K dimension
if target_mask.dim() == 2 and target_mask.shape[:2] == (B, teacher_seq_len):
target_mask = target_mask.unsqueeze(-1).expand(-1, -1, K)
# Re-scale "chunk average" back to "chunk sum"
chunk_loss_sum = chunk_loss * chunk_valid_count
teacher_z = zscore_standardize(
teacher_logits_for_zscore, mask=target_mask, base_temperature=zscore_base_temp
)
student_z = zscore_standardize(
student_topk_logits, mask=target_mask, base_temperature=zscore_base_temp
)
total_loss += chunk_loss_sum
total_valid_tokens += chunk_valid_count
# 5) Convert to log-probs for KL
teacher_logprobs_z = teacher_z - torch.logsumexp(teacher_z, dim=-1, keepdim=True)
student_logprobs_z = student_z - torch.logsumexp(student_z, dim=-1, keepdim=True)
# 3. Normalize *once* at the end.
if num_items_in_batch > 0:
# If the user gave us a manual denominator (e.g. total items in batch),
# we divide by it. Typically used if each item is of different length.
final_loss = total_loss / float(num_items_in_batch)
else:
# Otherwise, divide by total valid tokens across all chunks.
# to get the same result as a non-chunked approach.
final_loss = total_loss / float(total_valid_tokens)
# 6) Restrict to valid tokens if needed
valid_mask = target_mask.bool() # shape [B, seq_len, K]
teacher_probs_z = teacher_logprobs_z.exp()
teacher_probs_z = teacher_probs_z[valid_mask]
teacher_logprobs_z = teacher_logprobs_z[valid_mask]
student_logprobs_z = student_logprobs_z[valid_mask]
return final_loss
# 7) forward KL: sum( p_teacher * [log(p_teacher) - log(p_student)] )
kd_loss_per_token = teacher_probs_z * (teacher_logprobs_z - student_logprobs_z)
kd_loss = kd_loss_per_token.sum()
# 8) If using classical KD scaling by T^2
if kd_temperature != 1.0:
kd_loss = kd_loss * (kd_temperature**2)
# Optionally scale by zscore_base_temp**2 if you want (paper might differ).
# kd_loss = kd_loss * (zscore_base_temp**2)
# 9) Normalize
if num_items_in_batch is not None and num_items_in_batch > 0:
kd_loss = kd_loss / float(num_items_in_batch)
else:
kd_loss = kd_loss / float(kd_loss_per_token.size(0))
return kd_loss

View File

@@ -18,7 +18,8 @@ KD trainer
from axolotl.core.trainers.base import AxolotlTrainer
from .kernels.liger import LigerFusedLinearKLTopKLogprobLoss
from .topk_logprob.forward_kl import loss as topk_kd_loss
from .topk_logprob.forward_kl import topk_kd_loss_with_zscore
class AxolotlKDTrainer(AxolotlTrainer):
@@ -26,18 +27,6 @@ class AxolotlKDTrainer(AxolotlTrainer):
Custom trainer subclass for Knowledge Distillation (KD)
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.model_accepts_loss_kwargs = True
self.model._loss_function = LigerFusedLinearKLTopKLogprobLoss(
self.args.kd_ce_alpha, # hard label loss
self.args.kd_alpha, # kd loss
self.args.kd_temperature,
self.args.kd_beta,
compute_ce_loss=bool(self.args.kd_ce_alpha),
normalize_topk=self.args.kd_normalize_topk,
)
def _set_signature_columns_if_needed(self):
super()._set_signature_columns_if_needed()
columns_to_add = []
@@ -63,12 +52,12 @@ class AxolotlKDTrainer(AxolotlTrainer):
Subclass and override for custom behavior.
"""
if (
self.args.sample_packing
and hasattr(inputs, "attention_mask")
and hasattr(inputs, "position_ids")
):
del inputs["attention_mask"]
target_logprobs = inputs.pop("target_logprobs")
target_token_ids = inputs.pop("target_token_ids")
target_mask = inputs.pop("target_mask")
seq_len = target_token_ids.shape[1]
if self.model_accepts_loss_kwargs:
loss_kwargs = {}
@@ -76,4 +65,49 @@ class AxolotlKDTrainer(AxolotlTrainer):
loss_kwargs["num_items_in_batch"] = num_items_in_batch
inputs = {**inputs, **loss_kwargs}
outputs = model(**inputs)
return outputs[0]
# FIXME: account for tokenizer.padding_side
student_logits = outputs["logits"][:, : seq_len - 1, :].contiguous()
shift_logits = student_logits.contiguous()
target_logprobs_for_loss = target_logprobs[..., 1:, :].contiguous()
target_token_ids_for_loss = target_token_ids[..., 1:, :].contiguous()
target_mask_for_loss = target_mask[..., 1:, :].contiguous()
if self.args.kd_zscore_base_temp:
loss_kd = topk_kd_loss_with_zscore(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
kd_temperature=self.args.kd_temperature,
zscore_base_temp=self.args.kd_zscore_base_temp,
num_items_in_batch=num_items_in_batch,
)
else:
loss_kd = topk_kd_loss(
shift_logits,
target_token_ids_for_loss,
target_logprobs_for_loss,
target_mask_for_loss,
num_items_in_batch=num_items_in_batch,
kd_temperature=self.args.kd_temperature,
top_k_before_softmax=1 if self.args.kd_top_k_before_softmax else 0,
)
if self.args.kd_ce_alpha > 0:
kd_alpha = self.args.kd_alpha
loss = self.args.kd_ce_alpha * outputs["loss"] + kd_alpha * loss_kd
else:
loss = loss_kd
# Save past state if it exists
# TODO: this needs to be fixed and made cleaner later.
if self.args.past_index >= 0:
self._past = outputs[ # pylint: disable=attribute-defined-outside-init
self.args.past_index
]
if self.args.average_tokens_across_devices and self.model_accepts_loss_kwargs:
loss *= self.accelerator.num_processes
return (loss, outputs) if return_outputs else loss

View File

@@ -1,100 +0,0 @@
"""Helper KD utils"""
import math
from typing import List, Union
import numpy as np
import torch
from torch import FloatTensor, Tensor
def normalize_logprobs(logprobs: FloatTensor, topk: int) -> FloatTensor:
"""
Re-normalizes top-k raw logprobs as probabilities, and converts back to logprobs.
"""
# Ensure raw_logprobs matches kd_online_topk length for tensor operations
# This should ideally be handled by the caller ensuring correct padding/truncation first
if logprobs.shape[-1] != topk:
# pad last dimension of logprobs to match topk length with -inf
padding_len = topk - logprobs.shape[-1]
padding_tensor = torch.full(
(
*logprobs.shape[:-1],
padding_len,
), # Takes all dimensions of logprobs except the last, then appends padding_needed
float("-inf"),
dtype=logprobs.dtype,
device=logprobs.device,
)
logprobs = torch.cat((logprobs, padding_tensor), dim=-1)
# Convert logprobs at T_online to probabilities
# use log sum exp trick to avoid underflow
position_logprobs_lse = torch.logsumexp(logprobs, dim=-1, keepdim=True)
teacher_probs_t_online = torch.exp(logprobs - position_logprobs_lse)
# Normalize probabilities (sum to 1)
# This is important if the top-k from server aren't a full distribution
teacher_probs_t_online_sum = teacher_probs_t_online.sum(dim=-1, keepdim=True)
teacher_probs_t_online = teacher_probs_t_online / teacher_probs_t_online_sum
final_logprobs_tensor = torch.log(teacher_probs_t_online)
return final_logprobs_tensor
def strided_chunk_views(
tensor: Union[np.ndarray, torch.Tensor],
chunks: int,
dim: int = 0,
stride: int = 1,
chunk_size: int | None = None,
) -> List[Union[np.ndarray, torch.Tensor]]:
"""
Split a tensor into chunks along a dimension with striding, prioritizing views over copies.
Args:
tensor: Input tensor (numpy array or torch tensor)
chunks: Number of chunks to create
dim: Dimension along which to chunk (default: 0)
stride: Stride between chunk starting positions (default: 1)
chunk_size: Size of each chunk. If None, calculated automatically (default: None)
Returns:
List of tensor chunks (views when possible, copies when necessary)
"""
# Get the size of the specified dimension
dim_size = tensor.shape[dim]
# Calculate chunk size if not provided
if chunk_size is None:
chunk_size = (dim_size + chunks - 1) // chunks # Ceiling division
chunks_list = []
for i in range(chunks):
start_idx = i * stride
end_idx = min(start_idx + chunk_size, dim_size)
# Break if we've gone beyond the tensor
if start_idx >= dim_size:
break
# Create slice objects for all dimensions
slices = [slice(None)] * tensor.ndim
slices[dim] = slice(start_idx, end_idx)
chunk = tensor[tuple(slices)]
chunks_list.append(chunk)
return chunks_list
def chunk_overlap(input_tensor: Tensor, chunks: int, dim: int = 0, overlap: int = 1):
dim_size = input_tensor.shape[dim]
stride = math.ceil(dim_size / chunks)
return strided_chunk_views(
input_tensor, chunks, dim, stride=stride, chunk_size=stride + overlap
)

View File

@@ -1,13 +1,10 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
from __future__ import annotations
import importlib
import inspect
import os
import signal
import sys
import typing
import weakref
from contextlib import ExitStack
from pathlib import Path
@@ -28,6 +25,7 @@ from axolotl.common.datasets import TrainDatasetMeta
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
fix_untrained_tokens,
)
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.integrations.base import PluginManager
from axolotl.loaders import (
ModelLoader,
@@ -47,9 +45,6 @@ try:
except ImportError:
BetterTransformer = None
if typing.TYPE_CHECKING:
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
LOG = get_logger(__name__)
@@ -477,7 +472,7 @@ def handle_untrained_tokens_fix(
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
HFRLTrainerBuilder | HFCausalTrainerBuilder,
PeftModel | PreTrainedModel,
PreTrainedTokenizer,
PeftConfig | None,

View File

@@ -52,10 +52,3 @@ def patch_optimized_env():
if os.getenv("HF_HUB_ENABLE_HF_TRANSFER") is None:
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
set_pytorch_cuda_alloc_conf()
def get_not_null(value, default=None):
"""
return the value if it's not None, otherwise return the default value
"""
return value if value is not None else default

File diff suppressed because one or more lines are too long

View File

@@ -1,7 +1,7 @@
"""Data collators for axolotl to pad labels and position_ids for packed sequences"""
from dataclasses import dataclass
from typing import Any, List
from typing import Any
import numpy as np
from transformers import PreTrainedTokenizerBase
@@ -161,7 +161,7 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features: List[List[dict]] = [features]
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():

View File

@@ -40,7 +40,6 @@ def retry_on_request_exceptions(
except (
requests.exceptions.ReadTimeout,
requests.exceptions.ConnectionError,
requests.exceptions.HTTPError,
huggingface_hub.errors.HfHubHTTPError,
) as exc:
if attempt < max_retries - 1:

View File

@@ -258,7 +258,7 @@ class MultipackBatchSampler(BatchSampler):
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths
packing_efficiency_estimate: float = 1.0, # Initial efficiency estimate
drop_last: bool = True, # Whether to drop final batches (might be incomplete)
drop_last: bool = False, # Whether to drop final batches (might be incomplete)
num_count_samples: int = 16, # Number of times to estimate batch count
sequential: bool = False, # Whether to use sequential packing
group_size: int = 100_000, # Size of groups for parallel packing
@@ -443,18 +443,10 @@ class MultipackBatchSampler(BatchSampler):
if self._len_across_ranks is None:
# Sample multiple times to get stable estimate
_sampled_lens = []
for _ in range(self.num_count_samples):
self._batches = None # Reset cached batches
_sampled_lens.append(len(self.generate_batches(set_stats=False)))
len_batches = min(_sampled_lens)
len_batches = min( # pylint: disable=consider-using-generator
[len(self._batches) for _ in range(self.num_count_samples)]
)
# Gather minimum across all ranks
if self._len_across_ranks is None:
self._len_across_ranks = self.gather_len_batches(len_batches)
else:
self._len_across_ranks = min(
self._len_across_ranks, self.gather_len_batches(len_batches)
)
self._len_across_ranks = self.gather_len_batches(len_batches)
return self._len_across_ranks

View File

@@ -275,6 +275,7 @@ class AxolotlInputConfig(
torch_compile_mode: Literal["default", "reduce-overhead", "max-autotune"] | None = (
None
)
compile_optimizer: bool | None = None
max_steps: int | None = None
warmup_steps: int | None = None

View File

@@ -16,6 +16,7 @@ from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
from axolotl.utils.distributed import reduce_and_broadcast
from axolotl.utils.environment import check_cuda_p2p_ib_support
@@ -481,9 +482,6 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
data_loader_len * cfg.num_epochs * cfg.sequence_parallel_degree
)
)
if cfg.dataloader_drop_last:
# drop the last batch for each epoch
total_num_steps -= int(math.ceil(cfg.num_epochs))
def calc_sample_packing_eff_est(estimates: List[float]):
LOG.info(f"sample_packing_eff_est across ranks: {repr(estimates)}")
@@ -631,8 +629,6 @@ def setup_trainer(
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
on the provided parameters.
"""
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
if (
cfg.torch_compile
and cfg.fsdp_config