Compare commits
1 Commits
version-de
...
custom-tra
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9a1f288cf |
@@ -993,6 +993,9 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
return ReLoRATrainer
|
return ReLoRATrainer
|
||||||
if self.cfg.model_config_type == "mamba":
|
if self.cfg.model_config_type == "mamba":
|
||||||
return AxolotlMambaTrainer
|
return AxolotlMambaTrainer
|
||||||
|
if self.cfg.custom_trainer_cls:
|
||||||
|
_module, _cls = self.cfg.custom_trainer_cls.rsplit(".", 1)
|
||||||
|
return importlib.import_module(_module, _cls)
|
||||||
return AxolotlTrainer
|
return AxolotlTrainer
|
||||||
|
|
||||||
def build(self, total_num_steps):
|
def build(self, total_num_steps):
|
||||||
|
|||||||
@@ -561,6 +561,8 @@ class AxolotlInputConfig(
|
|||||||
torch_compile: Optional[bool] = None
|
torch_compile: Optional[bool] = None
|
||||||
torch_compile_backend: Optional[str] = None
|
torch_compile_backend: Optional[str] = None
|
||||||
|
|
||||||
|
custom_trainer_cls: Optional[str] = None
|
||||||
|
|
||||||
max_steps: Optional[int] = None
|
max_steps: Optional[int] = None
|
||||||
warmup_steps: Optional[int] = None
|
warmup_steps: Optional[int] = None
|
||||||
warmup_ratio: Optional[float] = None
|
warmup_ratio: Optional[float] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user