allow custom trainer_cls to be defined as a module reference in the YAML (#3024) [skip ci]

* allow custom trainer_cls to be defined as a module reference in the YAML

* address PR feedback and add test

* add tests
This commit is contained in:
Wing Lian
2025-08-06 22:49:19 -04:00
committed by GitHub
parent d09290f2f4
commit 4bce713b39
5 changed files with 96 additions and 0 deletions

View File

@@ -43,6 +43,7 @@ from axolotl.utils.collators import (
V2BatchSamplerDataCollatorForSeq2Seq,
)
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
LOG = get_logger(__name__)
@@ -136,6 +137,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return AxolotlRewardTrainer
if self.cfg.process_reward_model:
return AxolotlPRMTrainer
if self.cfg.trainer_cls:
# override the trainer cls
try:
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
return trainer_cls
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
) from e
return AxolotlTrainer
def build(self, total_num_steps):

View File

@@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy
from axolotl.integrations.base import PluginManager
from axolotl.loaders.utils import ensure_dtype
from axolotl.utils.callbacks.qat import QATCallback
from axolotl.utils.import_helper import get_cls_from_module_str
from axolotl.utils.logging import get_logger
from axolotl.utils.schemas.enums import RLType
@@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
else:
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
if self.cfg.trainer_cls:
# override the trainer cls
try:
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
except (ImportError, AttributeError, ValueError) as e:
raise ValueError(
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
) from e
return trainer_cls, trainer_cls_args
def _build_training_arguments(self, total_num_steps):

View File

@@ -0,0 +1,28 @@
"""
Helper for importing modules from strings
"""
import importlib
def get_cls_from_module_str(module_str: str):
# use importlib to dynamically load the reward function from the module
if not isinstance(module_str, str) or not module_str.strip():
raise ValueError("module_str must be a non-empty string")
parts = module_str.split(".")
if len(parts) < 2:
raise ValueError(f"Invalid module string format: {module_str}")
try:
cls_name = parts[-1]
module_path = ".".join(parts[:-1])
mod = importlib.import_module(module_path)
mod_cls = getattr(mod, cls_name)
return mod_cls
except ImportError as e:
raise ImportError(f"Failed to import module '{module_path}': {e}") from e
except AttributeError as e:
raise AttributeError(
f"Class '{cls_name}' not found in module '{module_path}': {e}"
) from e

View File

@@ -110,6 +110,13 @@ class AxolotlInputConfig(
},
)
trainer_cls: str | None = Field(
default=None,
json_schema_extra={
"description": "module to custom trainer class to use for training"
},
)
rl: RLType | None = Field(
default=None,
json_schema_extra={

View File

@@ -0,0 +1,37 @@
"""
test cases for axolotl.utils.import_helper
"""
import pytest
from axolotl.utils.import_helper import get_cls_from_module_str
def test_get_cls_from_module_str():
cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer")
assert cls.__name__ == "AxolotlTrainer"
def test_get_cls_from_module_str_empty_string():
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
get_cls_from_module_str("")
def test_get_cls_from_module_str_whitespace_only():
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
get_cls_from_module_str(" ")
def test_get_cls_from_module_str_invalid_format():
with pytest.raises(ValueError, match="Invalid module string format"):
get_cls_from_module_str("single_part")
def test_get_cls_from_module_str_nonexistent_module():
with pytest.raises(ImportError, match="Failed to import module"):
get_cls_from_module_str("nonexistent.module.Class")
def test_get_cls_from_module_str_nonexistent_class():
with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"):
get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass")