let transformers handle adamw_bnb_8bit

This commit is contained in:
Aman Karmani
2023-08-26 21:40:12 +00:00
parent 17605b85d8
commit 868530c39c

View File

@@ -10,19 +10,13 @@ from functools import partial
from pathlib import Path from pathlib import Path
from typing import Optional, Union from typing import Optional, Union
import bitsandbytes as bnb
import numpy as np import numpy as np
import torch.cuda import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled from datasets import Dataset, set_caching_enabled
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import ( from transformers.trainer_pt_utils import SequentialDistributedSampler
SequentialDistributedSampler,
get_parameter_names,
)
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
@@ -32,10 +26,7 @@ from axolotl.utils.callbacks import (
) )
from axolotl.utils.collators import DataCollatorForSeq2Seq from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import ( from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
)
LOG = logging.getLogger("axolotl") LOG = logging.getLogger("axolotl")
@@ -570,66 +561,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
if Path(cfg.torchdistx_path).exists(): if Path(cfg.torchdistx_path).exists():
sys.path.append(cfg.torchdistx_path) sys.path.append(cfg.torchdistx_path)
importlib.import_module("torchdistx") importlib.import_module("torchdistx")
if (
cfg.optimizer == "adamw_bnb_8bit"
and not cfg.gptq
and "deepspeed" not in training_arguments_kwargs
and not cfg.fsdp
):
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [
{
"params": [
p
for n, p in model.named_parameters()
if (n in decay_parameters and p.requires_grad)
],
"weight_decay": training_args.weight_decay,
},
{
"params": [
p
for n, p in model.named_parameters()
if (n not in decay_parameters and p.requires_grad)
],
"weight_decay": 0.0,
},
]
optimizer = bnb.optim.Adam8bit(
optimizer_grouped_parameters,
betas=(training_args.adam_beta1, training_args.adam_beta2),
eps=training_args.adam_epsilon,
lr=training_args.learning_rate,
)
if cfg.lr_scheduler == "one_cycle":
lr_scheduler_kwargs = (
cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {}
)
lr_scheduler = OneCycleLR(
optimizer,
cfg.learning_rate,
total_steps=total_num_steps,
epochs=cfg.num_epochs,
div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6,
**lr_scheduler_kwargs,
)
elif cfg.lr_scheduler == "log_sweep":
lr_scheduler = InterpolatingLogScheduler(
optimizer,
cfg.warmup_steps,
cfg.log_sweep_min_lr if cfg.log_sweep_min_lr else 1e-10,
cfg.log_sweep_max_lr if cfg.log_sweep_max_lr else 10,
)
else:
lr_scheduler = transformers.get_cosine_schedule_with_warmup(
optimizer,
training_args.warmup_steps,
total_num_steps,
)
trainer_kwargs["optimizers"] = (optimizer, lr_scheduler)
callbacks = [] callbacks = []
callbacks.append(GPUStatsCallback(cfg)) callbacks.append(GPUStatsCallback(cfg))