add lion-pytorch optimizer (#1299) [skip ci]

* add lion-pytorch optimizer

* update pydantic to support lion optimizer

---------

Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
Maxime
2024-02-27 00:45:14 +01:00
committed by GitHub
parent f30d062b48
commit 16482796b0
3 changed files with 35 additions and 7 deletions

View File

@@ -18,6 +18,7 @@ def parse_requirements():
or "flash-attention" in line
or "deepspeed" in line
or "mamba-ssm" in line
or "lion-pytorch" in line
)
if line.startswith("--extra-index-url"):
# Handle custom index URLs
@@ -85,5 +86,8 @@ setup(
"mlflow": [
"mlflow",
],
"lion-pytorch": [
"lion-pytorch==0.1.2",
],
},
)

View File

@@ -970,18 +970,42 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
"neftune_noise_alpha"
] = self.cfg.neftune_noise_alpha
trainer_kwargs = {}
if self.cfg.optimizer == "lion_pytorch":
from lion_pytorch import Lion
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
if "weight_decay" in training_arguments_kwargs:
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
if (
"adam_beta1" in training_arguments_kwargs
and "adam_beta2" in training_arguments_kwargs
):
lion_kwargs["betas"] = (
training_arguments_kwargs["adam_beta1"],
training_arguments_kwargs["adam_beta2"],
)
trainer_kwargs["optimizers"] = (
Lion(params=self.model.parameters(), **lion_kwargs),
None,
)
# Set default so transformers doesn't throw
training_arguments_kwargs["optim"] = "adamw_hf"
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
training_args = (
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
**training_arguments_kwargs,
)
)
training_args = self.hook_post_create_training_args(training_args)
trainer_kwargs = {}
if self.cfg.optimizer == "adamw_anyprecision":
if Path(self.cfg.torchdistx_path).exists():
sys.path.append(self.cfg.torchdistx_path)
importlib.import_module("torchdistx")
data_collator_kwargs = {
"padding": True, # True/"longest" is the default

View File

@@ -263,7 +263,7 @@ class HyperparametersConfig(BaseModel):
learning_rate: Union[str, float]
weight_decay: Optional[float] = None
optimizer: Optional[OptimizerNames] = None
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
torchdistx_path: Optional[str] = None
lr_scheduler: Optional[SchedulerType] = None
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None