chore: refactor if condition
This commit is contained in:
@@ -183,9 +183,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config and self.cfg.rl is not RLType.GRPO:
|
||||||
if self.cfg.rl is not RLType.GRPO:
|
trainer_kwargs["peft_config"] = self.peft_config
|
||||||
trainer_kwargs["peft_config"] = self.peft_config
|
|
||||||
if self.cfg.precompute_ref_log_probs is not None:
|
if self.cfg.precompute_ref_log_probs is not None:
|
||||||
trainer_kwargs["precompute_ref_log_probs"] = (
|
trainer_kwargs["precompute_ref_log_probs"] = (
|
||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
@@ -243,5 +242,5 @@ class HFPPOTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
# build PPOConfig
|
# TODO: build PPOConfig
|
||||||
raise NotImplementedError("PPO trainer builder is not implemented yet.")
|
raise NotImplementedError("PPO trainer builder is not implemented yet.")
|
||||||
|
|||||||
Reference in New Issue
Block a user