diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 7eadd3e59..e78b16e65 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -47,7 +47,6 @@ from axolotl.core.trainers.base import ( AxolotlTrainer, ReLoRATrainer, ) -from axolotl.core.trainers.kd import AxolotlKDTrainer from axolotl.core.training_args import ( AxolotlCPOConfig, AxolotlDPOConfig, @@ -77,7 +76,6 @@ from axolotl.utils.callbacks.profiler import PytorchProfilerCallback from axolotl.utils.chat_templates import get_chat_template_from_config from axolotl.utils.collators import ( BatchSamplerDataCollatorForSeq2Seq, - DataCollatorForKD, DataCollatorForSeq2Seq, MambaDataCollator, V2BatchSamplerDataCollatorForSeq2Seq, @@ -306,8 +304,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlMambaTrainer if self.cfg.reward_model: return AxolotlRewardTrainer - if self.cfg.trainer == "kd": - return AxolotlKDTrainer return AxolotlTrainer def build(self, total_num_steps): @@ -797,7 +793,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): Union[ V2BatchSamplerDataCollatorForSeq2Seq, BatchSamplerDataCollatorForSeq2Seq, - DataCollatorForKD, DataCollatorForSeq2Seq, DataCollatorWithFlattening, RewardDataCollatorWithPadding, @@ -828,7 +823,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): collator_args.pop(0) kwargs.pop("pad_to_multiple_of", None) kwargs.pop("padding", None) - elif self.cfg.trainer == "kd": + elif self.cfg.kd_trainer: + from axolotl.integrations.kd.collator import DataCollatorForKD + collator = DataCollatorForKD else: collator = DataCollatorForSeq2Seq diff --git a/src/axolotl/integrations/kd/__init__.py b/src/axolotl/integrations/kd/__init__.py index e69de29bb..a3f7c6036 100644 --- a/src/axolotl/integrations/kd/__init__.py +++ b/src/axolotl/integrations/kd/__init__.py @@ -0,0 +1,22 @@ +""" +Plugin init to add KD support to Axolotl. +""" +from axolotl.integrations.base import BasePlugin + +from .args import KDArgs # pylint: disable=unused-import. # noqa: F401 + + +class KDPlugin(BasePlugin): + """ + Plugin for KD support in Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.kd.KDArgs" + + def get_trainer_cls(self, cfg): + if cfg.kd_trainer: + from .trainer import AxolotlKDTrainer + + return AxolotlKDTrainer + return None diff --git a/src/axolotl/integrations/kd/args.py b/src/axolotl/integrations/kd/args.py new file mode 100644 index 000000000..78ecacab6 --- /dev/null +++ b/src/axolotl/integrations/kd/args.py @@ -0,0 +1,19 @@ +""" +Plugin args for KD support. +""" +from typing import Optional + +from pydantic import BaseModel + + +class KDArgs(BaseModel): + """ + Input args for knowledge distillation. + """ + + kd_trainer: Optional[bool] = None # whether to use KD trainer + kd_ce_alpha: Optional[ + float + ] = None # loss coefficient for cross-entropy loss during KD + kd_alpha: Optional[float] = None # loss coefficient for KD loss + kd_temperature: Optional[float] = None # temperature for sampling during KD diff --git a/src/axolotl/utils/collators/kd.py b/src/axolotl/integrations/kd/collator.py similarity index 100% rename from src/axolotl/utils/collators/kd.py rename to src/axolotl/integrations/kd/collator.py diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3a3923565..90d8a8bc0 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -615,13 +615,6 @@ class AxolotlInputConfig( bool ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. - trainer: Optional[Literal["kd"]] = None - kd_ce_alpha: Optional[ - float - ] = None # loss coefficient for cross-entropy loss during KD - kd_alpha: Optional[float] = None # loss coefficient for KD loss - kd_temperature: Optional[float] = None # temperature for sampling during KD - datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore shuffle_merged_datasets: Optional[bool] = True