support for custom trainer_cls from config
This commit is contained in:
@@ -993,6 +993,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
return ReLoRATrainer
|
||||
if self.cfg.model_config_type == "mamba":
|
||||
return AxolotlMambaTrainer
|
||||
if self.cfg.custom_trainer_cls:
|
||||
_module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1)
|
||||
return importlib.import_module(_module, _cls)
|
||||
return AxolotlTrainer
|
||||
|
||||
def build(self, total_num_steps):
|
||||
|
||||
@@ -561,6 +561,8 @@ class AxolotlInputConfig(
|
||||
torch_compile: Optional[bool] = None
|
||||
torch_compile_backend: Optional[str] = None
|
||||
|
||||
custom_trainer_cls: Optional[str] = None
|
||||
|
||||
max_steps: Optional[int] = None
|
||||
warmup_steps: Optional[int] = None
|
||||
warmup_ratio: Optional[float] = None
|
||||
|
||||
Reference in New Issue
Block a user