allow custom trainer_cls to be defined as a module reference in the YAML (#3024) [skip ci]

* allow custom trainer_cls to be defined as a module reference in the YAML

* address PR feedback and add test

* add tests
This commit is contained in:
Wing Lian
2025-08-06 22:49:19 -04:00
committed by GitHub
parent d09290f2f4
commit 4bce713b39
5 changed files with 96 additions and 0 deletions

View File

@@ -0,0 +1,37 @@
"""
test cases for axolotl.utils.import_helper
"""
import pytest
from axolotl.utils.import_helper import get_cls_from_module_str
def test_get_cls_from_module_str():
cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer")
assert cls.__name__ == "AxolotlTrainer"
def test_get_cls_from_module_str_empty_string():
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
get_cls_from_module_str("")
def test_get_cls_from_module_str_whitespace_only():
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
get_cls_from_module_str(" ")
def test_get_cls_from_module_str_invalid_format():
with pytest.raises(ValueError, match="Invalid module string format"):
get_cls_from_module_str("single_part")
def test_get_cls_from_module_str_nonexistent_module():
with pytest.raises(ImportError, match="Failed to import module"):
get_cls_from_module_str("nonexistent.module.Class")
def test_get_cls_from_module_str_nonexistent_class():
with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"):
get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass")