collator for grpo and prompt loader
This commit is contained in:
@@ -797,6 +797,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
def build_collator(
|
def build_collator(
|
||||||
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
self, training_args: AxolotlTrainingArguments, is_eval=False, **kwargs
|
||||||
):
|
):
|
||||||
|
if self.cfg.rl == "grpo":
|
||||||
|
return GRPOStrategy.get_collator(self.cfg, training_args, **kwargs)
|
||||||
|
|
||||||
if training_args.pretraining:
|
if training_args.pretraining:
|
||||||
if self.cfg.pretraining_sample_concatenation is False:
|
if self.cfg.pretraining_sample_concatenation is False:
|
||||||
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
return DataCollatorForSeq2Seq(self.tokenizer, **kwargs)
|
||||||
|
|||||||
@@ -50,3 +50,12 @@ class GRPOStrategy:
|
|||||||
trainer_kwargs["reward_funcs"] = reward_funcs
|
trainer_kwargs["reward_funcs"] = reward_funcs
|
||||||
|
|
||||||
return trainer_kwargs
|
return trainer_kwargs
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_collator(
|
||||||
|
cls, cfg, training_args, **kwargs
|
||||||
|
): # pylint: disable=unused-argument
|
||||||
|
def data_collator(features): # No data collation is needed in GRPO
|
||||||
|
return features
|
||||||
|
|
||||||
|
return data_collator
|
||||||
|
|||||||
@@ -13,18 +13,19 @@ def load(strategy, cfg, module_base=None, **kwargs):
|
|||||||
if len(strategy.split(".")) == 1:
|
if len(strategy.split(".")) == 1:
|
||||||
strategy = strategy + ".default"
|
strategy = strategy + ".default"
|
||||||
load_fn = strategy.split(".")[-1]
|
load_fn = strategy.split(".")[-1]
|
||||||
strategy = ".".join(strategy.split(".")[:-1])
|
|
||||||
if len(strategy.split(".")) > 1:
|
if len(strategy.split(".")) > 1:
|
||||||
try:
|
try:
|
||||||
importlib.import_module(
|
importlib.import_module(
|
||||||
"." + strategy.split(".")[-1],
|
strategy.split(".")[-2],
|
||||||
".".join(strategy.split(".")[:-1]),
|
".".join(strategy.split(".")[:-2]),
|
||||||
)
|
)
|
||||||
module_base = ".".join(strategy.split(".")[:-1])
|
module_base = ".".join(strategy.split(".")[:-2])
|
||||||
strategy = strategy.split(".")[-1]
|
strategy = strategy.split(".")[-2]
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
pass
|
pass
|
||||||
mod = importlib.import_module(f".{strategy}", module_base)
|
else:
|
||||||
|
strategy = "." + ".".join(strategy.split(".")[:-1])
|
||||||
|
mod = importlib.import_module(strategy, module_base)
|
||||||
func = getattr(mod, load_fn)
|
func = getattr(mod, load_fn)
|
||||||
return func(cfg, **kwargs)
|
return func(cfg, **kwargs)
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
|
|||||||
Reference in New Issue
Block a user