From 52e6249d2ebfdaf896287efb771a95da28209b37 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 18 Sep 2024 08:16:11 -0700 Subject: [PATCH] additional grafting config types and basic example doc --- docs/optimizers.qmd | 17 +++++++++++++++++ src/axolotl/core/trainer_builder.py | 12 +++++++++++- 2 files changed, 28 insertions(+), 1 deletion(-) create mode 100644 docs/optimizers.qmd diff --git a/docs/optimizers.qmd b/docs/optimizers.qmd new file mode 100644 index 000000000..d59ed43ab --- /dev/null +++ b/docs/optimizers.qmd @@ -0,0 +1,17 @@ +# Optimizers + +## Shampoo + +```yaml +optimizer: shampoo +optim_shampoo_betas: [0.9, 0.999] +optim_args: + epsilon: 1e-12 + max_preconditioner_dim: 8192 + precondition_frequency: 100 + use_decoupled_weight_decay: true + optim_shampoo_grafting_config_type: adam + optim_shampoo_grafting_config_kwargs: + beta2: 0.999 + epsilon: 1e-12 +``` diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index a8c9fca67..43f0256fc 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -479,10 +479,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer): elif self.args.alternate_optimizer == "shampoo": from distributed_shampoo.distributed_shampoo import DistributedShampoo from distributed_shampoo.shampoo_types import ( + AdaGradGraftingConfig, AdamGraftingConfig, CommunicationDType, DDPShampooConfig, FSDPShampooConfig, + SGDGraftingConfig, ) from distributed_shampoo.utils.shampoo_fsdp_utils import ( compile_fsdp_parameter_metadata, @@ -502,10 +504,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer): optim_args.get("use_decoupled_weight_decay") ) - if self.args.optim_shampoo_grafting_config_type in ["adam", "adamw"]: + if self.args.optim_shampoo_grafting_config_type == "adam": grafting_config = AdamGraftingConfig( self.args.optim_shampoo_grafting_config_kwargs ) + elif self.args.optim_shampoo_grafting_config_type == "sgd": + grafting_config = SGDGraftingConfig( + self.args.optim_shampoo_grafting_config_kwargs + ) + elif self.args.optim_shampoo_grafting_config_type == "adagrad": + grafting_config = AdaGradGraftingConfig( + self.args.optim_shampoo_grafting_config_kwargs + ) distributed_config = None if self.args.world_size > 1: