move more things to kd plugin
This commit is contained in:
@@ -47,7 +47,6 @@ from axolotl.core.trainers.base import (
|
|||||||
AxolotlTrainer,
|
AxolotlTrainer,
|
||||||
ReLoRATrainer,
|
ReLoRATrainer,
|
||||||
)
|
)
|
||||||
from axolotl.core.trainers.kd import AxolotlKDTrainer
|
|
||||||
from axolotl.core.training_args import (
|
from axolotl.core.training_args import (
|
||||||
AxolotlCPOConfig,
|
AxolotlCPOConfig,
|
||||||
AxolotlDPOConfig,
|
AxolotlDPOConfig,
|
||||||
@@ -77,7 +76,6 @@ from axolotl.utils.callbacks.profiler import PytorchProfilerCallback
|
|||||||
from axolotl.utils.chat_templates import get_chat_template_from_config
|
from axolotl.utils.chat_templates import get_chat_template_from_config
|
||||||
from axolotl.utils.collators import (
|
from axolotl.utils.collators import (
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForKD,
|
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
MambaDataCollator,
|
MambaDataCollator,
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
@@ -306,8 +304,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlMambaTrainer
|
return AxolotlMambaTrainer
|
||||||
if self.cfg.reward_model:
|
if self.cfg.reward_model:
|
||||||
return AxolotlRewardTrainer
|
return AxolotlRewardTrainer
|
||||||
if self.cfg.trainer == "kd":
|
|
||||||
return AxolotlKDTrainer
|
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
@@ -797,7 +793,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
Union[
|
Union[
|
||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
BatchSamplerDataCollatorForSeq2Seq,
|
BatchSamplerDataCollatorForSeq2Seq,
|
||||||
DataCollatorForKD,
|
|
||||||
DataCollatorForSeq2Seq,
|
DataCollatorForSeq2Seq,
|
||||||
DataCollatorWithFlattening,
|
DataCollatorWithFlattening,
|
||||||
RewardDataCollatorWithPadding,
|
RewardDataCollatorWithPadding,
|
||||||
@@ -828,7 +823,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
collator_args.pop(0)
|
collator_args.pop(0)
|
||||||
kwargs.pop("pad_to_multiple_of", None)
|
kwargs.pop("pad_to_multiple_of", None)
|
||||||
kwargs.pop("padding", None)
|
kwargs.pop("padding", None)
|
||||||
elif self.cfg.trainer == "kd":
|
elif self.cfg.kd_trainer:
|
||||||
|
from axolotl.integrations.kd.collator import DataCollatorForKD
|
||||||
|
|
||||||
collator = DataCollatorForKD
|
collator = DataCollatorForKD
|
||||||
else:
|
else:
|
||||||
collator = DataCollatorForSeq2Seq
|
collator = DataCollatorForSeq2Seq
|
||||||
|
|||||||
@@ -0,0 +1,22 @@
|
|||||||
|
"""
|
||||||
|
Plugin init to add KD support to Axolotl.
|
||||||
|
"""
|
||||||
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
|
||||||
|
from .args import KDArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
class KDPlugin(BasePlugin):
|
||||||
|
"""
|
||||||
|
Plugin for KD support in Axolotl.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_input_args(self):
|
||||||
|
return "axolotl.integrations.kd.KDArgs"
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg):
|
||||||
|
if cfg.kd_trainer:
|
||||||
|
from .trainer import AxolotlKDTrainer
|
||||||
|
|
||||||
|
return AxolotlKDTrainer
|
||||||
|
return None
|
||||||
|
|||||||
19
src/axolotl/integrations/kd/args.py
Normal file
19
src/axolotl/integrations/kd/args.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Plugin args for KD support.
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class KDArgs(BaseModel):
|
||||||
|
"""
|
||||||
|
Input args for knowledge distillation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
kd_trainer: Optional[bool] = None # whether to use KD trainer
|
||||||
|
kd_ce_alpha: Optional[
|
||||||
|
float
|
||||||
|
] = None # loss coefficient for cross-entropy loss during KD
|
||||||
|
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
||||||
|
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
||||||
@@ -623,13 +623,6 @@ class AxolotlInputConfig(
|
|||||||
bool
|
bool
|
||||||
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer.
|
||||||
|
|
||||||
trainer: Optional[Literal["kd"]] = None
|
|
||||||
kd_ce_alpha: Optional[
|
|
||||||
float
|
|
||||||
] = None # loss coefficient for cross-entropy loss during KD
|
|
||||||
kd_alpha: Optional[float] = None # loss coefficient for KD loss
|
|
||||||
kd_temperature: Optional[float] = None # temperature for sampling during KD
|
|
||||||
|
|
||||||
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore
|
||||||
shuffle_merged_datasets: Optional[bool] = True
|
shuffle_merged_datasets: Optional[bool] = True
|
||||||
|
|||||||
Reference in New Issue
Block a user