make sure to set alternate optimizer and set lr and eps from adam
This commit is contained in:
@@ -487,7 +487,10 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
elif self.args.alternate_optimizer == "soap":
|
elif self.args.alternate_optimizer == "soap":
|
||||||
from axolotl.utils.optimizers.soap import SOAP
|
from axolotl.utils.optimizers.soap import SOAP
|
||||||
|
|
||||||
optim_args = {}
|
optim_args = {
|
||||||
|
"lr": optimizer_kwargs.pop("lr"),
|
||||||
|
"eps": optimizer_kwargs.pop("eps"),
|
||||||
|
}
|
||||||
|
|
||||||
if self.cfg.optim_args:
|
if self.cfg.optim_args:
|
||||||
optim_args.update(self.cfg.optim_args)
|
optim_args.update(self.cfg.optim_args)
|
||||||
@@ -1639,6 +1642,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
"ao_adamw_4bit",
|
"ao_adamw_4bit",
|
||||||
"ao_adamw_8bit",
|
"ao_adamw_8bit",
|
||||||
"ao_adamw_fp8",
|
"ao_adamw_fp8",
|
||||||
|
"soap",
|
||||||
]:
|
]:
|
||||||
# Set default so transformers doesn't throw
|
# Set default so transformers doesn't throw
|
||||||
training_arguments_kwargs["optim"] = "adamw_hf"
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
|
|||||||
Reference in New Issue
Block a user