diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index e81740399..ab0c4e49a 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -424,6 +424,11 @@ class SchedulerMixin(Trainer): return self.lr_scheduler + def _load_optimizer_and_scheduler(self, checkpoint): + if not checkpoint and self.args.optimizer_checkpoint is not None: + checkpoint = self.args.optimizer_checkpoint + return super()._load_optimizer_and_scheduler(checkpoint) + class AxolotlTrainer(SchedulerMixin, Trainer): """ @@ -1764,6 +1769,10 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): ] = self.cfg.loraplus_lr_embedding training_arguments_kwargs["embedding_lr"] = self.cfg.embedding_lr training_arguments_kwargs["embedding_lr_scale"] = self.cfg.embedding_lr_scale + if self.cfg.optimizer_checkpoint: + training_arguments_kwargs[ + "optimizer_checkpoint" + ] = self.cfg.optimizer_checkpoint if self.cfg.lr_scheduler in ["one_cycle", "log_sweep"]: training_arguments_kwargs["lr_scheduler_type"] = "cosine" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 0781c6798..71d061bc4 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -603,6 +603,8 @@ class AxolotlInputConfig( strict: Optional[bool] = Field(default=False) resume_from_checkpoint: Optional[str] = None auto_resume_from_checkpoints: Optional[bool] = None + optimizer_checkpoint: Optional[str] = None + resize_token_embeddings_to_32x: Optional[bool] = None mean_resizing_embeddings: Optional[bool] = False