diff --git a/configs/gpt_neox_20b.yml b/configs/gpt_neox_20b.yml new file mode 100644 index 000000000..91698ffaa --- /dev/null +++ b/configs/gpt_neox_20b.yml @@ -0,0 +1,39 @@ +base_model: EleutherAI/gpt-neox-20b +base_model_ignore_patterns: pytorch* # prefer safetensors +model_type: GPTNeoXForCausalLM +tokenizer_type: AutoTokenizer +load_in_8bit: true +datasets: + - path: nomic-ai/gpt4all-j-prompt-generations + type: alpaca + shards: 4 + shards_index: 0 +dataset_prepared_path: last_run_prepared +val_set_size: 0.05 +adapter: lora +lora_model_dir: +sequence_len: 2048 +max_packed_sequence_len: 2048 +lora_r: 8 +lora_alpha: 32 +lora_dropout: 0.05 +lora_target_modules: + - query_key_value +lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific +wandb_project: gpt4all-neox-20b +wandb_watch: +wandb_run_id: +wandb_log_model: checkpoint +output_dir: ./gpt4all-neox-20b +batch_size: 48 +micro_batch_size: 4 +num_epochs: 5 +learning_rate: 0.00003 +lr_scheduler: one_cycle +train_on_inputs: false +group_by_length: false +bf16: True +tf32: True +early_stopping_patience: +resume_from_checkpoint: +local_rank: diff --git a/scripts/finetune.py b/scripts/finetune.py index 576681cb5..e2c813416 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -1,223 +1,29 @@ import logging -import math import os import random import signal import sys -from hashlib import md5 from pathlib import Path -import bitsandbytes as bnb import fire import torch -import transformers import yaml from attrdict import AttrDefault -from datasets import load_dataset, IterableDataset, Dataset, load_from_disk -from torch import nn -from transformers import ( - AutoModelForCausalLM, - AutoTokenizer, - LlamaForCausalLM, - LlamaTokenizer, - EarlyStoppingCallback, - GenerationConfig, -) # add src to the pythonpath so we don't need to pip install this -from transformers.trainer_pt_utils import get_parameter_names - project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) src_dir = os.path.join(project_root, "src") sys.path.insert(0, src_dir) -from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset -from axolotl.prompt_tokenizers import ( - AlpacaPromptTokenizingStrategy, - ShareGPTPromptTokenizingStrategy, - LLAMA_DEFAULT_PAD_TOKEN, - GPTeacherPromptTokenizingStrategy, - OpenAssistantPromptTokenizingStrategy, AlpacaReflectionPTStrategy, -) -from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter, ReflectAlpacaPrompter +from axolotl.utils.data import load_prepare_datasets +from axolotl.utils.models import load_model +from axolotl.utils.trainer import setup_trainer +from axolotl.utils.wandb import setup_wandb_env_vars logging.basicConfig(level=os.getenv("LOG_LEVEL", "INFO")) DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" -def setup_wandb_env_vars(cfg): - if cfg.wandb_project and len(cfg.wandb_project) > 0: - os.environ["WANDB_PROJECT"] = cfg.wandb_project - cfg.use_wandb = True - if cfg.wandb_watch and len(cfg.wandb_watch) > 0: - os.environ["WANDB_WATCH"] = cfg.wandb_watch - if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: - os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model - if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0: - os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id - - -def load_model( - base_model, - base_model_config, - model_type, - tokenizer_type, - cfg, - adapter="lora", - inference: bool = False, -): - # TODO refactor as a kwarg - load_in_8bit = cfg.load_in_8bit - tokenizer = None - is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower() - - if adapter != "lora": - raise NotImplementedError(f"{adapter} peft adapter not available") - if is_llama_derived_model and cfg.flash_attention: - if cfg.device not in ["mps", "cpu"] and inference is False: - from axolotl.flash_attn import replace_llama_attn_with_flash_attn - - logging.info("patching with flash attention") - replace_llama_attn_with_flash_attn() - - torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,) - try: - if cfg.load_4bit: - from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( - replace_peft_model_with_int4_lora_model, - ) - - replace_peft_model_with_int4_lora_model() - - from peft import ( - LoraConfig, - get_peft_model, - prepare_model_for_int8_training, - PeftModel, - ) - except Exception as e: - logging.exception(e) - raise e - - try: - if cfg.load_4bit and is_llama_derived_model: - from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram - from huggingface_hub import snapshot_download - - cache_model_path = Path(snapshot_download(base_model)) - files = ( - list(cache_model_path.glob("*.pt")) - + list(cache_model_path.glob("*.safetensors")) - + list(cache_model_path.glob("*.bin")) - ) - if len(files) > 0: - model_path = str(files[0]) - else: - logging.warning( - "unable to find a cached model file, this will likely fail..." - ) - model_path = str(cache_model_path) - model, tokenizer = load_llama_model_4bit_low_ram( - base_model_config if base_model_config else base_model, - model_path, - device_map=cfg.device_map, - groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1, - is_v1_model=cfg.gptq_model_v1 - if cfg.gptq_model_v1 is not None - else True, - ) - load_in_8bit = False - elif is_llama_derived_model: - model = LlamaForCausalLM.from_pretrained( - base_model, - load_in_8bit=cfg.load_in_8bit, - torch_dtype=torch_dtype, - device_map=cfg.device_map, - ) - else: - model = getattr(transformers, model_type).from_pretrained( - base_model, - load_in_8bit=cfg.load_in_8bit, - torch_dtype=torch_dtype, - device_map=cfg.device_map, - ) - except Exception as e: - logging.error( - "Exception raised attempting to load model, retrying with AutoModelForCausalLM" - ) - logging.exception(e) - model = AutoModelForCausalLM.from_pretrained( - base_model, - load_in_8bit=cfg.load_in_8bit, - torch_dtype=torch_dtype, - device_map=cfg.device_map, - ) - - if not tokenizer: - try: - if is_llama_derived_model: - tokenizer = LlamaTokenizer.from_pretrained(model) - else: - tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) - except: - tokenizer = AutoTokenizer.from_pretrained(base_model) - - logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") - logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") - logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") - logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") - - if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: - tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN - - if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": - tokenizer.add_special_tokens({"pad_token": "[PAD]"}) - os.environ["TOKENIZERS_PARALLELISM"] = "false" - - if load_in_8bit and not cfg.load_4bit: - logging.info("converting model w/ prepare_model_for_int8_training") - model = prepare_model_for_int8_training(model) - - lora_config = LoraConfig( - r=cfg.lora_r, - lora_alpha=cfg.lora_alpha, - target_modules=cfg.lora_target_modules, - lora_dropout=cfg.lora_dropout, - fan_in_fan_out=cfg.lora_fan_in_fan_out, - bias="none", - task_type="CAUSAL_LM", - ) - - if cfg.lora_model_dir: - model = PeftModel.from_pretrained( - model, - cfg.lora_model_dir, - device_map=cfg.device_map, - torch_dtype=torch.float16, - ) - else: - model = get_peft_model(model, lora_config) - - if cfg.ddp: - model.to(f"cuda:{cfg.local_rank}") - - if cfg.load_4bit: - # Scales to half - logging.info("Fitting 4bit scales and zeros to half") - for n, m in model.named_modules(): - if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str( - type(m) - ): - if hasattr(m, "is_v1_model") and m.is_v1_model: - m.zeros = m.zeros.half() - m.scales = m.scales.half() - m.bias = m.bias.half() - - # TODO resume_from_checkpoint handling - model.print_trainable_parameters() - return model, tokenizer, lora_config - - def choose_device(cfg): def get_device(): if torch.cuda.is_available(): @@ -271,11 +77,13 @@ def do_inference(cfg, model, tokenizer): tokenizer.add_special_tokens({"bos_token": ""}) tokenizer.add_special_tokens({"eos_token": ""}) - instruction = "Tell me a joke about dromedaries." - input = "" - prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n".format( - instruction=instruction, input=input + from axolotl.prompters import ReflectAlpacaPrompter + + instruction = str(input("Give me an instruction: ")) + instruction = ( + instruction if not instruction else "Tell me a joke about dromedaries." ) + prompt = ReflectAlpacaPrompter().build_prompt(instruction=instruction) batch = tokenizer(prompt, return_tensors="pt", add_special_tokens=True) model.eval() @@ -324,98 +132,6 @@ def choose_config(path: Path): return chosen_file -def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): - total_num_steps = int( - math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) - ) - warmup_steps = min(int(0.03 * total_num_steps), 100) - logging_steps = max(min(int(0.005 * total_num_steps), 10), 1) - save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) - - training_arguments_kwargs = {} - if cfg.bf16 == "full": - training_arguments_kwargs["bf16_full_eval"] = True - else: - training_arguments_kwargs["bf16"] = cfg.bf16 - training_arguments_kwargs["tf32"] = cfg.tf32 - training_arguments_kwargs["warmup_steps"] = warmup_steps - training_arguments_kwargs["logging_steps"] = logging_steps - if cfg.gradient_checkpointing is not None: - training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing - - training_args = transformers.TrainingArguments( - per_device_train_batch_size=cfg.micro_batch_size, - gradient_accumulation_steps=cfg.gradient_accumulation_steps, - num_train_epochs=cfg.num_epochs, - learning_rate=cfg.learning_rate, - evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", - save_strategy="steps", - eval_steps=eval_steps if cfg.val_set_size > 0 else None, - save_steps=save_steps, - output_dir=cfg.output_dir, - save_total_limit=3, - load_best_model_at_end=True if cfg.val_set_size > 0 else False, - ddp_find_unused_parameters=False if cfg.ddp else None, - group_by_length=cfg.group_by_length, - report_to="wandb" if cfg.use_wandb else None, - run_name=cfg.wandb_run_id if cfg.use_wandb else None, - **training_arguments_kwargs, - ) - - decay_parameters = get_parameter_names(model, [nn.LayerNorm]) - decay_parameters = [name for name in decay_parameters if "bias" not in name] - optimizer_grouped_parameters = [ - { - "params": [p for n, p in model.named_parameters() if n in decay_parameters], - "weight_decay": training_args.weight_decay, - }, - { - "params": [ - p for n, p in model.named_parameters() if n not in decay_parameters - ], - "weight_decay": 0.0, - }, - ] - - trainer_kwargs = {} - - if cfg.load_in_8bit and not cfg.load_4bit: - adam_bnb_optim = bnb.optim.Adam8bit( - optimizer_grouped_parameters, - betas=(training_args.adam_beta1, training_args.adam_beta2), - eps=training_args.adam_epsilon, - lr=training_args.learning_rate, - ) - - # TODO optionally use torch.optim.OneCycleLR - lr_scheduler = transformers.get_cosine_schedule_with_warmup( - adam_bnb_optim, - training_args.warmup_steps, - total_num_steps, - ) - trainer_kwargs["optimizers"] = (adam_bnb_optim, lr_scheduler) - - # TODO on_save callback to sync checkpoints to GCP/AWS in background - if cfg.early_stopping_patience: - early_stop_cb = EarlyStoppingCallback( - cfg.early_stopping_patience, - ) - trainer_kwargs["callbacks"] = [early_stop_cb] - - trainer = transformers.Trainer( - model=model, - train_dataset=train_dataset, - eval_dataset=eval_dataset, - args=training_args, - data_collator=transformers.DataCollatorForSeq2Seq( - tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True - ), - **trainer_kwargs, - ) - - return trainer - - def train( config: Path = Path("configs/"), prepare_ds_only: bool = False, @@ -474,110 +190,13 @@ def train( do_inference(cfg, model, tokenizer) return - max_packed_sequence_len = ( - cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len - ) - max_packed_sequence_len = min( - max_packed_sequence_len, cfg.sequence_len - ) # make sure we don't accidentally set it larger than sequence_len - ds_hash = str( - md5( - ( - str(max_packed_sequence_len) - + "@" - + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) - ).encode("utf-8") - ).hexdigest() - ) - prepared_ds_path = ( - Path(cfg.dataset_prepared_path) / ds_hash - if cfg.dataset_prepared_path - else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash + train_dataset, eval_dataset = load_prepare_datasets( + tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) - if any(prepared_ds_path.glob("*")): - logging.info("Loading prepared dataset from disk...") - dataset = load_from_disk(str(prepared_ds_path)) - logging.info("Prepared dataset loaded from disk...") - else: - logging.info("Loading raw datasets...") - datasets = [] - for d in cfg.datasets: - ds_from_hub = False - try: - load_dataset(d.path, streaming=True) - ds_from_hub = True - except FileNotFoundError: - pass - - # prefer local dataset, even if hub exists - if Path(d.path).exists(): - ds: IterableDataset = load_dataset( - "json", data_files=d.path, streaming=True, split=None - ) - elif ds_from_hub: - ds = load_dataset(d.path, streaming=True) - else: - raise Exception("unhandled dataset load") - - if d.type == "alpaca": - ds_strategy = AlpacaPromptTokenizingStrategy( - AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) - datasets.append(ds_wrapper) - elif d.type == "oasst": - ds_strategy = OpenAssistantPromptTokenizingStrategy( - AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) - datasets.append(ds_wrapper) - elif d.type == "gpteacher": - ds_strategy = GPTeacherPromptTokenizingStrategy( - GPTeacherPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) - datasets.append(ds_wrapper) - elif d.type == "reflection": - ds_strategy = AlpacaReflectionPTStrategy( - ReflectAlpacaPrompter(), - tokenizer, - cfg.train_on_inputs, - cfg.sequence_len, - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) - datasets.append(ds_wrapper) - elif d.type == "sharegpt": - ds_strategy = ShareGPTPromptTokenizingStrategy( - ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len - ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) - datasets.append(ds_wrapper) - else: - logging.error(f"unhandled prompt tokenization strategy: {d.type}") - constant_len_dataset = ConstantLengthDataset( - tokenizer, - datasets, - seq_length=max_packed_sequence_len, - ) - logging.info("merging, packing, shuffling, and splitting master dataset") - dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split( - test_size=cfg.val_set_size, shuffle=True, seed=42 - ) - - if cfg.local_rank == 0: - logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}") - dataset.save_to_disk(prepared_ds_path) - - if prepare_ds_only: - logging.info("Finished preparing dataset. Exiting...") - return - - train_dataset = dataset["train"] - eval_dataset = dataset["test"] + if prepare_ds_only: + logging.info("Finished preparing dataset. Exiting...") + return if cfg.debug: check_dataset_labels( @@ -594,8 +213,9 @@ def train( model = torch.compile(model) # go ahead and presave, so we have the adapter config available to inspect - logging.info(f"Pre-saving adapter config to {cfg.output_dir}") - lora_config.save_pretrained(cfg.output_dir) + if lora_config: + logging.info(f"Pre-saving adapter config to {cfg.output_dir}") + lora_config.save_pretrained(cfg.output_dir) # In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model if cfg.local_rank == 0: diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index ac00f5d6b..1909ec289 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -107,6 +107,15 @@ class GPTeacherPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): ) +class NomicGPT4AllPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + return ( + prompt["prompt"], + "", + prompt["response"], + ) + + class ReflectionPromptTokenizingStrategy(PromptTokenizingStrategy): def parse_instruction_fields(self, prompt) -> (str, str, str, str, str): raise NotImplementedError @@ -168,6 +177,7 @@ class AlpacaReflectionPTStrategy(ReflectionPromptTokenizingStrategy): prompt["corrected"], ) + class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): def tokenize_prompt(self, prompt): try: diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 4a991a8ec..070f10acb 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -35,6 +35,10 @@ class GPTeacherPrompter(AlpacaPrompter): ... +class NomicGPT4AllPrompter(AlpacaPrompter): + ... + + class ReflectAlpacaPrompter: prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n" diff --git a/src/axolotl/utils/__init__.py b/src/axolotl/utils/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py new file mode 100644 index 000000000..4e064a881 --- /dev/null +++ b/src/axolotl/utils/data.py @@ -0,0 +1,125 @@ +import logging +from hashlib import md5 +from pathlib import Path + +from datasets import load_from_disk, load_dataset, IterableDataset, Dataset + +from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset +from axolotl.prompt_tokenizers import ( + AlpacaPromptTokenizingStrategy, + GPTeacherPromptTokenizingStrategy, + OpenAssistantPromptTokenizingStrategy, + AlpacaReflectionPTStrategy, + ShareGPTPromptTokenizingStrategy, +) +from axolotl.prompters import ( + AlpacaPrompter, + GPTeacherPrompter, + ReflectAlpacaPrompter, + ShareGPTPrompter, +) + + +def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): + max_packed_sequence_len = ( + cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len + ) + max_packed_sequence_len = min( + max_packed_sequence_len, cfg.sequence_len + ) # make sure we don't accidentally set it larger than sequence_len + ds_hash = str( + md5( + ( + str(max_packed_sequence_len) + + "@" + + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + ).encode("utf-8") + ).hexdigest() + ) + prepared_ds_path = ( + Path(cfg.dataset_prepared_path) / ds_hash + if cfg.dataset_prepared_path + else Path(default_dataset_prepared_path) / ds_hash + ) + + if any(prepared_ds_path.glob("*")): + logging.info("Loading prepared dataset from disk...") + dataset = load_from_disk(str(prepared_ds_path)) + logging.info("Prepared dataset loaded from disk...") + else: + logging.info("Loading raw datasets...") + datasets = [] + for d in cfg.datasets: + ds_from_hub = False + try: + load_dataset(d.path, streaming=True) + ds_from_hub = True + except FileNotFoundError: + pass + + # prefer local dataset, even if hub exists + if Path(d.path).exists(): + ds: IterableDataset = load_dataset( + "json", data_files=d.path, streaming=True, split=None + ) + elif ds_from_hub: + ds = load_dataset(d.path, streaming=True) + else: + raise Exception("unhandled dataset load") + + if d.type == "alpaca": + ds_strategy = AlpacaPromptTokenizingStrategy( + AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) + elif d.type == "oasst": + ds_strategy = OpenAssistantPromptTokenizingStrategy( + AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) + elif d.type == "gpteacher": + ds_strategy = GPTeacherPromptTokenizingStrategy( + GPTeacherPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) + elif d.type == "reflection": + ds_strategy = AlpacaReflectionPTStrategy( + ReflectAlpacaPrompter(), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) + elif d.type == "sharegpt": + ds_strategy = ShareGPTPromptTokenizingStrategy( + ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) + else: + logging.error(f"unhandled prompt tokenization strategy: {d.type}") + constant_len_dataset = ConstantLengthDataset( + tokenizer, + datasets, + seq_length=max_packed_sequence_len, + ) + logging.info("merging, packing, shuffling, and splitting master dataset") + dataset = Dataset.from_list([_ for _ in constant_len_dataset]).train_test_split( + test_size=cfg.val_set_size, shuffle=True, seed=42 + ) + + if cfg.local_rank == 0: + logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}") + dataset.save_to_disk(prepared_ds_path) + + train_dataset = dataset["train"] + eval_dataset = dataset["test"] + + return train_dataset, eval_dataset diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py new file mode 100644 index 000000000..c51d0cd53 --- /dev/null +++ b/src/axolotl/utils/models.py @@ -0,0 +1,206 @@ +import logging +import os +from pathlib import Path +from typing import Optional, Tuple, TYPE_CHECKING + +import torch +import transformers +from transformers import ( + AutoModelForCausalLM, + LlamaForCausalLM, + LlamaTokenizer, + AutoTokenizer, + PreTrainedModel, +) + +from axolotl.prompt_tokenizers import LLAMA_DEFAULT_PAD_TOKEN + +if TYPE_CHECKING: + from peft import PeftModel, PeftConfig + from attrdict import AttrDefault + from transformers import PreTrainedTokenizer + + +def load_model( + base_model, + base_model_config, + model_type, + tokenizer_type, + cfg, + adapter="lora", + inference=False, +): + # type: (str, str, str, str, AttrDefault, Optional[str], bool) -> Tuple[PreTrainedModel, PreTrainedTokenizer, Optional[PeftConfig]] + + # TODO refactor as a kwarg + load_in_8bit = cfg.load_in_8bit + tokenizer = None + is_llama_derived_model = "llama" in base_model or "llama" in cfg.model_type.lower() + + if is_llama_derived_model and cfg.flash_attention: + if cfg.device not in ["mps", "cpu"] and inference is False: + from axolotl.flash_attn import replace_llama_attn_with_flash_attn + + logging.info("patching with flash attention") + replace_llama_attn_with_flash_attn() + + torch_dtype = (torch.float16 if cfg.load_in_8bit or cfg.fp16 else torch.float32,) + try: + if cfg.load_4bit: + from alpaca_lora_4bit.monkeypatch.peft_tuners_lora_monkey_patch import ( + replace_peft_model_with_int4_lora_model, + ) + + replace_peft_model_with_int4_lora_model() + from peft import prepare_model_for_int8_training + except Exception as e: + logging.exception(e) + raise e + + try: + if cfg.load_4bit and is_llama_derived_model: + from alpaca_lora_4bit.autograd_4bit import load_llama_model_4bit_low_ram + from huggingface_hub import snapshot_download + + cache_model_path = Path(snapshot_download(base_model)) + files = ( + list(cache_model_path.glob("*.pt")) + + list(cache_model_path.glob("*.safetensors")) + + list(cache_model_path.glob("*.bin")) + ) + if len(files) > 0: + model_path = str(files[0]) + else: + logging.warning( + "unable to find a cached model file, this will likely fail..." + ) + model_path = str(cache_model_path) + model, tokenizer = load_llama_model_4bit_low_ram( + base_model_config if base_model_config else base_model, + model_path, + device_map=cfg.device_map, + groupsize=cfg.gptq_groupsize if cfg.gptq_groupsize else -1, + is_v1_model=cfg.gptq_model_v1 + if cfg.gptq_model_v1 is not None + else True, + ) + load_in_8bit = False + elif is_llama_derived_model: + model = LlamaForCausalLM.from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit, + torch_dtype=torch_dtype, + device_map=cfg.device_map, + ) + else: + model = getattr(transformers, model_type).from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit, + torch_dtype=torch_dtype, + device_map=cfg.device_map, + ) + except Exception as e: + logging.error( + "Exception raised attempting to load model, retrying with AutoModelForCausalLM" + ) + logging.exception(e) + model = AutoModelForCausalLM.from_pretrained( + base_model, + load_in_8bit=cfg.load_in_8bit, + torch_dtype=torch_dtype, + device_map=cfg.device_map, + ) + + if not tokenizer: + try: + if is_llama_derived_model: + tokenizer = LlamaTokenizer.from_pretrained(model) + else: + tokenizer = getattr(transformers, tokenizer_type).from_pretrained(model) + except: + tokenizer = AutoTokenizer.from_pretrained(base_model) + + logging.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}") + logging.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}") + logging.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}") + logging.debug(f"UNK: {tokenizer.unk_token_id} / {tokenizer.unk_token}") + + if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: + tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN + + if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": + tokenizer.add_special_tokens({"pad_token": "[PAD]"}) + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + if load_in_8bit and not cfg.load_4bit: + logging.info("converting model w/ prepare_model_for_int8_training") + model = prepare_model_for_int8_training(model) + + model, lora_config = load_adapter(model, cfg, adapter) + + if cfg.ddp: + model.to(f"cuda:{cfg.local_rank}") + + if cfg.load_4bit: + # Scales to half + logging.info("Fitting 4bit scales and zeros to half") + for n, m in model.named_modules(): + if "Autograd4bitQuantLinear" in str(type(m)) or "Linear4bitLt" in str( + type(m) + ): + if hasattr(m, "is_v1_model") and m.is_v1_model: + m.zeros = m.zeros.half() + m.scales = m.scales.half() + m.bias = m.bias.half() + + # TODO resume_from_checkpoint handling + return model, tokenizer, lora_config + + +def load_adapter(model, cfg, adapter): + # type: (PreTrainedModel, AttrDefault, Optional[str]) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + + if adapter is None: + return model, None + if adapter == "lora": + return load_lora(model, cfg) + # TODO support Llama-Adapter once merged into peft https://github.com/huggingface/peft/pulls + + raise NotImplementedError(f"{adapter} peft adapter not available") + + +def load_lora(model, cfg): + # type: (PreTrainedModel, AttrDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]] + + from peft import ( + LoraConfig, + get_peft_model, + PeftModel, + ) + + lora_config = None + + if cfg.adapter == "lora": + lora_config = LoraConfig( + r=cfg.lora_r, + lora_alpha=cfg.lora_alpha, + target_modules=cfg.lora_target_modules, + lora_dropout=cfg.lora_dropout, + fan_in_fan_out=cfg.lora_fan_in_fan_out, + bias="none", + task_type="CAUSAL_LM", + ) + + if cfg.lora_model_dir: + model = PeftModel.from_pretrained( + model, + cfg.lora_model_dir, + device_map=cfg.device_map, + torch_dtype=torch.float16, + ) + else: + model = get_peft_model(model, lora_config) + + model.print_trainable_parameters() + + return model, lora_config diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py new file mode 100644 index 000000000..f1c357803 --- /dev/null +++ b/src/axolotl/utils/trainer.py @@ -0,0 +1,109 @@ +import math +import bitsandbytes as bnb +import transformers +from torch import nn +from torch.optim.lr_scheduler import OneCycleLR +from transformers import EarlyStoppingCallback +from transformers.trainer_pt_utils import get_parameter_names + + +def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): + total_num_steps = int( + math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) + ) + warmup_steps = min(int(0.03 * total_num_steps), 100) + logging_steps = max(min(int(0.005 * total_num_steps), 10), 1) + save_steps = eval_steps = min(int(0.05 * total_num_steps), 200) + + training_arguments_kwargs = {} + if cfg.bf16 == "full": + training_arguments_kwargs["bf16_full_eval"] = True + else: + training_arguments_kwargs["bf16"] = cfg.bf16 + training_arguments_kwargs["tf32"] = cfg.tf32 + training_arguments_kwargs["warmup_steps"] = warmup_steps + training_arguments_kwargs["logging_steps"] = logging_steps + if cfg.gradient_checkpointing is not None: + training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing + + training_args = transformers.TrainingArguments( + per_device_train_batch_size=cfg.micro_batch_size, + gradient_accumulation_steps=cfg.gradient_accumulation_steps, + num_train_epochs=cfg.num_epochs, + learning_rate=cfg.learning_rate, + evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", + save_strategy="steps", + eval_steps=eval_steps if cfg.val_set_size > 0 else None, + save_steps=save_steps, + output_dir=cfg.output_dir, + save_total_limit=3, + load_best_model_at_end=True if cfg.val_set_size > 0 else False, + ddp_find_unused_parameters=False if cfg.ddp else None, + group_by_length=cfg.group_by_length, + report_to="wandb" if cfg.use_wandb else None, + run_name=cfg.wandb_run_id if cfg.use_wandb else None, + **training_arguments_kwargs, + ) + + decay_parameters = get_parameter_names(model, [nn.LayerNorm]) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + optimizer_grouped_parameters = [ + { + "params": [p for n, p in model.named_parameters() if n in decay_parameters], + "weight_decay": training_args.weight_decay, + }, + { + "params": [ + p for n, p in model.named_parameters() if n not in decay_parameters + ], + "weight_decay": 0.0, + }, + ] + + trainer_kwargs = {} + + if cfg.load_in_8bit and not cfg.load_4bit: + optimizer = bnb.optim.Adam8bit( + optimizer_grouped_parameters, + betas=(training_args.adam_beta1, training_args.adam_beta2), + eps=training_args.adam_epsilon, + lr=training_args.learning_rate, + ) + + if cfg.lr_scheduler == "one_cycle": + lr_scheduler_kwargs = ( + cfg.lr_scheduler_kwargs if cfg.lr_scheduler_kwargs else {} + ) + lr_scheduler = OneCycleLR( + optimizer, + cfg.learning_rate, + total_steps=total_num_steps, + **lr_scheduler_kwargs, + ) + else: + lr_scheduler = transformers.get_cosine_schedule_with_warmup( + optimizer, + training_args.warmup_steps, + total_num_steps, + ) + trainer_kwargs["optimizers"] = (optimizer, lr_scheduler) + + # TODO on_save callback to sync checkpoints to GCP/AWS in background + if cfg.early_stopping_patience: + early_stop_cb = EarlyStoppingCallback( + cfg.early_stopping_patience, + ) + trainer_kwargs["callbacks"] = [early_stop_cb] + + trainer = transformers.Trainer( + model=model, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + args=training_args, + data_collator=transformers.DataCollatorForSeq2Seq( + tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True + ), + **trainer_kwargs, + ) + + return trainer diff --git a/src/axolotl/utils/wandb.py b/src/axolotl/utils/wandb.py new file mode 100644 index 000000000..1e805c6c6 --- /dev/null +++ b/src/axolotl/utils/wandb.py @@ -0,0 +1,13 @@ +import os + + +def setup_wandb_env_vars(cfg): + if cfg.wandb_project and len(cfg.wandb_project) > 0: + os.environ["WANDB_PROJECT"] = cfg.wandb_project + cfg.use_wandb = True + if cfg.wandb_watch and len(cfg.wandb_watch) > 0: + os.environ["WANDB_WATCH"] = cfg.wandb_watch + if cfg.wandb_log_model and len(cfg.wandb_log_model) > 0: + os.environ["WANDB_LOG_MODEL"] = cfg.wandb_log_model + if cfg.wandb_run_id and len(cfg.wandb_run_id) > 0: + os.environ["WANDB_RUN_ID"] = cfg.wandb_run_id