Compare commits
2 Commits
coderabbit
...
4bit-optim
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e6b78c1fca | ||
|
|
a236f5eab5 |
29
docs/optimizers.md
Normal file
29
docs/optimizers.md
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
# Optimizers
|
||||||
|
|
||||||
|
Optimizers are an important component when training LLMs. Optimizers are responsible for updating the model's weights (parameters) based on the gradients computed during backpropagation.
|
||||||
|
The goal of an optimizer is to minimize the loss function.
|
||||||
|
|
||||||
|
### Adam/AdamW Optimizers
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
adam_beta1: 0.9
|
||||||
|
adam_beta2: 0.999
|
||||||
|
adam_epsilon: 1e-8
|
||||||
|
weight_decay: 0.0
|
||||||
|
```
|
||||||
|
|
||||||
|
### GaLore Optimizer
|
||||||
|
|
||||||
|
https://huggingface.co/papers/2403.03507
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
optimizer: galore_adamw | galore_adamw_8bit | galore_adafactor
|
||||||
|
optim_args:
|
||||||
|
rank: 128
|
||||||
|
update_proj_gap: 200
|
||||||
|
scale: 0.25
|
||||||
|
proj_type: std
|
||||||
|
optim_target_modules:
|
||||||
|
- mlp
|
||||||
|
- attn
|
||||||
|
```
|
||||||
@@ -41,3 +41,6 @@ gcsfs
|
|||||||
|
|
||||||
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
trl @ git+https://github.com/huggingface/trl.git@304e208f778a5442c30cdda500348226cdc97d90
|
||||||
fastcore>=1.5.29
|
fastcore>=1.5.29
|
||||||
|
|
||||||
|
lpmm @ git+https://github.com/thu-ml/low-bit-optimizers.git@main
|
||||||
|
yacs
|
||||||
|
|||||||
@@ -15,18 +15,21 @@ from collections import defaultdict
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List, Literal, Optional, Type, Union
|
from typing import Any, Dict, List, Literal, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
|
import lpmm
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from accelerate import FullyShardedDataParallelPlugin
|
from accelerate import FullyShardedDataParallelPlugin
|
||||||
from accelerate.utils import str_to_bool
|
from accelerate.utils import str_to_bool
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from torch import nn
|
||||||
from torch.distributed.fsdp import MixedPrecision
|
from torch.distributed.fsdp import MixedPrecision
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
from transformers import (
|
from transformers import (
|
||||||
EarlyStoppingCallback,
|
EarlyStoppingCallback,
|
||||||
|
PreTrainedModel,
|
||||||
Trainer,
|
Trainer,
|
||||||
TrainerCallback,
|
TrainerCallback,
|
||||||
TrainingArguments,
|
TrainingArguments,
|
||||||
@@ -36,6 +39,7 @@ from transformers.utils import is_sagemaker_mp_enabled
|
|||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
from axolotl.core.policies.auto_wrap import get_wrapping_policy_factory
|
||||||
|
from axolotl.core.trainers import OptimizerNames
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
from axolotl.loraplus import create_loraplus_optimizer
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
@@ -62,6 +66,9 @@ from axolotl.utils.schedulers import (
|
|||||||
get_cosine_schedule_with_warmup_decay_constant,
|
get_cosine_schedule_with_warmup_decay_constant,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# monkeypatch so it accepts our custom optimizers
|
||||||
|
transformers.training_args.OptimizerNames = OptimizerNames
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
@@ -231,26 +238,104 @@ class AxolotlTrainer(Trainer):
|
|||||||
if self.args.orpo_alpha:
|
if self.args.orpo_alpha:
|
||||||
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
self.loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_optimizer_cls_and_kwargs(
|
||||||
|
args: TrainingArguments, model: Optional[PreTrainedModel] = None
|
||||||
|
) -> Tuple[Any, Any]:
|
||||||
|
optim_args = {}
|
||||||
|
if args.optim_args:
|
||||||
|
for mapping in args.optim_args.replace(" ", "").split(","):
|
||||||
|
key, value = mapping.split("=")
|
||||||
|
optim_args[key] = value
|
||||||
|
|
||||||
|
optimizer_kwargs = {"lr": args.learning_rate}
|
||||||
|
|
||||||
|
adam_kwargs = {
|
||||||
|
"betas": (args.adam_beta1, args.adam_beta2),
|
||||||
|
"eps": args.adam_epsilon,
|
||||||
|
}
|
||||||
|
|
||||||
|
if args.optim in [
|
||||||
|
OptimizerNames.LPMM_ADAMW_4BIT,
|
||||||
|
OptimizerNames.LPMM_ADAMW_4BIT_FUSED,
|
||||||
|
]:
|
||||||
|
optimizer_cls = lpmm.optim.AdamW
|
||||||
|
optimizer_kwargs.update(adam_kwargs)
|
||||||
|
if args.optim == OptimizerNames.LPMM_ADAMW_4BIT_FUSED:
|
||||||
|
optimizer_kwargs.update({"fused": True})
|
||||||
|
return optimizer_cls, optimizer_kwargs
|
||||||
|
|
||||||
|
return Trainer.get_optimizer_cls_and_kwargs(
|
||||||
|
args,
|
||||||
|
model=model,
|
||||||
|
)
|
||||||
|
|
||||||
def create_optimizer(self):
|
def create_optimizer(self):
|
||||||
if self.args.loraplus_lr_ratio is None:
|
|
||||||
return super().create_optimizer()
|
|
||||||
|
|
||||||
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
|
||||||
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
|
||||||
optimizer_cls, optimizer_kwargs = Trainer.get_optimizer_cls_and_kwargs(
|
|
||||||
self.args,
|
|
||||||
opt_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
if self.optimizer is None: # pylint: disable=access-member-before-definition
|
||||||
loraplus_lr_embedding = getattr(self.args, "loraplus_lr_embedding", None)
|
decay_parameters = self.get_decay_parameter_names(opt_model)
|
||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
optimizer_grouped_parameters = [
|
||||||
opt_model,
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": self.args.weight_decay,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"params": [
|
||||||
|
p
|
||||||
|
for n, p in opt_model.named_parameters()
|
||||||
|
if (n not in decay_parameters and p.requires_grad)
|
||||||
|
],
|
||||||
|
"weight_decay": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
(
|
||||||
optimizer_cls,
|
optimizer_cls,
|
||||||
optimizer_kwargs,
|
optimizer_kwargs,
|
||||||
loraplus_lr_ratio,
|
) = AxolotlTrainer.get_optimizer_cls_and_kwargs(self.args)
|
||||||
loraplus_lr_embedding,
|
|
||||||
)
|
if self.args.loraplus_lr_ratio:
|
||||||
|
loraplus_lr_ratio = getattr(self.args, "loraplus_lr_ratio", None)
|
||||||
|
loraplus_lr_embedding = getattr(
|
||||||
|
self.args, "loraplus_lr_embedding", None
|
||||||
|
)
|
||||||
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
opt_model,
|
||||||
|
optimizer_cls,
|
||||||
|
optimizer_kwargs,
|
||||||
|
loraplus_lr_ratio,
|
||||||
|
loraplus_lr_embedding,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
self.optimizer = ( # pylint: disable=attribute-defined-outside-init
|
||||||
|
optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
||||||
|
)
|
||||||
|
|
||||||
|
if optimizer_cls.__name__ == "Adam8bit":
|
||||||
|
import bitsandbytes
|
||||||
|
|
||||||
|
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
||||||
|
|
||||||
|
skipped = 0
|
||||||
|
for module in opt_model.modules():
|
||||||
|
if isinstance(module, nn.Embedding):
|
||||||
|
skipped += sum(
|
||||||
|
{
|
||||||
|
p.data_ptr(): p.numel() for p in module.parameters()
|
||||||
|
}.values()
|
||||||
|
)
|
||||||
|
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
||||||
|
manager.register_module_override(
|
||||||
|
module, "weight", {"optim_bits": 32}
|
||||||
|
)
|
||||||
|
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
||||||
|
LOG.info(f"skipped: {skipped/2**20}M params")
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = smp.DistributedOptimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
|
|||||||
@@ -0,0 +1,40 @@
|
|||||||
|
"""module for trainer helpers like OptimizerNames"""
|
||||||
|
|
||||||
|
from transformers.utils import ExplicitEnum
|
||||||
|
|
||||||
|
|
||||||
|
class OptimizerNames(ExplicitEnum):
|
||||||
|
"""
|
||||||
|
Stores the acceptable string identifiers for optimizers.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ADAMW_HF = "adamw_hf"
|
||||||
|
ADAMW_TORCH = "adamw_torch"
|
||||||
|
ADAMW_TORCH_FUSED = "adamw_torch_fused"
|
||||||
|
ADAMW_TORCH_XLA = "adamw_torch_xla"
|
||||||
|
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
|
||||||
|
ADAMW_APEX_FUSED = "adamw_apex_fused"
|
||||||
|
ADAFACTOR = "adafactor"
|
||||||
|
ADAMW_ANYPRECISION = "adamw_anyprecision"
|
||||||
|
SGD = "sgd"
|
||||||
|
ADAGRAD = "adagrad"
|
||||||
|
ADAMW_BNB = "adamw_bnb_8bit"
|
||||||
|
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
|
||||||
|
LION_8BIT = "lion_8bit"
|
||||||
|
LION = "lion_32bit"
|
||||||
|
PAGED_ADAMW = "paged_adamw_32bit"
|
||||||
|
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
|
||||||
|
PAGED_LION = "paged_lion_32bit"
|
||||||
|
PAGED_LION_8BIT = "paged_lion_8bit"
|
||||||
|
RMSPROP = "rmsprop"
|
||||||
|
RMSPROP_BNB = "rmsprop_bnb"
|
||||||
|
RMSPROP_8BIT = "rmsprop_bnb_8bit"
|
||||||
|
RMSPROP_32BIT = "rmsprop_bnb_32bit"
|
||||||
|
GALORE_ADAMW = "galore_adamw"
|
||||||
|
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
|
||||||
|
GALORE_ADAFACTOR = "galore_adafactor"
|
||||||
|
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
|
||||||
|
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
|
||||||
|
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
|
||||||
|
LPMM_ADAMW_4BIT = "lmpp_adamw_4bit"
|
||||||
|
LPMM_ADAMW_4BIT_FUSED = "lmpp_adamw_4bit_fused"
|
||||||
|
|||||||
Reference in New Issue
Block a user