diff --git a/cicd/Dockerfile.jinja b/cicd/Dockerfile.jinja index 287d563c1..263f4a661 100644 --- a/cicd/Dockerfile.jinja +++ b/cicd/Dockerfile.jinja @@ -24,9 +24,9 @@ RUN git fetch origin +$GITHUB_REF && \ # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ fi # So we can test the Docker image diff --git a/docker/Dockerfile b/docker/Dockerfile index cdb6d177a..be58d0354 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -22,9 +22,9 @@ WORKDIR /workspace/axolotl # If AXOLOTL_EXTRAS is set, append it in brackets RUN pip install causal_conv1d RUN if [ "$AXOLOTL_EXTRAS" != "" ] ; then \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers,$AXOLOTL_EXTRAS] $AXOLOTL_ARGS; \ else \ - pip install -e .[deepspeed,flash-attn,mamba-ssm,galore] $AXOLOTL_ARGS; \ + pip install -e .[deepspeed,flash-attn,mamba-ssm,optimizers] $AXOLOTL_ARGS; \ fi # So we can test the Docker image diff --git a/setup.py b/setup.py index 82e652241..9e6f34ad8 100644 --- a/setup.py +++ b/setup.py @@ -104,5 +104,11 @@ setup( "galore": [ "galore_torch", ], + "optimizers": [ + "galore_torch", + "lion-pytorch==0.1.2", + "lomo-optim==0.1.1", + "torch-optimi==0.2.1", + ], }, ) diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index ec175454e..5391904fc 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -226,6 +226,12 @@ class AxolotlTrainingMixins: default=None, metadata={"help": "whether to use sequential sampling for curriculum learning"}, ) + alternate_optimizer: Optional[str] = field( + default=None, + metadata={ + "help": "workaround to pass an alternate optimizer to the HF trainer" + }, + ) @dataclass @@ -285,25 +291,59 @@ class AxolotlTrainer(Trainer): self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") def create_optimizer(self): - if self.args.loraplus_lr_ratio is None: + if ( + self.args.loraplus_lr_ratio is None + and self.args.alternate_optimizer != "optimi_adamw" + ): return super().create_optimizer() opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model if self.optimizer is None: # pylint: disable=access-member-before-definition + decay_parameters = self.get_decay_parameter_names(opt_model) + optimizer_grouped_parameters = [ + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n in decay_parameters and p.requires_grad) + ], + "weight_decay": self.args.weight_decay, + }, + { + "params": [ + p + for n, p in opt_model.named_parameters() + if (n not in decay_parameters and p.requires_grad) + ], + "weight_decay": 0.0, + }, + ] + optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs( self.args, opt_model, ) - loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) - loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) - self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init - opt_model, - optimizer_cls, - optimizer_kwargs, - loraplus_lr_ratio, - loraplus_lr_embedding, - ) + if self.args.loraplus_lr_ratio is not None: + loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) + loraplus_lr_embedding = getattr( + self.args, "loraplus_lr_embedding", None + ) + self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init + opt_model, + optimizer_cls, + optimizer_kwargs, + loraplus_lr_ratio, + loraplus_lr_embedding, + ) + elif self.args.alternate_optimizer == "optimi_adamw": + from optimi import AdamW + + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + AdamW( + optimizer_grouped_parameters, foreach=False, **optimizer_kwargs + ) + ) if is_sagemaker_mp_enabled(): self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init @@ -1396,6 +1436,11 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} + if self.cfg.optimizer == "optimi_adamw": + # Set default so transformers doesn't throw + training_arguments_kwargs["optim"] = "adamw_hf" + training_arguments_kwargs["alternate_optimizer"] = self.cfg.optimizer + if self.cfg.optimizer == "lion_pytorch": from lion_pytorch import Lion diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 3cac4f839..3d0b02752 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -341,7 +341,7 @@ class HyperparametersConfig(BaseModel): learning_rate: Union[str, float] weight_decay: Optional[float] = 0.0 optimizer: Optional[ - Union[OptimizerNames, Literal["lion_pytorch"]] + Union[OptimizerNames, Literal["lion_pytorch", "optimi_adamw"]] ] = OptimizerNames.ADAMW_HF.value optim_args: Optional[Union[str, Dict[str, Any]]] = Field( default=None, metadata={"help": "Optional arguments to supply to optimizer."} diff --git a/tests/e2e/test_lora_llama.py b/tests/e2e/test_lora_llama.py index c79652bef..4c6fdaaa9 100644 --- a/tests/e2e/test_lora_llama.py +++ b/tests/e2e/test_lora_llama.py @@ -34,8 +34,8 @@ class TestLoraLlama(unittest.TestCase): "sequence_len": 1024, "load_in_8bit": True, "adapter": "lora", - "lora_r": 32, - "lora_alpha": 64, + "lora_r": 8, + "lora_alpha": 16, "lora_dropout": 0.05, "lora_target_linear": True, "val_set_size": 0.1, @@ -50,7 +50,7 @@ class TestLoraLlama(unittest.TestCase): "type": "alpaca", }, ], - "num_epochs": 2, + "num_epochs": 1, "micro_batch_size": 8, "gradient_accumulation_steps": 1, "output_dir": temp_dir, diff --git a/tests/e2e/test_optimizers.py b/tests/e2e/test_optimizers.py new file mode 100644 index 000000000..119dd3d7c --- /dev/null +++ b/tests/e2e/test_optimizers.py @@ -0,0 +1,67 @@ +""" +E2E tests for custom optimizers using Llama +""" + +import logging +import os +import unittest +from pathlib import Path + +from axolotl.cli import load_datasets +from axolotl.common.cli import TrainerCliArgs +from axolotl.train import train +from axolotl.utils.config import normalize_config +from axolotl.utils.dict import DictDefault + +from .utils import with_temp_dir + +LOG = logging.getLogger("axolotl.tests.e2e") +os.environ["WANDB_DISABLED"] = "true" + + +class TestCustomOptimizers(unittest.TestCase): + """ + Test case for Llama models using LoRA + """ + + @with_temp_dir + def test_optimi_adamw(self, temp_dir): + # pylint: disable=duplicate-code + cfg = DictDefault( + { + "base_model": "JackFram/llama-68m", + "tokenizer_type": "LlamaTokenizer", + "sequence_len": 1024, + "load_in_8bit": True, + "adapter": "lora", + "lora_r": 8, + "lora_alpha": 16, + "lora_dropout": 0.05, + "lora_target_linear": True, + "val_set_size": 0.1, + "special_tokens": { + "unk_token": "", + "bos_token": "", + "eos_token": "", + }, + "datasets": [ + { + "path": "mhenrichsen/alpaca_2k_test", + "type": "alpaca", + }, + ], + "num_epochs": 1, + "micro_batch_size": 8, + "gradient_accumulation_steps": 1, + "output_dir": temp_dir, + "learning_rate": 0.00001, + "optimizer": "optimi_adamw", + "lr_scheduler": "cosine", + } + ) + normalize_config(cfg) + cli_args = TrainerCliArgs() + dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) + + train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + assert (Path(temp_dir) / "adapter_model.bin").exists()