collator for grpo and prompt loader

This commit is contained in:
Wing Lian
2025-02-03 00:18:17 -05:00
parent 79159b4871
commit 626db6cb84
3 changed files with 19 additions and 6 deletions

View File

@@ -797,6 +797,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
def build_collator(
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
):
if self.cfg.rl == "grpo":
return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs)
if training_args.pretraining:
if self.cfg.pretraining_sample_concatenation is False:
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)

View File

@@ -50,3 +50,12 @@ class GRPOStrategy:
trainer_kwargs["reward_funcs"] = reward_funcs
return trainer_kwargs
@classmethod
def get_collator(
cls, cfg, training_args, **kwargs
): # pylint: disable=unused-argument
def data_collator(features): # No data collation is needed in GRPO
return features
return data_collator

View File

@@ -13,18 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
if len(strategy.split(".")) == 1:
strategy = strategy + ".default"
load_fn = strategy.split(".")[-1]
strategy = ".".join(strategy.split(".")[:-1])
if len(strategy.split(".")) > 1:
try:
importlib.import_module(
"." + strategy.split(".")[-1],
".".join(strategy.split(".")[:-1]),
strategy.split(".")[-2],
".".join(strategy.split(".")[:-2]),
)
module_base = ".".join(strategy.split(".")[:-1])
strategy = strategy.split(".")[-1]
module_base = ".".join(strategy.split(".")[:-2])
strategy = strategy.split(".")[-2]
except ModuleNotFoundError:
pass
mod = importlib.import_module(f".{strategy}", module_base)
else:
strategy = "." + ".".join(strategy.split(".")[:-1])
mod = importlib.import_module(strategy, module_base)
func = getattr(mod, load_fn)
return func(cfg, **kwargs)
except Exception: # pylint: disable=broad-exception-caught