From 798b5f5cfdc3478b51cd48d53c38a3b5a0d387f2 Mon Sep 17 00:00:00 2001 From: NanoCode012 Date: Thu, 22 May 2025 19:19:12 +0700 Subject: [PATCH] fix(RL): address plugin rl overwriting trainer_cls (#2697) [skip ci] * fix: plugin rl overwrite trainer_cls * feat(test): add test to catch trainer_cls is not None --- src/axolotl/core/trainer_builder.py | 4 +++- tests/core/test_trainer_builder.py | 25 +++++++++++++++++++++++++ 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 878dd176a..863b065e6 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1195,7 +1195,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.plugins: plugin_manager = PluginManager.get_instance() - trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + temp_trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + if temp_trainer_cls is not None: + trainer_cls = temp_trainer_cls sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys(): diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index fbfd7a87c..d1ad273ea 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -8,6 +8,7 @@ from axolotl.core.trainer_builder import HFRLTrainerBuilder from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.schemas.enums import RLType @pytest.fixture(name="cfg") @@ -65,3 +66,27 @@ class TestHFRLTrainerBuilder: assert training_arguments.adam_epsilon == 0.00001 assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_pin_memory is True + + +class TestTrainerClsPlugin: + """ + TestCase class for trainer builder with plugin + """ + + def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer): + """ + Test that the trainer cls is not none with plugin + + Fixes #2693 + """ + cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"] + cfg.rl = RLType.KTO + + # Expected AttributeError as we don't pass regular model configs to RL trainer builder + # If it throws `TypeError: None is not a callable object`, trainer_cls could be None + with pytest.raises( + AttributeError, match=r".*'tuple' object has no attribute 'config'.*" + ): + builder = HFRLTrainerBuilder(cfg, model, tokenizer) + + builder.build(100)