allow custom trainer_cls to be defined as a module reference in the YAML (#3024) [skip ci]
* allow custom trainer_cls to be defined as a module reference in the YAML * address PR feedback and add test * add tests
This commit is contained in:
@@ -43,6 +43,7 @@ from axolotl.utils.collators import (
|
||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||
)
|
||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
@@ -136,6 +137,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return AxolotlRewardTrainer
|
||||
if self.cfg.process_reward_model:
|
||||
return AxolotlPRMTrainer
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
return trainer_cls
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
|
||||
@@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy
|
||||
from axolotl.integrations.base import PluginManager
|
||||
from axolotl.loaders.utils import ensure_dtype
|
||||
from axolotl.utils.callbacks.qat import QATCallback
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
from axolotl.utils.logging import get_logger
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
@@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
else:
|
||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||
|
||||
if self.cfg.trainer_cls:
|
||||
# override the trainer cls
|
||||
try:
|
||||
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ValueError(
|
||||
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||
) from e
|
||||
|
||||
return trainer_cls, trainer_cls_args
|
||||
|
||||
def _build_training_arguments(self, total_num_steps):
|
||||
|
||||
28
src/axolotl/utils/import_helper.py
Normal file
28
src/axolotl/utils/import_helper.py
Normal file
@@ -0,0 +1,28 @@
|
||||
"""
|
||||
Helper for importing modules from strings
|
||||
"""
|
||||
|
||||
import importlib
|
||||
|
||||
|
||||
def get_cls_from_module_str(module_str: str):
|
||||
# use importlib to dynamically load the reward function from the module
|
||||
if not isinstance(module_str, str) or not module_str.strip():
|
||||
raise ValueError("module_str must be a non-empty string")
|
||||
|
||||
parts = module_str.split(".")
|
||||
if len(parts) < 2:
|
||||
raise ValueError(f"Invalid module string format: {module_str}")
|
||||
|
||||
try:
|
||||
cls_name = parts[-1]
|
||||
module_path = ".".join(parts[:-1])
|
||||
mod = importlib.import_module(module_path)
|
||||
mod_cls = getattr(mod, cls_name)
|
||||
return mod_cls
|
||||
except ImportError as e:
|
||||
raise ImportError(f"Failed to import module '{module_path}': {e}") from e
|
||||
except AttributeError as e:
|
||||
raise AttributeError(
|
||||
f"Class '{cls_name}' not found in module '{module_path}': {e}"
|
||||
) from e
|
||||
@@ -110,6 +110,13 @@ class AxolotlInputConfig(
|
||||
},
|
||||
)
|
||||
|
||||
trainer_cls: str | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
"description": "module to custom trainer class to use for training"
|
||||
},
|
||||
)
|
||||
|
||||
rl: RLType | None = Field(
|
||||
default=None,
|
||||
json_schema_extra={
|
||||
|
||||
37
tests/utils/test_import_helper.py
Normal file
37
tests/utils/test_import_helper.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
test cases for axolotl.utils.import_helper
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
|
||||
|
||||
def test_get_cls_from_module_str():
|
||||
cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer")
|
||||
assert cls.__name__ == "AxolotlTrainer"
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_empty_string():
|
||||
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
|
||||
get_cls_from_module_str("")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_whitespace_only():
|
||||
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
|
||||
get_cls_from_module_str(" ")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_invalid_format():
|
||||
with pytest.raises(ValueError, match="Invalid module string format"):
|
||||
get_cls_from_module_str("single_part")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_nonexistent_module():
|
||||
with pytest.raises(ImportError, match="Failed to import module"):
|
||||
get_cls_from_module_str("nonexistent.module.Class")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_nonexistent_class():
|
||||
with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"):
|
||||
get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass")
|
||||
Reference in New Issue
Block a user