support for dynamic plugin training args mixins and symmetric kl
This commit is contained in:
@@ -149,6 +149,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
|
from axolotl.core.training_args import (
|
||||||
|
AxolotlPRMConfig,
|
||||||
|
AxolotlRewardConfig,
|
||||||
|
AxolotlTrainingArguments,
|
||||||
|
)
|
||||||
|
|
||||||
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
training_arguments_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||||
total_num_steps
|
total_num_steps
|
||||||
)
|
)
|
||||||
@@ -317,12 +323,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs["image_resize_algorithm"] = (
|
training_arguments_kwargs["image_resize_algorithm"] = (
|
||||||
self.cfg.image_resize_algorithm
|
self.cfg.image_resize_algorithm
|
||||||
)
|
)
|
||||||
if self.cfg.kd_ce_alpha is not None:
|
|
||||||
training_arguments_kwargs["kd_ce_alpha"] = self.cfg.kd_ce_alpha
|
if self.cfg.plugins:
|
||||||
if self.cfg.kd_alpha is not None:
|
plugin_manager = PluginManager.get_instance()
|
||||||
training_arguments_kwargs["kd_alpha"] = self.cfg.kd_alpha
|
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||||
if self.cfg.kd_temperature is not None:
|
if plugin_training_args:
|
||||||
training_arguments_kwargs["kd_temperature"] = self.cfg.kd_temperature
|
training_arguments_kwargs.update(plugin_training_args)
|
||||||
|
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
training_args_cls = AxolotlRewardConfig
|
training_args_cls = AxolotlRewardConfig
|
||||||
@@ -403,7 +409,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return trainer
|
return trainer
|
||||||
|
|
||||||
def build_collator(
|
def build_collator(
|
||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self,
|
||||||
|
training_args, # type: "AxolotlTrainingArguments"
|
||||||
|
is_eval=False,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -12,11 +12,6 @@ from axolotl.core.trainers import (
|
|||||||
from axolotl.core.trainers.dpo import DPOStrategy
|
from axolotl.core.trainers.dpo import DPOStrategy
|
||||||
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
from axolotl.core.trainers.dpo.args import AxolotlDPOConfig
|
||||||
from axolotl.core.trainers.grpo import GRPOStrategy
|
from axolotl.core.trainers.grpo import GRPOStrategy
|
||||||
from axolotl.core.training_args import (
|
|
||||||
AxolotlCPOConfig,
|
|
||||||
AxolotlKTOConfig,
|
|
||||||
AxolotlORPOConfig,
|
|
||||||
)
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import ensure_dtype
|
from axolotl.loaders.utils import ensure_dtype
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
@@ -79,6 +74,12 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
"""
|
"""
|
||||||
Returns training_args and trainer_kwargs
|
Returns training_args and trainer_kwargs
|
||||||
"""
|
"""
|
||||||
|
from axolotl.core.training_args import (
|
||||||
|
AxolotlCPOConfig,
|
||||||
|
AxolotlKTOConfig,
|
||||||
|
AxolotlORPOConfig,
|
||||||
|
)
|
||||||
|
|
||||||
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
|
training_args_kwargs, trainer_kwargs = self._set_base_training_args(
|
||||||
total_num_steps=total_num_steps
|
total_num_steps=total_num_steps
|
||||||
)
|
)
|
||||||
@@ -165,6 +166,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if blocklist_key in training_args_kwargs:
|
if blocklist_key in training_args_kwargs:
|
||||||
del training_args_kwargs[blocklist_key]
|
del training_args_kwargs[blocklist_key]
|
||||||
|
|
||||||
|
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
plugin_training_args = plugin_manager.get_training_args(self.cfg)
|
||||||
|
if plugin_training_args:
|
||||||
|
training_args_kwargs.update(plugin_training_args)
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
logging_first_step=True,
|
logging_first_step=True,
|
||||||
**training_args_kwargs,
|
**training_args_kwargs,
|
||||||
|
|||||||
@@ -2,224 +2,17 @@
|
|||||||
extra axolotl specific training args
|
extra axolotl specific training args
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from __future__ import annotations
|
||||||
from typing import Optional
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional, Type
|
||||||
|
|
||||||
from PIL.Image import Resampling
|
|
||||||
from transformers import TrainingArguments
|
from transformers import TrainingArguments
|
||||||
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
from trl import CPOConfig, KTOConfig, ORPOConfig, PRMConfig, RewardConfig
|
||||||
|
|
||||||
|
from axolotl.integrations.config import merge_training_args
|
||||||
|
|
||||||
@dataclass
|
AxolotlTrainingMixins: Type = merge_training_args()
|
||||||
class AxolotlTrainingMixins:
|
|
||||||
"""
|
|
||||||
Mixin class for the Axolotl training args.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
model_type: Optional[str] = field(
|
|
||||||
default=None, metadata={"help": "HF model configuration model_type."}
|
|
||||||
)
|
|
||||||
lr_quadratic_warmup: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
|
||||||
)
|
|
||||||
pretraining: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Indicates to trainer whether we are doing continued pretraining."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use sample packing for efficient training."},
|
|
||||||
)
|
|
||||||
sample_packing_sequentially: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={
|
|
||||||
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
multipack_real_batches: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "Use real batches for efficient training."},
|
|
||||||
)
|
|
||||||
eval_sample_packing: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Use sample packing for efficient evals."},
|
|
||||||
)
|
|
||||||
sample_packing_efficiency: float = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
|
||||||
)
|
|
||||||
sample_packing_bin_size: int = field(
|
|
||||||
default=200,
|
|
||||||
metadata={
|
|
||||||
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
sample_packing_group_size: int = field(
|
|
||||||
default=100000,
|
|
||||||
metadata={
|
|
||||||
"help": "The number of samples to group together for packing. Increase for better packing."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
max_seq_length: int = field(
|
|
||||||
default=2048,
|
|
||||||
metadata={"help": "The maximum sequence length the model can handle"},
|
|
||||||
)
|
|
||||||
relora_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_warmup_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_anneal_steps: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
|
||||||
)
|
|
||||||
relora_prune_ratio: Optional[float] = field(
|
|
||||||
default=0.9,
|
|
||||||
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
|
||||||
)
|
|
||||||
bench_split: Optional[str] = field(
|
|
||||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
|
||||||
)
|
|
||||||
bench_dataset: Optional[str] = field(
|
|
||||||
default="pharaouk/dharma-1/dharma_1_mini.json",
|
|
||||||
metadata={
|
|
||||||
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
do_bench_eval: Optional[bool] = field(
|
|
||||||
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
|
||||||
)
|
|
||||||
do_causal_lm_eval: Optional[bool] = field(
|
|
||||||
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
|
||||||
)
|
|
||||||
max_bench_samples: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
|
||||||
},
|
|
||||||
)
|
|
||||||
bench_source_max_len: int = field(
|
|
||||||
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
|
||||||
)
|
|
||||||
dataloader_prefetch_factor: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "prefetch_factor argument to the dataloader"},
|
|
||||||
)
|
|
||||||
cosine_min_lr_ratio: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
|
||||||
)
|
|
||||||
cosine_constant_lr_ratio: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
loraplus_lr_ratio: Optional[float] = field(
|
|
||||||
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
|
||||||
)
|
|
||||||
loraplus_lr_embedding: Optional[float] = field(
|
|
||||||
default=1e-6,
|
|
||||||
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
|
||||||
)
|
|
||||||
embedding_lr_scale: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Scale the learning rate for the embedding layers."},
|
|
||||||
)
|
|
||||||
lr_groups: Optional[list[dict]] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Specify learning rate groups for with different LRs."},
|
|
||||||
)
|
|
||||||
embedding_lr: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "absolute learning rate for the embedding layers."},
|
|
||||||
)
|
|
||||||
qlora: bool = field(
|
|
||||||
default=False,
|
|
||||||
metadata={"help": "whether this is a qlora training"},
|
|
||||||
)
|
|
||||||
orpo_alpha: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
)
|
|
||||||
lisa_n_layers: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "the number of activate layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_step_interval: Optional[int] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "how often to switch layers in LISA"},
|
|
||||||
)
|
|
||||||
lisa_layers_attribute: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "path under the model to access the layers"},
|
|
||||||
)
|
|
||||||
curriculum_sampling: Optional[bool] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
|
||||||
)
|
|
||||||
alternate_lr_scheduler_type: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
chat_template: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "Chat template converting chat messages to text"},
|
|
||||||
)
|
|
||||||
|
|
||||||
kd_ce_alpha: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
kd_alpha: Optional[float] = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={"help": "The alpha scaling parameter for KD loss"},
|
|
||||||
)
|
|
||||||
|
|
||||||
kd_temperature: Optional[float] = field(
|
|
||||||
default=1.0,
|
|
||||||
metadata={
|
|
||||||
"help": "the temperature parameter for KL divergence loss when using KD"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
adam_beta3: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
adam_epsilon2: Optional[float] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={
|
|
||||||
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# multi-modal section
|
|
||||||
|
|
||||||
image_size: int | tuple[int, int] | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The size of the image to resize to"},
|
|
||||||
)
|
|
||||||
|
|
||||||
image_resize_algorithm: Resampling | None = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "The algorithm to use for image resizing"},
|
|
||||||
)
|
|
||||||
|
|
||||||
# end of multi-modal section
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
222
src/axolotl/core/training_args_base.py
Normal file
222
src/axolotl/core/training_args_base.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""
|
||||||
|
Base Axolotl Training Mixins shared across various trainer configs
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from PIL.Image import Resampling
|
||||||
|
|
||||||
|
from axolotl.utils.schemas.enums import RingAttnFunc
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AxolotlTrainingMixins:
|
||||||
|
"""
|
||||||
|
Mixin class for the Axolotl training args.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
model_type: Optional[str] = field(
|
||||||
|
default=None, metadata={"help": "HF model configuration model_type."}
|
||||||
|
)
|
||||||
|
lr_quadratic_warmup: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
|
)
|
||||||
|
pretraining: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Indicates to trainer whether we are doing continued pretraining."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_packing: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
|
)
|
||||||
|
sample_packing_sequentially: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={
|
||||||
|
"help": "Use next-fit sample packing that preserves the order of samples coming from the sampler. Use in combination with curriculum_sampling for fully sequential packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
multipack_real_batches: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use real batches for efficient training."},
|
||||||
|
)
|
||||||
|
eval_sample_packing: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Use sample packing for efficient evals."},
|
||||||
|
)
|
||||||
|
sample_packing_efficiency: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
|
)
|
||||||
|
sample_packing_bin_size: int = field(
|
||||||
|
default=200,
|
||||||
|
metadata={
|
||||||
|
"help": "The max number of samples that packed sample can contain after packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
sample_packing_group_size: int = field(
|
||||||
|
default=100000,
|
||||||
|
metadata={
|
||||||
|
"help": "The number of samples to group together for packing. Increase for better packing."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
max_seq_length: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
|
)
|
||||||
|
relora_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_warmup_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_anneal_steps: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||||
|
)
|
||||||
|
relora_prune_ratio: Optional[float] = field(
|
||||||
|
default=0.9,
|
||||||
|
metadata={"help": "prune ratio for magnitude pruning of the optimizer"},
|
||||||
|
)
|
||||||
|
bench_split: Optional[str] = field(
|
||||||
|
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||||
|
)
|
||||||
|
bench_dataset: Optional[str] = field(
|
||||||
|
default="pharaouk/dharma-1/dharma_1_mini.json",
|
||||||
|
metadata={
|
||||||
|
"help": "Benchmark dataset to use: options are `mmlu-zs`, `mmlu-fs`, or the full path to the dataset file"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
do_bench_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Benchmark evaluation."}
|
||||||
|
)
|
||||||
|
do_causal_lm_eval: Optional[bool] = field(
|
||||||
|
default=False, metadata={"help": "Whether to run the Causal LM evaluation."}
|
||||||
|
)
|
||||||
|
max_bench_samples: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "If set, only evaluates on `max_bench_samples` of the benchmark dataset."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
bench_source_max_len: int = field(
|
||||||
|
default=2048, metadata={"help": "Maximum source sequence length for bench."}
|
||||||
|
)
|
||||||
|
dataloader_prefetch_factor: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "prefetch_factor argument to the dataloader"},
|
||||||
|
)
|
||||||
|
cosine_min_lr_ratio: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Minimum learning rate is min_lr_ratio * learning_rate"},
|
||||||
|
)
|
||||||
|
cosine_constant_lr_ratio: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "Starting constant learning rate step is cosine_constant_lr_ratio * max_steps"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
loraplus_lr_ratio: Optional[float] = field(
|
||||||
|
default=None, metadata={"help": "loraplus learning rate ratio lr_B / lr_A."}
|
||||||
|
)
|
||||||
|
loraplus_lr_embedding: Optional[float] = field(
|
||||||
|
default=1e-6,
|
||||||
|
metadata={"help": "loraplus learning rate for lora embedding layers."},
|
||||||
|
)
|
||||||
|
embedding_lr_scale: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Scale the learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
|
lr_groups: Optional[list[dict]] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Specify learning rate groups for with different LRs."},
|
||||||
|
)
|
||||||
|
embedding_lr: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "absolute learning rate for the embedding layers."},
|
||||||
|
)
|
||||||
|
qlora: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "whether this is a qlora training"},
|
||||||
|
)
|
||||||
|
orpo_alpha: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
lisa_n_layers: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "the number of activate layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_step_interval: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "how often to switch layers in LISA"},
|
||||||
|
)
|
||||||
|
lisa_layers_attribute: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "path under the model to access the layers"},
|
||||||
|
)
|
||||||
|
curriculum_sampling: Optional[bool] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "whether to use sequential sampling for curriculum learning"},
|
||||||
|
)
|
||||||
|
alternate_lr_scheduler_type: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
chat_template: Optional[str] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "Chat template converting chat messages to text"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# kd_ce_alpha: Optional[float] = field(
|
||||||
|
# default=None,
|
||||||
|
# metadata={
|
||||||
|
# "help": "The alpha scaling parameter for SFT cross entropy loss when using KD"
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# kd_alpha: Optional[float] = field(
|
||||||
|
# default=1.0,
|
||||||
|
# metadata={"help": "The alpha scaling parameter for KD loss"},
|
||||||
|
# )
|
||||||
|
#
|
||||||
|
# kd_temperature: Optional[float] = field(
|
||||||
|
# default=1.0,
|
||||||
|
# metadata={
|
||||||
|
# "help": "the temperature parameter for KL divergence loss when using KD"
|
||||||
|
# },
|
||||||
|
# )
|
||||||
|
|
||||||
|
adam_beta3: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The beta3 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
adam_epsilon2: Optional[float] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={
|
||||||
|
"help": "The epsilon2 hyperparameter used in some optimizers such as CAME"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# multi-modal section
|
||||||
|
|
||||||
|
image_size: int | tuple[int, int] | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The size of the image to resize to"},
|
||||||
|
)
|
||||||
|
|
||||||
|
image_resize_algorithm: Resampling | None = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "The algorithm to use for image resizing"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# end of multi-modal section
|
||||||
@@ -22,6 +22,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import importlib
|
import importlib
|
||||||
|
import logging
|
||||||
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
from typing import TYPE_CHECKING, Callable, OrderedDict, Union
|
||||||
|
|
||||||
from peft import PeftModel
|
from peft import PeftModel
|
||||||
@@ -83,6 +85,11 @@ class BasePlugin:
|
|||||||
def get_input_args(self) -> str | None:
|
def get_input_args(self) -> str | None:
|
||||||
"""Returns a pydantic model for the plugin's input arguments."""
|
"""Returns a pydantic model for the plugin's input arguments."""
|
||||||
|
|
||||||
|
def get_training_args_mixin(self) -> str | None:
|
||||||
|
"""
|
||||||
|
Returns a dataclass model for the plugin's training arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
self, cfg: DictDefault, preprocess: bool = False
|
self, cfg: DictDefault, preprocess: bool = False
|
||||||
) -> Union["TrainDatasetMeta", None]:
|
) -> Union["TrainDatasetMeta", None]:
|
||||||
@@ -158,6 +165,32 @@ class BasePlugin:
|
|||||||
trainer: The trainer object for training.
|
trainer: The trainer object for training.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def get_training_args(self, cfg: DictDefault): # pylint: disable=unused-argument):
|
||||||
|
"""
|
||||||
|
Returns custom training arguments to set on TrainingArgs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The global axolotl configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: dict containing the training arguments.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_collator_cls_and_kwargs(
|
||||||
|
self, cfg: DictDefault, is_eval: bool=False
|
||||||
|
): # pylint: disable=unused-argument):
|
||||||
|
"""
|
||||||
|
Returns a custom class for the collator.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: The global axolotl configuration.
|
||||||
|
is_eval: Whether this is an eval split.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
class: The class for the collator.
|
||||||
|
"""
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
|
def create_optimizer(self, cfg: DictDefault, trainer: Trainer) -> Optimizer | None:
|
||||||
"""Creates and returns an optimizer for training.
|
"""Creates and returns an optimizer for training.
|
||||||
@@ -167,84 +200,7 @@ class BasePlugin:
|
|||||||
trainer: The trainer object for training.
|
trainer: The trainer object for training.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
<<<<<<< HEAD
|
|
||||||
The created optimizer.
|
The created optimizer.
|
||||||
=======
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
|
||||||
"""
|
|
||||||
Performs actions before LoRA weights are loaded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (dict): The configuration for the plugin.
|
|
||||||
model (object): The loaded model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
|
||||||
"""
|
|
||||||
Performs actions after LoRA weights are loaded.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (dict): The configuration for the plugin.
|
|
||||||
model (object): The loaded model.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
|
|
||||||
"""
|
|
||||||
Returns a custom class for the trainer.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (dict): The global axolotl configuration.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
class: The class for the trainer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def get_collator_cls_and_kwargs(
|
|
||||||
self, cfg, is_eval=False
|
|
||||||
): # pylint: disable=unused-argument):
|
|
||||||
"""
|
|
||||||
Returns a custom class for the collator.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (dict): The global axolotl configuration.
|
|
||||||
is_eval (bool): Whether this is an eval split.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
class: The class for the collator.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
|
||||||
"""
|
|
||||||
Performs actions after the trainer is created.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (dict): The configuration for the plugin.
|
|
||||||
trainer (object): The trainer object for training.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
|
||||||
|
|
||||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
|
||||||
"""
|
|
||||||
Creates and returns an optimizer for training.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
cfg (dict): The configuration for the plugin.
|
|
||||||
trainer (object): The trainer object for training.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
object: The created optimizer.
|
|
||||||
>>>>>>> f8df1563d (collator cls for plugins)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
@@ -355,7 +311,7 @@ def load_plugin(plugin_name: str) -> BasePlugin:
|
|||||||
return plugin
|
return plugin
|
||||||
|
|
||||||
|
|
||||||
class PluginManager:
|
class PluginManager: # pylint: disable=too-many-public-methods
|
||||||
"""The `PluginManager` class is responsible for loading and managing plugins. It
|
"""The `PluginManager` class is responsible for loading and managing plugins. It
|
||||||
should be a singleton so it can be accessed from anywhere in the codebase.
|
should be a singleton so it can be accessed from anywhere in the codebase.
|
||||||
|
|
||||||
@@ -414,8 +370,11 @@ class PluginManager:
|
|||||||
plugin = load_plugin(plugin_name)
|
plugin = load_plugin(plugin_name)
|
||||||
self.plugins[plugin_name] = plugin
|
self.plugins[plugin_name] = plugin
|
||||||
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
LOG.info(f"Plugin loaded successfully: {plugin_name}")
|
||||||
except ImportError:
|
except ImportError as exc:
|
||||||
LOG.error(f"Failed to load plugin: {plugin_name}")
|
LOG.error(f"Failed to load plugin: {plugin_name}")
|
||||||
|
# print stacktrace
|
||||||
|
traceback.print_exc()
|
||||||
|
print(f"Error: {exc}")
|
||||||
|
|
||||||
def get_input_args(self) -> list[str]:
|
def get_input_args(self) -> list[str]:
|
||||||
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
"""Returns a list of Pydantic classes for all registered plugins' input arguments.'
|
||||||
@@ -430,6 +389,21 @@ class PluginManager:
|
|||||||
input_args.append(input_args_from_plugin)
|
input_args.append(input_args_from_plugin)
|
||||||
return input_args
|
return input_args
|
||||||
|
|
||||||
|
def get_training_args_mixin(self):
|
||||||
|
"""
|
||||||
|
Returns a list of dataclasses for all registered plugins' training args mixins'
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list[str]: A list of dataclsses
|
||||||
|
"""
|
||||||
|
training_args = []
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
training_args_from_plugin = plugin.get_training_args_mixin()
|
||||||
|
print(f"Training args from plugin: {plugin.__class__.__name__}")
|
||||||
|
if training_args_from_plugin is not None:
|
||||||
|
training_args.append(training_args_from_plugin)
|
||||||
|
return training_args
|
||||||
|
|
||||||
def load_datasets(
|
def load_datasets(
|
||||||
self, cfg: DictDefault, preprocess: bool = False
|
self, cfg: DictDefault, preprocess: bool = False
|
||||||
) -> Union["TrainDatasetMeta", None]:
|
) -> Union["TrainDatasetMeta", None]:
|
||||||
@@ -519,6 +493,24 @@ class PluginManager:
|
|||||||
return trainer_cls
|
return trainer_cls
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_training_args(self, cfg):
|
||||||
|
"""
|
||||||
|
Calls the get_training_args method of all registered plugins and returns the combined training arguments.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The training arguments
|
||||||
|
"""
|
||||||
|
training_args_kwargs = {}
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
training_args = plugin.get_training_args(cfg)
|
||||||
|
if training_args is not None:
|
||||||
|
training_args_kwargs.update(training_args)
|
||||||
|
|
||||||
|
return training_args_kwargs
|
||||||
|
|
||||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||||
"""
|
"""
|
||||||
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
|
Calls the get_collator_cls_and_kwargs method of all registered plugins and returns the first non-None collator class.
|
||||||
@@ -531,9 +523,7 @@ class PluginManager:
|
|||||||
object: The collator class, or None if none was found.
|
object: The collator class, or None if none was found.
|
||||||
"""
|
"""
|
||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
collator = plugin.get_collator_cls_and_kwargs(
|
collator = plugin.get_collator_cls_and_kwargs(cfg, is_eval=is_eval)
|
||||||
cfg, is_eval=is_eval
|
|
||||||
)
|
|
||||||
if collator is not None:
|
if collator is not None:
|
||||||
collator_cls, collator_kwargs = collator
|
collator_cls, collator_kwargs = collator
|
||||||
return collator_cls, collator_kwargs
|
return collator_cls, collator_kwargs
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ Module to handle merging the plugins' input arguments with the base configuratio
|
|||||||
This was moved here to prevent circular imports.
|
This was moved here to prevent circular imports.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List, Type
|
||||||
|
|
||||||
from axolotl.utils.schemas.config import (
|
from axolotl.utils.schemas.config import (
|
||||||
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
AxolotlConfigWCapabilities as AxolotlConfigWCapabilitiesBase,
|
||||||
@@ -61,3 +61,43 @@ def merge_input_args():
|
|||||||
]
|
]
|
||||||
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
return AxolotlConfigWCapabilities, AxolotlInputConfig
|
||||||
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase
|
||||||
|
|
||||||
|
|
||||||
|
def merge_training_args() -> Type:
|
||||||
|
"""
|
||||||
|
Merges training arguments from registered plugins with the base TrainingArguments.
|
||||||
|
|
||||||
|
This function retrieves the training arguments from registered plugins using the PluginManager.
|
||||||
|
It then dynamically creates new classes, AxolotlTrainingMixins,
|
||||||
|
that inherit from the base configurations and include the training arguments from the plugins.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
tuple: A tuple containing the newly created classes, AxolotlTrainingMixins.
|
||||||
|
"""
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
from axolotl.core.training_args_base import (
|
||||||
|
AxolotlTrainingMixins as AxolotlTrainingMixinsBase,
|
||||||
|
)
|
||||||
|
from axolotl.integrations.base import PluginManager
|
||||||
|
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
training_args_mixins: List[str] = plugin_manager.get_training_args_mixin()
|
||||||
|
mixin_classes = []
|
||||||
|
dynamic_input = ""
|
||||||
|
for plugin_args in training_args_mixins:
|
||||||
|
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
|
||||||
|
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
|
||||||
|
mixin_classes.append(plugin_cls)
|
||||||
|
if dynamic_input:
|
||||||
|
dynamic_input += f"class AxolotlTrainingMixins(AxolotlTrainingMixinsBase, {', '.join(mixin_classes)}):\n pass\n"
|
||||||
|
|
||||||
|
namespace: Dict[Any, Any] = {}
|
||||||
|
local_vars = {"AxolotlTrainingMixinsBase": AxolotlTrainingMixinsBase}
|
||||||
|
exec( # pylint: disable=exec-used # nosec B102
|
||||||
|
dynamic_input, {**globals(), **local_vars}, namespace
|
||||||
|
)
|
||||||
|
AxolotlTrainingMixins = namespace[ # pylint: disable=invalid-name
|
||||||
|
"AxolotlTrainingMixins"
|
||||||
|
]
|
||||||
|
return AxolotlTrainingMixins
|
||||||
|
return AxolotlTrainingMixinsBase
|
||||||
|
|||||||
@@ -15,7 +15,12 @@
|
|||||||
"""
|
"""
|
||||||
Plugin init to add KD support to Axolotl.
|
Plugin init to add KD support to Axolotl.
|
||||||
"""
|
"""
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from transformers import Trainer
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.integrations.kd.callbacks import KDTemperatureSchedulerCallback
|
||||||
|
|
||||||
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
@@ -28,6 +33,9 @@ class KDPlugin(BasePlugin):
|
|||||||
def get_input_args(self):
|
def get_input_args(self):
|
||||||
return "axolotl.integrations.kd.KDArgs"
|
return "axolotl.integrations.kd.KDArgs"
|
||||||
|
|
||||||
|
def get_training_args_mixin(self):
|
||||||
|
return "axolotl.integrations.kd.args.KDTrainingArgsMixin"
|
||||||
|
|
||||||
def get_trainer_cls(self, cfg):
|
def get_trainer_cls(self, cfg):
|
||||||
if cfg.kd_trainer:
|
if cfg.kd_trainer:
|
||||||
from .trainer import AxolotlKDTrainer
|
from .trainer import AxolotlKDTrainer
|
||||||
@@ -35,6 +43,14 @@ class KDPlugin(BasePlugin):
|
|||||||
return AxolotlKDTrainer
|
return AxolotlKDTrainer
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_training_args(self, cfg):
|
||||||
|
return {
|
||||||
|
"kd_ce_alpha": cfg.kd_ce_alpha,
|
||||||
|
"kd_alpha": cfg.kd_alpha,
|
||||||
|
"kd_temperature": cfg.kd_temperature,
|
||||||
|
"kd_beta": cfg.kd_beta,
|
||||||
|
}
|
||||||
|
|
||||||
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
def get_collator_cls_and_kwargs(self, cfg, is_eval=False):
|
||||||
if not cfg.kd_trainer:
|
if not cfg.kd_trainer:
|
||||||
return None, None
|
return None, None
|
||||||
@@ -66,3 +82,24 @@ class KDPlugin(BasePlugin):
|
|||||||
from .kernels.models import apply_kernel
|
from .kernels.models import apply_kernel
|
||||||
|
|
||||||
apply_kernel(cfg.model_config_type)
|
apply_kernel(cfg.model_config_type)
|
||||||
|
|
||||||
|
def add_callbacks_post_trainer(self, cfg: Any, trainer: Trainer) -> list:
|
||||||
|
"""
|
||||||
|
Adds temp scheduler callback to the Trainer instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (Any): Configuration object containing the sparse recipe.
|
||||||
|
trainer (Trainer): Huggingface Trainer instance.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: List containing the configured callback instances.
|
||||||
|
"""
|
||||||
|
if cfg.kd_temperature_min is not None and cfg.kd_online_server_base_url:
|
||||||
|
callback = KDTemperatureSchedulerCallback(
|
||||||
|
cfg.kd_temperature,
|
||||||
|
cfg.kd_temperature_min,
|
||||||
|
trainer,
|
||||||
|
)
|
||||||
|
return [callback]
|
||||||
|
|
||||||
|
return []
|
||||||
|
|||||||
@@ -15,14 +15,19 @@
|
|||||||
"""
|
"""
|
||||||
Plugin args for KD support.
|
Plugin args for KD support.
|
||||||
"""
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
class InferenceServerType(str, Enum):
|
class InferenceServerType(str, Enum):
|
||||||
vllm = "vllm"
|
"""
|
||||||
sglang = "sglang"
|
Online inferences server types to handle different request args
|
||||||
|
"""
|
||||||
|
|
||||||
|
vllm = "vllm" # pylint: disable=invalid-name
|
||||||
|
sglang = "sglang" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class KDArgs(BaseModel):
|
class KDArgs(BaseModel):
|
||||||
@@ -36,9 +41,29 @@ class KDArgs(BaseModel):
|
|||||||
)
|
)
|
||||||
kd_alpha: float | None = None # loss coefficient for KD loss
|
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||||
kd_temperature: float | None = None # temperature for sampling during KD
|
kd_temperature: float | None = None # temperature for sampling during KD
|
||||||
|
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||||
|
|
||||||
# TODO online kd
|
# TODO online kd
|
||||||
kd_online_server_base_url: str | None = None
|
kd_online_server_base_url: str | None = None
|
||||||
kd_online_topk: int | None = None
|
kd_online_topk: int | None = None
|
||||||
kd_online_server: InferenceServerType | None = "vllm"
|
kd_online_server: InferenceServerType | None = Field(
|
||||||
|
default_factory=lambda: InferenceServerType.vllm
|
||||||
|
)
|
||||||
kd_online_timeout: int | None = 120
|
kd_online_timeout: int | None = 120
|
||||||
|
kd_temperature_min: float | None = (
|
||||||
|
None # kd temperature scheduling during online kd
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class KDTrainingArgsMixin:
|
||||||
|
"""
|
||||||
|
Additional args for KD training.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kd_ce_alpha: float | None = (
|
||||||
|
None # loss coefficient for cross-entropy loss during KD
|
||||||
|
)
|
||||||
|
kd_alpha: float | None = None # loss coefficient for KD loss
|
||||||
|
kd_temperature: float | None = None # temperature for sampling during KD
|
||||||
|
kd_beta: float | None = None # beta coefficient for ratio of fwd and reverse KL
|
||||||
|
|||||||
36
src/axolotl/integrations/kd/callbacks.py
Normal file
36
src/axolotl/integrations/kd/callbacks.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
"""
|
||||||
|
Transformers trainer callbacks to schedule the KD temperature during training
|
||||||
|
"""
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
from transformers.trainer_callback import CallbackHandler
|
||||||
|
|
||||||
|
|
||||||
|
class KDTemperatureSchedulerCallback(CallbackHandler):
|
||||||
|
"""
|
||||||
|
KD temperature scheduler callback for the trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, temperature_start, temperature_min, trainer):
|
||||||
|
self.temperature_start = temperature_start
|
||||||
|
self.temperature_min = temperature_min
|
||||||
|
self.temperature = temperature_start
|
||||||
|
|
||||||
|
self.trainer = trainer
|
||||||
|
|
||||||
|
def on_step_end(
|
||||||
|
self, args, state, control, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
# cosine decay temperature over the max steps
|
||||||
|
|
||||||
|
progress = state.global_step / state.max_steps
|
||||||
|
# Cosine decay factor: 0.5 * (1 + cos(pi * progress))
|
||||||
|
# This factor goes from 1 (at progress=0) to 0 (at progress=1)
|
||||||
|
decay_factor = 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||||
|
self.temperature = self.temperature_start - (
|
||||||
|
(self.temperature_start - self.temperature_min) * decay_factor
|
||||||
|
)
|
||||||
|
|
||||||
|
if hasattr(self.trainer.data_collator, "kd_temperature"):
|
||||||
|
self.trainer.data_collator.kd_temperature = self.temperature
|
||||||
@@ -2,12 +2,14 @@
|
|||||||
Packed data loader for online teacher training supporting vllm and sglang.
|
Packed data loader for online teacher training supporting vllm and sglang.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
import logging
|
import logging
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import pandas as pd
|
|
||||||
import requests
|
import requests
|
||||||
import torch
|
import torch
|
||||||
|
from orjson import orjson
|
||||||
|
|
||||||
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
from axolotl.integrations.kd.collator import KDBatchSamplerDataCollatorForSeq2Seq
|
||||||
from axolotl.utils.data.utils import retry_on_request_exceptions
|
from axolotl.utils.data.utils import retry_on_request_exceptions
|
||||||
@@ -15,6 +17,31 @@ from axolotl.utils.data.utils import retry_on_request_exceptions
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def hmac_sha_from_int_list(int_list, key, hash_func=hashlib.sha256):
|
||||||
|
"""
|
||||||
|
Create HMAC-SHA hash from a list of integers
|
||||||
|
|
||||||
|
Args:
|
||||||
|
int_list: List of integers
|
||||||
|
key: Secret key (string or bytes)
|
||||||
|
hash_func: Hash function (default: sha256)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
HMAC digest as hex string
|
||||||
|
"""
|
||||||
|
# Convert key to bytes if it's a string
|
||||||
|
if isinstance(key, str):
|
||||||
|
key = key.encode("utf-8")
|
||||||
|
|
||||||
|
# Convert list of ints to bytes
|
||||||
|
# Method 1: Convert each int to bytes and concatenate
|
||||||
|
data = b"".join(i.to_bytes(4, byteorder="big") for i in int_list)
|
||||||
|
|
||||||
|
# Create HMAC
|
||||||
|
h = hmac.new(key, data, hash_func)
|
||||||
|
return h.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
||||||
"""
|
"""
|
||||||
Collator for online teacher training.
|
Collator for online teacher training.
|
||||||
@@ -30,6 +57,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
kd_temperature: Optional[float] = 1.0,
|
kd_temperature: Optional[float] = 1.0,
|
||||||
kd_online_server: Optional[str] = "vllm",
|
kd_online_server: Optional[str] = "vllm",
|
||||||
kd_online_timeout: Optional[int] = 120,
|
kd_online_timeout: Optional[int] = 120,
|
||||||
|
kd_cache_dir: Optional[str] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
):
|
):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
@@ -49,6 +77,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
self.kd_online_server = kd_online_server
|
self.kd_online_server = kd_online_server
|
||||||
self.http_session = requests.Session()
|
self.http_session = requests.Session()
|
||||||
self.kd_online_timeout = kd_online_timeout
|
self.kd_online_timeout = kd_online_timeout
|
||||||
|
self.kd_cache_dir = kd_cache_dir
|
||||||
|
|
||||||
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
def _normalize_logprobs(self, raw_logprobs: List[float]) -> List[float]:
|
||||||
"""
|
"""
|
||||||
@@ -109,7 +138,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
return final_logprobs_tensor.tolist()
|
return final_logprobs_tensor.tolist()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e: # pylint: disable=broad-exception-caught
|
||||||
LOG.error(
|
LOG.error(
|
||||||
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
|
f"Error during online logprob scaling: {e}. Returning raw logprobs.",
|
||||||
exc_info=True,
|
exc_info=True,
|
||||||
@@ -142,11 +171,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Initialize with empty lists, so if API call fails, these are returned.
|
# Initialize with empty lists, so if API call fails, these are returned.
|
||||||
ret_logprobs_data = {
|
ret_data_target_token_ids: List[List[List[int]]] = []
|
||||||
"target_token_ids": [],
|
ret_data_target_logprobs: List[List[List[float]]] = []
|
||||||
"target_logprobs": [],
|
ret_data_target_mask: List[List[List[int]]] = []
|
||||||
"target_mask": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.http_session.post(
|
response = self.http_session.post(
|
||||||
@@ -162,7 +189,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
||||||
)
|
)
|
||||||
# Return empty data; items processed later will get default empty KD fields
|
# Return empty data; items processed later will get default empty KD fields
|
||||||
return ret_logprobs_data
|
return {
|
||||||
|
"target_token_ids": ret_data_target_token_ids,
|
||||||
|
"target_logprobs": ret_data_target_logprobs,
|
||||||
|
"target_mask": ret_data_target_mask,
|
||||||
|
}
|
||||||
|
|
||||||
for sequence_data, seq_input_ids, seq_labels in zip(
|
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||||
api_data, batch_input_ids, labels
|
api_data, batch_input_ids, labels
|
||||||
@@ -185,7 +216,7 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
# basic check that the logprob data len matches the input len, so no need to handle padding
|
# basic check that the logprob data len matches the input len, so no need to handle padding
|
||||||
assert len(seq_input_ids) == len(input_top_logprobs)
|
assert len(seq_input_ids) == len(input_top_logprobs)
|
||||||
|
|
||||||
for i, input_id, label in zip(
|
for i, _, label in zip(
|
||||||
range(len(seq_input_ids)), seq_input_ids, seq_labels
|
range(len(seq_input_ids)), seq_input_ids, seq_labels
|
||||||
):
|
):
|
||||||
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
if i < len(input_top_logprobs) and input_top_logprobs[i] is None:
|
||||||
@@ -254,9 +285,9 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
current_target_token_ids.append([0] * self.kd_online_topk)
|
current_target_token_ids.append([0] * self.kd_online_topk)
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
|
||||||
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
ret_data_target_token_ids.append(current_target_token_ids)
|
||||||
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
ret_data_target_logprobs.append(current_target_logprobs)
|
||||||
ret_logprobs_data["target_mask"].append(current_target_mask)
|
ret_data_target_mask.append(current_target_mask)
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||||
@@ -269,7 +300,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return ret_logprobs_data
|
return {
|
||||||
|
"target_token_ids": ret_data_target_token_ids,
|
||||||
|
"target_logprobs": ret_data_target_logprobs,
|
||||||
|
"target_mask": ret_data_target_mask,
|
||||||
|
}
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=10, delay=5)
|
@retry_on_request_exceptions(max_retries=10, delay=5)
|
||||||
def fetch_online_logprobs_vllm(
|
def fetch_online_logprobs_vllm(
|
||||||
@@ -296,18 +331,19 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
}
|
}
|
||||||
|
|
||||||
# Initialize with empty lists, so if API call fails, these are returned.
|
# Initialize with empty lists, so if API call fails, these are returned.
|
||||||
ret_logprobs_data = {
|
ret_data_target_token_ids: List[List[List[int]]] = []
|
||||||
"target_token_ids": [],
|
ret_data_target_logprobs: List[List[List[float]]] = []
|
||||||
"target_logprobs": [],
|
ret_data_target_mask: List[List[List[int]]] = []
|
||||||
"target_mask": [],
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = self.http_session.post(
|
response = self.http_session.post(
|
||||||
api_endpoint, json=payload, timeout=self.kd_online_timeout
|
api_endpoint,
|
||||||
|
json=payload,
|
||||||
|
timeout=self.kd_online_timeout,
|
||||||
|
# json_decoder=orjson.loads,
|
||||||
)
|
)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
api_data: dict = response.json()
|
api_data: dict = orjson.loads(response.content)
|
||||||
choices: list[dict] = api_data["choices"]
|
choices: list[dict] = api_data["choices"]
|
||||||
|
|
||||||
# Ensure api_data is a list, and its length matches batch_input_ids
|
# Ensure api_data is a list, and its length matches batch_input_ids
|
||||||
@@ -317,7 +353,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
f"items, got {type(api_data)} with length {len(api_data) if isinstance(api_data, list) else 'N/A'}."
|
||||||
)
|
)
|
||||||
# Return empty data; items processed later will get default empty KD fields
|
# Return empty data; items processed later will get default empty KD fields
|
||||||
return ret_logprobs_data
|
return {
|
||||||
|
"target_token_ids": ret_data_target_token_ids,
|
||||||
|
"target_logprobs": ret_data_target_logprobs,
|
||||||
|
"target_mask": ret_data_target_mask,
|
||||||
|
}
|
||||||
|
|
||||||
for sequence_data, seq_input_ids, seq_labels in zip(
|
for sequence_data, seq_input_ids, seq_labels in zip(
|
||||||
choices, batch_input_ids, labels
|
choices, batch_input_ids, labels
|
||||||
@@ -330,29 +370,10 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
current_target_mask = []
|
current_target_mask = []
|
||||||
|
|
||||||
# Ensure input_top_logprobs is a list
|
# Ensure input_top_logprobs is a list
|
||||||
input_top_logprobs: Optional[list[None | list[tuple]]] = (
|
input_top_logprobs: Optional[list[None | dict[str, dict]]] = (
|
||||||
sequence_data.pop("prompt_logprobs", [])
|
sequence_data.pop("prompt_logprobs", [])
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
vllm api data for prompt logprobs looks like:
|
|
||||||
"prompt_logprobs": [
|
|
||||||
null, # first token is always null
|
|
||||||
{ # second token logprobs
|
|
||||||
"8948": { # token ID
|
|
||||||
"logprob": -2.3841830625315197e-06,
|
|
||||||
"rank": 1,
|
|
||||||
"decoded_token": "system"
|
|
||||||
},
|
|
||||||
"1849": { # token ID
|
|
||||||
"logprob": -13.187501907348633,
|
|
||||||
"rank": 2,
|
|
||||||
"decoded_token": "Ġsystem"
|
|
||||||
},
|
|
||||||
... # rest of the top-k tokens/logprobs
|
|
||||||
},
|
|
||||||
... # more tokens
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
if not isinstance(input_top_logprobs, list):
|
if not isinstance(input_top_logprobs, list):
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
f"Received non-list input_top_logprobs: {input_top_logprobs}. Skipping sequence."
|
||||||
@@ -369,11 +390,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
# this is always the case for the first token.
|
# this is always the case for the first token.
|
||||||
# there is never logprob data for the first token since that's a true input
|
# there is never logprob data for the first token since that's a true input
|
||||||
continue
|
continue
|
||||||
elif (
|
if (
|
||||||
i < len(input_top_logprobs)
|
i < len(input_top_logprobs)
|
||||||
and input_top_logprobs[i] is not None
|
and input_top_logprobs[i] is not None
|
||||||
):
|
):
|
||||||
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i]
|
pos_top_logprobs_data: dict[str, dict] = input_top_logprobs[i] # type: ignore[assignment]
|
||||||
# Ensure pos_top_logprobs_data is a list of lists as expected
|
# Ensure pos_top_logprobs_data is a list of lists as expected
|
||||||
if not (
|
if not (
|
||||||
isinstance(pos_top_logprobs_data, dict)
|
isinstance(pos_top_logprobs_data, dict)
|
||||||
@@ -396,9 +417,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
# pos_top_logprobs: list of logprobs, pos_token_ids: list of token_ids
|
||||||
pos_token_ids = pos_top_logprobs_data.keys()
|
pos_token_ids_str = list(pos_top_logprobs_data.keys())
|
||||||
pos_logprobs_dict = pos_top_logprobs_data.values()
|
pos_logprobs_dict = pos_top_logprobs_data.values()
|
||||||
pos_token_ids = [int(token_id) for token_id in pos_token_ids]
|
pos_token_ids = [
|
||||||
|
int(token_id) for token_id in pos_token_ids_str
|
||||||
|
]
|
||||||
pos_logprobs_raw = [
|
pos_logprobs_raw = [
|
||||||
float(logprob.get("logprob", -float("inf")))
|
float(logprob.get("logprob", -float("inf")))
|
||||||
for logprob in pos_logprobs_dict
|
for logprob in pos_logprobs_dict
|
||||||
@@ -446,17 +469,18 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
current_target_token_ids.append(list(range(self.kd_online_topk)))
|
||||||
current_target_mask.append([0] * self.kd_online_topk)
|
current_target_mask.append([0] * self.kd_online_topk)
|
||||||
|
|
||||||
ret_logprobs_data["target_token_ids"].append(current_target_token_ids)
|
ret_data_target_token_ids.append(current_target_token_ids)
|
||||||
ret_logprobs_data["target_logprobs"].append(current_target_logprobs)
|
ret_data_target_logprobs.append(current_target_logprobs)
|
||||||
ret_logprobs_data["target_mask"].append(current_target_mask)
|
ret_data_target_mask.append(current_target_mask)
|
||||||
|
|
||||||
# TODO save and load targets to disk for caching for next epoch
|
# TODO save and load targets to disk for caching for next epoch
|
||||||
# generate a hash over seq_input_ids and convert it to an int
|
# generate a hmac SHA256 hash over the list seq_input_ids and convert it to an int
|
||||||
# hash_input_ids: int = hash(tuple(seq_input_ids))
|
# if self.kd_cache_dir:
|
||||||
# with open(f"/tmp/target_logprobs_{hash_input_ids}.parquet", "wb") as f:
|
# hash_input_ids = hmac_sha_from_int_list(
|
||||||
# pd.DataFrame(current_target_logprobs).to_parquet(f, index=False)
|
# seq_input_ids, f"{self.kd_online_server_base_url}:{self.kd_online_topk}"
|
||||||
# with open(f"/tmp/target_token_ids_{hash_input_ids}.parquet", "wb") as f:
|
# )
|
||||||
# pd.DataFrame(current_target_token_ids).to_parquet(f, index=False)
|
# with open(f"{self.kd_cache_dir}/{hash_input_ids}.parquet", "wb") as f:
|
||||||
|
# pd.DataFrame(ret_logprobs_data).to_parquet(f, index=False)
|
||||||
|
|
||||||
except requests.exceptions.RequestException as e:
|
except requests.exceptions.RequestException as e:
|
||||||
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
LOG.error(f"Error fetching logprobs from online teacher: {e}")
|
||||||
@@ -469,7 +493,11 @@ class OnlineTeacherCollator(KDBatchSamplerDataCollatorForSeq2Seq):
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
return ret_logprobs_data
|
return {
|
||||||
|
"target_token_ids": ret_data_target_token_ids,
|
||||||
|
"target_logprobs": ret_data_target_logprobs,
|
||||||
|
"target_mask": ret_data_target_mask,
|
||||||
|
}
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
|
self, features: List[List[Dict[str, Any]]], return_tensors: Optional[str] = None
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
|
target_token_ids_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||||
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
target_logprobs_chunk: torch.Tensor, # [chunk_size, top_k], already temp-scaled and normalized logprobs
|
||||||
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
target_mask_chunk: torch.Tensor, # [chunk_size, top_k]
|
||||||
|
beta: float = 0.0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Compute Top-K KL divergence loss for a chunk.
|
Compute Top-K KL divergence loss for a chunk.
|
||||||
@@ -28,6 +29,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
|
target_token_ids_chunk: Top-k teacher token IDs. Shape: (N, K).
|
||||||
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
|
target_logprobs_chunk: Top-k teacher log probabilities (temp-scaled, normalized). Shape: (N, K).
|
||||||
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
|
target_mask_chunk: Mask for valid top-k tokens. Shape: (N, K).
|
||||||
|
beta: Controls the type of KL divergence.
|
||||||
|
0.0 for Forward KL (P_teacher || P_student).
|
||||||
|
1.0 for Reverse KL (P_student || P_teacher).
|
||||||
|
0.5 for Symmetric KL (average of Forward and Reverse).
|
||||||
Returns:
|
Returns:
|
||||||
Sum of KL divergence losses for the chunk.
|
Sum of KL divergence losses for the chunk.
|
||||||
"""
|
"""
|
||||||
@@ -59,6 +64,10 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
# Teacher probabilities P(y|x_teacher) from logprobs
|
# Teacher probabilities P(y|x_teacher) from logprobs
|
||||||
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
# target_logprobs_valid are already normalized (log(softmax(teacher_logits/T)))
|
||||||
teacher_probs_valid = target_logprobs_valid.exp()
|
teacher_probs_valid = target_logprobs_valid.exp()
|
||||||
|
# Student probabilities P_student from log P_student
|
||||||
|
student_probs_topk_valid = student_logprobs_topk_valid.exp()
|
||||||
|
|
||||||
|
kd_loss_per_token = torch.zeros_like(target_logprobs_valid)
|
||||||
|
|
||||||
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
# KL divergence: sum(P_teacher * (log P_teacher - log P_student))
|
||||||
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
# = sum(P_teacher * log P_teacher) - sum(P_teacher * log P_student)
|
||||||
@@ -66,9 +75,17 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
# or as sum(P_teacher * (log_softmax_teacher - log_softmax_student))
|
||||||
# Here, target_logprobs_valid are log_softmax_teacher.
|
# Here, target_logprobs_valid are log_softmax_teacher.
|
||||||
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
# student_logprobs_topk_valid are log_softmax_student (for the selected K indices).
|
||||||
kd_loss_per_token = teacher_probs_valid * (
|
if beta < 1.0: # Contribution from Forward KL
|
||||||
target_logprobs_valid - student_logprobs_topk_valid
|
fwd_kl_per_token = teacher_probs_valid * (
|
||||||
)
|
target_logprobs_valid - student_logprobs_topk_valid
|
||||||
|
)
|
||||||
|
kd_loss_per_token += (1.0 - beta) * fwd_kl_per_token
|
||||||
|
if beta > 0.0: # Contribution from Reverse KL
|
||||||
|
rev_kl_per_token = student_probs_topk_valid * (
|
||||||
|
student_logprobs_topk_valid - target_logprobs_valid
|
||||||
|
)
|
||||||
|
kd_loss_per_token += beta * rev_kl_per_token
|
||||||
|
|
||||||
kd_loss = kd_loss_per_token.sum()
|
kd_loss = kd_loss_per_token.sum()
|
||||||
|
|
||||||
return kd_loss
|
return kd_loss
|
||||||
@@ -91,6 +108,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
weight_soft_loss: float = 0.5,
|
weight_soft_loss: float = 0.5,
|
||||||
compute_ce_loss: bool = True,
|
compute_ce_loss: bool = True,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
|
beta: float = 0.0,
|
||||||
):
|
):
|
||||||
# Compute student logits for the chunk from hidden states and LM head
|
# Compute student logits for the chunk from hidden states and LM head
|
||||||
# student_input_chunk: [chunk_size, hidden_dim]
|
# student_input_chunk: [chunk_size, hidden_dim]
|
||||||
@@ -125,10 +143,9 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
target_token_ids_chunk,
|
target_token_ids_chunk,
|
||||||
target_logprobs_chunk,
|
target_logprobs_chunk,
|
||||||
target_mask_chunk,
|
target_mask_chunk,
|
||||||
|
beta=beta,
|
||||||
)
|
)
|
||||||
|
|
||||||
loss = weight_hard_loss * ce_loss + weight_soft_loss * soft_loss
|
|
||||||
|
|
||||||
return soft_loss, ce_loss
|
return soft_loss, ce_loss
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@@ -146,6 +163,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
weight_soft_loss: float = 0.5,
|
weight_soft_loss: float = 0.5,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
|
beta: float = 0.0,
|
||||||
compiled: bool = False,
|
compiled: bool = False,
|
||||||
chunk_size: int = 1024,
|
chunk_size: int = 1024,
|
||||||
compute_ce_loss: bool = True,
|
compute_ce_loss: bool = True,
|
||||||
@@ -192,6 +210,7 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
weight_soft_loss=weight_soft_loss,
|
weight_soft_loss=weight_soft_loss,
|
||||||
compute_ce_loss=compute_ce_loss,
|
compute_ce_loss=compute_ce_loss,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
|
beta=beta,
|
||||||
)
|
)
|
||||||
|
|
||||||
def accumulate_chunk_grads(
|
def accumulate_chunk_grads(
|
||||||
@@ -288,13 +307,13 @@ class LigerFusedLinearKLTopKLogprobFunction(LigerFusedLinearDistillationBase):
|
|||||||
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
ctx.save_for_backward(grad_inputs_combined, grad_weight_acc, grad_bias_acc)
|
||||||
|
|
||||||
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
# For matching None returns in backward for non-tensor/non-grad_requiring inputs
|
||||||
ctx.hyperparams_count = 7 # Corresponds to number of hyperparams after main tensors in fwd signature
|
ctx.hyperparams_count = 8 # Corresponds to number of hyperparams after main tensors in fwd signature
|
||||||
ctx.bias_was_none = student_lm_head_bias is None
|
ctx.bias_was_none = student_lm_head_bias is None
|
||||||
ctx.orig_dims = (B, N, D, K)
|
ctx.orig_dims = (B, N, D, K)
|
||||||
|
|
||||||
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulatedsum
|
# since this is packed, there is simply a single batch, so batchmean reduciton of kl-div is simply the accumulated sum
|
||||||
# we still need to scale the kd_loss by the temp
|
# we still need to scale the kd_loss by the temp^2
|
||||||
kd_loss_acc = kd_loss_acc * (temperature ** 2)
|
kd_loss_acc = kd_loss_acc * (temperature**2)
|
||||||
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
final_loss = weight_soft_loss * kd_loss_acc + weight_hard_loss * ce_loss_acc
|
||||||
|
|
||||||
return final_loss
|
return final_loss
|
||||||
@@ -373,6 +392,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
|||||||
weight_hard_loss: float = 0.5,
|
weight_hard_loss: float = 0.5,
|
||||||
weight_soft_loss: float = 0.5,
|
weight_soft_loss: float = 0.5,
|
||||||
temperature: float = 1.0, # This is the kd_temperature
|
temperature: float = 1.0, # This is the kd_temperature
|
||||||
|
beta: float = 1.0,
|
||||||
ignore_index: int = -100,
|
ignore_index: int = -100,
|
||||||
compiled: bool = True,
|
compiled: bool = True,
|
||||||
chunk_size: int = 1024,
|
chunk_size: int = 1024,
|
||||||
@@ -387,6 +407,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
|||||||
self.weight_hard_loss = weight_hard_loss
|
self.weight_hard_loss = weight_hard_loss
|
||||||
self.weight_soft_loss = weight_soft_loss
|
self.weight_soft_loss = weight_soft_loss
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
|
self.beta = beta
|
||||||
self.ignore_index = ignore_index
|
self.ignore_index = ignore_index
|
||||||
self.compiled = compiled
|
self.compiled = compiled
|
||||||
self.chunk_size = chunk_size
|
self.chunk_size = chunk_size
|
||||||
@@ -424,6 +445,7 @@ class LigerFusedLinearKLTopKLogprobLoss(torch.nn.Module):
|
|||||||
self.weight_soft_loss,
|
self.weight_soft_loss,
|
||||||
self.ignore_index,
|
self.ignore_index,
|
||||||
self.temperature,
|
self.temperature,
|
||||||
|
self.beta,
|
||||||
self.compiled,
|
self.compiled,
|
||||||
self.chunk_size,
|
self.chunk_size,
|
||||||
self.compute_ce_loss,
|
self.compute_ce_loss,
|
||||||
|
|||||||
@@ -75,8 +75,8 @@ def kldiv_forward_llama_like(
|
|||||||
target_mask,
|
target_mask,
|
||||||
true_labels=labels,
|
true_labels=labels,
|
||||||
)
|
)
|
||||||
num_items_in_batch = kwargs.pop("num_items_in_batch", None)
|
num_items_in_batch = kwargs.pop("num_items_in_batch", -1)
|
||||||
if num_items_in_batch is not None:
|
if num_items_in_batch is not None and num_items_in_batch > 0:
|
||||||
loss = loss / num_items_in_batch
|
loss = loss / num_items_in_batch
|
||||||
|
|
||||||
return CausalLMOutputWithPast(
|
return CausalLMOutputWithPast(
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ class AxolotlKDTrainer(AxolotlTrainer):
|
|||||||
self.args.kd_ce_alpha, # hard label loss
|
self.args.kd_ce_alpha, # hard label loss
|
||||||
self.args.kd_alpha, # kd loss
|
self.args.kd_alpha, # kd loss
|
||||||
self.args.kd_temperature,
|
self.args.kd_temperature,
|
||||||
|
self.args.kd_beta,
|
||||||
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
compute_ce_loss=bool(self.args.kd_ce_alpha),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,13 @@
|
|||||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
|
import typing
|
||||||
import weakref
|
import weakref
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -25,7 +28,6 @@ from axolotl.common.datasets import TrainDatasetMeta
|
|||||||
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
from axolotl.contribs.lgpl import ( # pylint: disable = no-name-in-module
|
||||||
fix_untrained_tokens,
|
fix_untrained_tokens,
|
||||||
)
|
)
|
||||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders import (
|
from axolotl.loaders import (
|
||||||
ModelLoader,
|
ModelLoader,
|
||||||
@@ -45,6 +47,9 @@ try:
|
|||||||
except ImportError:
|
except ImportError:
|
||||||
BetterTransformer = None
|
BetterTransformer = None
|
||||||
|
|
||||||
|
if typing.TYPE_CHECKING:
|
||||||
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -472,7 +477,7 @@ def handle_untrained_tokens_fix(
|
|||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||||
HFRLTrainerBuilder | HFCausalTrainerBuilder,
|
"HFRLTrainerBuilder" | "HFCausalTrainerBuilder",
|
||||||
PeftModel | PreTrainedModel,
|
PeftModel | PreTrainedModel,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
PeftConfig | None,
|
PeftConfig | None,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -16,7 +16,6 @@ from datasets import IterableDataset, disable_caching, enable_caching
|
|||||||
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
from axolotl.core.builders import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
|
||||||
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
from axolotl.monkeypatch.trainer_eval_guard import patch_evaluation_loop_for_fsdp2
|
||||||
from axolotl.utils.distributed import reduce_and_broadcast
|
from axolotl.utils.distributed import reduce_and_broadcast
|
||||||
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
from axolotl.utils.environment import check_cuda_p2p_ib_support
|
||||||
@@ -629,6 +628,8 @@ def setup_trainer(
|
|||||||
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
A trainer instance (either `HFRLTrainer` or `HFCausalTrainer`) configured based
|
||||||
on the provided parameters.
|
on the provided parameters.
|
||||||
"""
|
"""
|
||||||
|
from axolotl.core.trainer_builder import HFCausalTrainerBuilder, HFRLTrainerBuilder
|
||||||
|
|
||||||
if (
|
if (
|
||||||
cfg.torch_compile
|
cfg.torch_compile
|
||||||
and cfg.fsdp_config
|
and cfg.fsdp_config
|
||||||
|
|||||||
Reference in New Issue
Block a user