be able to use adam bnb 8bit and one cycle scheduler w fsdp

This commit is contained in:
Wing Lian
2023-05-22 09:00:49 -04:00
parent 1b3e401241
commit 9493b1b137
2 changed files with 28 additions and 5 deletions

View File

@@ -7,7 +7,7 @@ from datasets import (
load_dataset,
IterableDataset,
Dataset,
concatenate_datasets,
concatenate_datasets, DatasetDict,
)
from huggingface_hub import hf_hub_download
from transformers import PreTrainedTokenizerBase
@@ -37,7 +37,7 @@ from axolotl.prompters import (
)
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__
ds_hash = str(
md5(
@@ -196,7 +196,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
return dataset
def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path):
def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset):
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)

View File

@@ -9,13 +9,31 @@ import torch.cuda
import transformers
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from transformers import EarlyStoppingCallback
from transformers import EarlyStoppingCallback, Trainer
from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.callbacks import SavePeftModelCallback
class OneCycleLRSchedulerTrainer(Trainer):
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
optimizer=self.optimizer if optimizer is None else optimizer
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
num_training_steps=num_training_steps
pct_start = num_warmup_steps / num_training_steps
lr_scheduler = OneCycleLR(
optimizer,
max_lr=self.args.learning_rate,
total_steps=num_training_steps,
pct_start=pct_start,
div_factor=6,
)
return lr_scheduler
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
@@ -63,6 +81,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
training_arguments_kwargs["fsdp"] = cfg.fsdp
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
# can't set optimizers directly on trainer when using fsdp, so set them here
if cfg.optimizer:
training_arguments_kwargs["optim"] = cfg.optimizer
# deepspeed
if (
@@ -119,6 +140,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
cfg.optimizer == "adamw_bnb_8bit"
and not cfg.load_4bit
and not "deepspeed" in training_arguments_kwargs
and not cfg.fsdp
):
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
@@ -194,7 +216,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
trainer = transformers.Trainer(
trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer
trainer = trainer_cls(
model=model,
train_dataset=train_dataset,
eval_dataset=eval_dataset,