diff --git a/docs/config.qmd b/docs/config.qmd index 10e5a5895..ac4c3fa4f 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -633,7 +633,9 @@ weight_decay: # adamw hyperparams adam_beta1: adam_beta2: +adam_beta3: # only used for CAME Optimizer adam_epsilon: +adam_epsilon2: # only used for CAME Optimizer # Gradient clipping max norm max_grad_norm: diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 25d327dcd..6bd4ef996 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -387,8 +387,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs["adam_beta1"] = self.cfg.adam_beta1 if self.cfg.adam_beta2: training_arguments_kwargs["adam_beta2"] = self.cfg.adam_beta2 + if self.cfg.adam_beta3: + training_arguments_kwargs["adam_beta3"] = self.cfg.adam_beta3 if self.cfg.adam_epsilon: training_arguments_kwargs["adam_epsilon"] = self.cfg.adam_epsilon + if self.cfg.adam_epsilon2: + training_arguments_kwargs["adam_epsilon2"] = self.cfg.adam_epsilon2 if self.cfg.max_grad_norm: training_arguments_kwargs["max_grad_norm"] = self.cfg.max_grad_norm @@ -713,7 +717,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): beta1 = training_arguments_kwargs.get("adam_beta1", 0.9) beta2 = training_arguments_kwargs.get("adam_beta2", 0.999) - beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999) + beta3 = training_arguments_kwargs.get("adam_beta3", 0.9999) eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30) eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16) adam_kwargs["betas"] = (beta1, beta2, beta3) diff --git a/src/axolotl/core/training_args.py b/src/axolotl/core/training_args.py index 0b14e7661..a81c33801 100644 --- a/src/axolotl/core/training_args.py +++ b/src/axolotl/core/training_args.py @@ -227,6 +227,19 @@ class AxolotlTrainingMixins: }, ) + adam_beta3: Optional[float] = field( + default=None, + metadata={ + "help": "The beta3 hyperparameter used in some optimizers such as CAME" + }, + ) + adam_epsilon2: Optional[float] = field( + default=None, + metadata={ + "help": "The epsilon2 hyperparameter used in some optimizers such as CAME" + }, + ) + # multi-modal section image_size: int | tuple[int, int] | None = field(