This commit is contained in:
Wing Lian
2024-09-18 10:38:59 -07:00
parent 52e6249d2e
commit 5334d0fc01
3 changed files with 10 additions and 7 deletions

View File

@@ -10,8 +10,8 @@ optim_args:
max_preconditioner_dim: 8192
precondition_frequency: 100
use_decoupled_weight_decay: true
optim_shampoo_grafting_config_type: adam
optim_shampoo_grafting_config_kwargs:
beta2: 0.999
epsilon: 1e-12
optim_shampoo_grafting_config_type: adam
optim_shampoo_grafting_config_kwargs:
beta2: 0.999
epsilon: 1e-12
```

View File

@@ -498,6 +498,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
optim_args[key] = value
optim_args["betas"] = self.args.optim_shampoo_betas
if isinstance(optim_args["epsilon"], str):
optim_args["epsilon"] = float(optim_args["epsilon"])
optim_args["lr"] = self.args.learning_rate
optim_args["weight_decay"] = self.args.weight_decay
optim_args["use_decoupled_weight_decay"] = bool(
@@ -506,15 +508,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
if self.args.optim_shampoo_grafting_config_type == "adam":
grafting_config = AdamGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "sgd":
grafting_config = SGDGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs
**self.args.optim_shampoo_grafting_config_kwargs
)
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
grafting_config = AdaGradGraftingConfig(
self.args.optim_shampoo_grafting_config_kwargs
**self.args.optim_shampoo_grafting_config_kwargs
)
distributed_config = None

View File

@@ -776,6 +776,7 @@ class AxolotlInputConfig(
data["accelerator_config"]["dispatch_batches"] = False
return data
@model_validator(mode="before")
@model_validator(mode="before")
@classmethod
def check_gptq_w_revision(cls, data):