collator for grpo and prompt loader

This commit is contained in:
Wing Lian
2025-02-03 00:18:17 -05:00
parent 79159b4871
commit 626db6cb84
3 changed files with 19 additions and 6 deletions

View File

@@ -797,6 +797,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def build_collator( def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
): ):
if self.cfg.rl == "grpo":
return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs)
if training_args.pretraining: if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False: if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs) return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)

View File

@@ -50,3 +50,12 @@ class GRPOStrategy:
trainer_kwargs["reward_funcs"] = reward_funcs trainer_kwargs["reward_funcs"] = reward_funcs
return trainer_kwargs return trainer_kwargs
@classmethod
def get_collator(
cls, cfg, training_args, **kwargs
): # pylint: disable=unused-argument
def data_collator(features): # No data collation is needed in GRPO
return features
return data_collator

View File

@@ -13,18 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
if len(strategy.split(".")) == 1: if len(strategy.split(".")) == 1:
strategy = strategy + ".default" strategy = strategy + ".default"
load_fn = strategy.split(".")[-1] load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
if len(strategy.split(".")) > 1: if len(strategy.split(".")) > 1:
try: try:
importlib.import_module( importlib.import_module(
"." + strategy.split(".")[-1], strategy.split(".")[-2],
".".join(strategy.split(".")[:-1]), ".".join(strategy.split(".")[:-2]),
) )
module_base = ".".join(strategy.split(".")[:-1]) module_base = ".".join(strategy.split(".")[:-2])
strategy = strategy.split(".")[-1] strategy = strategy.split(".")[-2]
except ModuleNotFoundError: except ModuleNotFoundError:
pass pass
mod = importlib.import_module(f".{strategy}", module_base) else:
strategy = "." + ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(strategy, module_base)
func = getattr(mod, load_fn) func = getattr(mod, load_fn)
return func(cfg, **kwargs) return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught except Exception: # pylint: disable=broad-exception-caught