update for sppo

This commit is contained in:
Wing Lian
2024-05-03 08:41:59 -04:00
parent 0554105baa
commit 027f7d54f0
6 changed files with 9 additions and 9 deletions

View File

@@ -138,7 +138,7 @@ test_datasets:
data_files: data_files:
- /workspace/data/eval.jsonl - /workspace/data/eval.jsonl
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo' # use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard'
rl: rl:
# Saves the desired chat template to the tokenizer_config.json for easier inferencing # Saves the desired chat template to the tokenizer_config.json for easier inferencing

View File

@@ -39,6 +39,6 @@ s3fs
gcsfs gcsfs
# adlfs # adlfs
trl @ git+https://github.com/huggingface/trl.git@7075cec94df1a0c5be90e75214e996efaf9a6c0b trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
zstandard==0.22.0 zstandard==0.22.0
fastcore fastcore

View File

@@ -1526,7 +1526,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.rl == "orpo": if self.cfg.rl == "orpo":
training_args_cls = ORPOConfig training_args_cls = ORPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
training_args_cls = DPOConfig training_args_cls = DPOConfig
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
@@ -1555,8 +1555,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
elif self.cfg.rl == "kto_pair": elif self.cfg.rl == "kto_pair":
dpo_trainer_kwargs["loss_type"] = "kto_pair" dpo_trainer_kwargs["loss_type"] = "kto_pair"
elif self.cfg.rl == "sppo": elif self.cfg.rl == "sppo_hard":
dpo_trainer_kwargs["loss_type"] = "sppo" dpo_trainer_kwargs["loss_type"] = "sppo_hard"
if self.eval_dataset: if self.eval_dataset:
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
if self.cfg.adapter and self.peft_config: if self.cfg.adapter and self.peft_config:
@@ -1565,7 +1565,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
dpo_trainer_kwargs[ dpo_trainer_kwargs[
"precompute_ref_log_probs" "precompute_ref_log_probs"
] = self.cfg.precompute_ref_log_probs ] = self.cfg.precompute_ref_log_probs
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]: if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
trainer_cls = AxolotlDPOTrainer trainer_cls = AxolotlDPOTrainer
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1 dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
trainer_cls_args = [self.model, self.model_ref] trainer_cls_args = [self.model, self.model_ref]

View File

@@ -133,7 +133,7 @@ class RLType(str, Enum):
ipo = "ipo" # pylint: disable=invalid-name ipo = "ipo" # pylint: disable=invalid-name
kto_pair = "kto_pair" # pylint: disable=invalid-name kto_pair = "kto_pair" # pylint: disable=invalid-name
orpo = "orpo" # pylint: disable=invalid-name orpo = "orpo" # pylint: disable=invalid-name
sppo = "sppo" # pylint: disable=invalid-name sppo = "sppo_hard" # pylint: disable=invalid-name
class ChatTemplate(str, Enum): class ChatTemplate(str, Enum):

View File

@@ -791,7 +791,7 @@ def load_model(
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config # then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
if ( if (
cfg.adapter cfg.adapter
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"] and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]
and not cfg.merge_lora and not cfg.merge_lora
): ):
_, lora_config = load_lora(model, cfg, inference=False, config_only=True) _, lora_config = load_lora(model, cfg, inference=False, config_only=True)

View File

@@ -438,7 +438,7 @@ def prepare_optim_env(cfg):
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps): def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo"]: if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]:
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer) trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
trainer_builder.model_ref = model[1] trainer_builder.model_ref = model[1]
trainer_builder.peft_config = model[2] trainer_builder.peft_config = model[2]