update upstream deps versions and replace lora+ (#1928)
* update upstream deps versions and replace lora+ * typo transformers version
This commit is contained in:
@@ -1,9 +1,9 @@
|
|||||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||||
packaging==23.2
|
packaging==23.2
|
||||||
peft==0.12.0
|
peft==0.13.0
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git@0963229e287501bed52ae1dabc17922524de6992
|
transformers==4.45.0
|
||||||
tokenizers>=0.19.1
|
tokenizers>=0.19.1
|
||||||
bitsandbytes==0.43.3
|
bitsandbytes==0.44.0
|
||||||
accelerate==0.34.2
|
accelerate==0.34.2
|
||||||
datasets==2.21.0
|
datasets==2.21.0
|
||||||
deepspeed==0.14.4
|
deepspeed==0.14.4
|
||||||
@@ -34,7 +34,7 @@ tensorboard
|
|||||||
python-dotenv==1.0.1
|
python-dotenv==1.0.1
|
||||||
autoawq>=0.2.5
|
autoawq>=0.2.5
|
||||||
triton>=2.3.0
|
triton>=2.3.0
|
||||||
liger-kernel==0.2.1
|
liger-kernel==0.3.0
|
||||||
|
|
||||||
mamba-ssm==1.2.0.post1
|
mamba-ssm==1.2.0.post1
|
||||||
|
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from typing import Any, Dict, List, Literal, Optional, Type, Union
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from datasets import Dataset
|
from datasets import Dataset
|
||||||
|
from peft.optimizers import create_loraplus_optimizer
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||||
@@ -45,7 +46,6 @@ from trl import (
|
|||||||
)
|
)
|
||||||
from trl.trainer.utils import pad_to_length
|
from trl.trainer.utils import pad_to_length
|
||||||
|
|
||||||
from axolotl.loraplus import create_loraplus_optimizer
|
|
||||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||||
from axolotl.utils import is_mlflow_available
|
from axolotl.utils import is_mlflow_available
|
||||||
@@ -461,9 +461,9 @@ class AxolotlTrainer(SchedulerMixin, Trainer):
|
|||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
opt_model,
|
opt_model,
|
||||||
optimizer_cls,
|
optimizer_cls,
|
||||||
optimizer_kwargs,
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
loraplus_lr_ratio,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
loraplus_lr_embedding,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
elif self.args.alternate_optimizer == "optimi_adamw":
|
elif self.args.alternate_optimizer == "optimi_adamw":
|
||||||
from optimi import AdamW
|
from optimi import AdamW
|
||||||
@@ -969,9 +969,9 @@ class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
|||||||
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
self.optimizer = create_loraplus_optimizer( # pylint: disable=attribute-defined-outside-init
|
||||||
opt_model,
|
opt_model,
|
||||||
optimizer_cls,
|
optimizer_cls,
|
||||||
optimizer_kwargs,
|
loraplus_lr_ratio=loraplus_lr_ratio,
|
||||||
loraplus_lr_ratio,
|
loraplus_lr_embedding=loraplus_lr_embedding,
|
||||||
loraplus_lr_embedding,
|
**optimizer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_sagemaker_mp_enabled():
|
if is_sagemaker_mp_enabled():
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""Module for LoRA+"""
|
|
||||||
|
|
||||||
# MIT License
|
|
||||||
#
|
|
||||||
# Copyright (c) 2024 nikhil-ghosh-berkeley
|
|
||||||
# https://github.com/nikhil-ghosh-berkeley/loraplus
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from functools import reduce
|
|
||||||
|
|
||||||
from peft.tuners import lora
|
|
||||||
from torch import nn
|
|
||||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl.loraplus")
|
|
||||||
|
|
||||||
|
|
||||||
def get_module(name, opt_model):
|
|
||||||
"""
|
|
||||||
Retrieve a module from a model using its parameter name.
|
|
||||||
Args:
|
|
||||||
name (str): Full name of the parameter, typically including module path.
|
|
||||||
opt_model (torch.nn.Module): The model from which to retrieve the module.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Module corresponding to the given name.
|
|
||||||
"""
|
|
||||||
parent_idx = 2 if "lora" in name else 1
|
|
||||||
module_names = name.split(sep=".")[:-parent_idx]
|
|
||||||
module = reduce(getattr, module_names, opt_model)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
def create_loraplus_optimizer(
|
|
||||||
opt_model,
|
|
||||||
optimizer_cls,
|
|
||||||
optimizer_kwargs,
|
|
||||||
loraplus_lr_ratio,
|
|
||||||
loraplus_lr_embedding=None,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Creates an optimizer for the given model, applying LoRA-specific learning rate adjustments to different parameter groups.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
opt_model (torch.nn.Module): The model for which the optimizer is being created.
|
|
||||||
optimizer_cls (class): The class of the optimizer to be used (e.g., torch.optim.Adam).
|
|
||||||
optimizer_kwargs (dict): A dictionary of keyword arguments for the optimizer's initialization.
|
|
||||||
loraplus_lr_ratio (float): The learning rate ratio to be applied to LoRA parameters.
|
|
||||||
loraplus_lr_embedding (float, optional): A specific learning rate for embedding parameters, with a default value if not provided.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of the specified optimizer class configured with the model's parameters organized into groups with custom learning rates.
|
|
||||||
"""
|
|
||||||
|
|
||||||
assert loraplus_lr_ratio is not None, "loraplus_lr_ratio must be provided."
|
|
||||||
|
|
||||||
if loraplus_lr_embedding is None:
|
|
||||||
loraplus_lr_embedding = 1e-6
|
|
||||||
|
|
||||||
decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
|
|
||||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
|
||||||
param_groups = {
|
|
||||||
"groupA": {},
|
|
||||||
"groupB": {},
|
|
||||||
"groupB_no_decay": {},
|
|
||||||
"embedding": {},
|
|
||||||
}
|
|
||||||
|
|
||||||
for name, param in opt_model.named_parameters():
|
|
||||||
if not param.requires_grad:
|
|
||||||
continue
|
|
||||||
|
|
||||||
module = get_module(name, opt_model)
|
|
||||||
if isinstance(module, lora.Embedding):
|
|
||||||
param_groups["embedding"][name] = param
|
|
||||||
elif "lora_B" in name or param.ndim == 1:
|
|
||||||
if name in decay_parameters:
|
|
||||||
param_groups["groupB"][name] = param
|
|
||||||
else:
|
|
||||||
param_groups["groupB_no_decay"][name] = param
|
|
||||||
else:
|
|
||||||
param_groups["groupA"][name] = param
|
|
||||||
|
|
||||||
assigned_param_groups = ""
|
|
||||||
for group, group_params in param_groups.items():
|
|
||||||
assigned_param_groups += f"{group}\n {list(group_params.keys())}\n\n"
|
|
||||||
LOG.info(assigned_param_groups)
|
|
||||||
|
|
||||||
lr = optimizer_kwargs["lr"] # pylint: disable=invalid-name
|
|
||||||
weight_decay = optimizer_kwargs.get("weight_decay", 0.0)
|
|
||||||
|
|
||||||
optimizer_grouped_parameters = [
|
|
||||||
{
|
|
||||||
"params": list(param_groups["groupA"].values()),
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"lr": lr,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": list(param_groups["embedding"].values()),
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"lr": loraplus_lr_embedding,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": list(param_groups["groupB"].values()),
|
|
||||||
"weight_decay": weight_decay,
|
|
||||||
"lr": lr * loraplus_lr_ratio,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"params": list(param_groups["groupB_no_decay"].values()),
|
|
||||||
"weight_decay": 0.0,
|
|
||||||
"lr": lr * loraplus_lr_ratio,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer = optimizer_cls(optimizer_grouped_parameters, **optimizer_kwargs)
|
|
||||||
if optimizer_cls.__name__ == "Adam8bit":
|
|
||||||
import bitsandbytes
|
|
||||||
|
|
||||||
manager = bitsandbytes.optim.GlobalOptimManager.get_instance()
|
|
||||||
|
|
||||||
skipped = 0
|
|
||||||
for module in opt_model.modules():
|
|
||||||
if isinstance(module, nn.Embedding):
|
|
||||||
skipped += sum(
|
|
||||||
{p.data_ptr(): p.numel() for p in module.parameters()}.values()
|
|
||||||
)
|
|
||||||
LOG.info(f"skipped {module}: {skipped/2**20}M params")
|
|
||||||
manager.register_module_override(module, "weight", {"optim_bits": 32})
|
|
||||||
LOG.debug(f"bitsandbytes: will optimize {module} in fp32")
|
|
||||||
LOG.info(f"skipped: {skipped/2**20}M params")
|
|
||||||
|
|
||||||
return optimizer
|
|
||||||
Reference in New Issue
Block a user