From 33e117088f0299ab80414ed972bceaddf3629308 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 Jan 2024 09:31:55 -0500 Subject: [PATCH] precompute dpo logprobs setting and fixes (#1199) [skip ci] * add support for precompute_ref_log_probs for dpo * add chatml.icr type for argilla orca dpo * update inline doc * also set use_reentrant to false for dpo when not set * don't set use_reentrant to true for rl * make sure to set gradient checkpointing too --- src/axolotl/core/trainer_builder.py | 21 +++++++++++++---- src/axolotl/prompt_strategies/dpo/chatml.py | 25 +++++++++++++++++++++ src/axolotl/utils/config.py | 1 + 3 files changed, 43 insertions(+), 4 deletions(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 15ad71470..7d47d0e49 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -651,7 +651,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "gradient_checkpointing" ] = self.cfg.gradient_checkpointing - if self.cfg.gradient_checkpointing_kwargs: + if self.cfg.gradient_checkpointing_kwargs is not None: training_arguments_kwargs[ "gradient_checkpointing_kwargs" ] = self.cfg.gradient_checkpointing_kwargs @@ -1028,6 +1028,18 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): training_args_kwargs[ "dataloader_prefetch_factor" ] = self.cfg.dataloader_prefetch_factor + if self.cfg.gradient_checkpointing: + training_args_kwargs[ + "gradient_checkpointing" + ] = self.cfg.gradient_checkpointing + if self.cfg.gradient_checkpointing_kwargs is not None: + training_args_kwargs[ + "gradient_checkpointing_kwargs" + ] = self.cfg.gradient_checkpointing_kwargs + else: + training_args_kwargs["gradient_checkpointing_kwargs"] = { + "use_reentrant": False + } training_args = TrainingArguments( per_device_train_batch_size=self.cfg.micro_batch_size, @@ -1038,9 +1050,6 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): save_steps=self.cfg.save_steps, output_dir=self.cfg.output_dir, warmup_steps=self.cfg.warmup_steps, - gradient_checkpointing=self.cfg.gradient_checkpointing, - gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs - or {"use_reentrant": False}, logging_first_step=True, logging_steps=1, optim=self.cfg.optimizer, @@ -1063,6 +1072,10 @@ class HFDPOTrainerBuilder(TrainerBuilderBase): dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset if self.cfg.adapter and self.peft_config: dpo_trainer_kwargs["peft_config"] = self.peft_config + if self.cfg.precompute_ref_log_probs is not None: + dpo_trainer_kwargs[ + "precompute_ref_log_probs" + ] = self.cfg.precompute_ref_log_probs dpo_trainer = DPOTrainer( self.model, self.model_ref, diff --git a/src/axolotl/prompt_strategies/dpo/chatml.py b/src/axolotl/prompt_strategies/dpo/chatml.py index e0840f762..8f62a5088 100644 --- a/src/axolotl/prompt_strategies/dpo/chatml.py +++ b/src/axolotl/prompt_strategies/dpo/chatml.py @@ -23,6 +23,31 @@ def argilla( return transform_fn +def icr( + cfg, +): # pylint: disable=possibly-unused-variable,unused-argument + """ + chatml transforms for datasets with system, input, chosen, rejected + ex. https://huggingface.co/datasets/argilla/distilabel-intel-orca-dpo-pairs + """ + + def transform_fn(sample): + if "system" in sample and sample["system"]: + sample["prompt"] = ( + f"<|im_start|>system\n{sample['system']}<|im_end|>\n" + f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" + ) + else: + sample[ + "prompt" + ] = f"<|im_start|>user\n{sample['input']}<|im_end|>\n<|im_start|>assistant\n" + sample["chosen"] = f"{sample['chosen']}<|im_end|>" + sample["rejected"] = f"{sample['rejected']}<|im_end|>" + return sample + + return transform_fn + + def intel(cfg): # pylint: disable=possibly-unused-variable,unused-argument """ For Intel Orca DPO Pairs diff --git a/src/axolotl/utils/config.py b/src/axolotl/utils/config.py index 59cbef15d..dcd795099 100644 --- a/src/axolotl/utils/config.py +++ b/src/axolotl/utils/config.py @@ -163,6 +163,7 @@ def normalize_config(cfg): cfg.gradient_checkpointing and cfg.unfrozen_parameters is None and cfg.gradient_checkpointing_kwargs is None + and cfg.rl is None ): cfg.gradient_checkpointing_kwargs = {"use_reentrant": True}