diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index b237b1ef3..9f0742739 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -663,6 +663,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): optimizer_cls = MuonOptimizerFactory optimizer_kwargs.update(adam_kwargs) + elif self.cfg.optimizer == "soap": + from axolotl.utils.optimizers.soap import SOAP + + optimizer_cls = SOAP + optimizer_kwargs.update(adam_kwargs) elif self.cfg.optimizer == "optimi_adamw": from optimi import AdamW diff --git a/src/axolotl/utils/schemas/enums.py b/src/axolotl/utils/schemas/enums.py index ad735afe0..963513cc2 100644 --- a/src/axolotl/utils/schemas/enums.py +++ b/src/axolotl/utils/schemas/enums.py @@ -52,3 +52,4 @@ class CustomSupportedOptimizers(str, Enum): ao_adamw_fp8 = "ao_adamw_fp8" # pylint: disable=invalid-name adopt_adamw = "adopt_adamw" # pylint: disable=invalid-name muon = "muon" # pylint: disable=invalid-name + soap = "soap" # pylint: disable=invalid-name diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py index 43a4735aa..a936cc40c 100644 --- a/tests/e2e/test_optimizers.py +++ b/tests/e2e/test_optimizers.py @@ -201,3 +201,46 @@ class TestCustomOptimizers(unittest.TestCase): train(cfg=cfg, dataset_meta=dataset_meta) check_model_output_exists(temp_dir, cfg) + + @with_temp_dir + def test_soap(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "HuggingFaceTB/SmolLM-135M", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "pad_token": "<|endoftext|>", + }, + "datasets": [ + { + "path": "vicgalle/alpaca-gpt4", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "soap", + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "lr_scheduler": "cosine", + } + ) + + cfg = validate_config(cfg) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, dataset_meta=dataset_meta) + check_model_output_exists(temp_dir, cfg)