allow custom trainer_cls to be defined as a module reference in the YAML (#3024) [skip ci]
* allow custom trainer_cls to be defined as a module reference in the YAML * address PR feedback and add test * add tests
This commit is contained in:
@@ -43,6 +43,7 @@ from axolotl.utils.collators import (
|
|||||||
V2BatchSamplerDataCollatorForSeq2Seq,
|
V2BatchSamplerDataCollatorForSeq2Seq,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
from axolotl.utils.collators.mm_chat import MultiModalChatDataCollator
|
||||||
|
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
|
|
||||||
LOG = get_logger(__name__)
|
LOG = get_logger(__name__)
|
||||||
@@ -136,6 +137,18 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return AxolotlRewardTrainer
|
return AxolotlRewardTrainer
|
||||||
if self.cfg.process_reward_model:
|
if self.cfg.process_reward_model:
|
||||||
return AxolotlPRMTrainer
|
return AxolotlPRMTrainer
|
||||||
|
|
||||||
|
if self.cfg.trainer_cls:
|
||||||
|
# override the trainer cls
|
||||||
|
try:
|
||||||
|
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||||
|
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||||
|
return trainer_cls
|
||||||
|
except (ImportError, AttributeError, ValueError) as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from axolotl.core.trainers.grpo import GRPOStrategy
|
|||||||
from axolotl.integrations.base import PluginManager
|
from axolotl.integrations.base import PluginManager
|
||||||
from axolotl.loaders.utils import ensure_dtype
|
from axolotl.loaders.utils import ensure_dtype
|
||||||
from axolotl.utils.callbacks.qat import QATCallback
|
from axolotl.utils.callbacks.qat import QATCallback
|
||||||
|
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||||
from axolotl.utils.logging import get_logger
|
from axolotl.utils.logging import get_logger
|
||||||
from axolotl.utils.schemas.enums import RLType
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
@@ -72,6 +73,16 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
raise ValueError(f"Unsupported RL: {self.cfg.rl}")
|
||||||
|
|
||||||
|
if self.cfg.trainer_cls:
|
||||||
|
# override the trainer cls
|
||||||
|
try:
|
||||||
|
trainer_cls = get_cls_from_module_str(self.cfg.trainer_cls)
|
||||||
|
LOG.debug(f"Using custom trainer class: {self.cfg.trainer_cls}")
|
||||||
|
except (ImportError, AttributeError, ValueError) as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to load custom trainer class '{self.cfg.trainer_cls}': {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
return trainer_cls, trainer_cls_args
|
return trainer_cls, trainer_cls_args
|
||||||
|
|
||||||
def _build_training_arguments(self, total_num_steps):
|
def _build_training_arguments(self, total_num_steps):
|
||||||
|
|||||||
28
src/axolotl/utils/import_helper.py
Normal file
28
src/axolotl/utils/import_helper.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""
|
||||||
|
Helper for importing modules from strings
|
||||||
|
"""
|
||||||
|
|
||||||
|
import importlib
|
||||||
|
|
||||||
|
|
||||||
|
def get_cls_from_module_str(module_str: str):
|
||||||
|
# use importlib to dynamically load the reward function from the module
|
||||||
|
if not isinstance(module_str, str) or not module_str.strip():
|
||||||
|
raise ValueError("module_str must be a non-empty string")
|
||||||
|
|
||||||
|
parts = module_str.split(".")
|
||||||
|
if len(parts) < 2:
|
||||||
|
raise ValueError(f"Invalid module string format: {module_str}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
cls_name = parts[-1]
|
||||||
|
module_path = ".".join(parts[:-1])
|
||||||
|
mod = importlib.import_module(module_path)
|
||||||
|
mod_cls = getattr(mod, cls_name)
|
||||||
|
return mod_cls
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(f"Failed to import module '{module_path}': {e}") from e
|
||||||
|
except AttributeError as e:
|
||||||
|
raise AttributeError(
|
||||||
|
f"Class '{cls_name}' not found in module '{module_path}': {e}"
|
||||||
|
) from e
|
||||||
@@ -110,6 +110,13 @@ class AxolotlInputConfig(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
trainer_cls: str | None = Field(
|
||||||
|
default=None,
|
||||||
|
json_schema_extra={
|
||||||
|
"description": "module to custom trainer class to use for training"
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
rl: RLType | None = Field(
|
rl: RLType | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
json_schema_extra={
|
json_schema_extra={
|
||||||
|
|||||||
37
tests/utils/test_import_helper.py
Normal file
37
tests/utils/test_import_helper.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""
|
||||||
|
test cases for axolotl.utils.import_helper
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cls_from_module_str():
|
||||||
|
cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer")
|
||||||
|
assert cls.__name__ == "AxolotlTrainer"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cls_from_module_str_empty_string():
|
||||||
|
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
|
||||||
|
get_cls_from_module_str("")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cls_from_module_str_whitespace_only():
|
||||||
|
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
|
||||||
|
get_cls_from_module_str(" ")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cls_from_module_str_invalid_format():
|
||||||
|
with pytest.raises(ValueError, match="Invalid module string format"):
|
||||||
|
get_cls_from_module_str("single_part")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cls_from_module_str_nonexistent_module():
|
||||||
|
with pytest.raises(ImportError, match="Failed to import module"):
|
||||||
|
get_cls_from_module_str("nonexistent.module.Class")
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_cls_from_module_str_nonexistent_class():
|
||||||
|
with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"):
|
||||||
|
get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass")
|
||||||
Reference in New Issue
Block a user