more fixes to get grpo working
This commit is contained in:
@@ -797,9 +797,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
def build_collator(
|
def build_collator(
|
||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
if self.cfg.rl == "grpo":
|
|
||||||
return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs)
|
|
||||||
|
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False:
|
if self.cfg.pretraining_sample_concatenation is False:
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
@@ -982,6 +979,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
training_args_kwargs["rpo_alpha"] = self.cfg.rpo_alpha
|
||||||
|
|
||||||
training_args_cls = None
|
training_args_cls = None
|
||||||
|
blocklist_args_kwargs = []
|
||||||
if self.cfg.rl == "simpo":
|
if self.cfg.rl == "simpo":
|
||||||
training_args_cls = AxolotlCPOConfig
|
training_args_cls = AxolotlCPOConfig
|
||||||
training_args_kwargs["loss_type"] = "simpo"
|
training_args_kwargs["loss_type"] = "simpo"
|
||||||
@@ -1014,14 +1012,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_args_cls = GRPOStrategy.get_training_args_class()
|
training_args_cls = GRPOStrategy.get_training_args_class()
|
||||||
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
training_args_kwargs.update(GRPOStrategy.set_training_args_kwargs(self.cfg))
|
||||||
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
blocklist_args_kwargs = GRPOStrategy.get_blocklist_args_kwargs()
|
||||||
for blocklist_key in blocklist_args_kwargs:
|
|
||||||
if blocklist_key in training_args_kwargs:
|
|
||||||
del training_args_kwargs[blocklist_key]
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
training_args_cls = DPOConfig.get_training_args_class()
|
training_args_cls = DPOConfig.get_training_args_class()
|
||||||
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
|
training_args_kwargs.update(DPOConfig.set_training_args_kwargs(self.cfg))
|
||||||
|
|
||||||
|
for blocklist_key in blocklist_args_kwargs:
|
||||||
|
if blocklist_key in training_args_kwargs:
|
||||||
|
del training_args_kwargs[blocklist_key]
|
||||||
|
|
||||||
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
training_args = training_args_cls( # pylint: disable=unexpected-keyword-arg
|
||||||
self.cfg.output_dir,
|
self.cfg.output_dir,
|
||||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||||
@@ -1054,6 +1053,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.precompute_ref_log_probs
|
] = self.cfg.precompute_ref_log_probs
|
||||||
if self.cfg.rl == "grpo":
|
if self.cfg.rl == "grpo":
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class()
|
trainer_cls = GRPOStrategy.get_trainer_class()
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
dpo_trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
elif self.cfg.rl in ["dpo", "ipo"]:
|
elif self.cfg.rl in ["dpo", "ipo"]:
|
||||||
trainer_cls = DPOStrategy.get_trainer_class()
|
trainer_cls = DPOStrategy.get_trainer_class()
|
||||||
|
|||||||
@@ -44,19 +44,18 @@ class GRPOStrategy:
|
|||||||
for reward_func_module in cfg.grpo_reward_funcs:
|
for reward_func_module in cfg.grpo_reward_funcs:
|
||||||
# use importlib to dynamically load the reward function from the module
|
# use importlib to dynamically load the reward function from the module
|
||||||
reward_func_module_name = reward_func_module.split(".")[-1]
|
reward_func_module_name = reward_func_module.split(".")[-1]
|
||||||
reward_func_module = importlib.import_module(reward_func_module)
|
reward_func_module = importlib.import_module(
|
||||||
|
reward_func_module.split(".")[-2]
|
||||||
|
)
|
||||||
reward_func = getattr(reward_func_module, reward_func_module_name)
|
reward_func = getattr(reward_func_module, reward_func_module_name)
|
||||||
reward_funcs.append(reward_func)
|
reward_funcs.append(reward_func)
|
||||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||||
trainer_kwargs["data_collator"] = cls.get_collator(cfg)
|
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
def get_collator(cls, *args, **kwargs): # pylint: disable=unused-argument
|
||||||
def data_collator(features): # No data collation is needed in GRPO
|
# No data collation is needed in GRPO, handled by trl's trainer __init__
|
||||||
return features
|
return None
|
||||||
|
|
||||||
return data_collator
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_blocklist_args_kwargs(cls):
|
def get_blocklist_args_kwargs(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user