This commit is contained in:
Wing Lian
2024-09-18 10:38:59 -07:00
parent 52e6249d2e
commit 5334d0fc01
3 changed files with 10 additions and 7 deletions

View File

@@ -10,8 +10,8 @@ optim_args:
max_preconditioner_dim: 8192 max_preconditioner_dim: 8192
precondition_frequency: 100 precondition_frequency: 100
use_decoupled_weight_decay: true use_decoupled_weight_decay: true
optim_shampoo_grafting_config_type: adam optim_shampoo_grafting_config_type: adam
optim_shampoo_grafting_config_kwargs: optim_shampoo_grafting_config_kwargs:
beta2: 0.999 beta2: 0.999
epsilon: 1e-12 epsilon: 1e-12
``` ```

View File

@@ -498,6 +498,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
optim_args[key] = value optim_args[key] = value
optim_args["betas"] = self.args.optim_shampoo_betas optim_args["betas"] = self.args.optim_shampoo_betas
if isinstance(optim_args["epsilon"], str):
optim_args["epsilon"] = float(optim_args["epsilon"])
optim_args["lr"] = self.args.learning_rate optim_args["lr"] = self.args.learning_rate
optim_args["weight_decay"] = self.args.weight_decay optim_args["weight_decay"] = self.args.weight_decay
optim_args["use_decoupled_weight_decay"] = bool( optim_args["use_decoupled_weight_decay"] = bool(
@@ -506,15 +508,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if self.args.optim_shampoo_grafting_config_type == "adam": if self.args.optim_shampoo_grafting_config_type == "adam":
grafting_config = AdamGraftingConfig( grafting_config = AdamGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs **self.args.optim_shampoo_grafting_config_kwargs
) )
elif self.args.optim_shampoo_grafting_config_type == "sgd": elif self.args.optim_shampoo_grafting_config_type == "sgd":
grafting_config = SGDGraftingConfig( grafting_config = SGDGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs **self.args.optim_shampoo_grafting_config_kwargs
) )
elif self.args.optim_shampoo_grafting_config_type == "adagrad": elif self.args.optim_shampoo_grafting_config_type == "adagrad":
grafting_config = AdaGradGraftingConfig( grafting_config = AdaGradGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs **self.args.optim_shampoo_grafting_config_kwargs
) )
distributed_config = None distributed_config = None

View File

@@ -776,6 +776,7 @@ class AxolotlInputConfig(
data["accelerator_config"]["dispatch_batches"] = False data["accelerator_config"]["dispatch_batches"] = False
return data return data
@model_validator(mode="before")
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_gptq_w_revision(cls, data): def check_gptq_w_revision(cls, data):