Compare commits

...

9 Commits

Author SHA1 Message Date
Wing Lian
17330c05a3 shampoo checkpoint save workaround 2024-09-23 15:21:00 -04:00
Wing Lian
992ea517b7 setup precision config for bf16 2024-09-18 12:01:38 -07:00
Wing Lian
beaee36191 ddp shampoo 2024-09-18 10:50:46 -07:00
Wing Lian
69a29382e1 fix casting of optim args 2024-09-18 10:48:15 -07:00
Wing Lian
84dad0bd12 ensure epsilon is cast to float 2024-09-18 10:42:26 -07:00
Wing Lian
05f61a0ea5 remove accidental duplidcated line 2024-09-18 10:41:03 -07:00
Wing Lian
5334d0fc01 fixes 2024-09-18 10:38:59 -07:00
Wing Lian
52e6249d2e additional grafting config types and basic example doc 2024-09-18 08:16:11 -07:00
Wing Lian
eb3eab3450 wip shampoo optim support 2024-09-18 08:10:52 -07:00
4 changed files with 156 additions and 3 deletions

17
docs/optimizers.qmd Normal file
View File

@@ -0,0 +1,17 @@
# Optimizers
## Shampoo
```yaml
optimizer: shampoo
optim_shampoo_betas: [0.9, 0.999]
optim_args:
epsilon: 1e-12
max_preconditioner_dim: 8192
precondition_frequency: 100
use_decoupled_weight_decay: true
optim_shampoo_grafting_config_type: adam
optim_shampoo_grafting_config_kwargs:
beta2: 0.999
epsilon: 1e-12
```

View File

@@ -35,6 +35,7 @@ python-dotenv==1.0.1
autoawq>=0.2.5 autoawq>=0.2.5
triton>=2.3.0 triton>=2.3.0
liger-kernel==0.2.1 liger-kernel==0.2.1
distributed_shampoo @ git+https://github.com/facebookresearch/optimizers.git@main
mamba-ssm==1.2.0.post1 mamba-ssm==1.2.0.post1

View File

