diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 878dd176a..863b065e6 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -1195,7 +1195,9 @@ class HFRLTrainerBuilder(TrainerBuilderBase): if self.cfg.plugins: plugin_manager = PluginManager.get_instance() - trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + temp_trainer_cls = plugin_manager.get_trainer_cls(self.cfg) + if temp_trainer_cls is not None: + trainer_cls = temp_trainer_cls sig = inspect.signature(trainer_cls) if "tokenizer" in sig.parameters.keys(): diff --git a/tests/core/test_trainer_builder.py b/tests/core/test_trainer_builder.py index fbfd7a87c..d1ad273ea 100644 --- a/tests/core/test_trainer_builder.py +++ b/tests/core/test_trainer_builder.py @@ -8,6 +8,7 @@ from axolotl.core.trainer_builder import HFRLTrainerBuilder from axolotl.utils.config import normalize_config from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer +from axolotl.utils.schemas.enums import RLType @pytest.fixture(name="cfg") @@ -65,3 +66,27 @@ class TestHFRLTrainerBuilder: assert training_arguments.adam_epsilon == 0.00001 assert training_arguments.dataloader_num_workers == 1 assert training_arguments.dataloader_pin_memory is True + + +class TestTrainerClsPlugin: + """ + TestCase class for trainer builder with plugin + """ + + def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer): + """ + Test that the trainer cls is not none with plugin + + Fixes #2693 + """ + cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"] + cfg.rl = RLType.KTO + + # Expected AttributeError as we don't pass regular model configs to RL trainer builder + # If it throws `TypeError: None is not a callable object`, trainer_cls could be None + with pytest.raises( + AttributeError, match=r".*'tuple' object has no attribute 'config'.*" + ): + builder = HFRLTrainerBuilder(cfg, model, tokenizer) + + builder.build(100)