additional grafting config types and basic example doc
This commit is contained in:
17
docs/optimizers.qmd
Normal file
17
docs/optimizers.qmd
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
# Optimizers
|
||||||
|
|
||||||
|
## Shampoo
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: shampoo
|
||||||
|
optim_shampoo_betas: [0.9, 0.999]
|
||||||
|
optim_args:
|
||||||
|
epsilon: 1e-12
|
||||||
|
max_preconditioner_dim: 8192
|
||||||
|
precondition_frequency: 100
|
||||||
|
use_decoupled_weight_decay: true
|
||||||
|
optim_shampoo_grafting_config_type: adam
|
||||||
|
optim_shampoo_grafting_config_kwargs:
|
||||||
|
beta2: 0.999
|
||||||
|
epsilon: 1e-12
|
||||||
|
```
|
||||||
@@ -479,10 +479,12 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
elif self.args.alternate_optimizer == "shampoo":
|
elif self.args.alternate_optimizer == "shampoo":
|
||||||
from distributed_shampoo.distributed_shampoo import DistributedShampoo
|
from distributed_shampoo.distributed_shampoo import DistributedShampoo
|
||||||
from distributed_shampoo.shampoo_types import (
|
from distributed_shampoo.shampoo_types import (
|
||||||
|
AdaGradGraftingConfig,
|
||||||
AdamGraftingConfig,
|
AdamGraftingConfig,
|
||||||
CommunicationDType,
|
CommunicationDType,
|
||||||
DDPShampooConfig,
|
DDPShampooConfig,
|
||||||
FSDPShampooConfig,
|
FSDPShampooConfig,
|
||||||
|
SGDGraftingConfig,
|
||||||
)
|
)
|
||||||
from distributed_shampoo.utils.shampoo_fsdp_utils import (
|
from distributed_shampoo.utils.shampoo_fsdp_utils import (
|
||||||
compile_fsdp_parameter_metadata,
|
compile_fsdp_parameter_metadata,
|
||||||
@@ -502,10 +504,18 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
optim_args.get("use_decoupled_weight_decay")
|
optim_args.get("use_decoupled_weight_decay")
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.optim_shampoo_grafting_config_type in ["adam", "adamw"]:
|
if self.args.optim_shampoo_grafting_config_type == "adam":
|
||||||
grafting_config = AdamGraftingConfig(
|
grafting_config = AdamGraftingConfig(
|
||||||
self.args.optim_shampoo_grafting_config_kwargs
|
self.args.optim_shampoo_grafting_config_kwargs
|
||||||
)
|
)
|
||||||
|
elif self.args.optim_shampoo_grafting_config_type == "sgd":
|
||||||
|
grafting_config = SGDGraftingConfig(
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs
|
||||||
|
)
|
||||||
|
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
|
||||||
|
grafting_config = AdaGradGraftingConfig(
|
||||||
|
self.args.optim_shampoo_grafting_config_kwargs
|
||||||
|
)
|
||||||
|
|
||||||
distributed_config = None
|
distributed_config = None
|
||||||
if self.args.world_size > 1:
|
if self.args.world_size > 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user