add lion-pytorch optimizer (#1299) [skip ci]
* add lion-pytorch optimizer * update pydantic to support lion optimizer --------- Co-authored-by: Wing Lian <wing.lian@gmail.com>
This commit is contained in:
4
setup.py
4
setup.py
@@ -18,6 +18,7 @@ def parse_requirements():
|
|||||||
or "flash-attention" in line
|
or "flash-attention" in line
|
||||||
or "deepspeed" in line
|
or "deepspeed" in line
|
||||||
or "mamba-ssm" in line
|
or "mamba-ssm" in line
|
||||||
|
or "lion-pytorch" in line
|
||||||
)
|
)
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
@@ -85,5 +86,8 @@ setup(
|
|||||||
"mlflow": [
|
"mlflow": [
|
||||||
"mlflow",
|
"mlflow",
|
||||||
],
|
],
|
||||||
|
"lion-pytorch": [
|
||||||
|
"lion-pytorch==0.1.2",
|
||||||
|
],
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -970,18 +970,42 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
|||||||
"neftune_noise_alpha"
|
"neftune_noise_alpha"
|
||||||
] = self.cfg.neftune_noise_alpha
|
] = self.cfg.neftune_noise_alpha
|
||||||
|
|
||||||
|
trainer_kwargs = {}
|
||||||
|
|
||||||
|
if self.cfg.optimizer == "lion_pytorch":
|
||||||
|
from lion_pytorch import Lion
|
||||||
|
|
||||||
|
lion_kwargs = {"lr": training_arguments_kwargs["learning_rate"]}
|
||||||
|
if "weight_decay" in training_arguments_kwargs:
|
||||||
|
lion_kwargs["weight_decay"] = training_arguments_kwargs["weight_decay"]
|
||||||
|
|
||||||
|
if (
|
||||||
|
"adam_beta1" in training_arguments_kwargs
|
||||||
|
and "adam_beta2" in training_arguments_kwargs
|
||||||
|
):
|
||||||
|
lion_kwargs["betas"] = (
|
||||||
|
training_arguments_kwargs["adam_beta1"],
|
||||||
|
training_arguments_kwargs["adam_beta2"],
|
||||||
|
)
|
||||||
|
|
||||||
|
trainer_kwargs["optimizers"] = (
|
||||||
|
Lion(params=self.model.parameters(), **lion_kwargs),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
# Set default so transformers doesn't throw
|
||||||
|
training_arguments_kwargs["optim"] = "adamw_hf"
|
||||||
|
|
||||||
|
if self.cfg.optimizer == "adamw_anyprecision":
|
||||||
|
if Path(self.cfg.torchdistx_path).exists():
|
||||||
|
sys.path.append(self.cfg.torchdistx_path)
|
||||||
|
importlib.import_module("torchdistx")
|
||||||
|
|
||||||
training_args = (
|
training_args = (
|
||||||
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
training_args = self.hook_post_create_training_args(training_args)
|
training_args = self.hook_post_create_training_args(training_args)
|
||||||
trainer_kwargs = {}
|
|
||||||
|
|
||||||
if self.cfg.optimizer == "adamw_anyprecision":
|
|
||||||
if Path(self.cfg.torchdistx_path).exists():
|
|
||||||
sys.path.append(self.cfg.torchdistx_path)
|
|
||||||
importlib.import_module("torchdistx")
|
|
||||||
|
|
||||||
data_collator_kwargs = {
|
data_collator_kwargs = {
|
||||||
"padding": True, # True/"longest" is the default
|
"padding": True, # True/"longest" is the default
|
||||||
|
|||||||
@@ -263,7 +263,7 @@ class HyperparametersConfig(BaseModel):
|
|||||||
|
|
||||||
learning_rate: Union[str, float]
|
learning_rate: Union[str, float]
|
||||||
weight_decay: Optional[float] = None
|
weight_decay: Optional[float] = None
|
||||||
optimizer: Optional[OptimizerNames] = None
|
optimizer: Optional[Union[OptimizerNames, Literal["lion_pytorch"]]] = None
|
||||||
torchdistx_path: Optional[str] = None
|
torchdistx_path: Optional[str] = None
|
||||||
lr_scheduler: Optional[SchedulerType] = None
|
lr_scheduler: Optional[SchedulerType] = None
|
||||||
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
lr_scheduler_kwargs: Optional[Dict[str, Any]] = None
|
||||||
|
|||||||
Reference in New Issue
Block a user