fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci]
* fix: plugin rl overwrite trainer_cls * feat(test): add test to catch trainer_cls is not None
This commit is contained in:
@@ -1195,7 +1195,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
|
||||
|
||||
if self.cfg.plugins:
|
||||
plugin_manager = PluginManager.get_instance()
|
||||
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
temp_trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
|
||||
if temp_trainer_cls is not None:
|
||||
trainer_cls = temp_trainer_cls
|
||||
|
||||
sig = inspect.signature(trainer_cls)
|
||||
if "tokenizer" in sig.parameters.keys():
|
||||
|
||||
@@ -8,6 +8,7 @@ from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
@@ -65,3 +66,27 @@ class TestHFRLTrainerBuilder:
|
||||
assert training_arguments.adam_epsilon == 0.00001
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
|
||||
|
||||
class TestTrainerClsPlugin:
|
||||
"""
|
||||
TestCase class for trainer builder with plugin
|
||||
"""
|
||||
|
||||
def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer):
|
||||
"""
|
||||
Test that the trainer cls is not none with plugin
|
||||
|
||||
Fixes #2693
|
||||
"""
|
||||
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
|
||||
cfg.rl = RLType.KTO
|
||||
|
||||
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
|
||||
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
|
||||
with pytest.raises(
|
||||
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
|
||||
):
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
|
||||
builder.build(100)
|
||||
|
||||
Reference in New Issue
Block a user