diff --git a/docs/config.qmd b/docs/config.qmd index 857d0eb03..1cff9e6f4 100644 --- a/docs/config.qmd +++ b/docs/config.qmd @@ -612,6 +612,7 @@ lr_div_factor: # Learning rate div factor # - optimi_adamw # - ao_adamw_8bit # - ao_adamw_fp8 +# - came_pytorch optimizer: # Dictionary of arguments to pass to the optimizer optim_args: diff --git a/setup.py b/setup.py index 51a97e4d9..97e7f5ff5 100644 --- a/setup.py +++ b/setup.py @@ -142,6 +142,7 @@ extras_require = { "apollo-torch", "lomo-optim==0.1.1", "torch-optimi==0.2.1", + "came_pytorch==0.1.3", ], "ray": [ "ray[train]", diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index af9a43db3..5cb397b28 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -708,6 +708,20 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): optimizer_cls = ADOPT adam_kwargs["decouple"] = True optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "came_pytorch": + from came_pytorch import CAME + + optimizer_cls = CAME + + beta1 = training_arguments_kwargs.get("adam_beta1", 0.9) + beta2 = training_arguments_kwargs.get("adam_beta2", 0.999) + beta3 = training_arguments_kwargs.get("adam_beta2", 0.9999) + eps1 = training_arguments_kwargs.get("adam_epsilon", 1e-30) + eps2 = training_arguments_kwargs.get("adam_epsilon2", 1e-16) + adam_kwargs["betas"] = (beta1, beta2, beta3) + adam_kwargs["eps"] = (eps1, eps2) + + optimizer_kwargs.update(adam_kwargs) # Parse any additional optimizer args from config if self.cfg.optim_args: diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index 118176d34..fe5cf62ba 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -53,4 +53,5 @@ class CustomSupportedOptimizers(str, Enum): ao_adamw_8bit = "ao_adamw_8bit" # pylint: disable=invalid-name ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name + came_pytorch = "came_pytorch" # pylint: disable=invalid-name muon = "muon" # pylint: disable=invalid-name diff --git a/src/axolotl/utils/schemas/training.py b/src/axolotl/utils/schemas/training.py index 2ab4b4286..69547c17f 100644 --- a/src/axolotl/utils/schemas/training.py +++ b/src/axolotl/utils/schemas/training.py @@ -75,8 +75,10 @@ class HyperparametersConfig(BaseModel): lr_groups: list[LrGroup] | None = None adam_epsilon: float | None = None + adam_epsilon2: float | None = None adam_beta1: float | None = None adam_beta2: float | None = None + adam_beta3: float | None = None max_grad_norm: float | None = None num_epochs: float = Field(default=1.0) diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index d3ff27ca5..91f45b762 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -199,3 +199,50 @@ class TestCustomOptimizers(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + + @with_temp_dir + def test_came_pytorch(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "came_pytorch", + "adam_beta3": 0.9999, + "adam_epsilon2": 1e-16, + "max_steps": 5, + "lr_scheduler": "cosine", + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg)