diff --git a/src/axolotl/cli/args.py b/src/axolotl/cli/args.py index 0618e07f1..f3a0d539b 100644 --- a/src/axolotl/cli/args.py +++ b/src/axolotl/cli/args.py @@ -13,6 +13,7 @@ class PreprocessCliArgs: debug_num_examples: int = field(default=1) prompter: Optional[str] = field(default=None) download: Optional[bool] = field(default=True) + iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"}) @dataclass diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 43e2de3db..d836d6d74 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -39,6 +39,8 @@ def preprocess(config: str, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ + kwargs = {k: v for k, v in kwargs.items() if v is not None} + from axolotl.cli.preprocess import do_cli do_cli(config=config, **kwargs) diff --git a/src/axolotl/cli/preprocess.py b/src/axolotl/cli/preprocess.py index 760fe76fa..627a95f8f 100644 --- a/src/axolotl/cli/preprocess.py +++ b/src/axolotl/cli/preprocess.py @@ -3,7 +3,7 @@ import logging import warnings from pathlib import Path -from typing import Union +from typing import Optional, Union import fire import transformers @@ -75,7 +75,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None: ) -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: +def do_cli( + config: Union[Path, str] = Path("examples/"), + **kwargs, +) -> None: """ Parses `axolotl` config, CLI args, and calls `do_preprocess`. diff --git a/src/axolotl/common/datasets.py b/src/axolotl/common/datasets.py index d07add29b..8694f0986 100644 --- a/src/axolotl/common/datasets.py +++ b/src/axolotl/common/datasets.py @@ -63,11 +63,13 @@ def load_datasets( """ tokenizer = load_tokenizer(cfg) processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None + preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( cfg, tokenizer, processor=processor, + preprocess_iterable=preprocess_iterable, ) if ( diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 13ae13d3b..1fcee571f 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -51,7 +51,7 @@ class TokenizedPromptDataset(Dataset): map_kwargs = {} if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True - map_kwargs["batch_size"] = 100 + map_kwargs["batch_size"] = 1_000 if self.prompt_tokenizer.filter_rows: dataset = dataset.filter( self.prompt_tokenizer.filter_rows, diff --git a/src/axolotl/integrations/kd/chat_template.py b/src/axolotl/integrations/kd/chat_template.py index 67402a033..b03e059b9 100644 --- a/src/axolotl/integrations/kd/chat_template.py +++ b/src/axolotl/integrations/kd/chat_template.py @@ -132,9 +132,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy): return sample - def tokenize_prompt(self, prompt): + def _tokenize_single_prompt(self, prompt): logprobs = prompt.pop(self.logprobs_field) - tokenized_prompt = super().tokenize_prompt(prompt) + tokenized_prompt = super()._tokenize_single_prompt(prompt) tokenized_prompt[self.logprobs_field] = logprobs tokenized_prompt = self.transform_logprobs(tokenized_prompt) diff --git a/src/axolotl/prompt_strategies/chat_template.py b/src/axolotl/prompt_strategies/chat_template.py index e4d6f619d..519d37ff8 100644 --- a/src/axolotl/prompt_strategies/chat_template.py +++ b/src/axolotl/prompt_strategies/chat_template.py @@ -3,6 +3,7 @@ HF Chat Templates prompt strategy """ import logging +from collections import defaultdict from typing import Any, Dict, List, Optional from transformers import ProcessorMixin @@ -193,7 +194,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def __init__( self, - prompter, + prompter: ChatTemplatePrompter, tokenizer, train_on_inputs, sequence_len, @@ -220,22 +221,50 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): def messages(self, messages): self._messages = messages - def tokenize_prompt(self, prompt): + @property + def supports_batched(self) -> bool: + # Let calling code know we can handle lists of examples + return True + + def tokenize_prompt(self, prompt: dict[str, Any]) -> Dict[str, List[List[int]]]: + """ + Public method that can handle either a single prompt or a batch of prompts. + """ + + res = defaultdict(lambda: []) + feature_names = list(prompt.keys()) + + # Process each prompt individually + for row in zip(*prompt.values()): + tokenized_prompt = self._tokenize_single_prompt( + dict(zip(feature_names, row)) + ) + for key, val in tokenized_prompt.items(): + for i in range(0, len(val), self.sequence_len): + res[key].append(val[i : i + self.sequence_len]) + + # If there are no examples left, return an empty dictionary + if not res: + return {} + + return dict(res) + + def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]: # Old simple legacy behavior that works reliably. if ( not self.roles_to_train and not self.train_on_eos - and not self.prompter.message_field_training - and not self.prompter.message_field_training_detail + and not self.prompter.message_field_training # type: ignore + and not self.prompter.message_field_training_detail # type: ignore ): turns = self.get_conversation_thread(prompt) images = self.get_images(prompt) - prompt_ids = self.prompter.build_prompt( + prompt_ids = self.prompter.build_prompt( # type: ignore turns[:-1], add_generation_prompt=True, images=images, ) - tokenized_res = self.prompter.build_prompt(turns, images=images) + tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore tokenized_prompt = {} if isinstance(tokenized_res, list): input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] @@ -256,7 +285,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): return tokenized_prompt turns = self.get_conversation_thread(prompt) - input_ids = self.prompter.build_prompt(turns) + input_ids = self.prompter.build_prompt(turns) # type: ignore labels = [IGNORE_TOKEN_ID] * len(input_ids) last_eos_idx = -1 @@ -286,7 +315,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy): if should_train and turn_start_idx != -1 and turn_end_idx != -1: if train_detail: - token_offsets = self.prompter.get_offsets_for_train_detail( + token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore content, train_detail ) LOG.debug(f"Token offsets: {token_offsets}") diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index ea3b7b2c9..c047468af 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -155,6 +155,7 @@ class SFTDataset(BaseModel): type: Optional[Union[str, UserDefinedPrompterType]] = None input_transform: Optional[str] = None shards: Optional[int] = None + preprocess_shards: Optional[int] = None conversation: Optional[str] = None # Do not make this too strict or it will break the validator to choose different dataset class chat_template: Optional[ @@ -809,6 +810,7 @@ class AxolotlInputConfig( # INTERNALS - document for now, generally not set externally is_preprocess: Optional[bool] = None + preprocess_iterable: Optional[bool] = None total_num_tokens: Optional[int] = None total_supervised_tokens: Optional[int] = None diff --git a/src/axolotl/utils/data/sft.py b/src/axolotl/utils/data/sft.py index 98a2a1af8..59d862a7f 100644 --- a/src/axolotl/utils/data/sft.py +++ b/src/axolotl/utils/data/sft.py @@ -3,11 +3,12 @@ import functools import logging from pathlib import Path -from typing import List, Tuple, Union +from typing import List, Tuple, Union, Optional from datasets import ( Dataset, DatasetDict, + IterableDataset, concatenate_datasets, load_dataset, load_from_disk, @@ -57,7 +58,7 @@ LOG = logging.getLogger("axolotl") @retry_on_request_exceptions(max_retries=3, delay=5) -def prepare_dataset(cfg, tokenizer, processor=None): +def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None): prompters = [] if not cfg.pretraining_dataset: with zero_first(is_local_main_process()): @@ -68,6 +69,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): DEFAULT_DATASET_PREPARED_PATH, split="train", processor=processor, + preprocess_iterable=preprocess_iterable, ) _, eval_dataset, _ = load_prepare_datasets( tokenizer, @@ -75,6 +77,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): DEFAULT_DATASET_PREPARED_PATH, split="test", processor=processor, + preprocess_iterable=preprocess_iterable, ) else: train_dataset, eval_dataset, prompters = load_prepare_datasets( @@ -82,6 +85,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): cfg, DEFAULT_DATASET_PREPARED_PATH, processor=processor, + preprocess_iterable=preprocess_iterable, ) else: # Load streaming dataset if pretraining_dataset is given @@ -137,6 +141,7 @@ def prepare_dataset(cfg, tokenizer, processor=None): DEFAULT_DATASET_PREPARED_PATH, split="test", processor=processor, + preprocess_iterable=preprocess_iterable, ) if cfg.dataset_exact_deduplication: @@ -168,6 +173,7 @@ def load_tokenized_prepared_datasets( default_dataset_prepared_path, split="train", processor=None, + preprocess_iterable: Optional[bool] = None, ) -> Tuple[DatasetDict, List[Prompter]]: cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets tokenizer_name = cfg.tokenizer_config @@ -261,13 +267,25 @@ def load_tokenized_prepared_datasets( # at the same time for a given dataset for name in dataset.name: yield DictDefault({**dataset, "name": name}) + elif dataset.preprocess_shards and not dataset.shards: + for shard in range(dataset.preprocess_shards): + yield DictDefault( + { + **dataset, + "shards": dataset.preprocess_shards, + "shards_idx": shard, + } + ) else: yield dataset + streaming_ds = False + if preprocess_iterable: + streaming_ds = True # pylint: disable=invalid-name for config_dataset in for_d_in_datasets(cfg_datasets): ds: Union[Dataset, DatasetDict] = load_dataset_w_config( - config_dataset, use_auth_token + config_dataset, use_auth_token, streaming=streaming_ds ) d_base_type = d_prompt_style = None @@ -324,7 +342,21 @@ def load_tokenized_prepared_datasets( if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") - dataset.save_to_disk(str(prepared_ds_path)) + if isinstance(dataset, IterableDataset): + + def gen_from_iter_ds(_ds, _=None): + yield from _ds + + ds_from_iter = Dataset.from_generator( + functools.partial(gen_from_iter_ds, dataset), + features=dataset.features, + num_proc=cfg.dataset_processes, + split=split, + gen_kwargs={"_": list(range(cfg.dataset_processes))}, + ) + ds_from_iter.save_to_disk(str(prepared_ds_path)) + else: + dataset.save_to_disk(str(prepared_ds_path)) if cfg.push_dataset_to_hub: LOG.info( f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." @@ -344,6 +376,7 @@ def load_prepare_datasets( default_dataset_prepared_path, split="train", processor=None, + preprocess_iterable: Optional[bool] = False, ) -> Tuple[Dataset, Dataset, List[Prompter]]: dataset, prompters = load_tokenized_prepared_datasets( tokenizer, @@ -351,6 +384,7 @@ def load_prepare_datasets( default_dataset_prepared_path, split=split, processor=processor, + preprocess_iterable=preprocess_iterable, ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: diff --git a/src/axolotl/utils/data/shared.py b/src/axolotl/utils/data/shared.py index d14496d96..456a3a882 100644 --- a/src/axolotl/utils/data/shared.py +++ b/src/axolotl/utils/data/shared.py @@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault): return ds_type -def load_dataset_w_config(config_dataset, auth_token): +def load_dataset_w_config( + config_dataset, auth_token, streaming=False +) -> Union[Dataset, DatasetDict]: # pylint: disable=invalid-name ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name ds_from_hub = False @@ -117,7 +119,7 @@ def load_dataset_w_config(config_dataset, auth_token): ds_type, name=config_dataset.name, data_files=config_dataset.data_files, - streaming=False, + streaming=streaming, split=None, ) else: @@ -153,7 +155,7 @@ def load_dataset_w_config(config_dataset, auth_token): ds = load_dataset( config_dataset.path, name=config_dataset.name, - streaming=False, + streaming=streaming, data_files=config_dataset.data_files, token=auth_token, revision=config_dataset.revision, @@ -172,7 +174,7 @@ def load_dataset_w_config(config_dataset, auth_token): ds_type, name=config_dataset.name, data_files=config_dataset.path, - streaming=False, + streaming=streaming, split=None, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, @@ -183,7 +185,7 @@ def load_dataset_w_config(config_dataset, auth_token): ds_type, name=config_dataset.name, data_files=config_dataset.path, - streaming=False, + streaming=streaming, split=None, storage_options=storage_options, trust_remote_code=config_dataset.trust_remote_code, @@ -213,7 +215,7 @@ def load_dataset_w_config(config_dataset, auth_token): "json", name=config_dataset.name, data_files=fp, - streaming=False, + streaming=streaming, split=None, ) if not ds: diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index 34b505ff1..ce4deafa9 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -11,7 +11,7 @@ import numpy as np import torch import torch.cuda from accelerate.logging import get_logger -from datasets import disable_caching, enable_caching +from datasets import IterableDataset, disable_caching, enable_caching from torch.utils.data import DataLoader, RandomSampler from transformers.utils import is_torch_bf16_gpu_available @@ -95,9 +95,46 @@ def disable_datasets_caching(): def add_position_ids(sample): - sample_len = len(sample["input_ids"]) - sample["position_ids"] = torch.arange(len(sample["input_ids"])) - sample["length"] = sample_len + """ + Handle both single-example and batched data. + - single example: sample['input_ids'] is a list[int] + - batched data: sample['input_ids'] is a list[list[int]] + """ + if "input_ids" not in sample: + # If there's no "input_ids", just return sample unchanged + return sample + + input_ids = sample["input_ids"] + + # Detect if it's a single example or a batch + if not input_ids: + # Edge case: empty + return sample + + # If first element is an int, it’s a single example + # If first element is a list, it’s a batch + if isinstance(input_ids[0], int): + # ---- SINGLE EXAMPLE ---- + seq_len = len(input_ids) + # Position IDs for a single example + # As a list + sample["position_ids"] = list(range(seq_len)) + sample["length"] = seq_len + + else: + # ---- BATCHED EXAMPLES ---- + # input_ids is a list of lists + position_ids_batch = [] + lengths_batch = [] + for seq in input_ids: + seq_len = len(seq) + position_ids_batch.append(list(range(seq_len))) + lengths_batch.append(seq_len) + + # Now store them back + sample["position_ids"] = position_ids_batch + sample["length"] = lengths_batch + return sample @@ -172,10 +209,31 @@ def add_length(sample): def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): - return ( - len(sample["input_ids"]) <= sequence_len - and len(sample["input_ids"]) >= min_sequence_len - ) + """ + Drop samples whose sequence length is either too long (> sequence_len) + or too short (< min_sequence_len). + + Works for both single-example (list[int]) or batched (list[list[int]]). + """ + input_ids = sample["input_ids"] + + # Edge case: if input_ids is empty + if not input_ids: + # Decide if you want to drop or keep empty. Let's drop. + return False + + # Check if single example or batched by looking at the first element + if isinstance(input_ids[0], int): + # Single example (input_ids is a list of int) + length = len(input_ids) + return min_sequence_len <= length <= sequence_len + + # Batched (input_ids is a list of lists) + results = [] + for seq in input_ids: + length = len(seq) + results.append(min_sequence_len <= length <= sequence_len) + return results def process_datasets_for_packing(cfg, train_dataset, eval_dataset): @@ -185,10 +243,13 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): min_sequence_len=cfg.min_sample_len or 2, ) - min_input_len = np.min(get_dataset_lengths(train_dataset)) - LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) - max_input_len = np.max(get_dataset_lengths(train_dataset)) - LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) + try: + min_input_len = np.min(get_dataset_lengths(train_dataset)) + LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) + max_input_len = np.max(get_dataset_lengths(train_dataset)) + LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) + except AttributeError: + pass if cfg.model_config_type == "mamba": LOG.info("dropping attention_mask column") @@ -203,60 +264,109 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): if eval_dataset and "token_type_ids" in eval_dataset.column_names: eval_dataset = eval_dataset.remove_columns("token_type_ids") - prior_len = len(train_dataset) + filter_map_kwargs = {} + if not isinstance(train_dataset, IterableDataset): + filter_map_kwargs["num_proc"] = cfg.dataset_processes + filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess + + try: + prior_len = len(train_dataset) + except TypeError: + # handle iterable datasets case + prior_len = None + drop_long_kwargs = {} + if filter_map_kwargs: + drop_long_kwargs["desc"] = "Dropping Long Sequences" train_dataset = train_dataset.filter( drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", + **filter_map_kwargs, + **drop_long_kwargs, ) - dropped = prior_len - len(train_dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from train dataset") + if prior_len: + dropped = prior_len - len(train_dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from train dataset") if eval_dataset: - prior_len = len(eval_dataset) + try: + prior_len = len(eval_dataset) + except TypeError: + # handle iterable datasets case + prior_len = None eval_dataset = eval_dataset.filter( drop_long, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Dropping Long Sequences", + **filter_map_kwargs, + **drop_long_kwargs, ) - dropped = prior_len - len(eval_dataset) - if dropped: - LOG.warning(f"Dropped {dropped} long samples from eval dataset") + if prior_len: + dropped = prior_len - len(eval_dataset) + if dropped: + LOG.warning(f"Dropped {dropped} long samples from eval dataset") - # drop samples with where the number of elements with labels not equal to -100 is zero def drop_no_trainable_tokens(sample): - return np.sum(np.array(sample["labels"]) != -100) > 0 + """ + Drop samples if all labels are -100 (i.e., zero trainable tokens). + Works for both single-example or batched input. + """ + labels = sample["labels"] + if not labels: + # Edge case: if labels is empty, decide if you want to keep or drop + return True # or False - prior_len = len(train_dataset) + # Check if single example or batch + # If first element is an int, we assume a single example + # If it's a list, we assume we're dealing with a batch + if isinstance(labels[0], int): + # Single example: return a single bool + return np.sum(np.array(labels) != -100) > 0 + + # Batched: 'labels' is a list of lists + # Return a list of booleans, one per sub-list + results = [] + for row_labels in labels: + # Each row_labels is a list[int] + results.append(np.sum(np.array(row_labels) != -100) > 0) + return results + + try: + prior_len = len(train_dataset) + except TypeError: + # handle iterable datasets case + prior_len = None + drop_long_kwargs = {} + if filter_map_kwargs: + drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens" train_dataset = train_dataset.filter( drop_no_trainable_tokens, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Drop Samples with Zero Trainable Tokens", + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, ) - dropped = prior_len - len(train_dataset) - if dropped: - LOG.warning( - f"Dropped {dropped} samples with no trainable tokens from train dataset" - ) - - if eval_dataset: - prior_len = len(eval_dataset) - eval_dataset = eval_dataset.filter( - drop_no_trainable_tokens, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Drop Samples with Zero Trainable Tokens", - ) - dropped = prior_len - len(eval_dataset) + if prior_len: + dropped = prior_len - len(train_dataset) if dropped: LOG.warning( - f"Dropped {dropped} samples with no trainable tokens from eval dataset" + f"Dropped {dropped} samples with no trainable tokens from train dataset" ) + if eval_dataset: + try: + prior_len = len(eval_dataset) + except TypeError: + # handle iterable datasets case + prior_len = None + eval_dataset = eval_dataset.filter( + drop_no_trainable_tokens, + **filter_map_kwargs, + **drop_long_kwargs, + ) + if prior_len: + dropped = prior_len - len(eval_dataset) + if dropped: + LOG.warning( + f"Dropped {dropped} samples with no trainable tokens from eval dataset" + ) + if cfg.group_by_length: train_dataset = train_dataset.map( add_length, @@ -291,19 +401,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset): desc="Add position_id column (PoSE)", ) elif cfg.sample_packing: + drop_long_kwargs = {} + if filter_map_kwargs: + drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)" train_dataset = train_dataset.map( add_position_ids, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (Sample Packing)", + batched=True, + **filter_map_kwargs, + **drop_long_kwargs, ) if cfg.eval_sample_packing is not False: if eval_dataset: eval_dataset = eval_dataset.map( add_position_ids, - num_proc=cfg.dataset_processes, - load_from_cache_file=not cfg.is_preprocess, - desc="Add position_id column (Sample Packing)", + **filter_map_kwargs, + **drop_long_kwargs, ) return train_dataset, eval_dataset @@ -334,7 +446,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True): and not cfg.reward_model ): total_num_tokens = np.sum( - train_dataset.data.column("input_ids") + train_dataset.select_columns("input_ids") .to_pandas() .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .values