update for sppo
This commit is contained in:
@@ -138,7 +138,7 @@ test_datasets:
|
|||||||
data_files:
|
data_files:
|
||||||
- /workspace/data/eval.jsonl
|
- /workspace/data/eval.jsonl
|
||||||
|
|
||||||
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo'
|
# use RL training: 'dpo', 'ipo', 'kto_pair', 'orpo', 'sppo_hard'
|
||||||
rl:
|
rl:
|
||||||
|
|
||||||
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
# Saves the desired chat template to the tokenizer_config.json for easier inferencing
|
||||||
|
|||||||
@@ -39,6 +39,6 @@ s3fs
|
|||||||
gcsfs
|
gcsfs
|
||||||
# adlfs
|
# adlfs
|
||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@7075cec94df1a0c5be90e75214e996efaf9a6c0b
|
trl @ git+https://github.com/huggingface/trl.git@75de236c09bd5846f79c24d9bf371481b0b7582c
|
||||||
zstandard==0.22.0
|
zstandard==0.22.0
|
||||||
fastcore
|
fastcore
|
||||||
|
|||||||
@@ -1526,7 +1526,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
if self.cfg.rl == "orpo":
|
if self.cfg.rl == "orpo":
|
||||||
training_args_cls = ORPOConfig
|
training_args_cls = ORPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]:
|
elif self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
||||||
training_args_cls = DPOConfig
|
training_args_cls = DPOConfig
|
||||||
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
training_args_kwargs["dataset_num_proc"] = self.cfg.dataset_processes
|
||||||
|
|
||||||
@@ -1555,8 +1555,8 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
dpo_trainer_kwargs["label_smoothing"] = self.cfg.dpo_label_smoothing
|
||||||
elif self.cfg.rl == "kto_pair":
|
elif self.cfg.rl == "kto_pair":
|
||||||
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
dpo_trainer_kwargs["loss_type"] = "kto_pair"
|
||||||
elif self.cfg.rl == "sppo":
|
elif self.cfg.rl == "sppo_hard":
|
||||||
dpo_trainer_kwargs["loss_type"] = "sppo"
|
dpo_trainer_kwargs["loss_type"] = "sppo_hard"
|
||||||
if self.eval_dataset:
|
if self.eval_dataset:
|
||||||
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
dpo_trainer_kwargs["eval_dataset"] = self.eval_dataset
|
||||||
if self.cfg.adapter and self.peft_config:
|
if self.cfg.adapter and self.peft_config:
|
||||||
@@ -1565,7 +1565,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
dpo_trainer_kwargs[
|
dpo_trainer_kwargs[
|
||||||
"precompute_ref_log_probs"
|
"precompute_ref_log_probs"
|
||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]:
|
if self.cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]:
|
||||||
trainer_cls = AxolotlDPOTrainer
|
trainer_cls = AxolotlDPOTrainer
|
||||||
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
dpo_trainer_kwargs["beta"] = self.cfg.dpo_beta or 0.1
|
||||||
trainer_cls_args = [self.model, self.model_ref]
|
trainer_cls_args = [self.model, self.model_ref]
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ class RLType(str, Enum):
|
|||||||
ipo = "ipo" # pylint: disable=invalid-name
|
ipo = "ipo" # pylint: disable=invalid-name
|
||||||
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
kto_pair = "kto_pair" # pylint: disable=invalid-name
|
||||||
orpo = "orpo" # pylint: disable=invalid-name
|
orpo = "orpo" # pylint: disable=invalid-name
|
||||||
sppo = "sppo" # pylint: disable=invalid-name
|
sppo = "sppo_hard" # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
class ChatTemplate(str, Enum):
|
class ChatTemplate(str, Enum):
|
||||||
|
|||||||
@@ -791,7 +791,7 @@ def load_model(
|
|||||||
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
# then the dpo trainer doesn't want the peft model loaded over it, it just wants the lora/peft config
|
||||||
if (
|
if (
|
||||||
cfg.adapter
|
cfg.adapter
|
||||||
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo"]
|
and cfg.rl in ["dpo", "ipo", "kto_pair", "sppo_hard"]
|
||||||
and not cfg.merge_lora
|
and not cfg.merge_lora
|
||||||
):
|
):
|
||||||
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
_, lora_config = load_lora(model, cfg, inference=False, config_only=True)
|
||||||
|
|||||||
@@ -438,7 +438,7 @@ def prepare_optim_env(cfg):
|
|||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo"]:
|
if cfg.rl in ["dpo", "ipo", "kto_pair", "orpo", "sppo_hard"]:
|
||||||
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
trainer_builder = HFRLTrainerBuilder(cfg, model[0], tokenizer)
|
||||||
trainer_builder.model_ref = model[1]
|
trainer_builder.model_ref = model[1]
|
||||||
trainer_builder.peft_config = model[2]
|
trainer_builder.peft_config = model[2]
|
||||||
|
|||||||
Reference in New Issue
Block a user