diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 226988037..7ed98b8b6 100644 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -11,7 +11,7 @@ from abc import abstractmethod from dataclasses import dataclass, field from functools import partial from pathlib import Path -from typing import Optional, Union +from typing import Optional import torch import transformers @@ -31,7 +31,6 @@ from axolotl.utils.callbacks import ( log_prediction_callback_factory, ) from axolotl.utils.collators import BatchSamplerDataCollatorForSeq2Seq -from axolotl.utils.dataloader import MultipackDistributedDataloader from axolotl.utils.samplers import MultipackBatchSampler from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup @@ -215,9 +214,7 @@ class AxolotlTrainer(Trainer): ) return super().get_train_dataloader() - def get_eval_dataloader( - self, eval_dataset: Optional[Dataset] = None - ) -> Union[DataLoader, MultipackDistributedDataloader]: + def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader: if self.args.sample_packing and self.args.eval_sample_packing is not False: eval_dataset = ( eval_dataset if eval_dataset is not None else self.eval_dataset @@ -260,7 +257,7 @@ class AxolotlTrainer(Trainer): def get_bench_dataloader( self, bench_dataset: Dataset, - ) -> Union[DataLoader, MultipackDistributedDataloader]: + ) -> DataLoader: dataloader_params = { "batch_size": self.args.eval_batch_size, "collate_fn": self.bench_data_collator, diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 2839c946e..e749ad4c8 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -22,7 +22,13 @@ class PromptStyle(Enum): CHATML = "chatml" -class AlpacaPrompter: +class Prompter: + """ + Base prompter class for all prompters + """ + + +class AlpacaPrompter(Prompter): """ Base class for alpaca prompters """ @@ -159,7 +165,7 @@ class NomicGPT4AllPrompter(AlpacaPrompter): """ -class ReflectAlpacaPrompter: +class ReflectAlpacaPrompter(Prompter): """ Prompter for ReflectAlpaca """ @@ -254,7 +260,7 @@ SHAREGPT_ASSERTION_FAILED_ROLE = ( ) -class ShareGPTPrompter: # pylint: disable=too-few-public-methods +class ShareGPTPrompter(Prompter): # pylint: disable=too-few-public-methods """ A prompter that generates prompts for the ShareGPT """ @@ -349,7 +355,7 @@ class ShareGPTPrompterV2(ShareGPTPrompter): ) -class UnsupportedPrompter: +class UnsupportedPrompter(Prompter): """ A dummy class for custom prompters """ diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index fd41d2c61..52002a7c3 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -3,7 +3,7 @@ import functools import hashlib import logging from pathlib import Path -from typing import Any, Dict, List, Tuple, Union +from typing import Dict, List, Tuple, Union import torch from datasets import ( @@ -34,6 +34,7 @@ from axolotl.prompters import ( JeopardyPrompter, MultipleChoiceConcisePrompter, MultipleChoiceExplainPrompter, + Prompter, ReflectAlpacaPrompter, SummarizeTLDRPrompter, UnsupportedPrompter, @@ -90,7 +91,7 @@ def prepare_dataset(cfg, tokenizer): def load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path -) -> DatasetDict: +) -> Tuple[DatasetDict, List[Prompter]]: tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( @@ -302,7 +303,7 @@ def load_prepare_datasets( tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path, -) -> Tuple[Dataset, Dataset, List[Any]]: +) -> Tuple[Dataset, Dataset, List[Prompter]]: max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -311,7 +312,7 @@ def load_prepare_datasets( ) # make sure we don't accidentally set it larger than sequence_len tokenizer_name = tokenizer.__class__.__name__ - prompters = [] + prompters: List[Prompter] = [] if cfg.max_packed_sequence_len is not None: # see if we can go ahead and load the stacked dataset seed = f"@{str(cfg.seed)}" if cfg.seed else "" @@ -445,14 +446,13 @@ def load_prepare_datasets( train_fingerprint = md5(to_hash_train) test_fingerprint = md5(to_hash_test) - with zero_first(is_main_process()): - dataset = dataset.train_test_split( - test_size=cfg.val_set_size, - shuffle=False, - seed=cfg.seed or 42, - train_new_fingerprint=train_fingerprint, - test_new_fingerprint=test_fingerprint, - ) + dataset = dataset.train_test_split( + test_size=cfg.val_set_size, + shuffle=False, + seed=cfg.seed or 42, + train_new_fingerprint=train_fingerprint, + test_new_fingerprint=test_fingerprint, + ) train_dataset = dataset["train"] eval_dataset = dataset["test"] diff --git a/src/axolotl/utils/dataloader.py b/src/axolotl/utils/dataloader.py deleted file mode 100644 index 54c95db78..000000000 --- a/src/axolotl/utils/dataloader.py +++ /dev/null @@ -1,342 +0,0 @@ -# pylint: skip-file -import hashlib -import itertools -import logging -import math -import time -from queue import Queue -from threading import Thread -from typing import Any, Callable, List, Union - -import numba -import numpy as np -from torch.utils.data import DistributedSampler, Sampler - -LOG = logging.getLogger("axolotl.utils.dataloader") - - -@numba.njit -def ffd_check(a: np.ndarray, c: int, n: int): - # First-fit-decreasing bin packing - # Check if a[] could fit in n bins with capacity c - # https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing - - a = np.sort(a)[::-1] - bins = np.full((n,), c, dtype=a.dtype) - for size in a: - not_found = True - for idx in range(n): - if bins[idx] >= size: - bins[idx] -= size - not_found = False - break - - if not_found: - return False - - return True - - -@numba.njit -def ffd_with_result(a: np.ndarray, c: int, start_index: int): - # First-fit-decreasing bin packing (with result return) - - indices = np.argsort(a)[::-1] - a = a[indices] - - bins: List[Any] = [] - bins_result: List[Any] = [] - for a_id, size in enumerate(a): - add_new = True - for idx in range(len(bins)): - if bins[idx] >= size: - bins[idx] -= size - bins_result[idx].append(indices[a_id] + start_index) - add_new = False - break - - if add_new: - bins.append(c - size) - bins_result.append([indices[a_id] + start_index]) - - return bins_result, len(a) - - -@numba.njit -def allocate( - lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int -): - """ - :param lengths: array of lengths of each sample - :param lengths_cumsum: cumulative sum of consecutive lengths - :param rank: rank for this process - :param c: length of tokens per batch - :param n: number of ranks - :return: - """ - # Dynamic batch allocator, similar to Multifit - # https://en.wikipedia.org/wiki/Multifit_algorithm - # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) - - s = 0 - start_index = 0 - result = [] - result_totseqs = [] - - while True: - # binary search [left, right) - left = 1 - right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right") - - while right - left > 1: - mid = (left + right) // 2 - if ffd_check(lengths[start_index : start_index + mid], c, n): - left = mid - else: - right = mid - - # use length left - batch, tot_seqs = ffd_with_result( - lengths[start_index : start_index + left], c, start_index - ) - if len(batch) < n: - break - - start_index += left - s = lengths_cumsum[start_index - 1] - - # add local rank - result.append(batch[rank]) - # add total seqs for all ranks - result_totseqs.append(tot_seqs) - # yield batch[rank], tot_seqs, s, len(result) * c * n - return result, result_totseqs, s, len(result) * c * n - - -def chunk(iterable, n): - """ - Chunk data into tuples of length n - """ - # batched('ABCDEFG', 3) --> ABC DEF G - if n < 1: - raise ValueError("n must be at least one") - it = iter(iterable) - while batch := tuple(itertools.islice(it, n)): - yield batch - - -def hash_indices(lst: List[int]) -> str: - # Convert the list of integers to a string representation - concatenated = ",".join(map(str, lst)) - - # Generate the hash - sha256 = hashlib.sha256() - sha256.update(concatenated.encode()) - - return sha256.hexdigest() - - -class MultipackDistributedDataloader: - """Unpadded data loading using Multipack. - Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py - Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard. - """ - - def __init__( - self, - dataset: Any, - collate_fn: Callable, - seq_max_length: int = 2048, - batch_size: int = 1, - sampler: Union[Sampler, DistributedSampler] = None, - packing_efficiency_estimate: float = 1.0, - sample_packing_seq_len_multiplier: int = 1, - device_count: int = 1, - prefetch_max: int = 1000, - num_epochs: int = 1, - ): - # Dataset - self.dataset = dataset - self.lengths = ( - dataset.data.column("position_ids") - .to_pandas() - .apply(lambda x: x[-1] + 1) - .values - ) - assert isinstance(self.lengths, np.ndarray) - assert batch_size % sample_packing_seq_len_multiplier == 0 - assert batch_size >= sample_packing_seq_len_multiplier - self.sampler = sampler - self.batch_size = batch_size - self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier - self.seq_max_length = seq_max_length - self.batch_max_length = batch_size * seq_max_length - self.collate_fn = collate_fn - self.num_epochs = num_epochs - - self.num_replicas = 1 - self.rank = 0 - - # statistics - self.eff_total_used = 0 - self.eff_total_slots = 0 - self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 - self.device_count = device_count - - # maxsize is maximum number of samples in queue - self.prefetch_max = prefetch_max - self.queue: Queue = Queue(maxsize=prefetch_max) - self.thread = None - - def _worker(self): - LOG.info( - f"[WORKER] Epochs: {self.num_epochs}, Samples: {self.len_w_stats()*self.batch_size}" - ) - for epoch in range(self.num_epochs): - for sample in self._internal_batch_generator(): - while True: - if self.queue.full(): - time.sleep(1) - else: - break - self.queue.put(sample) - - # stop the queue when epoch is done - self.queue.put(None) - - def __iter__(self): - if hasattr(self.sampler, "set_epoch"): - new_epoch = self.sampler.epoch + 1 - self.sampler.set_epoch(new_epoch) - LOG.info(f"calling sampler.set_epoch({new_epoch})") - - if self.thread is None: - self.thread = Thread(target=self._worker, daemon=True) - self.thread.start() - - while True: - item = self.queue.get() - - if item is None: - break - yield item - - def generate_batches(self, set_stats=False): - LOG.info("generating packed batches") - if self.sampler: - indices = [idx for idx in self.sampler] - else: - indices = range(0, len(self.dataset)) - - LOG.info(hash_indices(indices)) - lengths = self.lengths[indices] - lengths_cumsum = np.cumsum(lengths) - - batches, totseqs, total_used, total_slots = allocate( - lengths=lengths, - lengths_cumsum=lengths_cumsum, - rank=self.rank, - # c=self.batch_max_length, - c=self.seq_max_length * self.sample_packing_seq_len_multiplier, - n=self.num_replicas, - ) - - batches = [[indices[b_idx] for b_idx in batch] for batch in batches] - - # statistics - if set_stats: - self.eff_total_used += total_used - self.eff_total_slots += total_slots - - return batches, totseqs - - def _internal_batch_generator(self): - all_batches, _ = self.generate_batches(set_stats=True) - features = self.dataset.features.keys() - len_remaining = self._len_est() - for batches in chunk( - all_batches, self.batch_size // self.sample_packing_seq_len_multiplier - ): - chunked_data = [] - attn_mask_cum_idx = 0 - for batch in batches: - concatenated = {} - batched_data = [self.dataset[batch_idx] for batch_idx in batch] - for feature in features: - if feature == "length": - continue - if feature == "attention_mask": - arrays = [ - (attn_mask_cum_idx + idx + 1) * np.array(item[feature]) - for idx, item in enumerate(batched_data) - if feature in item - ] - attn_mask_cum_idx += len(batched_data) - concatenated[feature] = np.concatenate(arrays) - else: - arrays = [ - np.array(item[feature]) - for item in batched_data - if feature in item - ] - concatenated[feature] = np.concatenate(arrays) - chunked_data.append(concatenated) - yield self.collate_fn(chunked_data) - len_remaining -= 1 - if not len_remaining: - return - # yield a no-op for cases where we don't have any data left to pack - for i in range(0, len_remaining): - yield self.collate_fn( - [ - { - "input_ids": [0], - "labels": [-100], - "attention_mask": [True], - "position_ids": [0], - } - ] - ) - - def _len_est(self): - lengths_sum = np.sum(self.lengths) - lengths_sum_per_device = lengths_sum // self.device_count - LOG.info( - f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " - f"total_num_tokens per device: {lengths_sum_per_device}" - ) - - # shave off 1% + 1 for dealing with variance in packing from random sampler to sampler - return ( - math.floor( - 0.99 - * lengths_sum_per_device - / self.packing_efficiency_estimate - // self.seq_max_length - // self.batch_size - ) - - 1 - ) - - def __len__(self): - # this doesn't return the actual length b/c with distributed samplers, not all dataloaders get - # the same share of total tokens - # if not self.eff_total_used: - # batches, _ = self.generate_batches(set_stats=True) - # LOG.info( - # f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " - # f"actual packing efficiency: {self.efficiency()}" - # ) - return max(1, self._len_est()) - - def len_w_stats(self): - if not self.eff_total_used: - batches, _ = self.generate_batches(set_stats=True) - LOG.info( - f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " - f"actual packing efficiency: {self.efficiency()}" - ) - return max(1, self._len_est()) - - def efficiency(self): - return self.eff_total_used / self.eff_total_slots