Compare commits
1 Commits
v0.16.0
...
fix/rl-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d47093fcdd |
@@ -6,6 +6,7 @@ from pathlib import Path
|
||||
from axolotl.core.builders.base import TrainerBuilderBase
|
||||
from axolotl.core.trainers import (
|
||||
AxolotlCPOTrainer,
|
||||
AxolotlDPOTrainer,
|
||||
AxolotlKTOTrainer,
|
||||
AxolotlORPOTrainer,
|
||||
)
|
||||
@@ -36,33 +37,23 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||
return callbacks
|
||||
|
||||
def _get_trainer_cls(self, trainer_kwargs: dict):
|
||||
"""
|
||||
Returns trainer_cls and trainer_cls_args
|
||||
"""
|
||||
def _get_trainer_cls(self):
|
||||
"""Returns trainer_cls"""
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
trainer_cls_args = [] # type: ignore
|
||||
|
||||
if trainer_cls is not None:
|
||||
return trainer_cls, trainer_cls_args
|
||||
return trainer_cls
|
||||
|
||||
trainer_cls = None
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||
)
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls = DPOStrategy.get_trainer_class()
|
||||
trainer_cls_args.append(self.model_ref)
|
||||
|
||||
trainer_cls = AxolotlDPOTrainer
|
||||
elif self.cfg.rl is RLType.ORPO:
|
||||
trainer_cls = AxolotlORPOTrainer
|
||||
elif self.cfg.rl is RLType.KTO:
|
||||
@@ -72,7 +63,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
return trainer_cls
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
"""
|
||||
@@ -182,7 +173,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
self.cfg.precompute_ref_log_probs
|
||||
)
|
||||
|
||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
||||
trainer_cls = self._get_trainer_cls()
|
||||
trainer_cls_args = [self.model]
|
||||
|
||||
if self.cfg.rl is RLType.GRPO:
|
||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||
|
||||
if self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||
trainer_cls_args.append(self.model_ref)
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters:
|
||||
@@ -190,9 +189,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
trainer_kwargs["processing_class"] = self.tokenizer
|
||||
|
||||
if self.cfg.datasets is not None and (
|
||||
trainer_cls is DPOStrategy.get_trainer_class()
|
||||
):
|
||||
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||
trainer_kwargs["dataset_tags"] = [
|
||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user