additional grafting config types and basic example doc

This commit is contained in:
Wing Lian
2024-09-18 08:16:11 -07:00
parent eb3eab3450
commit 52e6249d2e
2 changed files with 28 additions and 1 deletions

17
docs/optimizers.qmd Normal file
View File

@@ -0,0 +1,17 @@
# Optimizers
## Shampoo
```yaml
optimizer: shampoo
optim_shampoo_betas: [0.9, 0.999]
optim_args:
epsilon: 1e-12
max_preconditioner_dim: 8192
precondition_frequency: 100
use_decoupled_weight_decay: true
optim_shampoo_grafting_config_type: adam
optim_shampoo_grafting_config_kwargs:
beta2: 0.999
epsilon: 1e-12
```

View File

@@ -479,10 +479,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
elif self.args.alternate_optimizer == "shampoo":
from distributed_shampoo.distributed_shampoo import DistributedShampoo
from distributed_shampoo.shampoo_types import (
AdaGradGraftingConfig,
AdamGraftingConfig,
CommunicationDType,
DDPShampooConfig,
FSDPShampooConfig,
SGDGraftingConfig,
)
from distributed_shampoo.utils.shampoo_fsdp_utils import (
compile_fsdp_parameter_metadata,
@@ -502,10 +504,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
optim_args.get("use_decoupled_weight_decay")
)
if self.args.optim_shampoo_grafting_config_type in ["adam", "adamw"]:
if self.args.optim_shampoo_grafting_config_type == "adam":
grafting_config = AdamGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "sgd":
grafting_config = SGDGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
grafting_config = AdaGradGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs
)
distributed_config = None
if self.args.world_size > 1: