diff --git a/src/axolotl/core/builders/causal.py b/src/axolotl/core/builders/causal.py index 8babf6a65..e2ad1f579 100644 --- a/src/axolotl/core/builders/causal.py +++ b/src/axolotl/core/builders/causal.py @@ -130,6 +130,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return callbacks def _get_trainer_cls(self): + """ + Gets the trainer class for the given configuration. + """ if self.cfg.plugins: plugin_manager = PluginManager.get_instance() trainer_cls = plugin_manager.get_trainer_cls(self.cfg)