wip shampoo optim support
This commit is contained in:
@@ -35,6 +35,7 @@ python-dotenv==1.0.1
|
|||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.2.1
|
liger-kernel==0.2.1
|
||||||
|
distributed_shampoo @ git+https://github.com/facebookresearch/optimizers.git@main
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -250,6 +250,11 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
optim_shampoo_grafting_config_type: Optional[
|
||||||
|
Literal["adam", "sgd", "adagrad"]
|
||||||
|
] = None
|
||||||
|
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
optim_shampoo_betas: Optional[Tuple[float, float]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -422,7 +427,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
not in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
"shampoo",
|
||||||
|
]
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
@@ -465,6 +476,59 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
loraplus_lr_ratio,
|
loraplus_lr_ratio,
|
||||||
loraplus_lr_embedding,
|
loraplus_lr_embedding,
|
||||||
)
|
)
|
||||||
|
elif self.args.alternate_optimizer == "shampoo":
|
||||||
|
from distributed_shampoo.distributed_shampoo import DistributedShampoo
|
||||||
|
from distributed_shampoo.shampoo_types import (
|
||||||
|
AdamGraftingConfig,
|
||||||
|
CommunicationDType,
|
||||||
|
DDPShampooConfig,
|
||||||
|
FSDPShampooConfig,
|
||||||
|
)
|
||||||
|
from distributed_shampoo.utils.shampoo_fsdp_utils import (
|
||||||
|
compile_fsdp_parameter_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# parse args.optim_args
|
||||||
|
optim_args = {}
|
||||||
|
if self.args.optim_args:
|
||||||
|
for mapping in self.args.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optim_args[key] = value
|
||||||
|
|
||||||
|
optim_args["betas"] = self.args.optim_shampoo_betas
|
||||||
|
optim_args["lr"] = self.args.learning_rate
|
||||||
|
optim_args["weight_decay"] = self.args.weight_decay
|
||||||
|
optim_args["use_decoupled_weight_decay"] = bool(
|
||||||
|
optim_args.get("use_decoupled_weight_decay")
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.args.optim_shampoo_grafting_config_type in ["adam", "adamw"]:
|
||||||
|
grafting_config = AdamGraftingConfig(
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
distributed_config = None
|
||||||
|
if self.args.world_size > 1:
|
||||||
|
if self.args.fsdp_config:
|
||||||
|
distributed_config = FSDPShampooConfig(
|
||||||
|
param_to_metadata=compile_fsdp_parameter_metadata(
|
||||||
|
self.model_wrapped
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
distributed_config = DDPShampooConfig(
|
||||||
|
communication_dtype=CommunicationDType.BFLOAT16,
|
||||||
|
num_trainers_per_group=self.args.world_size,
|
||||||
|
communicate_params=False,
|
||||||
|
)
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
DistributedShampoo(
|
||||||
|
optimizer_grouped_parameters,
|
||||||
|
grafting_config=grafting_config,
|
||||||
|
distributed_config=distributed_config,
|
||||||
|
**optim_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
@@ -1441,6 +1505,21 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"optim_target_modules"
|
"optim_target_modules"
|
||||||
] = self.cfg.optim_target_modules
|
] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
|
# shampoo optimizer config
|
||||||
|
if self.cfg.optim_shampoo_betas:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_shampoo_betas"
|
||||||
|
] = self.cfg.optim_shampoo_betas
|
||||||
|
if self.cfg.optim_shampoo_grafting_config_type:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_shampoo_grafting_config_type"
|
||||||
|
] = self.cfg.optim_shampoo_grafting_config_type
|
||||||
|
if self.cfg.optim_shampoo_grafting_config_kwargs:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_shampoo_grafting_config_kwargs"
|
||||||
|
] = self.cfg.optim_shampoo_grafting_config_kwargs
|
||||||
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"loraplus_lr_embedding"
|
"loraplus_lr_embedding"
|
||||||
@@ -1525,10 +1604,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"shampoo",
|
||||||
]:
|
]:
|
||||||
# Set default so transformers doesn't throw
|
# Set default so transformers doesn't throw
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
|
|||||||
@@ -372,6 +372,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"shampoo",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
@@ -384,6 +385,12 @@ class HyperparametersConfig(BaseModel):
|
|||||||
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
optim_shampoo_grafting_config_type: Optional[
|
||||||
|
Literal["adam", "sgd", "adagrad"]
|
||||||
|
] = None
|
||||||
|
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
optim_shampoo_betas: Optional[Tuple[float, float]] = None
|
||||||
|
|
||||||
torchdistx_path: Optional[str] = None
|
torchdistx_path: Optional[str] = None
|
||||||
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
|
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
|
||||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user