allow custom trainer_cls to be defined as a module reference in the YAML (#3024) [skip ci]
* allow custom trainer_cls to be defined as a module reference in the YAML * address PR feedback and add test * add tests
This commit is contained in:
37
tests/utils/test_import_helper.py
Normal file
37
tests/utils/test_import_helper.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
test cases for axolotl.utils.import_helper
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.utils.import_helper import get_cls_from_module_str
|
||||
|
||||
|
||||
def test_get_cls_from_module_str():
|
||||
cls = get_cls_from_module_str("axolotl.core.trainers.base.AxolotlTrainer")
|
||||
assert cls.__name__ == "AxolotlTrainer"
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_empty_string():
|
||||
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
|
||||
get_cls_from_module_str("")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_whitespace_only():
|
||||
with pytest.raises(ValueError, match="module_str must be a non-empty string"):
|
||||
get_cls_from_module_str(" ")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_invalid_format():
|
||||
with pytest.raises(ValueError, match="Invalid module string format"):
|
||||
get_cls_from_module_str("single_part")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_nonexistent_module():
|
||||
with pytest.raises(ImportError, match="Failed to import module"):
|
||||
get_cls_from_module_str("nonexistent.module.Class")
|
||||
|
||||
|
||||
def test_get_cls_from_module_str_nonexistent_class():
|
||||
with pytest.raises(AttributeError, match="Class 'NonExistentClass' not found"):
|
||||
get_cls_from_module_str("axolotl.core.trainers.base.NonExistentClass")
|
||||
Reference in New Issue
Block a user