collator for grpo and prompt loader
This commit is contained in:
@@ -797,6 +797,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
def build_collator(
|
||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||
):
|
||||
if self.cfg.rl == "grpo":
|
||||
return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs)
|
||||
|
||||
if training_args.pretraining:
|
||||
if self.cfg.pretraining_sample_concatenation is False:
|
||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||
|
||||
@@ -50,3 +50,12 @@ class GRPOStrategy:
|
||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||
|
||||
return trainer_kwargs
|
||||
|
||||
@classmethod
|
||||
def get_collator(
|
||||
cls, cfg, training_args, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
def data_collator(features): # No data collation is needed in GRPO
|
||||
return features
|
||||
|
||||
return data_collator
|
||||
|
||||
@@ -13,18 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
||||
if len(strategy.split(".")) == 1:
|
||||
strategy = strategy + ".default"
|
||||
load_fn = strategy.split(".")[-1]
|
||||
strategy = ".".join(strategy.split(".")[:-1])
|
||||
if len(strategy.split(".")) > 1:
|
||||
try:
|
||||
importlib.import_module(
|
||||
"." + strategy.split(".")[-1],
|
||||
".".join(strategy.split(".")[:-1]),
|
||||
strategy.split(".")[-2],
|
||||
".".join(strategy.split(".")[:-2]),
|
||||
)
|
||||
module_base = ".".join(strategy.split(".")[:-1])
|
||||
strategy = strategy.split(".")[-1]
|
||||
module_base = ".".join(strategy.split(".")[:-2])
|
||||
strategy = strategy.split(".")[-2]
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
mod = importlib.import_module(f".{strategy}", module_base)
|
||||
else:
|
||||
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||
mod = importlib.import_module(strategy, module_base)
|
||||
func = getattr(mod, load_fn)
|
||||
return func(cfg, **kwargs)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
|
||||
Reference in New Issue
Block a user