support for custom trainer classes from plugins
This commit is contained in:
@@ -295,6 +295,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return callbacks
|
return callbacks
|
||||||
|
|
||||||
def _get_trainer_cls(self):
|
def _get_trainer_cls(self):
|
||||||
|
if self.cfg.plugins:
|
||||||
|
plugin_manager = PluginManager.get_instance()
|
||||||
|
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||||
|
if trainer_cls:
|
||||||
|
return trainer_cls
|
||||||
if self.cfg.relora_steps:
|
if self.cfg.relora_steps:
|
||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
|
|||||||
@@ -111,6 +111,17 @@ class BasePlugin:
|
|||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument):
|
||||||
|
"""
|
||||||
|
Returns a custom class for the trainer.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The global axolotl configuration.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
class: The class for the trainer.
|
||||||
|
"""
|
||||||
|
|
||||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||||
"""
|
"""
|
||||||
Creates and returns an optimizer for training.
|
Creates and returns an optimizer for training.
|
||||||
@@ -346,6 +357,22 @@ class PluginManager:
|
|||||||
for plugin in self.plugins.values():
|
for plugin in self.plugins.values():
|
||||||
plugin.post_lora_load(cfg, model)
|
plugin.post_lora_load(cfg, model)
|
||||||
|
|
||||||
|
def get_trainer_cls(self, cfg):
|
||||||
|
"""
|
||||||
|
Calls the get_trainer_cls method of all registered plugins and returns the first non-None trainer class.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
cfg (dict): The configuration for the plugins.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
object: The trainer class, or None if none was found.
|
||||||
|
"""
|
||||||
|
for plugin in self.plugins.values():
|
||||||
|
trainer_cls = plugin.get_trainer_cls(cfg)
|
||||||
|
if trainer_cls is not None:
|
||||||
|
return trainer_cls
|
||||||
|
return None
|
||||||
|
|
||||||
def create_optimizer(self, cfg, trainer):
|
def create_optimizer(self, cfg, trainer):
|
||||||
"""
|
"""
|
||||||
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
|
Calls the create_optimizer method of all registered plugins and returns the first non-None optimizer.
|
||||||
|
|||||||
Reference in New Issue
Block a user