From 4bce713b39035f25dcbd22b11edbedae76db917b Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 6 Aug 2025 22:49:19 -0400 Subject: [PATCH] allow custom trainer_cls to be defined as a module reference in the YAML (#3024) [skip ci] * allow custom trainer_cls to be defined as a module reference in the YAML * address PR feedback and add test * add tests --- src/axolotl/core/builders/causal.py | 13 ++++++++++ src/axolotl/core/builders/rl.py | 11 +++++++++ src/axolotl/utils/import_helper.py | 28 ++++++++++++++++++++++ src/axolotl/utils/schemas/config.py | 7 ++++++ tests/utils/test_import_helper.py | 37 +++++++++++++++++++++++++++++ 5 files changed, 96 insertions(+) create mode 100644 src/axolotl/utils/import_helper.py create mode 100644 tests/utils/test_import_helper.py diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index b461e9009..db35a2412 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -43,6 +43,7 @@ from axolotl.utils.collators import ( V2BatchSamplerDataCollatorForSeq2Seq, ) from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator +from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger LOG = get_logger(__name__) @@ -136,6 +137,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return AxolotlRewardTrainer if self.cfg.process_reward_model: return AxolotlPRMTrainer + + if self.cfg.trainer_cls: + # override the trainer cls + try: + trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls) + LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}") + return trainer_cls + except (ImportError, AttributeError, ValueError) as e: + raise ValueError( + f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}" + ) from e + return AxolotlTrainer def build(self, total_num_steps): diff --git a/src/axolotl/core/builders/rl.py b/src/axolotl/core/builders/rl.py index 8cc6eeebf..bc7816807 100644 --- a/src/axolotl/core/builders/rl.py +++ b/src/axolotl/core/builders/rl.py @@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy from axolotl.integrations.base import PluginManager from axolotl.loaders.utils import ensure_dtype from axolotl.utils.callbacks.qat import QATCallback +from axolotl.utils.import_helper import get_cls_from_module_str from axolotl.utils.logging import get_logger from axolotl.utils.schemas.enums import RLType @@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase): else: raise ValueError(f"Unsupported RL: {self.cfg.rl}") + if self.cfg.trainer_cls: + # override the trainer cls + try: + trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls) + LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}") + except (ImportError, AttributeError, ValueError) as e: + raise ValueError( + f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}" + ) from e + return trainer_cls, trainer_cls_args def _build_training_arguments(self, total_num_steps): diff --git a/src/axolotl/utils/import_helper.py b/src/axolotl/utils/import_helper.py new file mode 100644 index 000000000..f7d20099c --- /dev/null +++ b/src/axolotl/utils/import_helper.py @@ -0,0 +1,28 @@ +""" +Helper for importing modules from strings +""" + +import importlib + + +def get_cls_from_module_str(module_str: str): + # use importlib to dynamically load the reward function from the module + if not isinstance(module_str, str) or not module_str.strip(): + raise ValueError("module_str must be a non-empty string") + + parts = module_str.split(".") + if len(parts) < 2: + raise ValueError(f"Invalid module string format: {module_str}") + + try: + cls_name = parts[-1] + module_path = ".".join(parts[:-1]) + mod = importlib.import_module(module_path) + mod_cls = getattr(mod, cls_name) + return mod_cls + except ImportError as e: + raise ImportError(f"Failed to import module '{module_path}': {e}") from e + except AttributeError as e: + raise AttributeError( + f"Class '{cls_name}' not found in module '{module_path}': {e}" + ) from e diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index e3de6e37b..21e99c048 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -110,6 +110,13 @@ class AxolotlInputConfig( }, ) + trainer_cls: str | None = Field( + default=None, + json_schema_extra={ + "description": "module to custom trainer class to use for training" + }, + ) + rl: RLType | None = Field( default=None, json_schema_extra={ diff --git a/tests/utils/test_import_helper.py b/tests/utils/test_import_helper.py new file mode 100644 index 000000000..e1ab8bec5 --- /dev/null +++ b/tests/utils/test_import_helper.py @@ -0,0 +1,37 @@ +""" +test cases for axolotl.utils.import_helper +""" + +import pytest + +from axolotl.utils.import_helper import get_cls_from_module_str + + +def test_get_cls_from_module_str(): + cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer") + assert cls.__name__ == "AxolotlTrainer" + + +def test_get_cls_from_module_str_empty_string(): + with pytest.raises(ValueError, match="module_str must be a non-empty string"): + get_cls_from_module_str("") + + +def test_get_cls_from_module_str_whitespace_only(): + with pytest.raises(ValueError, match="module_str must be a non-empty string"): + get_cls_from_module_str(" ") + + +def test_get_cls_from_module_str_invalid_format(): + with pytest.raises(ValueError, match="Invalid module string format"): + get_cls_from_module_str("single_part") + + +def test_get_cls_from_module_str_nonexistent_module(): + with pytest.raises(ImportError, match="Failed to import module"): + get_cls_from_module_str("nonexistent.module.Class") + + +def test_get_cls_from_module_str_nonexistent_class(): + with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"): + get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass")