diff --git a/README.md b/README.md index de929f237..2bc55732d 100644 --- a/README.md +++ b/README.md @@ -410,6 +410,8 @@ optimizer: # specify weight decay weight_decay: +# whether to bettertransformers +flash_optimum: # whether to use xformers attention patch https://github.com/facebookresearch/xformers: xformers_attention: # whether to use flash attention patch https://github.com/HazyResearch/flash-attention: diff --git a/scripts/finetune.py b/scripts/finetune.py index 9bed61ca4..ab226f68f 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -14,7 +14,6 @@ import torch import yaml # add src to the pythonpath so we don't need to pip install this -from datasets import Dataset from optimum.bettertransformer import BetterTransformer from transformers import GenerationConfig, TextStreamer @@ -208,14 +207,11 @@ def train( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH ) else: - if cfg.pretraining_dataset is True: - pretraining_dataset = "togethercomputer/RedPajama-Data-1T" - else: - pretraining_dataset = cfg.pretraining_dataset train_dataset = load_pretraining_dataset( - pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len + cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len ) - train_dataset = Dataset.from_list(list(train_dataset)) + # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 + train_dataset = train_dataset.with_format("torch") eval_dataset = None if cfg.debug or "debug" in kwargs: @@ -262,19 +258,6 @@ def train( model.save_pretrained(cfg.output_dir) return - if cfg.debug: - logging.info("check_dataset_labels...") - check_dataset_labels( - train_dataset.select( - [random.randrange(0, len(train_dataset) - 1) for i in range(5)] # nosec - ), - tokenizer, - ) - - if prepare_ds_only: - logging.info("Finished preparing dataset. Exiting...") - return - model.train() trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 13ad7c75d..492d8059b 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,12 +1,12 @@ """Module containing data utilities""" - +import functools import logging from hashlib import md5 from pathlib import Path from typing import List, Tuple, Union import torch -from datasets import Dataset, DatasetDict, IterableDataset, load_dataset, load_from_disk +from datasets import Dataset, DatasetDict, load_dataset, load_from_disk from huggingface_hub import hf_hub_download from transformers import PreTrainedTokenizerBase @@ -399,32 +399,116 @@ def load_prepare_datasets( return train_dataset, eval_dataset -class PretrainingDatasetWrapper(IterableDataset): - """ - Wrapper for pretraining dataset that avoids loading the dataset into memory - """ +def encode_pretraining(tokenizer, max_tokens, examples): + res = tokenizer( + examples["text"], + truncation=True, + max_length=max_tokens - 2, + add_special_tokens=True, + ) + # Convert to PyTorch tensors + input_ids = [torch.tensor(seq) for seq in res["input_ids"]] + attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]] + new_input_ids = [] + new_attention_mask = [] + # Append EOS and PAD tokens to input_ids, and correct attention_mask + for i, _ in enumerate(input_ids): + input_ids[i] = torch.cat( + ( + input_ids[i], + torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]), + ), + dim=0, + ) + attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0) - def __init__(self, tokenizer, dataset_path, max_tokens=2048): - self.tokenizer = tokenizer - self.dataset_path = dataset_path - self.max_tokens = max_tokens + # Concatenate tokens so that their lengths are less than max_tokens + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) - def __iter__(self): - buffer = [] - for sample in load_dataset( - self.dataset_path, - )["train"].shuffle(): - buffer += self.tokenizer(sample["text"])["input_ids"] - buffer += [self.tokenizer.eos_token_id] - while len(buffer) > self.max_tokens: - input_ids = torch.tensor(buffer[: self.max_tokens]) - yield { - "input_ids": input_ids, - "attention_mask": torch.ones(input_ids.size()), - "labels": input_ids, - } - buffer = buffer[self.max_tokens :] + for ids, mask in zip(input_ids, attention_mask): + if buffer_input_ids.numel() == max_tokens: + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + elif buffer_input_ids.numel() + ids.numel() <= max_tokens: + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + else: + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + buffer_input_ids = torch.tensor([], dtype=torch.long) + buffer_attention_mask = torch.tensor([], dtype=torch.long) + + buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0) + buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0) + + if buffer_input_ids.numel() > 0: # for any leftover tokens + while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size + buffer_input_ids = torch.cat( + ( + buffer_input_ids, + torch.full( + (max_tokens - buffer_input_ids.numel(),), + tokenizer.pad_token_id, + dtype=torch.long, + ), + ), + dim=0, + ) + buffer_attention_mask = torch.cat( + ( + buffer_attention_mask, + torch.full( + (max_tokens - buffer_attention_mask.numel(),), + 0, + dtype=torch.long, + ), + ), + dim=0, + ) + new_input_ids.append(buffer_input_ids) + new_attention_mask.append(buffer_attention_mask) + + ret = { + "input_ids": [seq.tolist() for seq in new_input_ids], + "labels": [seq.tolist() for seq in new_input_ids], + "attention_mask": [seq.tolist() for seq in new_attention_mask], + } + + logging.debug(len(ret["input_ids"])) + return ret def load_pretraining_dataset(path, tokenizer, max_tokens=2048): - return PretrainingDatasetWrapper(tokenizer, path, max_tokens=max_tokens) + encode = functools.partial(encode_pretraining, tokenizer, max_tokens) + dataset = load_dataset(path, streaming=True, split="train") + dataset = dataset.shuffle(seed=42, buffer_size=10_000) + # TODO dynamically figure out which columns/features to remove + dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"]) + return dataset diff --git a/src/axolotl/utils/validation.py b/src/axolotl/utils/validation.py index 2e2450fba..603afbfee 100644 --- a/src/axolotl/utils/validation.py +++ b/src/axolotl/utils/validation.py @@ -77,6 +77,11 @@ def validate_config(cfg): f"flash_optimum for BetterTransformers may not be used with {torch.__version__}" ) + if cfg.pretraining_dataset and cfg.group_by_length: + logging.warning( + "You probably want to disable group_by_length as it will force a streamed dataset to download completely." + ) + # TODO # MPT 7b # https://github.com/facebookresearch/bitsandbytes/issues/25 diff --git a/tests/test_validation.py b/tests/test_validation.py index 50bdf37e6..575392ab4 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -198,3 +198,54 @@ class ValidationTest(unittest.TestCase): ) validate_config(cfg) + + def test_flash_optimum(self): + cfg = DictDefault( + { + "flash_optimum": True, + "adapter": "lora", + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "BetterTransformers probably doesn't work with PEFT adapters" + in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "flash_optimum": True, + } + ) + + with self._caplog.at_level(logging.WARNING): + validate_config(cfg) + assert any( + "probably set bfloat16 or float16" in record.message + for record in self._caplog.records + ) + + cfg = DictDefault( + { + "flash_optimum": True, + "fp16": True, + } + ) + regex_exp = r".*AMP is not supported.*" + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg) + + cfg = DictDefault( + { + "flash_optimum": True, + "bf16": True, + } + ) + regex_exp = r".*AMP is not supported.*" + + with pytest.raises(ValueError, match=regex_exp): + validate_config(cfg)