wip shampoo optim support

This commit is contained in:
Wing Lian
2024-09-18 08:10:52 -07:00
parent 7b9f669a3a
commit eb3eab3450
3 changed files with 91 additions and 2 deletions

View File

@@ -35,6 +35,7 @@ python-dotenv==1.0.1
autoawq>=0.2.5
triton>=2.3.0
liger-kernel==0.2.1
distributed_shampoo @ git+https://github.com/facebookresearch/optimizers.git@main
mamba-ssm==1.2.0.post1

View File

@@ -16,7 +16,7 @@ from collections import defaultdict
from dataclasses import dataclass, field
from functools import wraps
from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import torch
import transformers
@@ -250,6 +250,11 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
},
)
optim_shampoo_grafting_config_type: Optional[
Literal["adam", "sgd", "adagrad"]
] = None
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
@dataclass
@@ -422,7 +427,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if (
self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"shampoo",
]
):
return super().create_optimizer()
@@ -465,6 +476,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_ratio,
loraplus_lr_embedding,
)
elif self.args.alternate_optimizer == "shampoo":
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import (
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
FSDPShampooConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import (
compile_fsdp_parameter_metadata,
)
# parse args.optim_args
optim_args = {}
if self.args.optim_args:
for mapping in self.args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optim_args["betas"] = self.args.optim_shampoo_betas
optim_args["lr"] = self.args.learning_rate
optim_args["weight_decay"] = self.args.weight_decay
optim_args["use_decoupled_weight_decay"] = bool(
optim_args.get("use_decoupled_weight_decay")
)
if self.args.optim_shampoo_grafting_config_type in ["adam", "adamw"]:
grafting_config = AdamGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs
)
distributed_config = None
if self.args.world_size > 1:
if self.args.fsdp_config:
distributed_config = FSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(
self.model_wrapped
)
)
else:
distributed_config = DDPShampooConfig(
communication_dtype=CommunicationDType.BFLOAT16,
num_trainers_per_group=self.args.world_size,
communicate_params=False,
)
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
DistributedShampoo(
optimizer_grouped_parameters,
grafting_config=grafting_config,
distributed_config=distributed_config,
**optim_args,
)
)
elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW
@@ -1441,6 +1505,21 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[
"optim_target_modules"
] = self.cfg.optim_target_modules
# shampoo optimizer config
if self.cfg.optim_shampoo_betas:
training_arguments_kwargs[
"optim_shampoo_betas"
] = self.cfg.optim_shampoo_betas
if self.cfg.optim_shampoo_grafting_config_type:
training_arguments_kwargs[
"optim_shampoo_grafting_config_type"
] = self.cfg.optim_shampoo_grafting_config_type
if self.cfg.optim_shampoo_grafting_config_kwargs:
training_arguments_kwargs[
"optim_shampoo_grafting_config_kwargs"
] = self.cfg.optim_shampoo_grafting_config_kwargs
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[
"loraplus_lr_embedding"
@@ -1525,10 +1604,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {}
if self.cfg.optimizer in [
# pylint: disable=duplicate-code
"optimi_adamw",
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"shampoo",
]:
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"

View File

@@ -372,6 +372,7 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit",
"ao_adamw_8bit",
"ao_adamw_fp8",
"shampoo",
],
]
] = OptimizerNames.ADAMW_HF.value
@@ -384,6 +385,12 @@ class HyperparametersConfig(BaseModel):
"help": "The target modules to optimize, i.e. the module names that you would like to train."
},
)
optim_shampoo_grafting_config_type: Optional[
Literal["adam", "sgd", "adagrad"]
] = None
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None