From 1edc30c786794ba2d57976c417378a0d27ced6eb Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 27 May 2023 17:57:29 -0400 Subject: [PATCH 01/14] add support for opimum bettertransformers --- configs/gpt_neox_20b.yml | 30 ++++++++++++++++++------------ requirements.txt | 1 + scripts/finetune.py | 15 +++++++++++---- src/axolotl/utils/models.py | 8 ++++++-- src/axolotl/utils/validation.py | 8 ++++++++ 5 files changed, 44 insertions(+), 18 deletions(-) diff --git a/configs/gpt_neox_20b.yml b/configs/gpt_neox_20b.yml index 730afb72c..25fdae53b 100644 --- a/configs/gpt_neox_20b.yml +++ b/configs/gpt_neox_20b.yml @@ -1,24 +1,25 @@ base_model: EleutherAI/gpt-neox-20b +base_model_config: EleutherAI/gpt-neox-20b base_model_ignore_patterns: pytorch* # prefer safetensors model_type: GPTNeoXForCausalLM tokenizer_type: AutoTokenizer -load_in_8bit: true +load_in_8bit: false +load_in_4bit: true +load_4bit: false datasets: - - path: nomic-ai/gpt4all-j-prompt-generations + - path: vicgalle/alpaca-gpt4 type: alpaca - shards: 4 - shards_index: 0 dataset_prepared_path: last_run_prepared val_set_size: 0.05 -adapter: lora +adapter: lora_model_dir: sequence_len: 2048 max_packed_sequence_len: 2048 -lora_r: 8 +lora_r: 64 lora_alpha: 32 -lora_dropout: 0.05 +lora_dropout: 0.0 lora_target_modules: - - query_key_value +lora_target_linear: true lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific wandb_project: gpt4all-neox-20b wandb_watch: @@ -26,14 +27,19 @@ wandb_run_id: wandb_log_model: output_dir: ./gpt4all-neox-20b gradient_accumulation_steps: 1 -micro_batch_size: 4 +micro_batch_size: 2 num_epochs: 5 learning_rate: 0.00003 -lr_scheduler: one_cycle +optimizer: paged_adamw_32bit +lr_scheduler: cosine train_on_inputs: false group_by_length: false -bf16: True -tf32: True +bf16: false +fp16: false +float16: true +tf32: true +flash_optimum: true early_stopping_patience: resume_from_checkpoint: local_rank: +gradient_checkpointing: true diff --git a/requirements.txt b/requirements.txt index c9123fce8..d1b2f4555 100644 --- a/requirements.txt +++ b/requirements.txt @@ -11,6 +11,7 @@ sentencepiece wandb einops xformers +optimum # qlora things bert-score==0.3.13 evaluate==0.4.0 diff --git a/scripts/finetune.py b/scripts/finetune.py index fa2dcf903..a5b5e7c85 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -6,6 +6,7 @@ import os import random import signal import sys +from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Union @@ -19,6 +20,8 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.models import load_model, load_tokenizer # add src to the pythonpath so we don't need to pip install this +from optimum.bettertransformer import BetterTransformer + from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import setup_trainer from axolotl.utils.validation import validate_config @@ -264,12 +267,14 @@ def train( # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: + def terminate_handler(signum, frame, model): + if cfg.flash_optimum: + model = BetterTransformer.reverse(model) + model.save_pretrained(cfg.output_dir) + sys.exit(0) signal.signal( signal.SIGINT, - lambda signal, frame: ( - model.save_pretrained(cfg.output_dir), - sys.exit(0), - ), + lambda signum, frame: terminate_handler(signum, frame, model) ) logging.info("Starting trainer...") @@ -299,6 +304,8 @@ def train( # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file if cfg.local_rank == 0: + if cfg.flash_optimum: + model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir) # trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 1acaf6ab3..11b4629ec 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -11,7 +11,8 @@ import bitsandbytes as bnb import torch import transformers from transformers import PreTrainedModel # noqa: F401 -from transformers import ( # noqa: F401 +from optimum.bettertransformer import BetterTransformer +from transformers import ( AutoConfig, AutoModelForCausalLM, AutoTokenizer, @@ -137,7 +138,7 @@ def load_model( if cfg.bf16: torch_dtype = torch.bfloat16 - elif cfg.load_in_8bit or cfg.fp16: + elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: torch_dtype = torch.float16 else: torch_dtype = torch.float32 @@ -342,6 +343,9 @@ def load_model( logging.warning("there are no parameters that require gradient updates") model.config.use_cache = False + if cfg.flash_optimum: + model = BetterTransformer.transform(model) + # TODO resume_from_checkpoint handling return model, lora_config diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 04ffc4c1b..ba5feafe8 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -57,6 +57,14 @@ def validate_config(cfg): if (cfg.base_model and "falcon" in cfg.base_model.lower()) and cfg.fsdp: raise ValueError("FSDP is not supported for falcon models") + if cfg.flash_optimum is True: + if cfg.adapter: + logging.warning("BetterTransformers probably doesn't work with PEFT adapters") + if cfg.fp16 or cfg.bf16: + raise ValueError("AMP is not supported with BetterTransformer") + if cfg.float16 is not True: + logging.warning("You should probably set float16 to true") + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 From 879219979955fa2c3a2394578a8886f77e687594 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 27 May 2023 18:12:12 -0400 Subject: [PATCH 02/14] add flash attn context for efficient training and attempt setting model to train mode: --- scripts/finetune.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index a5b5e7c85..99236b087 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -252,6 +252,24 @@ def train( model.save_pretrained(cfg.output_dir) return + if cfg.debug: + logging.info("check_dataset_labels...") + check_dataset_labels( + train_dataset.select( + [random.randrange(0, len(train_dataset) - 1) for i in range(5)] + ), + tokenizer, + ) + + if prepare_ds_only: + logging.info("Finished preparing dataset. Exiting...") + return + + try: + model.train() + except: + pass + trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) model.config.use_cache = False @@ -297,7 +315,11 @@ def train( if not Path(cfg.output_dir).is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) - trainer.train(resume_from_checkpoint=resume_from_checkpoint) + if cfg.flash_optimum: + with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + trainer.train(resume_from_checkpoint=resume_from_checkpoint) + else: + trainer.train(resume_from_checkpoint=resume_from_checkpoint) logging.info(f"Training Completed!!! Saving pre-trained model to {cfg.output_dir}") From 39619028a37f4af77dd0b89c9b8191c783d7049a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 27 May 2023 19:37:24 -0400 Subject: [PATCH 03/14] use pythia-12b, neox-20b is flaky --- examples/pythia-12b/README.md | 10 ++++++++++ .../pythia-12b/config.yml | 20 +++++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) create mode 100644 examples/pythia-12b/README.md rename configs/gpt_neox_20b.yml => examples/pythia-12b/config.yml (72%) diff --git a/examples/pythia-12b/README.md b/examples/pythia-12b/README.md new file mode 100644 index 000000000..0953caa4e --- /dev/null +++ b/examples/pythia-12b/README.md @@ -0,0 +1,10 @@ +# Python 12B + +- Single-GPU A100 only (?) + +```shell +python scripts/finetune.py examples/pythia-12b/config.yml +``` + +⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️ + diff --git a/configs/gpt_neox_20b.yml b/examples/pythia-12b/config.yml similarity index 72% rename from configs/gpt_neox_20b.yml rename to examples/pythia-12b/config.yml index 25fdae53b..28e822c77 100644 --- a/configs/gpt_neox_20b.yml +++ b/examples/pythia-12b/config.yml @@ -1,11 +1,12 @@ -base_model: EleutherAI/gpt-neox-20b -base_model_config: EleutherAI/gpt-neox-20b +base_model: EleutherAI/pythia-12b-deduped +base_model_config: EleutherAI/pythia-12b-deduped base_model_ignore_patterns: pytorch* # prefer safetensors model_type: GPTNeoXForCausalLM tokenizer_type: AutoTokenizer load_in_8bit: false -load_in_4bit: true -load_4bit: false +load_in_4bit: false +gptq: false +device_map: auto datasets: - path: vicgalle/alpaca-gpt4 type: alpaca @@ -21,16 +22,16 @@ lora_dropout: 0.0 lora_target_modules: lora_target_linear: true lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific -wandb_project: gpt4all-neox-20b +wandb_project: pythia-12b wandb_watch: wandb_run_id: wandb_log_model: -output_dir: ./gpt4all-neox-20b +output_dir: ./pythia-12b gradient_accumulation_steps: 1 -micro_batch_size: 2 +micro_batch_size: 1 num_epochs: 5 learning_rate: 0.00003 -optimizer: paged_adamw_32bit +optimizer: adamw_bnb_8bit lr_scheduler: cosine train_on_inputs: false group_by_length: false @@ -43,3 +44,6 @@ early_stopping_patience: resume_from_checkpoint: local_rank: gradient_checkpointing: true +fsdp: +fsdp_transformer_layer_cls_to_wrap: +collator_pad_to_longest: true From 71a43f8479a1cef0247ceb2cc00c7c1a048ed863 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 28 May 2023 08:56:08 -0400 Subject: [PATCH 04/14] add validation/warning for bettertransformers and torch version --- src/axolotl/utils/validation.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index ba5feafe8..db19900cc 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,7 +1,7 @@ """Module for validating config files""" import logging - +import torch def validate_config(cfg): if cfg.gradient_accumulation_steps and cfg.batch_size: @@ -63,7 +63,10 @@ def validate_config(cfg): if cfg.fp16 or cfg.bf16: raise ValueError("AMP is not supported with BetterTransformer") if cfg.float16 is not True: - logging.warning("You should probably set float16 to true") + logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers") + if torch.__version__.split(".")[0] < 2: + logging.warning("torch>=2.0.0 required") + raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}") # TODO # MPT 7b From 488a67d75a4a6ccf7ed0862bbe913a356a473b0d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 16:51:19 -0400 Subject: [PATCH 05/14] experimental expansion of ctx len --- scripts/finetune.py | 44 +++++++++++++++++++++++---------------- src/axolotl/utils/data.py | 32 +++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 19 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 99236b087..88815dfdd 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -6,22 +6,20 @@ import os import random import signal import sys -from functools import partial from pathlib import Path from typing import Any, Dict, List, Optional, Union import fire import torch import yaml -from transformers import GenerationConfig, TextStreamer - -from axolotl.utils.data import load_prepare_datasets -from axolotl.utils.dict import DictDefault -from axolotl.utils.models import load_model, load_tokenizer # add src to the pythonpath so we don't need to pip install this from optimum.bettertransformer import BetterTransformer +from transformers import GenerationConfig, TextStreamer +from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset +from axolotl.utils.dict import DictDefault +from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.trainer import setup_trainer from axolotl.utils.validation import validate_config @@ -204,9 +202,19 @@ def train( if check_not_in( ["inference", "shard", "merge_lora"], kwargs ): # don't need to load dataset for these - train_dataset, eval_dataset = load_prepare_datasets( - tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH - ) + if not cfg.pretraining_dataset: + train_dataset, eval_dataset = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH + ) + else: + if cfg.pretraining_dataset is True: + pretraining_dataset = "togethercomputer/RedPajama-Data-1T" + else: + pretraining_dataset = cfg.pretraining_dataset + train_dataset = load_pretraining_dataset( + pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len + ) + eval_dataset = None if cfg.debug or "debug" in kwargs: logging.info("check_dataset_labels...") @@ -256,7 +264,7 @@ def train( logging.info("check_dataset_labels...") check_dataset_labels( train_dataset.select( - [random.randrange(0, len(train_dataset) - 1) for i in range(5)] + [random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec ), tokenizer, ) @@ -265,10 +273,7 @@ def train( logging.info("Finished preparing dataset. Exiting...") return - try: - model.train() - except: - pass + model.train() trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) @@ -285,14 +290,15 @@ def train( # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: - def terminate_handler(signum, frame, model): + + def terminate_handler(_, __, model): if cfg.flash_optimum: model = BetterTransformer.reverse(model) model.save_pretrained(cfg.output_dir) sys.exit(0) + signal.signal( - signal.SIGINT, - lambda signum, frame: terminate_handler(signum, frame, model) + signal.SIGINT, lambda signum, frame: terminate_handler(signum, frame, model) ) logging.info("Starting trainer...") @@ -316,7 +322,9 @@ def train( if not Path(cfg.output_dir).is_dir(): os.makedirs(cfg.output_dir, exist_ok=True) if cfg.flash_optimum: - with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True): + with torch.backends.cuda.sdp_kernel( + enable_flash=True, enable_math=True, enable_mem_efficient=True + ): trainer.train(resume_from_checkpoint=resume_from_checkpoint) else: trainer.train(resume_from_checkpoint=resume_from_checkpoint) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index cba964076..49314372a 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -5,7 +5,8 @@ from hashlib import md5 from pathlib import Path from typing import List, Tuple, Union -from datasets import Dataset, DatasetDict, load_dataset, load_from_disk +import torch +from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -392,3 +393,32 @@ def load_prepare_datasets( eval_dataset = dataset["test"] return train_dataset, eval_dataset + + +class PretrainingDatasetWrapper(IterableDataset): + """ + Wrapper for pretraining dataset that avoids loading the dataset into memory + """ + + def __init__(self, tokenizer, dataset_path, max_tokens=2048): + self.tokenizer = tokenizer + self.dataset_path = dataset_path + self.max_tokens = max_tokens + + def __iter__(self): + buffer = [] + for sample in load_dataset( + self.dataset_path, + name="all", + split="train", + streaming=True, + ).shuffle(buffer_size=10000): + buffer += self.tokenizer(sample["text"])["input_ids"] + buffer += [self.tokenizer.eos_token_id] + while len(buffer) > self.max_tokens: + yield torch.tensor(buffer[: self.max_tokens]) + buffer = buffer[self.max_tokens :] + + +def load_pretraining_dataset(path, tokenizer, max_tokens=2048): + return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens) From 1210dc8fd5c494face7165338f1ed9f2981a2245 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 31 May 2023 21:59:15 -0400 Subject: [PATCH 06/14] more tweaks to do pre-training with bettertransformers --- scripts/finetune.py | 2 ++ src/axolotl/utils/callbacks.py | 24 ++++++++++++++++++++++++ src/axolotl/utils/data.py | 12 +++++++----- src/axolotl/utils/models.py | 4 ++-- src/axolotl/utils/trainer.py | 8 +++++++- src/axolotl/utils/validation.py | 16 ++++++++++++---- 6 files changed, 54 insertions(+), 12 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 88815dfdd..9bed61ca4 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,6 +14,7 @@ import torch import yaml # add src to the pythonpath so we don't need to pip install this +from datasets import Dataset from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer @@ -214,6 +215,7 @@ def train( train_dataset = load_pretraining_dataset( pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len ) + train_dataset = Dataset.from_list(list(train_dataset)) eval_dataset = None if cfg.debug or "debug" in kwargs: diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index f6852249a..ab197304c 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -2,6 +2,7 @@ import os +from optimum.bettertransformer import BetterTransformer from transformers import ( TrainerCallback, TrainerControl, @@ -30,3 +31,26 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- kwargs["model"].save_pretrained(peft_model_path) return control + + +class SaveBetterTransformerModelCallback( + TrainerCallback +): # pylint: disable=too-few-public-methods + """Callback to save the BatterTransformer wrapped model""" + + def on_save( + self, + args: TrainingArguments, + state: TrainerState, + control: TrainerControl, + **kwargs, + ): + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + + model = BetterTransformer.reverse(kwargs["model"]) + model.save_pretrained(checkpoint_folder) + + return control diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 49314372a..164296ee2 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -409,14 +409,16 @@ class PretrainingDatasetWrapper(IterableDataset): buffer = [] for sample in load_dataset( self.dataset_path, - name="all", - split="train", - streaming=True, - ).shuffle(buffer_size=10000): + )["train"].shuffle(): buffer += self.tokenizer(sample["text"])["input_ids"] buffer += [self.tokenizer.eos_token_id] while len(buffer) > self.max_tokens: - yield torch.tensor(buffer[: self.max_tokens]) + input_ids = torch.tensor(buffer[: self.max_tokens]) + yield { + "input_ids": input_ids, + "attention_mask": torch.ones(input_ids.size()), + "labels": input_ids, + } buffer = buffer[self.max_tokens :] diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 11b4629ec..91ef96ca9 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -10,8 +10,8 @@ from typing import TYPE_CHECKING, Optional, Tuple # noqa: F401 import bitsandbytes as bnb import torch import transformers -from transformers import PreTrainedModel # noqa: F401 from optimum.bettertransformer import BetterTransformer +from transformers import PreTrainedModel # noqa: F401 from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -136,7 +136,7 @@ def load_model( logging.info("patching with xpos rope") replace_llama_rope_with_xpos_rope() - if cfg.bf16: + if cfg.bf16 or cfg.bfloat16: torch_dtype = torch.bfloat16 elif cfg.load_in_8bit or cfg.fp16 or cfg.float16: torch_dtype = torch.float16 diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 9ae1e7e93..b7823fea4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -16,7 +16,10 @@ from torch.optim.lr_scheduler import OneCycleLR from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names -from axolotl.utils.callbacks import SavePeftModelCallback +from axolotl.utils.callbacks import ( + SaveBetterTransformerModelCallback, + SavePeftModelCallback, +) from axolotl.utils.schedulers import InterpolatingLogScheduler @@ -228,6 +231,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): ]: # only save in rank 0 callbacks.append(SavePeftModelCallback) + if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: + callbacks.append(SaveBetterTransformerModelCallback) + data_collator_kwargs = { "padding": True, } diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index db19900cc..abaaba8d0 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -1,8 +1,10 @@ """Module for validating config files""" import logging + import torch + def validate_config(cfg): if cfg.gradient_accumulation_steps and cfg.batch_size: raise ValueError( @@ -59,14 +61,20 @@ def validate_config(cfg): if cfg.flash_optimum is True: if cfg.adapter: - logging.warning("BetterTransformers probably doesn't work with PEFT adapters") + logging.warning( + "BetterTransformers probably doesn't work with PEFT adapters" + ) if cfg.fp16 or cfg.bf16: raise ValueError("AMP is not supported with BetterTransformer") if cfg.float16 is not True: - logging.warning("You should probably set float16 to true to load the model in float16 for BetterTransformers") - if torch.__version__.split(".")[0] < 2: + logging.warning( + "You should probably set float16 to true to load the model in float16 for BetterTransformers" + ) + if int(torch.__version__.split(".")[0]) < 2: logging.warning("torch>=2.0.0 required") - raise ValueError(f"flash_optimum for BetterTransformers may not be used with {torch.__version__}") + raise ValueError( + f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" + ) # TODO # MPT 7b From 1a82082e91127fedae540cfbc9e68ce2b3ef08a4 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 1 Jun 2023 00:33:13 -0400 Subject: [PATCH 07/14] fix bettertransformers save, force it to skip after saving correctly in callback --- src/axolotl/utils/callbacks.py | 30 +++++++++++++++++++++--------- src/axolotl/utils/trainer.py | 1 + src/axolotl/utils/validation.py | 5 +++-- 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index ab197304c..64bf48664 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -9,7 +9,7 @@ from transformers import ( TrainerState, TrainingArguments, ) -from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public-methods @@ -36,21 +36,33 @@ class SavePeftModelCallback(TrainerCallback): # pylint: disable=too-few-public- class SaveBetterTransformerModelCallback( TrainerCallback ): # pylint: disable=too-few-public-methods - """Callback to save the BatterTransformer wrapped model""" + """Callback to save the BetterTransformer wrapped model""" - def on_save( + def on_step_end( self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs, ): - checkpoint_folder = os.path.join( - args.output_dir, - f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", - ) + # Save + if ( + args.save_strategy == IntervalStrategy.STEPS + and args.save_steps > 0 + and state.global_step % args.save_steps == 0 + ): + control.should_save = True - model = BetterTransformer.reverse(kwargs["model"]) - model.save_pretrained(checkpoint_folder) + if control.should_save: + checkpoint_folder = os.path.join( + args.output_dir, + f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}", + ) + model = BetterTransformer.reverse(kwargs["model"]) + model.save_pretrained(checkpoint_folder) + + # since we're saving here, we don't need the trainer loop to attempt to save too b/c + # the trainer will raise an exception since it can't save a BetterTransformer wrapped model + control.should_save = False return control diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index b7823fea4..59b1dc803 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -232,6 +232,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): callbacks.append(SavePeftModelCallback) if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: + logging.info("Setting up SaveBetterTransformerModelCallback.") callbacks.append(SaveBetterTransformerModelCallback) data_collator_kwargs = { diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index abaaba8d0..396036621 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -66,9 +66,10 @@ def validate_config(cfg): ) if cfg.fp16 or cfg.bf16: raise ValueError("AMP is not supported with BetterTransformer") - if cfg.float16 is not True: + if cfg.float16 is not True and cfg.bloat16 is not True: logging.warning( - "You should probably set float16 to true to load the model in float16 for BetterTransformers" + "You should probably set bfloat16 or float16 to true to " + "load the model in float16 for BetterTransformers" ) if int(torch.__version__.split(".")[0]) < 2: logging.warning("torch>=2.0.0 required") From ab5cd28acfd12304201c4c184aa03a5ac3885ce2 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 1 Jun 2023 08:20:08 -0400 Subject: [PATCH 08/14] more gpt-neox long ctx fixes --- src/axolotl/utils/callbacks.py | 1 + src/axolotl/utils/data.py | 10 +++++++--- src/axolotl/utils/models.py | 6 ++++++ src/axolotl/utils/validation.py | 9 ++++++++- 4 files changed, 22 insertions(+), 4 deletions(-) diff --git a/src/axolotl/utils/callbacks.py b/src/axolotl/utils/callbacks.py index 64bf48664..526121f2e 100644 --- a/src/axolotl/utils/callbacks.py +++ b/src/axolotl/utils/callbacks.py @@ -61,6 +61,7 @@ class SaveBetterTransformerModelCallback( model = BetterTransformer.reverse(kwargs["model"]) model.save_pretrained(checkpoint_folder) + # FIXME - need to cleanup old checkpoints # since we're saving here, we don't need the trainer loop to attempt to save too b/c # the trainer will raise an exception since it can't save a BetterTransformer wrapped model diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 164296ee2..13ad7c75d 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -388,9 +388,13 @@ def load_prepare_datasets( index=cfg.dataset_shard_idx, ) - dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) - train_dataset = dataset["train"] - eval_dataset = dataset["test"] + if cfg.val_set_size: + dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + else: + train_dataset = dataset + eval_dataset = None return train_dataset, eval_dataset diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 91ef96ca9..49a9b6f85 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -300,6 +300,12 @@ def load_model( embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) + if cfg.sequence_len >= model.config.max_position_embeddings: + logging.warning( + f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" + ) + model.config.max_position_embeddings = cfg.sequence_len + if not cfg.gptq and ( (cfg.adapter == "lora" and load_in_8bit) or (cfg.adapter == "qlora" and cfg.load_in_4bit) diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 396036621..2e2450fba 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -80,4 +80,11 @@ def validate_config(cfg): # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 - # no 8bit adamw w bf16 + # no 8bit adaAmw w bf16 + + # GPT-NeoX + # evals broken when extending context len + # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/transformers/models/gpt_neox/modeling_gpt_neox.py", line 162, in forward attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + # File "/root/miniconda3/envs/py3.9/lib/python3.9/site-packages/optimum/bettertransformer/models/attention.py", line 74, in gpt2_wrapped_scaled_dot_product + # attention_mask = causal_mask + attention_mask + # RuntimeError: The size of tensor a (2048) must match the size of tensor b (8132) at non-singleton dimension 3 From 1db46a9c720d60113ff2828ab6de219e1b857c79 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 8 Jun 2023 22:05:06 -0400 Subject: [PATCH 09/14] linting fix --- examples/pythia-12b/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/pythia-12b/README.md b/examples/pythia-12b/README.md index 0953caa4e..d28d5e77d 100644 --- a/examples/pythia-12b/README.md +++ b/examples/pythia-12b/README.md @@ -7,4 +7,3 @@ python scripts/finetune.py examples/pythia-12b/config.yml ``` ⚠️ Multiple-GPU A100 - Doesn't seem to work with multi-gpu without causing OOM! ⚠️ - From eea2731a5ebc113e769aa2a57af9b96effed2053 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 9 Jun 2023 20:25:38 -0400 Subject: [PATCH 10/14] add streaming dataset support for pretraining datasets --- README.md | 2 + scripts/finetune.py | 23 +----- src/axolotl/utils/data.py | 136 ++++++++++++++++++++++++++------ src/axolotl/utils/validation.py | 5 ++ tests/test_validation.py | 51 ++++++++++++ 5 files changed, 171 insertions(+), 46 deletions(-) diff --git a/README.md b/README.md index de929f237..2bc55732d 100644 --- a/README.md +++ b/README.md @@ -410,6 +410,8 @@ optimizer: # specify weight decay weight_decay: +# whether to bettertransformers +flash_optimum: # whether to use xformers attention patch https://github.com/facebookresearch/xformers: xformers_attention: # whether to use flash attention patch https://github.com/HazyResearch/flash-attention: diff --git a/scripts/finetune.py b/scripts/finetune.py index 9bed61ca4..ab226f68f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,7 +14,6 @@ import torch import yaml # add src to the pythonpath so we don't need to pip install this -from datasets import Dataset from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer @@ -208,14 +207,11 @@ def train( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: - if cfg.pretraining_dataset is True: - pretraining_dataset = "togethercomputer/RedPajama-Data-1T" - else: - pretraining_dataset = cfg.pretraining_dataset train_dataset = load_pretraining_dataset( - pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len + cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len ) - train_dataset = Dataset.from_list(list(train_dataset)) + # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 + train_dataset = train_dataset.with_format("torch") eval_dataset = None if cfg.debug or "debug" in kwargs: @@ -262,19 +258,6 @@ def train( model.save_pretrained(cfg.output_dir) return - if cfg.debug: - logging.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec - ), - tokenizer, - ) - - if prepare_ds_only: - logging.info("Finished preparing dataset. Exiting...") - return - model.train() trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 13ad7c75d..492d8059b 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,12 +1,12 @@ """Module containing data utilities""" - +import functools import logging from hashlib import md5 from pathlib import Path from typing import List, Tuple, Union import torch -from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -399,32 +399,116 @@ def load_prepare_datasets( return train_dataset, eval_dataset -class PretrainingDatasetWrapper(IterableDataset): - """ - Wrapper for pretraining dataset that avoids loading the dataset into memory - """ +def encode_pretraining(tokenizer, max_tokens, examples): + res = tokenizer( + examples["text"], + truncation=True, + max_length=max_tokens - 2, + add_special_tokens=True, + ) + # Convert to PyTorch tensors + input_ids = [torch.tensor(seq) for seq in res["input_ids"]] + attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] + new_input_ids = [] + new_attention_mask = [] + # Append EOS and PAD tokens to input_ids, and correct attention_mask + for i, _ in enumerate(input_ids): + input_ids[i] = torch.cat( + ( + input_ids[i], + torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), + ), + dim=0, + ) + attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) - def __init__(self, tokenizer, dataset_path, max_tokens=2048): - self.tokenizer = tokenizer - self.dataset_path = dataset_path - self.max_tokens = max_tokens + # Concatenate tokens so that their lengths are less than max_tokens + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) - def __iter__(self): - buffer = [] - for sample in load_dataset( - self.dataset_path, - )["train"].shuffle(): - buffer += self.tokenizer(sample["text"])["input_ids"] - buffer += [self.tokenizer.eos_token_id] - while len(buffer) > self.max_tokens: - input_ids = torch.tensor(buffer[: self.max_tokens]) - yield { - "input_ids": input_ids, - "attention_mask": torch.ones(input_ids.size()), - "labels": input_ids, - } - buffer = buffer[self.max_tokens :] + for ids, mask in zip(input_ids, attention_mask): + if buffer_input_ids.numel() == max_tokens: + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + elif buffer_input_ids.numel() + ids.numel() <= max_tokens: + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + else: + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + + if buffer_input_ids.numel() > 0: # for any leftover tokens + while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + + ret = { + "input_ids": [seq.tolist() for seq in new_input_ids], + "labels": [seq.tolist() for seq in new_input_ids], + "attention_mask": [seq.tolist() for seq in new_attention_mask], + } + + logging.debug(len(ret["input_ids"])) + return ret def load_pretraining_dataset(path, tokenizer, max_tokens=2048): - return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens) + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + dataset = load_dataset(path, streaming=True, split="train") + dataset = dataset.shuffle(seed=42, buffer_size=10_000) + # TODO dynamically figure out which columns/features to remove + dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"]) + return dataset diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 2e2450fba..603afbfee 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -77,6 +77,11 @@ def validate_config(cfg): f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" ) + if cfg.pretraining_dataset and cfg.group_by_length: + logging.warning( + "You probably want to disable group_by_length as it will force a streamed dataset to download completely." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index 50bdf37e6..575392ab4 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -198,3 +198,54 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_flash_optimum(self): + cfg = DictDefault( + { + "flash_optimum": True, + "adapter": "lora", + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "BetterTransformers probably doesn't work with PEFT adapters" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "flash_optimum": True, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "probably set bfloat16 or float16" in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "flash_optimum": True, + "fp16": True, + } + ) + regex_exp = r".*AMP is not supported.*" + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg) + + cfg = DictDefault( + { + "flash_optimum": True, + "bf16": True, + } + ) + regex_exp = r".*AMP is not supported.*" + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg) From 0c6f928601ac289f7d4b513855feab5047cd7a5a Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 10 Jun 2023 14:21:43 -0400 Subject: [PATCH 11/14] address PR feedback --- examples/pythia-12b/README.md | 2 +- examples/pythia-12b/config.yml | 4 ++-- scripts/finetune.py | 5 ++++- src/axolotl/utils/data.py | 4 ++-- src/axolotl/utils/trainer.py | 2 -- 5 files changed, 9 insertions(+), 8 deletions(-) diff --git a/examples/pythia-12b/README.md b/examples/pythia-12b/README.md index d28d5e77d..123ffa710 100644 --- a/examples/pythia-12b/README.md +++ b/examples/pythia-12b/README.md @@ -1,4 +1,4 @@ -# Python 12B +# Pythia 12B - Single-GPU A100 only (?) diff --git a/examples/pythia-12b/config.yml b/examples/pythia-12b/config.yml index 28e822c77..3b3d91630 100644 --- a/examples/pythia-12b/config.yml +++ b/examples/pythia-12b/config.yml @@ -22,7 +22,7 @@ lora_dropout: 0.0 lora_target_modules: lora_target_linear: true lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific -wandb_project: pythia-12b +wandb_project: wandb_watch: wandb_run_id: wandb_log_model: @@ -45,5 +45,5 @@ resume_from_checkpoint: local_rank: gradient_checkpointing: true fsdp: -fsdp_transformer_layer_cls_to_wrap: +fsdp_config: collator_pad_to_longest: true diff --git a/scripts/finetune.py b/scripts/finetune.py index ab226f68f..47aada411 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -208,7 +208,10 @@ def train( ) else: train_dataset = load_pretraining_dataset( - cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len + cfg.pretraining_dataset, + tokenizer, + max_tokens=cfg.sequence_len, + seed=cfg.seed, ) # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 train_dataset = train_dataset.with_format("torch") diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 492d8059b..058c24bcd 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -505,10 +505,10 @@ def encode_pretraining(tokenizer, max_tokens, examples): return ret -def load_pretraining_dataset(path, tokenizer, max_tokens=2048): +def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42): encode = functools.partial(encode_pretraining, tokenizer, max_tokens) dataset = load_dataset(path, streaming=True, split="train") - dataset = dataset.shuffle(seed=42, buffer_size=10_000) + dataset = dataset.shuffle(seed=seed, buffer_size=10_000) # TODO dynamically figure out which columns/features to remove dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"]) return dataset diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 59b1dc803..57a08aa53 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,7 +1,6 @@ """Module containing the Trainer class and related functions""" import importlib -import logging import math import os import sys @@ -232,7 +231,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): callbacks.append(SavePeftModelCallback) if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: - logging.info("Setting up SaveBetterTransformerModelCallback.") callbacks.append(SaveBetterTransformerModelCallback) data_collator_kwargs = { From 759e8673ce497125da5855a173fd80f57bb071b3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 10 Jun 2023 14:25:21 -0400 Subject: [PATCH 12/14] Update scripts/finetune.py Co-authored-by: NanoCode012 --- scripts/finetune.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index 47aada411..cd9234334 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -261,7 +261,6 @@ def train( model.save_pretrained(cfg.output_dir) return - model.train() trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) From 958da703762b7759eabdaa6fd7fad231228e1ad9 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 10 Jun 2023 15:28:08 -0400 Subject: [PATCH 13/14] fix formatting --- scripts/finetune.py | 1 - src/axolotl/utils/trainer.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index cd9234334..2f6bef3ef 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -261,7 +261,6 @@ def train( model.save_pretrained(cfg.output_dir) return - trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) model.config.use_cache = False diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 57a08aa53..b7823fea4 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -1,6 +1,7 @@ """Module containing the Trainer class and related functions""" import importlib +import logging import math import os import sys From c9a149f9e8bacdcd59a9e6de435499b2f4a845c1 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 11 Jun 2023 10:11:17 -0400 Subject: [PATCH 14/14] add check for attr --- src/axolotl/utils/models.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 49a9b6f85..532fa5518 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -300,7 +300,10 @@ def load_model( embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) - if cfg.sequence_len >= model.config.max_position_embeddings: + if ( + hasattr(model.config, "max_position_embeddings") + and cfg.sequence_len >= model.config.max_position_embeddings + ): logging.warning( f"increasing model.config.max_position_embeddings to {cfg.sequence_len}" )