Compare commits

...

2 Commits

Author SHA1 Message Date
Wing Lian
e6b78c1fca override the entire create_optimzier method 2024-03-19 23:19:56 -04:00
Wing Lian
a236f5eab5 add support for 4bit optimizers 2024-03-19 22:57:40 -04:00
4 changed files with 173 additions and 16 deletions

29
docs/optimizers.md Normal file
View File

@@ -0,0 +1,29 @@
# Optimizers
Optimizers are an important component when training LLMs. Optimizers are responsible for updating the model's weights (parameters) based on the gradients computed during backpropagation.
The goal of an optimizer is to minimize the loss function.
### Adam/AdamW Optimizers
```yaml
adam_beta1: 0.9
adam_beta2: 0.999
adam_epsilon: 1e-8
weight_decay: 0.0
```
### GaLore Optimizer
https://huggingface.co/papers/2403.03507
```yaml
optimizer: galore_adamw | galore_adamw_8bit | galore_adafactor
optim_args:
rank: 128
update_proj_gap: 200
scale: 0.25
proj_type: std
optim_target_modules:
- mlp
- attn
```

View File

@@ -41,3 +41,6 @@ gcsfs
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90 trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
fastcore>=1.5.29 fastcore>=1.5.29
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
yacs

View File

@@ -15,18 +15,21 @@ from collections import defaultdict
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from pathlib import Path from pathlib import Path
from typing import Dict, List, Literal, Optional, Type, Union from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
import lpmm
import torch import torch
import transformers import transformers
from accelerate import FullyShardedDataParallelPlugin from accelerate import FullyShardedDataParallelPlugin
from accelerate.utils import str_to_bool from accelerate.utils import str_to_bool
from datasets import Dataset from datasets import Dataset
from torch import nn
from torch.distributed.fsdp import MixedPrecision from torch.distributed.fsdp import MixedPrecision
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import ( from transformers import (
EarlyStoppingCallback, EarlyStoppingCallback,
PreTrainedModel,
Trainer, Trainer,
TrainerCallback, TrainerCallback,
TrainingArguments, TrainingArguments,
@@ -36,6 +39,7 @@ from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
from axolotl.core.trainers import OptimizerNames
from axolotl.loraplus import create_loraplus_optimizer from axolotl.loraplus import create_loraplus_optimizer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
@@ -62,6 +66,9 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_warmup_decay_constant, get_cosine_schedule_with_warmup_decay_constant,
) )
# monkeypatch so it accepts our custom optimizers
transformers.training_args.OptimizerNames = OptimizerNames
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
@@ -231,26 +238,104 @@ class AxolotlTrainer(Trainer):
if self.args.orpo_alpha: if self.args.orpo_alpha:
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none") self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
@staticmethod
def get_optimizer_cls_and_kwargs(
args: TrainingArguments, model: Optional[PreTrainedModel] = None
) -> Tuple[Any, Any]:
optim_args = {}
if args.optim_args:
for mapping in args.optim_args.replace(" ", "").split(","):
key, value = mapping.split("=")
optim_args[key] = value
optimizer_kwargs = {"lr": args.learning_rate}
adam_kwargs = {
"betas": (args.adam_beta1, args.adam_beta2),
"eps": args.adam_epsilon,
}
if args.optim in [
OptimizerNames.LPMM_ADAMW_4BIT,
OptimizerNames.LPMM_ADAMW_4BIT_FUSED,
]:
optimizer_cls = lpmm.optim.AdamW
optimizer_kwargs.update(adam_kwargs)
if args.optim == OptimizerNames.LPMM_ADAMW_4BIT_FUSED:
optimizer_kwargs.update({"fused": True})
return optimizer_cls, optimizer_kwargs
return Trainer.get_optimizer_cls_and_kwargs(
args,
model=model,
)
def create_optimizer(self): def create_optimizer(self):
if self.args.loraplus_lr_ratio is None:
return super().create_optimizer()
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: # pylint: disable=access-member-before-definition
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
self.args,
opt_model,
)
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None) if self.optimizer is None: # pylint: disable=access-member-before-definition
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None) decay_parameters = self.get_decay_parameter_names(opt_model)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init optimizer_grouped_parameters = [
opt_model, {
"params": [
p
for n, p in opt_model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": self.args.weight_decay,
},
{
"params": [
p
for n, p in opt_model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
(
optimizer_cls, optimizer_cls,
optimizer_kwargs, optimizer_kwargs,
loraplus_lr_ratio, ) = AxolotlTrainer.get_optimizer_cls_and_kwargs(self.args)
loraplus_lr_embedding,
) if self.args.loraplus_lr_ratio:
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
loraplus_lr_embedding = getattr(
self.args, "loraplus_lr_embedding", None
)
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
opt_model,
optimizer_cls,
optimizer_kwargs,
loraplus_lr_ratio,
loraplus_lr_embedding,
)
else:
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
)
if optimizer_cls.__name__ == "Adam8bit":
import bitsandbytes
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
skipped = 0
for module in opt_model.modules():
if isinstance(module, nn.Embedding):
skipped += sum(
{
p.data_ptr(): p.numel() for p in module.parameters()
}.values()
)
LOG.info(f"skipped {module}: {skipped/2**20}M params")
manager.register_module_override(
module, "weight", {"optim_bits": 32}
)
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
LOG.info(f"skipped: {skipped/2**20}M params")
if is_sagemaker_mp_enabled(): if is_sagemaker_mp_enabled():
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init

View File

@@ -0,0 +1,40 @@
"""module for trainer helpers like OptimizerNames"""
from transformers.utils import ExplicitEnum
class OptimizerNames(ExplicitEnum):
"""
Stores the acceptable string identifiers for optimizers.
"""
ADAMW_HF = "adamw_hf"
ADAMW_TORCH = "adamw_torch"
ADAMW_TORCH_FUSED = "adamw_torch_fused"
ADAMW_TORCH_XLA = "adamw_torch_xla"
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
ADAMW_APEX_FUSED = "adamw_apex_fused"
ADAFACTOR = "adafactor"
ADAMW_ANYPRECISION = "adamw_anyprecision"
SGD = "sgd"
ADAGRAD = "adagrad"
ADAMW_BNB = "adamw_bnb_8bit"
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
LION_8BIT = "lion_8bit"
LION = "lion_32bit"
PAGED_ADAMW = "paged_adamw_32bit"
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
PAGED_LION = "paged_lion_32bit"
PAGED_LION_8BIT = "paged_lion_8bit"
RMSPROP = "rmsprop"
RMSPROP_BNB = "rmsprop_bnb"
RMSPROP_8BIT = "rmsprop_bnb_8bit"
RMSPROP_32BIT = "rmsprop_bnb_32bit"
GALORE_ADAMW = "galore_adamw"
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
GALORE_ADAFACTOR = "galore_adafactor"
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
LPMM_ADAMW_4BIT = "lmpp_adamw_4bit"
LPMM_ADAMW_4BIT_FUSED = "lmpp_adamw_4bit_fused"