Compare commits

...

1 Commits

Author SHA1 Message Date
NanoCode012
d47093fcdd fix: simplify fn same as sft and pass model to plugin 2025-07-08 22:29:56 +07:00

View File

@@ -6,6 +6,7 @@ from pathlib import Path
from axolotl.core.builders.base import TrainerBuilderBase from axolotl.core.builders.base import TrainerBuilderBase
from axolotl.core.trainers import ( from axolotl.core.trainers import (
AxolotlCPOTrainer, AxolotlCPOTrainer,
AxolotlDPOTrainer,
AxolotlKTOTrainer, AxolotlKTOTrainer,
AxolotlORPOTrainer, AxolotlORPOTrainer,
) )
@@ -36,33 +37,23 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer) callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
return callbacks return callbacks
def _get_trainer_cls(self, trainer_kwargs: dict): def _get_trainer_cls(self):
""" """Returns trainer_cls"""
Returns trainer_cls and trainer_cls_args
"""
if self.cfg.plugins: if self.cfg.plugins:
plugin_manager = PluginManager.get_instance() plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg) trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
trainer_cls_args = [] # type: ignore
if trainer_cls is not None: if trainer_cls is not None:
return trainer_cls, trainer_cls_args return trainer_cls
trainer_cls = None trainer_cls = None
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO: if self.cfg.rl is RLType.GRPO:
trainer_cls = GRPOStrategy.get_trainer_class( trainer_cls = GRPOStrategy.get_trainer_class(
sequence_parallel=self.cfg.sequence_parallel_degree > 1 sequence_parallel=self.cfg.sequence_parallel_degree > 1
) )
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
elif self.cfg.rl in [RLType.DPO, RLType.IPO]: elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls = DPOStrategy.get_trainer_class() trainer_cls = AxolotlDPOTrainer
trainer_cls_args.append(self.model_ref)
elif self.cfg.rl is RLType.ORPO: elif self.cfg.rl is RLType.ORPO:
trainer_cls = AxolotlORPOTrainer trainer_cls = AxolotlORPOTrainer
elif self.cfg.rl is RLType.KTO: elif self.cfg.rl is RLType.KTO:
@@ -72,7 +63,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else: else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}") raise ValueError(f"Unsupported RL: {self.cfg.rl}")
return trainer_cls, trainer_cls_args return trainer_cls
def _build_training_arguments(self, total_num_steps): def _build_training_arguments(self, total_num_steps):
""" """
@@ -182,7 +173,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
self.cfg.precompute_ref_log_probs self.cfg.precompute_ref_log_probs
) )
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs) trainer_cls = self._get_trainer_cls()
trainer_cls_args = [self.model]
if self.cfg.rl is RLType.GRPO:
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
if self.cfg.rl in [RLType.DPO, RLType.IPO]:
trainer_cls_args.append(self.model_ref)
sig = inspect.signature(trainer_cls) sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters: if "tokenizer" in sig.parameters:
@@ -190,9 +189,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else: else:
trainer_kwargs["processing_class"] = self.tokenizer trainer_kwargs["processing_class"] = self.tokenizer
if self.cfg.datasets is not None and ( if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
trainer_cls is DPOStrategy.get_trainer_class()
):
trainer_kwargs["dataset_tags"] = [ trainer_kwargs["dataset_tags"] = [
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir() d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
] ]