Compare commits
1 Commits
fix/replac
...
optimizer-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7b095d77f |
@@ -424,6 +424,11 @@ class SchedulerMixin(Trainer):
|
|||||||
|
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||||
|
if not checkpoint and self.args.optimizer_checkpoint is not None:
|
||||||
|
checkpoint = self.args.optimizer_checkpoint
|
||||||
|
return super()._load_optimizer_and_scheduler(checkpoint)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||||
"""
|
"""
|
||||||
@@ -1764,6 +1769,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
] = self.cfg.loraplus_lr_embedding
|
] = self.cfg.loraplus_lr_embedding
|
||||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||||
|
if self.cfg.optimizer_checkpoint:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optimizer_checkpoint"
|
||||||
|
] = self.cfg.optimizer_checkpoint
|
||||||
|
|
||||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||||
|
|||||||
@@ -603,6 +603,8 @@ class AxolotlInputConfig(
|
|||||||
strict: Optional[bool] = Field(default=False)
|
strict: Optional[bool] = Field(default=False)
|
||||||
resume_from_checkpoint: Optional[str] = None
|
resume_from_checkpoint: Optional[str] = None
|
||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
|
optimizer_checkpoint: Optional[str] = None
|
||||||
|
|
||||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||||
mean_resizing_embeddings: Optional[bool] = False
|
mean_resizing_embeddings: Optional[bool] = False
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user