fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci]

* fix: plugin rl overwrite trainer_cls

* feat(test): add test to catch trainer_cls is not None
This commit is contained in:
NanoCode012
2025-05-22 19:19:12 +07:00
committed by GitHub
parent 1c83a1a020
commit 798b5f5cfd
2 changed files with 28 additions and 1 deletions

View File

@@ -1195,7 +1195,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase):
if self.cfg.plugins:
plugin_manager = PluginManager.get_instance()
trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
temp_trainer_cls = plugin_manager.get_trainer_cls(self.cfg)
if temp_trainer_cls is not None:
trainer_cls = temp_trainer_cls
sig = inspect.signature(trainer_cls)
if "tokenizer" in sig.parameters.keys():

View File

@@ -8,6 +8,7 @@ from axolotl.core.trainer_builder import HFRLTrainerBuilder
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.schemas.enums import RLType
@pytest.fixture(name="cfg")
@@ -65,3 +66,27 @@ class TestHFRLTrainerBuilder:
assert training_arguments.adam_epsilon == 0.00001
assert training_arguments.dataloader_num_workers == 1
assert training_arguments.dataloader_pin_memory is True
class TestTrainerClsPlugin:
"""
TestCase class for trainer builder with plugin
"""
def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer):
"""
Test that the trainer cls is not none with plugin
Fixes #2693
"""
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
cfg.rl = RLType.KTO
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
with pytest.raises(
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
):
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
builder.build(100)