fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci]
* fix: plugin rl overwrite trainer_cls * feat(test): add test to catch trainer_cls is not None
This commit is contained in:
@@ -1195,7 +1195,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
|||||||
|
|
||||||
if self.cfg.plugins:
|
if self.cfg.plugins:
|
||||||
plugin_manager = PluginManager.get_instance()
|
plugin_manager = PluginManager.get_instance()
|
||||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
temp_trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||||
|
if temp_trainer_cls is not None:
|
||||||
|
trainer_cls = temp_trainer_cls
|
||||||
|
|
||||||
sig = inspect.signature(trainer_cls)
|
sig = inspect.signature(trainer_cls)
|
||||||
if "tokenizer" in sig.parameters.keys():
|
if "tokenizer" in sig.parameters.keys():
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
|||||||
from axolotl.utils.config import normalize_config
|
from axolotl.utils.config import normalize_config
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
|
from axolotl.utils.schemas.enums import RLType
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="cfg")
|
@pytest.fixture(name="cfg")
|
||||||
@@ -65,3 +66,27 @@ class TestHFRLTrainerBuilder:
|
|||||||
assert training_arguments.adam_epsilon == 0.00001
|
assert training_arguments.adam_epsilon == 0.00001
|
||||||
assert training_arguments.dataloader_num_workers == 1
|
assert training_arguments.dataloader_num_workers == 1
|
||||||
assert training_arguments.dataloader_pin_memory is True
|
assert training_arguments.dataloader_pin_memory is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainerClsPlugin:
|
||||||
|
"""
|
||||||
|
TestCase class for trainer builder with plugin
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer):
|
||||||
|
"""
|
||||||
|
Test that the trainer cls is not none with plugin
|
||||||
|
|
||||||
|
Fixes #2693
|
||||||
|
"""
|
||||||
|
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
|
||||||
|
cfg.rl = RLType.KTO
|
||||||
|
|
||||||
|
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
|
||||||
|
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
|
||||||
|
with pytest.raises(
|
||||||
|
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
|
||||||
|
):
|
||||||
|
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||||
|
|
||||||
|
builder.build(100)
|
||||||
|
|||||||
Reference in New Issue
Block a user