Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
e9a1f288cf support for custom trainer_cls from config 2024-05-14 18:57:53 -04:00
2 changed files with 5 additions and 0 deletions

View File

@@ -993,6 +993,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
return ReLoRATrainer
if self.cfg.model_config_type == "mamba":
return AxolotlMambaTrainer
if self.cfg.custom_trainer_cls:
_module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1)
return importlib.import_module(_module, _cls)
return AxolotlTrainer
def build(self, total_num_steps):

View File

@@ -561,6 +561,8 @@ class AxolotlInputConfig(
torch_compile: Optional[bool] = None
torch_compile_backend: Optional[str] = None
custom_trainer_cls: Optional[str] = None
max_steps: Optional[int] = None
warmup_steps: Optional[int] = None
warmup_ratio: Optional[float] = None