Compare commits
1 Commits
moekernels
...
fix/rl-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d47093fcdd |
@@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from axolotl.core.builders.base import TrainerBuilderBase
|
from axolotl.core.builders.base import TrainerBuilderBase
|
||||||
from axolotl.core.trainers import (
|
from axolotl.core.trainers import (
|
||||||
AxolotlCPOTrainer,
|
AxolotlCPOTrainer,
|
||||||
|
AxolotlDPOTrainer,
|
||||||
AxolotlKTOTrainer,
|
AxolotlKTOTrainer,
|
||||||
AxolotlORPOTrainer,
|
AxolotlORPOTrainer,
|
||||||
)
|
)
|
||||||
@@ -36,33 +37,23 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
callbacks = super().get_post_trainer_create_callbacks(trainer=trainer)
|
||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self, trainer_kwargs: dict):
|
def _get_trainer_cls(self):
|
||||||
"""
|
"""Returns trainer_cls"""
|
||||||
Returns trainer_cls and trainer_cls_args
|
|
||||||
"""
|
|
||||||
if self.cfg.plugins:
|
if self.cfg.plugins:
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||||
trainer_cls_args = [] # type: ignore
|
|
||||||
|
|
||||||
if trainer_cls is not None:
|
if trainer_cls is not None:
|
||||||
return trainer_cls, trainer_cls_args
|
return trainer_cls
|
||||||
|
|
||||||
trainer_cls = None
|
trainer_cls = None
|
||||||
trainer_cls_args = [self.model]
|
|
||||||
|
|
||||||
if self.cfg.rl is RLType.GRPO:
|
if self.cfg.rl is RLType.GRPO:
|
||||||
trainer_cls = GRPOStrategy.get_trainer_class(
|
trainer_cls = GRPOStrategy.get_trainer_class(
|
||||||
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
sequence_parallel=self.cfg.sequence_parallel_degree > 1
|
||||||
)
|
)
|
||||||
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
|
||||||
|
|
||||||
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
|
||||||
|
|
||||||
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
elif self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||||
trainer_cls = DPOStrategy.get_trainer_class()
|
trainer_cls = AxolotlDPOTrainer
|
||||||
trainer_cls_args.append(self.model_ref)
|
|
||||||
|
|
||||||
elif self.cfg.rl is RLType.ORPO:
|
elif self.cfg.rl is RLType.ORPO:
|
||||||
trainer_cls = AxolotlORPOTrainer
|
trainer_cls = AxolotlORPOTrainer
|
||||||
elif self.cfg.rl is RLType.KTO:
|
elif self.cfg.rl is RLType.KTO:
|
||||||
@@ -72,7 +63,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
return trainer_cls, trainer_cls_args
|
return trainer_cls
|
||||||
|
|
||||||
def _build_training_arguments(self, total_num_steps):
|
def _build_training_arguments(self, total_num_steps):
|
||||||
"""
|
"""
|
||||||
@@ -182,7 +173,15 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
self.cfg.precompute_ref_log_probs
|
self.cfg.precompute_ref_log_probs
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer_cls, trainer_cls_args = self._get_trainer_cls(trainer_kwargs)
|
trainer_cls = self._get_trainer_cls()
|
||||||
|
trainer_cls_args = [self.model]
|
||||||
|
|
||||||
|
if self.cfg.rl is RLType.GRPO:
|
||||||
|
trainer_cls_args.extend(GRPOStrategy.set_trainer_args(self.cfg))
|
||||||
|
trainer_kwargs.update(GRPOStrategy.set_trainer_kwargs(self.cfg))
|
||||||
|
|
||||||
|
if self.cfg.rl in [RLType.DPO, RLType.IPO]:
|
||||||
|
trainer_cls_args.append(self.model_ref)
|
||||||
|
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
if "tokenizer" in sig.parameters:
|
if "tokenizer" in sig.parameters:
|
||||||
@@ -190,9 +189,7 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
trainer_kwargs["processing_class"] = self.tokenizer
|
trainer_kwargs["processing_class"] = self.tokenizer
|
||||||
|
|
||||||
if self.cfg.datasets is not None and (
|
if self.cfg.datasets is not None and (trainer_cls is AxolotlDPOTrainer):
|
||||||
trainer_cls is DPOStrategy.get_trainer_class()
|
|
||||||
):
|
|
||||||
trainer_kwargs["dataset_tags"] = [
|
trainer_kwargs["dataset_tags"] = [
|
||||||
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
d["path"] for d in self.cfg.datasets if not Path(d["path"]).is_dir()
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user