Compare commits
1 Commits
08fc7de87e
...
optimizer-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7b095d77f |
@@ -424,6 +424,11 @@ class SchedulerMixin(Trainer):
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
def _load_optimizer_and_scheduler(self, checkpoint):
|
||||
if not checkpoint and self.args.optimizer_checkpoint is not None:
|
||||
checkpoint = self.args.optimizer_checkpoint
|
||||
return super()._load_optimizer_and_scheduler(checkpoint)
|
||||
|
||||
|
||||
class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
"""
|
||||
@@ -1764,6 +1769,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.loraplus_lr_embedding
|
||||
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
|
||||
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
|
||||
if self.cfg.optimizer_checkpoint:
|
||||
training_arguments_kwargs[
|
||||
"optimizer_checkpoint"
|
||||
] = self.cfg.optimizer_checkpoint
|
||||
|
||||
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
|
||||
training_arguments_kwargs["lr_scheduler_type"] = "cosine"
|
||||
|
||||
@@ -603,6 +603,8 @@ class AxolotlInputConfig(
|
||||
strict: Optional[bool] = Field(default=False)
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
optimizer_checkpoint: Optional[str] = None
|
||||
|
||||
resize_token_embeddings_to_32x: Optional[bool] = None
|
||||
mean_resizing_embeddings: Optional[bool] = False
|
||||
|
||||
|
||||
Reference in New Issue
Block a user