From 317761406ead12154e00fcf5ecf94e2b44e61530 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 6 May 2024 17:01:14 -0400 Subject: [PATCH] add support for NCA --- docs/config.qmd | 2 +- src/axolotl/core/trainer_builder.py | 10 ++++------ .../utils/config/models/input/v0_4_1/__init__.py | 1 + src/axolotl/utils/models.py | 2 +- src/axolotl/utils/trainer.py | 2 +- 5 files changed, 8 insertions(+), 9 deletions(-) diff --git a/docs/config.qmd b/docs/config.qmd index 7cc4a712f..54173cf0c 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -138,7 +138,7 @@ test_datasets: data_files: - /workspace/data/eval.jsonl -# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard' +# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard', 'nca_pair' rl: # Saves the desired chat template to the tokenizer_config.json for easier inferencing diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0974f6f61..10b14d3da 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1526,7 +1526,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.rl == "orpo": training_args_cls = ORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]: + elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]: training_args_cls = DPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes @@ -1553,10 +1553,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["loss_type"] = "ipo" if self.cfg.dpo_label_smoothing: dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing - elif self.cfg.rl == "kto_pair": - dpo_trainer_kwargs["loss_type"] = "kto_pair" - elif self.cfg.rl == "sppo_hard": - dpo_trainer_kwargs["loss_type"] = "sppo_hard" + elif self.cfg.rl in ["kto_pair", "sppo_hard", "nca_pair"]: + dpo_trainer_kwargs["loss_type"] = self.cfg.rl if self.eval_dataset: dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: @@ -1565,7 +1563,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs[ "precompute_ref_log_probs" ] = self.cfg.precompute_ref_log_probs - if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]: + if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"]: trainer_cls = AxolotlDPOTrainer dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 trainer_cls_args = [self.model, self.model_ref] diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 78a36232c..55c456ad3 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -134,6 +134,7 @@ class RLType(str, Enum): kto_pair = "kto_pair" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name sppo_hard = "sppo_hard" # pylint: disable=invalid-name + nca_pair = "nca_pair" # pylint: disable=invalid-name class ChatTemplate(str, Enum): diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index fc8a67acf..287418a55 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -791,7 +791,7 @@ def load_model( # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config if ( cfg.adapter - and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"] + and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard", "nca_pair"] and not cfg.merge_lora ): _, lora_config = load_lora(model, cfg, inference=False, config_only=True) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 1a0e55010..d819e9c0f 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -438,7 +438,7 @@ def prepare_optim_env(cfg): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): - if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]: + if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard", "nca_pair"]: trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder.model_ref = model[1] trainer_builder.peft_config = model[2]