diff --git a/configs/llama_65B_alpaca.yml b/configs/llama_65B_alpaca.yml index 556c09157..54fe0786f 100644 --- a/configs/llama_65B_alpaca.yml +++ b/configs/llama_65B_alpaca.yml @@ -1,4 +1,4 @@ -base_model: decapoda-research/llama-65b-hf +base_model: huggyllama/llama-7b model_type: LlamaForCausalLM tokenizer_type: LlamaTokenizer load_in_8bit: true @@ -33,8 +33,8 @@ num_epochs: 5 learning_rate: 0.00003 train_on_inputs: false group_by_length: false -bf16: True -tf32: True +bf16: true +tf32: true resume_from_checkpoint: local_rank: deepspeed: diff --git a/requirements.txt b/requirements.txt index 048936baf..8e12e8b7b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,6 @@ accelerate sentencepiece wandb flash-attn +deepspeed +einops + diff --git a/scripts/finetune.py b/scripts/finetune.py index c5d467c6f..33d3f1a51 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -11,7 +11,7 @@ import torch import transformers import yaml from attrdict import AttrDefault -from datasets import load_dataset, IterableDataset, Dataset +from datasets import load_dataset, IterableDataset, Dataset, load_from_disk from peft import ( LoraConfig, get_peft_model, @@ -52,8 +52,9 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): if adapter != "lora": raise NotImplementedError(f"{adapter} peft adapter not available") if "llama" in base_model: - from axolotl.flash_attn import replace_llama_attn_with_flash_attn - replace_llama_attn_with_flash_attn() + if cfg.device not in ["mps", "cpu"]: + from axolotl.flash_attn import replace_llama_attn_with_flash_attn + replace_llama_attn_with_flash_attn() try: if "llama" in base_model: @@ -86,7 +87,7 @@ def load_model(base_model, model_type, tokenizer_type, cfg, adapter="lora"): except: tokenizer = AutoTokenizer.from_pretrained(base_model) - if tokenizer.__class__.__name__ == "LlamaTokenizer": + if tokenizer.__class__.__name__ in ["LlamaTokenizer", "LlamaTokenizerFast"]: tokenizer.pad_token = LLAMA_DEFAULT_PAD_TOKEN if tokenizer.__class__.__name__ == "GPTNeoXTokenizerFast": @@ -255,8 +256,9 @@ def train( return datasets = [] - if len(cfg.datasets) == 1 and cfg.datasets[0].type == "arrow": - dataset = load_dataset(cfg.datasets[0].path, split="train") + if not isinstance(cfg.datasets, list) and isinstance(cfg.datasets, str): + # assumption that we are loading a previously saved/cached dataset + dataset = load_from_disk(cfg.datasets) else: for d in cfg.datasets: ds: IterableDataset = load_dataset( @@ -288,7 +290,6 @@ def train( [_ for _ in constant_len_dataset] ).train_test_split(test_size=cfg.val_set_size, shuffle=True, seed=42) dataset.save_to_disk("data/last_run") - print(dataset) train_dataset = dataset["train"] eval_dataset = dataset["test"] diff --git a/setup.cfg b/setup.cfg index 8f0ba619a..0a41a6876 100644 --- a/setup.cfg +++ b/setup.cfg @@ -23,6 +23,7 @@ install_requires = sentencepiece wandb flash-attn + einops [options.packages.find] where = src diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 6d9902106..014fdbc3a 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -93,22 +93,24 @@ class ConstantLengthDataset(IterableDataset): buffer_len = 0 if example: - input_ids = example["input_ids"] - attention_mask = example["attention_mask"] - labels = example["labels"] + # just going to drop data points that are too long + if len(example["input_ids"]) <= self.seq_length: + input_ids = example["input_ids"] + attention_mask = example["attention_mask"] + labels = example["labels"] - if add_concat_token: - input_ids.append(self.concat_token_id) - attention_mask.append(1) - labels.append(self.concat_token_id) + if add_concat_token: + input_ids.append(self.concat_token_id) + attention_mask.append(1) + labels.append(self.concat_token_id) - input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long) - attention_mask_with_concat = torch.tensor( - attention_mask, dtype=torch.long - ) - labels_with_concat = torch.tensor(labels, dtype=torch.long) + input_ids_with_concat = torch.tensor(input_ids, dtype=torch.long) + attention_mask_with_concat = torch.tensor( + attention_mask, dtype=torch.long + ) + labels_with_concat = torch.tensor(labels, dtype=torch.long) - buffer["input_ids"].append(input_ids_with_concat) - buffer["attention_mask"].append(attention_mask_with_concat) - buffer["labels"].append(labels_with_concat) - buffer_len += len(input_ids) + buffer["input_ids"].append(input_ids_with_concat) + buffer["attention_mask"].append(attention_mask_with_concat) + buffer["labels"].append(labels_with_concat) + buffer_len += len(input_ids)