be able to use adam bnb 8bit and one cycle scheduler w fsdp
This commit is contained in:
@@ -7,7 +7,7 @@ from datasets import (
|
||||
load_dataset,
|
||||
IterableDataset,
|
||||
Dataset,
|
||||
concatenate_datasets,
|
||||
concatenate_datasets, DatasetDict,
|
||||
)
|
||||
from huggingface_hub import hf_hub_download
|
||||
from transformers import PreTrainedTokenizerBase
|
||||
@@ -37,7 +37,7 @@ from axolotl.prompters import (
|
||||
)
|
||||
|
||||
|
||||
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict:
|
||||
tokenizer_name = tokenizer.__class__.__name__
|
||||
ds_hash = str(
|
||||
md5(
|
||||
@@ -196,7 +196,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa
|
||||
return dataset
|
||||
|
||||
|
||||
def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path):
|
||||
def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset):
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
|
||||
@@ -9,13 +9,31 @@ import torch.cuda
|
||||
import transformers
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from transformers import EarlyStoppingCallback
|
||||
from transformers import EarlyStoppingCallback, Trainer
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from axolotl.utils.schedulers import InterpolatingLogScheduler
|
||||
from axolotl.utils.callbacks import SavePeftModelCallback
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(Trainer):
|
||||
def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None):
|
||||
optimizer=self.optimizer if optimizer is None else optimizer
|
||||
num_warmup_steps=self.args.get_warmup_steps(num_training_steps)
|
||||
num_training_steps=num_training_steps
|
||||
pct_start = num_warmup_steps / num_training_steps
|
||||
|
||||
lr_scheduler = OneCycleLR(
|
||||
optimizer,
|
||||
max_lr=self.args.learning_rate,
|
||||
total_steps=num_training_steps,
|
||||
pct_start=pct_start,
|
||||
div_factor=6,
|
||||
)
|
||||
|
||||
return lr_scheduler
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
@@ -63,6 +81,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
training_arguments_kwargs["fsdp"] = cfg.fsdp
|
||||
if cfg.fsdp_config:
|
||||
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
|
||||
# can't set optimizers directly on trainer when using fsdp, so set them here
|
||||
if cfg.optimizer:
|
||||
training_arguments_kwargs["optim"] = cfg.optimizer
|
||||
|
||||
# deepspeed
|
||||
if (
|
||||
@@ -119,6 +140,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
cfg.optimizer == "adamw_bnb_8bit"
|
||||
and not cfg.load_4bit
|
||||
and not "deepspeed" in training_arguments_kwargs
|
||||
and not cfg.fsdp
|
||||
):
|
||||
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
|
||||
decay_parameters = [name for name in decay_parameters if "bias" not in name]
|
||||
@@ -194,7 +216,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
else:
|
||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
||||
|
||||
trainer = transformers.Trainer(
|
||||
trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer
|
||||
trainer = trainer_cls(
|
||||
model=model,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
|
||||
Reference in New Issue
Block a user