diff --git a/requirements.txt b/requirements.txt index 557c5293f..b5114bbf6 100644 --- a/requirements.txt +++ b/requirements.txt @@ -39,6 +39,6 @@ s3fs gcsfs # adlfs -trl==0.8.6 +trl @ git+https://github.com/huggingface/trl.git@f18253bf2d747f68acc9cd89da95c85ebf59dbb9 zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0551ddbc0..b88cc4221 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -30,7 +30,7 @@ from transformers import ( ) from transformers.trainer_utils import seed_worker from transformers.utils import is_sagemaker_mp_enabled -from trl import DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer +from trl import DPOConfig, DPOTrainer, KTOConfig, KTOTrainer, ORPOConfig, ORPOTrainer from trl.trainer.utils import pad_to_length from axolotl.loraplus import create_loraplus_optimizer @@ -238,6 +238,13 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ +@dataclass +class AxolotlDPOConfig(AxolotlTrainingMixins, DPOConfig): + """ + DPO config for DPO training + """ + + @dataclass class AxolotlORPOConfig(AxolotlTrainingMixins, ORPOConfig): """ @@ -1608,7 +1615,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha - training_args_cls = AxolotlTrainingArguments + training_args_cls = AxolotlDPOConfig + if self.cfg.rpo_alpha is not None: + training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha if self.cfg.rl == "orpo": training_args_cls = AxolotlORPOConfig training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f363ebfdc..240a816ed 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -619,6 +619,7 @@ class AxolotlInputConfig( neftune_noise_alpha: Optional[float] = None orpo_alpha: Optional[float] = None + rpo_alpha: Optional[float] = None kto_desirable_weight: Optional[float] = None kto_undesirable_weight: Optional[float] = None diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 5d2522bdf..5f03e6bc1 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -70,6 +70,51 @@ class TestDPOLlamaLora(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @with_temp_dir + def test_dpo_nll_lora(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "dpo", + "rpo_alpha": 0.5, + "datasets": [ + { + "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", + "type": "chatml.ultra", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @with_temp_dir def test_kto_pair_lora(self, temp_dir): # pylint: disable=duplicate-code