diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 2f38b12dc..b973e5cc3 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -993,6 +993,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): return ReLoRATrainer if self.cfg.model_config_type == "mamba": return AxolotlMambaTrainer + if self.cfg.custom_trainer_cls: + _module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1) + return importlib.import_module(_module, _cls) return AxolotlTrainer def build(self, total_num_steps): diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index f1c12b2ba..a0a2fd1e5 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -561,6 +561,8 @@ class AxolotlInputConfig( torch_compile: Optional[bool] = None torch_compile_backend: Optional[str] = None + custom_trainer_cls: Optional[str] = None + max_steps: Optional[int] = None warmup_steps: Optional[int] = None warmup_ratio: Optional[float] = None