move more things to kd plugin

This commit is contained in:
Wing Lian
2024-12-30 13:15:28 -05:00
parent 3da6a652fa
commit 06370b386a
5 changed files with 44 additions and 13 deletions

View File

@@ -47,7 +47,6 @@ from axolotl.core.trainers.base import (
AxolotlTrainer,
ReLoRATrainer,
)
from axolotl.core.trainers.kd import AxolotlKDTrainer
from axolotl.core.training_args import (
AxolotlCPOConfig,
AxolotlDPOConfig,
@@ -77,7 +76,6 @@ from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
from axolotl.utils.chat_templates import get_chat_template_from_config
from axolotl.utils.collators import (
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForKD,
DataCollatorForSeq2Seq,
MambaDataCollator,
V2BatchSamplerDataCollatorForSeq2Seq,
@@ -306,8 +304,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlMambaTrainer
if self.cfg.reward_model:
return AxolotlRewardTrainer
if self.cfg.trainer == "kd":
return AxolotlKDTrainer
return AxolotlTrainer
def build(self, total_num_steps):
@@ -797,7 +793,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
Union[
V2BatchSamplerDataCollatorForSeq2Seq,
BatchSamplerDataCollatorForSeq2Seq,
DataCollatorForKD,
DataCollatorForSeq2Seq,
DataCollatorWithFlattening,
RewardDataCollatorWithPadding,
@@ -828,7 +823,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
collator_args.pop(0)
kwargs.pop("pad_to_multiple_of", None)
kwargs.pop("padding", None)
elif self.cfg.trainer == "kd":
elif self.cfg.kd_trainer:
from axolotl.integrations.kd.collator import DataCollatorForKD
collator = DataCollatorForKD
else:
collator = DataCollatorForSeq2Seq

View File

@@ -0,0 +1,22 @@
"""
Plugin init to add KD support to Axolotl.
"""
from axolotl.integrations.base import BasePlugin
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
class KDPlugin(BasePlugin):
"""
Plugin for KD support in Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.kd.KDArgs"
def get_trainer_cls(self, cfg):
if cfg.kd_trainer:
from .trainer import AxolotlKDTrainer
return AxolotlKDTrainer
return None

View File

@@ -0,0 +1,19 @@
"""
Plugin args for KD support.
"""
from typing import Optional
from pydantic import BaseModel
class KDArgs(BaseModel):
"""
Input args for knowledge distillation.
"""
kd_trainer: Optional[bool] = None # whether to use KD trainer
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD

View File

@@ -615,13 +615,6 @@ class AxolotlInputConfig(
bool
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
trainer: Optional[Literal["kd"]] = None
kd_ce_alpha: Optional[
float
] = None # loss coefficient for cross-entropy loss during KD
kd_alpha: Optional[float] = None # loss coefficient for KD loss
kd_temperature: Optional[float] = None # temperature for sampling during KD
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
shuffle_merged_datasets: Optional[bool] = True