From 9493b1b1377497c89cf0a175e0742eb22e179fe0 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 22 May 2023 09:00:49 -0400 Subject: [PATCH] be able to use adam bnb 8bit and one cycle scheduler w fsdp --- src/axolotl/utils/data.py | 6 +++--- src/axolotl/utils/trainer.py | 27 +++++++++++++++++++++++++-- 2 files changed, 28 insertions(+), 5 deletions(-) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 4d4f3c1b2..2ceaa4d99 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -7,7 +7,7 @@ from datasets import ( load_dataset, IterableDataset, Dataset, - concatenate_datasets, + concatenate_datasets, DatasetDict, ) from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -37,7 +37,7 @@ from axolotl.prompters import ( ) -def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path): +def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict: tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( @@ -196,7 +196,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa return dataset -def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path): +def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset): max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 7e1109708..4c6eb7626 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -9,13 +9,31 @@ import torch.cuda import transformers from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from transformers import EarlyStoppingCallback +from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.schedulers import InterpolatingLogScheduler from axolotl.utils.callbacks import SavePeftModelCallback +class OneCycleLRSchedulerTrainer(Trainer): + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + optimizer=self.optimizer if optimizer is None else optimizer + num_warmup_steps=self.args.get_warmup_steps(num_training_steps) + num_training_steps=num_training_steps + pct_start = num_warmup_steps / num_training_steps + + lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + pct_start=pct_start, + div_factor=6, + ) + + return lr_scheduler + + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) @@ -63,6 +81,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): training_arguments_kwargs["fsdp"] = cfg.fsdp if cfg.fsdp_config: training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) + # can't set optimizers directly on trainer when using fsdp, so set them here + if cfg.optimizer: + training_arguments_kwargs["optim"] = cfg.optimizer # deepspeed if ( @@ -119,6 +140,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): cfg.optimizer == "adamw_bnb_8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs + and not cfg.fsdp ): decay_parameters = get_parameter_names(model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name] @@ -194,7 +216,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): else: data_collator_kwargs["pad_to_multiple_of"] = 8 - trainer = transformers.Trainer( + trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer + trainer = trainer_cls( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset,