Compare commits
9 Commits
diffusion-
...
shampoo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
17330c05a3 | ||
|
|
992ea517b7 | ||
|
|
beaee36191 | ||
|
|
69a29382e1 | ||
|
|
84dad0bd12 | ||
|
|
05f61a0ea5 | ||
|
|
5334d0fc01 | ||
|
|
52e6249d2e | ||
|
|
eb3eab3450 |
17
docs/optimizers.qmd
Normal file
17
docs/optimizers.qmd
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# Optimizers
|
||||||
|
|
||||||
|
## Shampoo
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: shampoo
|
||||||
|
optim_shampoo_betas: [0.9, 0.999]
|
||||||
|
optim_args:
|
||||||
|
epsilon: 1e-12
|
||||||
|
max_preconditioner_dim: 8192
|
||||||
|
precondition_frequency: 100
|
||||||
|
use_decoupled_weight_decay: true
|
||||||
|
optim_shampoo_grafting_config_type: adam
|
||||||
|
optim_shampoo_grafting_config_kwargs:
|
||||||
|
beta2: 0.999
|
||||||
|
epsilon: 1e-12
|
||||||
|
```
|
||||||
@@ -35,6 +35,7 @@ python-dotenv==1.0.1
|
|||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.2.1
|
liger-kernel==0.2.1
|
||||||
|
distributed_shampoo @ git+https://github.com/facebookresearch/optimizers.git@main
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -250,6 +250,11 @@ class AxolotlTrainingMixins:
|
|||||||
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
"help": "workaround to pass an alternate lr scheduler to the HF trainer"
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
optim_shampoo_grafting_config_type: Optional[
|
||||||
|
Literal["adam", "sgd", "adagrad"]
|
||||||
|
] = None
|
||||||
|
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
optim_shampoo_betas: Optional[Tuple[float, float]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -422,7 +427,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
if (
|
if (
|
||||||
self.args.loraplus_lr_ratio is None
|
self.args.loraplus_lr_ratio is None
|
||||||
and self.args.alternate_optimizer
|
and self.args.alternate_optimizer
|
||||||
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"]
|
not in [
|
||||||
|
"optimi_adamw",
|
||||||
|
"ao_adamw_8bit",
|
||||||
|
"ao_adamw_4bit",
|
||||||
|
"ao_adamw_fp8",
|
||||||
|
"shampoo",
|
||||||
|
]
|
||||||
):
|
):
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
@@ -465,6 +476,102 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
loraplus_lr_ratio,
|
loraplus_lr_ratio,
|
||||||
loraplus_lr_embedding,
|
loraplus_lr_embedding,
|
||||||
)
|
)
|
||||||
|
elif self.args.alternate_optimizer == "shampoo":
|
||||||
|
from distributed_shampoo.distributed_shampoo import DistributedShampoo
|
||||||
|
from distributed_shampoo.shampoo_types import (
|
||||||
|
AdaGradGraftingConfig,
|
||||||
|
AdamGraftingConfig,
|
||||||
|
CommunicationDType,
|
||||||
|
DDPShampooConfig,
|
||||||
|
FSDPShampooConfig,
|
||||||
|
PrecisionConfig,
|
||||||
|
SGDGraftingConfig,
|
||||||
|
)
|
||||||
|
from distributed_shampoo.utils.shampoo_fsdp_utils import (
|
||||||
|
compile_fsdp_parameter_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# parse args.optim_args
|
||||||
|
optim_args = {}
|
||||||
|
if self.args.optim_args:
|
||||||
|
for mapping in self.args.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optim_args[key] = value
|
||||||
|
|
||||||
|
optim_args["betas"] = self.args.optim_shampoo_betas
|
||||||
|
if "max_preconditioner_dim" in optim_args:
|
||||||
|
optim_args["max_preconditioner_dim"] = int(
|
||||||
|
optim_args["max_preconditioner_dim"]
|
||||||
|
)
|
||||||
|
if "precondition_frequency" in optim_args:
|
||||||
|
optim_args["precondition_frequency"] = int(
|
||||||
|
optim_args["precondition_frequency"]
|
||||||
|
)
|
||||||
|
if "use_decoupled_weight_decay" in optim_args:
|
||||||
|
optim_args["use_decoupled_weight_decay"] = bool(
|
||||||
|
optim_args["use_decoupled_weight_decay"]
|
||||||
|
)
|
||||||
|
if isinstance(optim_args["epsilon"], str):
|
||||||
|
optim_args["epsilon"] = float(optim_args["epsilon"])
|
||||||
|
optim_args["lr"] = self.args.learning_rate
|
||||||
|
optim_args["weight_decay"] = self.args.weight_decay
|
||||||
|
|
||||||
|
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
|
||||||
|
if isinstance(
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str
|
||||||
|
):
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs[
|
||||||
|
"epsilon"
|
||||||
|
] = float(
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs["epsilon"]
|
||||||
|
)
|
||||||
|
if self.args.optim_shampoo_grafting_config_type == "adam":
|
||||||
|
grafting_config = AdamGraftingConfig(
|
||||||
|
**self.args.optim_shampoo_grafting_config_kwargs
|
||||||
|
)
|
||||||
|
elif self.args.optim_shampoo_grafting_config_type == "sgd":
|
||||||
|
grafting_config = SGDGraftingConfig(
|
||||||
|
**self.args.optim_shampoo_grafting_config_kwargs
|
||||||
|
)
|
||||||
|
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
|
||||||
|
grafting_config = AdaGradGraftingConfig(
|
||||||
|
**self.args.optim_shampoo_grafting_config_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
|
distributed_config = None
|
||||||
|
if self.args.world_size > 1:
|
||||||
|
if self.args.fsdp and self.args.fsdp_config:
|
||||||
|
distributed_config = FSDPShampooConfig(
|
||||||
|
param_to_metadata=compile_fsdp_parameter_metadata(
|
||||||
|
self.model_wrapped
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
distributed_config = DDPShampooConfig(
|
||||||
|
communication_dtype=CommunicationDType.BF16,
|
||||||
|
num_trainers_per_group=self.args.world_size,
|
||||||
|
communicate_params=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
precision_config = None
|
||||||
|
if self.args.bf16:
|
||||||
|
precision_config = PrecisionConfig(
|
||||||
|
computation_dtype=torch.bfloat16,
|
||||||
|
factor_matrix_dtype=torch.bfloat16,
|
||||||
|
inv_factor_matrix_dtype=torch.bfloat16,
|
||||||
|
filtered_grad_dtype=torch.bfloat16,
|
||||||
|
momentum_dtype=torch.bfloat16,
|
||||||
|
grafting_state_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
DistributedShampoo(
|
||||||
|
optimizer_grouped_parameters,
|
||||||
|
grafting_config=grafting_config,
|
||||||
|
distributed_config=distributed_config,
|
||||||
|
precision_config=precision_config,
|
||||||
|
**optim_args,
|
||||||
|
)
|
||||||
|
)
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
|
|
||||||
@@ -870,7 +977,11 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
run_dir = self._get_output_dir(trial=trial)
|
run_dir = self._get_output_dir(trial=trial)
|
||||||
output_dir = os.path.join(run_dir, checkpoint_folder)
|
output_dir = os.path.join(run_dir, checkpoint_folder)
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
return super()._save_checkpoint(model, trial, metrics=metrics)
|
try:
|
||||||
|
return super()._save_checkpoint(model, trial, metrics=metrics)
|
||||||
|
except NotImplementedError as exc:
|
||||||
|
LOG.warning(f"Failed to save checkpoint: {exc}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class AxolotlMambaTrainer(AxolotlTrainer):
|
class AxolotlMambaTrainer(AxolotlTrainer):
|
||||||
@@ -1441,6 +1552,21 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"optim_target_modules"
|
"optim_target_modules"
|
||||||
] = self.cfg.optim_target_modules
|
] = self.cfg.optim_target_modules
|
||||||
|
|
||||||
|
# shampoo optimizer config
|
||||||
|
if self.cfg.optim_shampoo_betas:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_shampoo_betas"
|
||||||
|
] = self.cfg.optim_shampoo_betas
|
||||||
|
if self.cfg.optim_shampoo_grafting_config_type:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_shampoo_grafting_config_type"
|
||||||
|
] = self.cfg.optim_shampoo_grafting_config_type
|
||||||
|
if self.cfg.optim_shampoo_grafting_config_kwargs:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"optim_shampoo_grafting_config_kwargs"
|
||||||
|
] = self.cfg.optim_shampoo_grafting_config_kwargs
|
||||||
|
|
||||||
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
|
||||||
training_arguments_kwargs[
|
training_arguments_kwargs[
|
||||||
"loraplus_lr_embedding"
|
"loraplus_lr_embedding"
|
||||||
@@ -1525,10 +1651,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
trainer_kwargs = {}
|
trainer_kwargs = {}
|
||||||
|
|
||||||
if self.cfg.optimizer in [
|
if self.cfg.optimizer in [
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
"optimi_adamw",
|
"optimi_adamw",
|
||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"shampoo",
|
||||||
]:
|
]:
|
||||||
# Set default so transformers doesn't throw
|
# Set default so transformers doesn't throw
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
|
|||||||
@@ -372,6 +372,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"shampoo",
|
||||||
],
|
],
|
||||||
]
|
]
|
||||||
] = OptimizerNames.ADAMW_HF.value
|
] = OptimizerNames.ADAMW_HF.value
|
||||||
@@ -384,6 +385,12 @@ class HyperparametersConfig(BaseModel):
|
|||||||
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
"help": "The target modules to optimize, i.e. the module names that you would like to train."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
optim_shampoo_grafting_config_type: Optional[
|
||||||
|
Literal["adam", "sgd", "adagrad"]
|
||||||
|
] = None
|
||||||
|
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
optim_shampoo_betas: Optional[Tuple[float, float]] = None
|
||||||
|
|
||||||
torchdistx_path: Optional[str] = None
|
torchdistx_path: Optional[str] = None
|
||||||
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
|
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
|
||||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user