diff --git a/requirements.txt b/requirements.txt index 32a9e0e01..b2d178183 100644 --- a/requirements.txt +++ b/requirements.txt @@ -35,6 +35,7 @@ python-dotenv==1.0.1 autoawq>=0.2.5 triton>=2.3.0 liger-kernel==0.2.1 +distributed_shampoo @ git+https://github.com/facebookresearch/optimizers.git@main mamba-ssm==1.2.0.post1 diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index f4cd25783..a8c9fca67 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -16,7 +16,7 @@ from collections import defaultdict from dataclasses import dataclass, field from functools import wraps from pathlib import Path -from typing import Any, Dict, List, Literal, Optional, Type, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union import torch import transformers @@ -250,6 +250,11 @@ class AxolotlTrainingMixins: "help": "workaround to pass an alternate lr scheduler to the HF trainer" }, ) + optim_shampoo_grafting_config_type: Optional[ + Literal["adam", "sgd", "adagrad"] + ] = None + optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None + optim_shampoo_betas: Optional[Tuple[float, float]] = None @dataclass @@ -422,7 +427,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer): if ( self.args.loraplus_lr_ratio is None and self.args.alternate_optimizer - not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"] + not in [ + "optimi_adamw", + "ao_adamw_8bit", + "ao_adamw_4bit", + "ao_adamw_fp8", + "shampoo", + ] ): return super().create_optimizer() @@ -465,6 +476,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer): loraplus_lr_ratio, loraplus_lr_embedding, ) + elif self.args.alternate_optimizer == "shampoo": + from distributed_shampoo.distributed_shampoo import DistributedShampoo + from distributed_shampoo.shampoo_types import ( + AdamGraftingConfig, + CommunicationDType, + DDPShampooConfig, + FSDPShampooConfig, + ) + from distributed_shampoo.utils.shampoo_fsdp_utils import ( + compile_fsdp_parameter_metadata, + ) + + # parse args.optim_args + optim_args = {} + if self.args.optim_args: + for mapping in self.args.optim_args.replace(" ", "").split(","): + key, value = mapping.split("=") + optim_args[key] = value + + optim_args["betas"] = self.args.optim_shampoo_betas + optim_args["lr"] = self.args.learning_rate + optim_args["weight_decay"] = self.args.weight_decay + optim_args["use_decoupled_weight_decay"] = bool( + optim_args.get("use_decoupled_weight_decay") + ) + + if self.args.optim_shampoo_grafting_config_type in ["adam", "adamw"]: + grafting_config = AdamGraftingConfig( + self.args.optim_shampoo_grafting_config_kwargs + ) + + distributed_config = None + if self.args.world_size > 1: + if self.args.fsdp_config: + distributed_config = FSDPShampooConfig( + param_to_metadata=compile_fsdp_parameter_metadata( + self.model_wrapped + ) + ) + else: + distributed_config = DDPShampooConfig( + communication_dtype=CommunicationDType.BFLOAT16, + num_trainers_per_group=self.args.world_size, + communicate_params=False, + ) + self.optimizer = ( # pylint: disable=attribute-defined-outside-init + DistributedShampoo( + optimizer_grouped_parameters, + grafting_config=grafting_config, + distributed_config=distributed_config, + **optim_args, + ) + ) elif self.args.alternate_optimizer == "optimi_adamw": from optimi import AdamW @@ -1441,6 +1505,21 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): training_arguments_kwargs[ "optim_target_modules" ] = self.cfg.optim_target_modules + + # shampoo optimizer config + if self.cfg.optim_shampoo_betas: + training_arguments_kwargs[ + "optim_shampoo_betas" + ] = self.cfg.optim_shampoo_betas + if self.cfg.optim_shampoo_grafting_config_type: + training_arguments_kwargs[ + "optim_shampoo_grafting_config_type" + ] = self.cfg.optim_shampoo_grafting_config_type + if self.cfg.optim_shampoo_grafting_config_kwargs: + training_arguments_kwargs[ + "optim_shampoo_grafting_config_kwargs" + ] = self.cfg.optim_shampoo_grafting_config_kwargs + training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio training_arguments_kwargs[ "loraplus_lr_embedding" @@ -1525,10 +1604,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase): trainer_kwargs = {} if self.cfg.optimizer in [ + # pylint: disable=duplicate-code "optimi_adamw", "ao_adamw_4bit", "ao_adamw_8bit", "ao_adamw_fp8", + "shampoo", ]: # Set default so transformers doesn't throw training_arguments_kwargs["optim"] = "adamw_hf" diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 458bacdb1..9154b3230 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -372,6 +372,7 @@ class HyperparametersConfig(BaseModel): "ao_adamw_4bit", "ao_adamw_8bit", "ao_adamw_fp8", + "shampoo", ], ] ] = OptimizerNames.ADAMW_HF.value @@ -384,6 +385,12 @@ class HyperparametersConfig(BaseModel): "help": "The target modules to optimize, i.e. the module names that you would like to train." }, ) + optim_shampoo_grafting_config_type: Optional[ + Literal["adam", "sgd", "adagrad"] + ] = None + optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None + optim_shampoo_betas: Optional[Tuple[float, float]] = None + torchdistx_path: Optional[str] = None lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine" lr_scheduler_kwargs: Optional[Dict[str, Any]] = None