diff --git a/docs/config.qmd b/docs/config.qmd index a7bf9080b..238f7201d 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -183,6 +183,8 @@ test_datasets: # use RL training: 'dpo', 'ipo', 'kto' rl: +# whether to perform weighting if doing DPO training. Boolean. +dpo_use_weighting: # The name of the chat template to use for training, following values are supported: # - tokenizer_default: Uses the chat template that is available in the tokenizer_config.json. If the chat template is not available in the tokenizer, it will raise an error. This is the default value. diff --git a/requirements.txt b/requirements.txt index ec823a82a..735f860a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -43,7 +43,7 @@ s3fs>=2024.5.0 gcsfs>=2024.5.0 # adlfs -trl @ git+https://github.com/huggingface/trl.git@31d02cfb795284591a084416b9dcb7bef5d08924 +trl==0.12.0 zstandard==0.22.0 fastcore diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 4fadd7eb4..7b83707b8 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1890,17 +1890,18 @@ class HFRLTrainerBuilder(TrainerBuilderBase): # default to saving each epoch if not defined training_args_kwargs["save_strategy"] = "epoch" + training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes + if self.cfg.rl_beta: training_args_kwargs["beta"] = self.cfg.rl_beta if self.cfg.orpo_alpha: # trl does some odd mapping of alpha to beta to reuse the beta parameter ??? training_args_kwargs["beta"] = self.cfg.orpo_alpha - training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes - training_args_cls = AxolotlDPOConfig if self.cfg.rpo_alpha is not None: training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha + training_args_cls = None if self.cfg.rl == "simpo": training_args_cls = AxolotlCPOConfig training_args_kwargs["loss_type"] = "simpo" @@ -1909,13 +1910,13 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.cpo_alpha is not None: training_args_kwargs["cpo_alpha"] = self.cfg.cpo_alpha - if self.cfg.rl == "orpo": + elif self.cfg.rl == "orpo": training_args_cls = AxolotlORPOConfig training_args_kwargs["max_length"] = self.cfg.sequence_len if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len - if self.cfg.rl == "kto": + elif self.cfg.rl == "kto": training_args_cls = AxolotlKTOConfig training_args_kwargs["desirable_weight"] = ( @@ -1930,6 +1931,11 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.max_prompt_len: training_args_kwargs["max_prompt_length"] = self.cfg.max_prompt_len + else: + training_args_cls = AxolotlDPOConfig + if self.cfg.dpo_use_weighting is not None: + training_args_kwargs["use_weighting"] = self.cfg.dpo_use_weighting + training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg output_dir=self.cfg.output_dir, per_device_train_batch_size=self.cfg.micro_batch_size, diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 2e5749230..64fec67e0 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -588,6 +588,9 @@ class AxolotlInputConfig( rl: Optional[RLType] = None reward_model: Optional[bool] = None + dpo_use_weighting: Optional[ + bool + ] = None # whether to use weighting in DPO trainer. If none, default is false in the trainer. datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore test_datasets: Optional[conlist(Union[SFTDataset, DPODataset, KTODataset], min_length=1)] = None # type: ignore diff --git a/tests/e2e/test_dpo.py b/tests/e2e/test_dpo.py index 1c354e9a0..4a705922f 100644 --- a/tests/e2e/test_dpo.py +++ b/tests/e2e/test_dpo.py @@ -115,6 +115,51 @@ class TestDPOLlamaLora(unittest.TestCase): train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @with_temp_dir + def test_dpo_use_weighting(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 64, + "lora_alpha": 32, + "lora_dropout": 0.1, + "lora_target_linear": True, + "special_tokens": {}, + "rl": "dpo", + "dpo_use_weighting": True, + "datasets": [ + { + "path": "arcee-ai/distilabel-intel-orca-dpo-pairs-binarized", + "type": "chatml.ultra", + "split": "train", + }, + ], + "num_epochs": 1, + "micro_batch_size": 4, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "paged_adamw_8bit", + "lr_scheduler": "cosine", + "max_steps": 20, + "save_steps": 10, + "warmup_steps": 5, + "gradient_checkpointing": True, + "gradient_checkpointing_kwargs": {"use_reentrant": True}, + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_rl_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "checkpoint-20/adapter_model.safetensors").exists() + @pytest.mark.skip("kto_pair no longer supported in trl") @with_temp_dir def test_kto_pair_lora(self, temp_dir):