Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
c7b095d77f resume from optimizer checkpoint only 2024-12-30 07:51:44 -05:00
2 changed files with 11 additions and 0 deletions

View File

@@ -424,6 +424,11 @@ class SchedulerMixin(Trainer):
return self.lr_scheduler return self.lr_scheduler
def _load_optimizer_and_scheduler(self, checkpoint):
if not checkpoint and self.args.optimizer_checkpoint is not None:
checkpoint = self.args.optimizer_checkpoint
return super()._load_optimizer_and_scheduler(checkpoint)
class AxolotlTrainer(SchedulerMixin, Trainer): class AxolotlTrainer(SchedulerMixin, Trainer):
""" """
@@ -1764,6 +1769,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.loraplus_lr_embedding ] = self.cfg.loraplus_lr_embedding
training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr
training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale
if self.cfg.optimizer_checkpoint:
training_arguments_kwargs[
"optimizer_checkpoint"
] = self.cfg.optimizer_checkpoint
if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]:
training_arguments_kwargs["lr_scheduler_type"] = "cosine" training_arguments_kwargs["lr_scheduler_type"] = "cosine"

View File

@@ -603,6 +603,8 @@ class AxolotlInputConfig(
strict: Optional[bool] = Field(default=False) strict: Optional[bool] = Field(default=False)
resume_from_checkpoint: Optional[str] = None resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None auto_resume_from_checkpoints: Optional[bool] = None
optimizer_checkpoint: Optional[str] = None
resize_token_embeddings_to_32x: Optional[bool] = None resize_token_embeddings_to_32x: Optional[bool] = None
mean_resizing_embeddings: Optional[bool] = False mean_resizing_embeddings: Optional[bool] = False