diff --git a/src/axolotl/core/trainer_builder/rl.py b/src/axolotl/core/trainer_builder/rl.py index cc5a9cf86..3406b0f00 100644 --- a/src/axolotl/core/trainer_builder/rl.py +++ b/src/axolotl/core/trainer_builder/rl.py @@ -183,9 +183,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase): trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing if self.eval_dataset: trainer_kwargs["eval_dataset"] = self.eval_dataset - if self.cfg.adapter and self.peft_config: - if self.cfg.rl is not RLType.GRPO: - trainer_kwargs["peft_config"] = self.peft_config + if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO: + trainer_kwargs["peft_config"] = self.peft_config if self.cfg.precompute_ref_log_probs is not None: trainer_kwargs["precompute_ref_log_probs"] = ( self.cfg.precompute_ref_log_probs @@ -243,5 +242,5 @@ class HFPPOTrainerBuilder(TrainerBuilderBase): return callbacks def build(self, total_num_steps): - # build PPOConfig + # TODO: build PPOConfig raise NotImplementedError("PPO trainer builder is not implemented yet.")