fixes
This commit is contained in:
@@ -10,8 +10,8 @@ optim_args:
|
||||
max_preconditioner_dim: 8192
|
||||
precondition_frequency: 100
|
||||
use_decoupled_weight_decay: true
|
||||
optim_shampoo_grafting_config_type: adam
|
||||
optim_shampoo_grafting_config_kwargs:
|
||||
beta2: 0.999
|
||||
epsilon: 1e-12
|
||||
optim_shampoo_grafting_config_type: adam
|
||||
optim_shampoo_grafting_config_kwargs:
|
||||
beta2: 0.999
|
||||
epsilon: 1e-12
|
||||
```
|
||||
|
||||
@@ -498,6 +498,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
optim_args[key] = value
|
||||
|
||||
optim_args["betas"] = self.args.optim_shampoo_betas
|
||||
if isinstance(optim_args["epsilon"], str):
|
||||
optim_args["epsilon"] = float(optim_args["epsilon"])
|
||||
optim_args["lr"] = self.args.learning_rate
|
||||
optim_args["weight_decay"] = self.args.weight_decay
|
||||
optim_args["use_decoupled_weight_decay"] = bool(
|
||||
@@ -506,15 +508,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
||||
|
||||
if self.args.optim_shampoo_grafting_config_type == "adam":
|
||||
grafting_config = AdamGraftingConfig(
|
||||
self.args.optim_shampoo_grafting_config_kwargs
|
||||
**self.args.optim_shampoo_grafting_config_kwargs
|
||||
)
|
||||
elif self.args.optim_shampoo_grafting_config_type == "sgd":
|
||||
grafting_config = SGDGraftingConfig(
|
||||
self.args.optim_shampoo_grafting_config_kwargs
|
||||
**self.args.optim_shampoo_grafting_config_kwargs
|
||||
)
|
||||
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
|
||||
grafting_config = AdaGradGraftingConfig(
|
||||
self.args.optim_shampoo_grafting_config_kwargs
|
||||
**self.args.optim_shampoo_grafting_config_kwargs
|
||||
)
|
||||
|
||||
distributed_config = None
|
||||
|
||||
@@ -776,6 +776,7 @@ class AxolotlInputConfig(
|
||||
data["accelerator_config"]["dispatch_batches"] = False
|
||||
return data
|
||||
|
||||
@model_validator(mode="before")
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gptq_w_revision(cls, data):
|
||||
|
||||
Reference in New Issue
Block a user