paired kto support (#1069)
This commit is contained in:
@@ -595,6 +595,9 @@ datasets:
|
|||||||
# For `completion` datsets only, uses the provided field instead of `text` column
|
# For `completion` datsets only, uses the provided field instead of `text` column
|
||||||
field:
|
field:
|
||||||
|
|
||||||
|
# use RL training: dpo, ipo, kto_pair
|
||||||
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
# Currently supports chatml and inst (mistral/mixtral)
|
# Currently supports chatml and inst (mistral/mixtral)
|
||||||
chat_template: chatml
|
chat_template: chatml
|
||||||
|
|||||||
@@ -40,4 +40,4 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@main
|
trl>=0.7.9
|
||||||
|
|||||||
@@ -927,6 +927,8 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["loss_type"] = "ipo"
|
dpo_trainer_kwargs["loss_type"] = "ipo"
|
||||||
if self.cfg.dpo_label_smoothing:
|
if self.cfg.dpo_label_smoothing:
|
||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
|
elif self.cfg.rl == "kto_pair":
|
||||||
|
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||||
|
|
||||||
dpo_trainer = DPOTrainer(
|
dpo_trainer = DPOTrainer(
|
||||||
self.model,
|
self.model,
|
||||||
|
|||||||
Reference in New Issue
Block a user