@@ -16,7 +16,7 @@ from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Literal, Optional, Type, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import torch import torch
import transformers import transformers
@@ -250,6 +250,11 @@ class AxolotlTrainingMixins:
"help": "workaround to pass an alternate lr scheduler to the HF trainer" "help": "workaround to pass an alternate lr scheduler to the HF trainer"
}, },
) )
optim_shampoo_grafting_config_type: Optional[
Literal["adam", "sgd", "adagrad"]
] = None
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
@dataclass @dataclass
@@ -422,7 +427,13 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if ( if (
self.args.loraplus_lr_ratio is None self.args.loraplus_lr_ratio is None
and self.args.alternate_optimizer and self.args.alternate_optimizer
not in ["optimi_adamw", "ao_adamw_8bit", "ao_adamw_4bit", "ao_adamw_fp8"] not in [
"optimi_adamw",
"ao_adamw_8bit",
"ao_adamw_4bit",
"ao_adamw_fp8",
"shampoo",
]
): ):
return super().create_optimizer() return super().create_optimizer()
@@ -465,6 +476,102 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
loraplus_lr_ratio, loraplus_lr_ratio,
loraplus_lr_embedding, loraplus_lr_embedding,
) )
elif self.args.alternate_optimizer == "shampoo":
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
FSDPShampooConfig,
PrecisionConfig,
SGDGraftingConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import (
compile_fsdp_parameter_metadata,
)
# parse args.optim_args
optim_args = {}
if self.args.optim_args:
for mapping in self.args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optim_args["betas"] = self.args.optim_shampoo_betas
if "max_preconditioner_dim" in optim_args:
optim_args["max_preconditioner_dim"] = int(
optim_args["max_preconditioner_dim"]
)
if "precondition_frequency" in optim_args:
optim_args["precondition_frequency"] = int(
optim_args["precondition_frequency"]
)
if "use_decoupled_weight_decay" in optim_args:
optim_args["use_decoupled_weight_decay"] = bool(
optim_args["use_decoupled_weight_decay"]
)
if isinstance(optim_args["epsilon"], str):
optim_args["epsilon"] = float(optim_args["epsilon"])
optim_args["lr"] = self.args.learning_rate
optim_args["weight_decay"] = self.args.weight_decay
if "epsilon" in self.args.optim_shampoo_grafting_config_kwargs:
if isinstance(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"], str
):
self.args.optim_shampoo_grafting_config_kwargs[
"epsilon"
] = float(
self.args.optim_shampoo_grafting_config_kwargs["epsilon"]
)
if self.args.optim_shampoo_grafting_config_type == "adam":
grafting_config = AdamGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "sgd":
grafting_config = SGDGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
grafting_config = AdaGradGraftingConfig(
**self.args.optim_shampoo_grafting_config_kwargs
)
distributed_config = None
if self.args.world_size > 1:
if self.args.fsdp and self.args.fsdp_config:
distributed_config = FSDPShampooConfig(
param_to_metadata=compile_fsdp_parameter_metadata(
self.model_wrapped
)
)
else:
distributed_config = DDPShampooConfig(
communication_dtype=CommunicationDType.BF16,
num_trainers_per_group=self.args.world_size,
communicate_params=False,
)
precision_config = None
if self.args.bf16:
precision_config = PrecisionConfig(
computation_dtype=torch.bfloat16,
factor_matrix_dtype=torch.bfloat16,
inv_factor_matrix_dtype=torch.bfloat16,
filtered_grad_dtype=torch.bfloat16,
momentum_dtype=torch.bfloat16,
grafting_state_dtype=torch.bfloat16,
)
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
DistributedShampoo(
optimizer_grouped_parameters,
grafting_config=grafting_config,
distributed_config=distributed_config,
precision_config=precision_config,
**optim_args,
)
)
elif self.args.alternate_optimizer == "optimi_adamw": elif self.args.alternate_optimizer == "optimi_adamw":
from optimi import AdamW from optimi import AdamW
@@ -870,7 +977,11 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
run_dir = self._get_output_dir(trial=trial) run_dir = self._get_output_dir(trial=trial)
output_dir = os.path.join(run_dir, checkpoint_folder) output_dir = os.path.join(run_dir, checkpoint_folder)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
return super()._save_checkpoint(model, trial, metrics=metrics) try:
return super()._save_checkpoint(model, trial, metrics=metrics)
except NotImplementedError as exc:
LOG.warning(f"Failed to save checkpoint: {exc}")
return None
class AxolotlMambaTrainer(AxolotlTrainer): class AxolotlMambaTrainer(AxolotlTrainer):
@@ -1441,6 +1552,21 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs[ training_arguments_kwargs[
"optim_target_modules" "optim_target_modules"
] = self.cfg.optim_target_modules ] = self.cfg.optim_target_modules
# shampoo optimizer config
if self.cfg.optim_shampoo_betas:
training_arguments_kwargs[
"optim_shampoo_betas"
] = self.cfg.optim_shampoo_betas
if self.cfg.optim_shampoo_grafting_config_type:
training_arguments_kwargs[
"optim_shampoo_grafting_config_type"
] = self.cfg.optim_shampoo_grafting_config_type
if self.cfg.optim_shampoo_grafting_config_kwargs:
training_arguments_kwargs[
"optim_shampoo_grafting_config_kwargs"
] = self.cfg.optim_shampoo_grafting_config_kwargs
training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio training_arguments_kwargs["loraplus_lr_ratio"] = self.cfg.loraplus_lr_ratio
training_arguments_kwargs[ training_arguments_kwargs[
"loraplus_lr_embedding" "loraplus_lr_embedding"
@@ -1525,10 +1651,12 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
trainer_kwargs = {} trainer_kwargs = {}
if self.cfg.optimizer in [ if self.cfg.optimizer in [
# pylint: disable=duplicate-code
"optimi_adamw", "optimi_adamw",
"ao_adamw_4bit", "ao_adamw_4bit",
"ao_adamw_8bit", "ao_adamw_8bit",
"ao_adamw_fp8", "ao_adamw_fp8",
"shampoo",
]: ]:
# Set default so transformers doesn't throw # Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf" training_arguments_kwargs["optim"] = "adamw_hf"

View File

@@ -372,6 +372,7 @@ class HyperparametersConfig(BaseModel):
"ao_adamw_4bit", "ao_adamw_4bit",
"ao_adamw_8bit", "ao_adamw_8bit",
"ao_adamw_fp8", "ao_adamw_fp8",
"shampoo",
], ],
] ]
] = OptimizerNames.ADAMW_HF.value ] = OptimizerNames.ADAMW_HF.value
@@ -384,6 +385,12 @@ class HyperparametersConfig(BaseModel):
"help": "The target modules to optimize, i.e. the module names that you would like to train." "help": "The target modules to optimize, i.e. the module names that you would like to train."
}, },
) )
optim_shampoo_grafting_config_type: Optional[
Literal["adam", "sgd", "adagrad"]
] = None
optim_shampoo_grafting_config_kwargs: Optional[Dict[str, Any]] = None
optim_shampoo_betas: Optional[Tuple[float, float]] = None
torchdistx_path: Optional[str] = None torchdistx_path: Optional[str] = None
lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine" lr_scheduler: Optional[Union[SchedulerType, Literal["one_cycle"]]] = "cosine"
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None lr_scheduler_kwargs: Optional[Dict[str, Any]] = None