fixes
This commit is contained in:
@@ -10,8 +10,8 @@ optim_args:
|
|||||||
max_preconditioner_dim: 8192
|
max_preconditioner_dim: 8192
|
||||||
precondition_frequency: 100
|
precondition_frequency: 100
|
||||||
use_decoupled_weight_decay: true
|
use_decoupled_weight_decay: true
|
||||||
optim_shampoo_grafting_config_type: adam
|
optim_shampoo_grafting_config_type: adam
|
||||||
optim_shampoo_grafting_config_kwargs:
|
optim_shampoo_grafting_config_kwargs:
|
||||||
beta2: 0.999
|
beta2: 0.999
|
||||||
epsilon: 1e-12
|
epsilon: 1e-12
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -498,6 +498,8 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
optim_args[key] = value
|
optim_args[key] = value
|
||||||
|
|
||||||
optim_args["betas"] = self.args.optim_shampoo_betas
|
optim_args["betas"] = self.args.optim_shampoo_betas
|
||||||
|
if isinstance(optim_args["epsilon"], str):
|
||||||
|
optim_args["epsilon"] = float(optim_args["epsilon"])
|
||||||
optim_args["lr"] = self.args.learning_rate
|
optim_args["lr"] = self.args.learning_rate
|
||||||
optim_args["weight_decay"] = self.args.weight_decay
|
optim_args["weight_decay"] = self.args.weight_decay
|
||||||
optim_args["use_decoupled_weight_decay"] = bool(
|
optim_args["use_decoupled_weight_decay"] = bool(
|
||||||
@@ -506,15 +508,15 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
|
|
||||||
if self.args.optim_shampoo_grafting_config_type == "adam":
|
if self.args.optim_shampoo_grafting_config_type == "adam":
|
||||||
grafting_config = AdamGraftingConfig(
|
grafting_config = AdamGraftingConfig(
|
||||||
self.args.optim_shampoo_grafting_config_kwargs
|
**self.args.optim_shampoo_grafting_config_kwargs
|
||||||
)
|
)
|
||||||
elif self.args.optim_shampoo_grafting_config_type == "sgd":
|
elif self.args.optim_shampoo_grafting_config_type == "sgd":
|
||||||
grafting_config = SGDGraftingConfig(
|
grafting_config = SGDGraftingConfig(
|
||||||
self.args.optim_shampoo_grafting_config_kwargs
|
**self.args.optim_shampoo_grafting_config_kwargs
|
||||||
)
|
)
|
||||||
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
|
elif self.args.optim_shampoo_grafting_config_type == "adagrad":
|
||||||
grafting_config = AdaGradGraftingConfig(
|
grafting_config = AdaGradGraftingConfig(
|
||||||
self.args.optim_shampoo_grafting_config_kwargs
|
**self.args.optim_shampoo_grafting_config_kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
distributed_config = None
|
distributed_config = None
|
||||||
|
|||||||
@@ -776,6 +776,7 @@ class AxolotlInputConfig(
|
|||||||
data["accelerator_config"]["dispatch_batches"] = False
|
data["accelerator_config"]["dispatch_batches"] = False
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="before")
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_gptq_w_revision(cls, data):
|
def check_gptq_w_revision(cls, data):
|
||||||
|
|||||||
Reference in New Issue
Block a user