improve iterable support

This commit is contained in:
Wing Lian
2025-01-02 13:50:35 -05:00
parent e659c01646
commit 01896b1bde
11 changed files with 265 additions and 78 deletions

View File

@@ -13,6 +13,7 @@ class PreprocessCliArgs:
debug_num_examples: int = field(default=1) debug_num_examples: int = field(default=1)
prompter: Optional[str] = field(default=None) prompter: Optional[str] = field(default=None)
download: Optional[bool] = field(default=True) download: Optional[bool] = field(default=True)
iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"})
@dataclass @dataclass

View File

@@ -39,6 +39,8 @@ def preprocess(config: str, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options. config options.
""" """
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.preprocess import do_cli from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs) do_cli(config=config, **kwargs)

View File

@@ -3,7 +3,7 @@
import logging import logging
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Union from typing import Optional, Union
import fire import fire
import transformers import transformers
@@ -75,7 +75,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
) )
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(
config: Union[Path, str] = Path("examples/"),
**kwargs,
) -> None:
""" """
Parses `axolotl` config, CLI args, and calls `do_preprocess`. Parses `axolotl` config, CLI args, and calls `do_preprocess`.

View File

@@ -63,11 +63,13 @@ def load_datasets(
""" """
tokenizer = load_tokenizer(cfg) tokenizer = load_tokenizer(cfg)
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset( train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
cfg, cfg,
tokenizer, tokenizer,
processor=processor, processor=processor,
preprocess_iterable=preprocess_iterable,
) )
if ( if (

View File

@@ -51,7 +51,7 @@ class TokenizedPromptDataset(Dataset):
map_kwargs = {} map_kwargs = {}
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100 map_kwargs["batch_size"] = 1_000
if self.prompt_tokenizer.filter_rows: if self.prompt_tokenizer.filter_rows:
dataset = dataset.filter( dataset = dataset.filter(
self.prompt_tokenizer.filter_rows, self.prompt_tokenizer.filter_rows,

View File

@@ -132,9 +132,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
return sample return sample
def tokenize_prompt(self, prompt): def _tokenize_single_prompt(self, prompt):
logprobs = prompt.pop(self.logprobs_field) logprobs = prompt.pop(self.logprobs_field)
tokenized_prompt = super().tokenize_prompt(prompt) tokenized_prompt = super()._tokenize_single_prompt(prompt)
tokenized_prompt[self.logprobs_field] = logprobs tokenized_prompt[self.logprobs_field] = logprobs
tokenized_prompt = self.transform_logprobs(tokenized_prompt) tokenized_prompt = self.transform_logprobs(tokenized_prompt)

View File

@@ -3,6 +3,7 @@ HF Chat Templates prompt strategy
""" """
import logging import logging
from collections import defaultdict
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from transformers import ProcessorMixin from transformers import ProcessorMixin
@@ -193,7 +194,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def __init__( def __init__(
self, self,
prompter, prompter: ChatTemplatePrompter,
tokenizer, tokenizer,
train_on_inputs, train_on_inputs,
sequence_len, sequence_len,
@@ -220,22 +221,50 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
def messages(self, messages): def messages(self, messages):
self._messages = messages self._messages = messages
def tokenize_prompt(self, prompt): @property
def supports_batched(self) -> bool:
# Let calling code know we can handle lists of examples
return True
def tokenize_prompt(self, prompt: dict[str, Any]) -> Dict[str, List[List[int]]]:
"""
Public method that can handle either a single prompt or a batch of prompts.
"""
res = defaultdict(lambda: [])
feature_names = list(prompt.keys())
# Process each prompt individually
for row in zip(*prompt.values()):
tokenized_prompt = self._tokenize_single_prompt(
dict(zip(feature_names, row))
)
for key, val in tokenized_prompt.items():
for i in range(0, len(val), self.sequence_len):
res[key].append(val[i : i + self.sequence_len])
# If there are no examples left, return an empty dictionary
if not res:
return {}
return dict(res)
def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
# Old simple legacy behavior that works reliably. # Old simple legacy behavior that works reliably.
if ( if (
not self.roles_to_train not self.roles_to_train
and not self.train_on_eos and not self.train_on_eos
and not self.prompter.message_field_training and not self.prompter.message_field_training # type: ignore
and not self.prompter.message_field_training_detail and not self.prompter.message_field_training_detail # type: ignore
): ):
turns = self.get_conversation_thread(prompt) turns = self.get_conversation_thread(prompt)
images = self.get_images(prompt) images = self.get_images(prompt)
prompt_ids = self.prompter.build_prompt( prompt_ids = self.prompter.build_prompt( # type: ignore
turns[:-1], turns[:-1],
add_generation_prompt=True, add_generation_prompt=True,
images=images, images=images,
) )
tokenized_res = self.prompter.build_prompt(turns, images=images) tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
tokenized_prompt = {} tokenized_prompt = {}
if isinstance(tokenized_res, list): if isinstance(tokenized_res, list):
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :] input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
@@ -256,7 +285,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
return tokenized_prompt return tokenized_prompt
turns = self.get_conversation_thread(prompt) turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns) input_ids = self.prompter.build_prompt(turns) # type: ignore
labels = [IGNORE_TOKEN_ID] * len(input_ids) labels = [IGNORE_TOKEN_ID] * len(input_ids)
last_eos_idx = -1 last_eos_idx = -1
@@ -286,7 +315,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
if should_train and turn_start_idx != -1 and turn_end_idx != -1: if should_train and turn_start_idx != -1 and turn_end_idx != -1:
if train_detail: if train_detail:
token_offsets = self.prompter.get_offsets_for_train_detail( token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
content, train_detail content, train_detail
) )
LOG.debug(f"Token offsets: {token_offsets}") LOG.debug(f"Token offsets: {token_offsets}")

View File

@@ -155,6 +155,7 @@ class SFTDataset(BaseModel):
type: Optional[Union[str, UserDefinedPrompterType]] = None type: Optional[Union[str, UserDefinedPrompterType]] = None
input_transform: Optional[str] = None input_transform: Optional[str] = None
shards: Optional[int] = None shards: Optional[int] = None
preprocess_shards: Optional[int] = None
conversation: Optional[str] = None conversation: Optional[str] = None
# Do not make this too strict or it will break the validator to choose different dataset class # Do not make this too strict or it will break the validator to choose different dataset class
chat_template: Optional[ chat_template: Optional[
@@ -809,6 +810,7 @@ class AxolotlInputConfig(
# INTERNALS - document for now, generally not set externally # INTERNALS - document for now, generally not set externally
is_preprocess: Optional[bool] = None is_preprocess: Optional[bool] = None
preprocess_iterable: Optional[bool] = None
total_num_tokens: Optional[int] = None total_num_tokens: Optional[int] = None
total_supervised_tokens: Optional[int] = None total_supervised_tokens: Optional[int] = None

View File

@@ -3,11 +3,12 @@
import functools import functools
import logging import logging
from pathlib import Path from pathlib import Path
from typing import List, Tuple, Union from typing import List, Tuple, Union, Optional
from datasets import ( from datasets import (
Dataset, Dataset,
DatasetDict, DatasetDict,
IterableDataset,
concatenate_datasets, concatenate_datasets,
load_dataset, load_dataset,
load_from_disk, load_from_disk,
@@ -57,7 +58,7 @@ LOG = logging.getLogger("axolotl")
@retry_on_request_exceptions(max_retries=3, delay=5) @retry_on_request_exceptions(max_retries=3, delay=5)
def prepare_dataset(cfg, tokenizer, processor=None): def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
prompters = [] prompters = []
if not cfg.pretraining_dataset: if not cfg.pretraining_dataset:
with zero_first(is_local_main_process()): with zero_first(is_local_main_process()):
@@ -68,6 +69,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH, DEFAULT_DATASET_PREPARED_PATH,
split="train", split="train",
processor=processor, processor=processor,
preprocess_iterable=preprocess_iterable,
) )
_, eval_dataset, _ = load_prepare_datasets( _, eval_dataset, _ = load_prepare_datasets(
tokenizer, tokenizer,
@@ -75,6 +77,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH, DEFAULT_DATASET_PREPARED_PATH,
split="test", split="test",
processor=processor, processor=processor,
preprocess_iterable=preprocess_iterable,
) )
else: else:
train_dataset, eval_dataset, prompters = load_prepare_datasets( train_dataset, eval_dataset, prompters = load_prepare_datasets(
@@ -82,6 +85,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
cfg, cfg,
DEFAULT_DATASET_PREPARED_PATH, DEFAULT_DATASET_PREPARED_PATH,
processor=processor, processor=processor,
preprocess_iterable=preprocess_iterable,
) )
else: else:
# Load streaming dataset if pretraining_dataset is given # Load streaming dataset if pretraining_dataset is given
@@ -137,6 +141,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
DEFAULT_DATASET_PREPARED_PATH, DEFAULT_DATASET_PREPARED_PATH,
split="test", split="test",
processor=processor, processor=processor,
preprocess_iterable=preprocess_iterable,
) )
if cfg.dataset_exact_deduplication: if cfg.dataset_exact_deduplication:
@@ -168,6 +173,7 @@ def load_tokenized_prepared_datasets(
default_dataset_prepared_path, default_dataset_prepared_path,
split="train", split="train",
processor=None, processor=None,
preprocess_iterable: Optional[bool] = None,
) -> Tuple[DatasetDict, List[Prompter]]: ) -> Tuple[DatasetDict, List[Prompter]]:
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
tokenizer_name = cfg.tokenizer_config tokenizer_name = cfg.tokenizer_config
@@ -261,13 +267,25 @@ def load_tokenized_prepared_datasets(
# at the same time for a given dataset # at the same time for a given dataset
for name in dataset.name: for name in dataset.name:
yield DictDefault({**dataset, "name": name}) yield DictDefault({**dataset, "name": name})
elif dataset.preprocess_shards and not dataset.shards:
for shard in range(dataset.preprocess_shards):
yield DictDefault(
{
**dataset,
"shards": dataset.preprocess_shards,
"shards_idx": shard,
}
)
else: else:
yield dataset yield dataset
streaming_ds = False
if preprocess_iterable:
streaming_ds = True
# pylint: disable=invalid-name # pylint: disable=invalid-name
for config_dataset in for_d_in_datasets(cfg_datasets): for config_dataset in for_d_in_datasets(cfg_datasets):
ds: Union[Dataset, DatasetDict] = load_dataset_w_config( ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
config_dataset, use_auth_token config_dataset, use_auth_token, streaming=streaming_ds
) )
d_base_type = d_prompt_style = None d_base_type = d_prompt_style = None
@@ -324,7 +342,21 @@ def load_tokenized_prepared_datasets(
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset: if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}") LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(str(prepared_ds_path)) if isinstance(dataset, IterableDataset):
def gen_from_iter_ds(_ds, _=None):
yield from _ds
ds_from_iter = Dataset.from_generator(
functools.partial(gen_from_iter_ds, dataset),
features=dataset.features,
num_proc=cfg.dataset_processes,
split=split,
gen_kwargs={"_": list(range(cfg.dataset_processes))},
)
ds_from_iter.save_to_disk(str(prepared_ds_path))
else:
dataset.save_to_disk(str(prepared_ds_path))
if cfg.push_dataset_to_hub: if cfg.push_dataset_to_hub:
LOG.info( LOG.info(
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..." f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
@@ -344,6 +376,7 @@ def load_prepare_datasets(
default_dataset_prepared_path, default_dataset_prepared_path,
split="train", split="train",
processor=None, processor=None,
preprocess_iterable: Optional[bool] = False,
) -> Tuple[Dataset, Dataset, List[Prompter]]: ) -> Tuple[Dataset, Dataset, List[Prompter]]:
dataset, prompters = load_tokenized_prepared_datasets( dataset, prompters = load_tokenized_prepared_datasets(
tokenizer, tokenizer,
@@ -351,6 +384,7 @@ def load_prepare_datasets(
default_dataset_prepared_path, default_dataset_prepared_path,
split=split, split=split,
processor=processor, processor=processor,
preprocess_iterable=preprocess_iterable,
) )
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:

View File

@@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault):
return ds_type return ds_type
def load_dataset_w_config(config_dataset, auth_token): def load_dataset_w_config(
config_dataset, auth_token, streaming=False
) -> Union[Dataset, DatasetDict]:
# pylint: disable=invalid-name # pylint: disable=invalid-name
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
ds_from_hub = False ds_from_hub = False
@@ -117,7 +119,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type, ds_type,
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
streaming=False, streaming=streaming,
split=None, split=None,
) )
else: else:
@@ -153,7 +155,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds = load_dataset( ds = load_dataset(
config_dataset.path, config_dataset.path,
name=config_dataset.name, name=config_dataset.name,
streaming=False, streaming=streaming,
data_files=config_dataset.data_files, data_files=config_dataset.data_files,
token=auth_token, token=auth_token,
revision=config_dataset.revision, revision=config_dataset.revision,
@@ -172,7 +174,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type, ds_type,
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=streaming,
split=None, split=None,
storage_options=storage_options, storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
@@ -183,7 +185,7 @@ def load_dataset_w_config(config_dataset, auth_token):
ds_type, ds_type,
name=config_dataset.name, name=config_dataset.name,
data_files=config_dataset.path, data_files=config_dataset.path,
streaming=False, streaming=streaming,
split=None, split=None,
storage_options=storage_options, storage_options=storage_options,
trust_remote_code=config_dataset.trust_remote_code, trust_remote_code=config_dataset.trust_remote_code,
@@ -213,7 +215,7 @@ def load_dataset_w_config(config_dataset, auth_token):
"json", "json",
name=config_dataset.name, name=config_dataset.name,
data_files=fp, data_files=fp,
streaming=False, streaming=streaming,
split=None, split=None,
) )
if not ds: if not ds:

View File

@@ -11,7 +11,7 @@ import numpy as np
import torch import torch
import torch.cuda import torch.cuda
from accelerate.logging import get_logger from accelerate.logging import get_logger
from datasets import disable_caching, enable_caching from datasets import IterableDataset, disable_caching, enable_caching
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from transformers.utils import is_torch_bf16_gpu_available from transformers.utils import is_torch_bf16_gpu_available
@@ -95,9 +95,46 @@ def disable_datasets_caching():
def add_position_ids(sample): def add_position_ids(sample):
sample_len = len(sample["input_ids"]) """
sample["position_ids"] = torch.arange(len(sample["input_ids"])) Handle both single-example and batched data.
sample["length"] = sample_len - single example: sample['input_ids'] is a list[int]
- batched data: sample['input_ids'] is a list[list[int]]
"""
if "input_ids" not in sample:
# If there's no "input_ids", just return sample unchanged
return sample
input_ids = sample["input_ids"]
# Detect if it's a single example or a batch
if not input_ids:
# Edge case: empty
return sample
# If first element is an int, its a single example
# If first element is a list, its a batch
if isinstance(input_ids[0], int):
# ---- SINGLE EXAMPLE ----
seq_len = len(input_ids)
# Position IDs for a single example
# As a list
sample["position_ids"] = list(range(seq_len))
sample["length"] = seq_len
else:
# ---- BATCHED EXAMPLES ----
# input_ids is a list of lists
position_ids_batch = []
lengths_batch = []
for seq in input_ids:
seq_len = len(seq)
position_ids_batch.append(list(range(seq_len)))
lengths_batch.append(seq_len)
# Now store them back
sample["position_ids"] = position_ids_batch
sample["length"] = lengths_batch
return sample return sample
@@ -172,10 +209,31 @@ def add_length(sample):
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2): def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
return ( """
len(sample["input_ids"]) <= sequence_len Drop samples whose sequence length is either too long (> sequence_len)
and len(sample["input_ids"]) >= min_sequence_len or too short (< min_sequence_len).
)
Works for both single-example (list[int]) or batched (list[list[int]]).
"""
input_ids = sample["input_ids"]
# Edge case: if input_ids is empty
if not input_ids:
# Decide if you want to drop or keep empty. Let's drop.
return False
# Check if single example or batched by looking at the first element
if isinstance(input_ids[0], int):
# Single example (input_ids is a list of int)
length = len(input_ids)
return min_sequence_len <= length <= sequence_len
# Batched (input_ids is a list of lists)
results = []
for seq in input_ids:
length = len(seq)
results.append(min_sequence_len <= length <= sequence_len)
return results
def process_datasets_for_packing(cfg, train_dataset, eval_dataset): def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
@@ -185,10 +243,13 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
min_sequence_len=cfg.min_sample_len or 2, min_sequence_len=cfg.min_sample_len or 2,
) )
min_input_len = np.min(get_dataset_lengths(train_dataset)) try:
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True) min_input_len = np.min(get_dataset_lengths(train_dataset))
max_input_len = np.max(get_dataset_lengths(train_dataset)) LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True) max_input_len = np.max(get_dataset_lengths(train_dataset))
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
except AttributeError:
pass
if cfg.model_config_type == "mamba": if cfg.model_config_type == "mamba":
LOG.info("dropping attention_mask column") LOG.info("dropping attention_mask column")
@@ -203,60 +264,109 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if eval_dataset and "token_type_ids" in eval_dataset.column_names: if eval_dataset and "token_type_ids" in eval_dataset.column_names:
eval_dataset = eval_dataset.remove_columns("token_type_ids") eval_dataset = eval_dataset.remove_columns("token_type_ids")
prior_len = len(train_dataset) filter_map_kwargs = {}
if not isinstance(train_dataset, IterableDataset):
filter_map_kwargs["num_proc"] = cfg.dataset_processes
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Dropping Long Sequences"
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_long, drop_long,
num_proc=cfg.dataset_processes, **filter_map_kwargs,
load_from_cache_file=not cfg.is_preprocess, **drop_long_kwargs,
desc="Dropping Long Sequences",
) )
dropped = prior_len - len(train_dataset) if prior_len:
if dropped: dropped = prior_len - len(train_dataset)
LOG.warning(f"Dropped {dropped} long samples from train dataset") if dropped:
LOG.warning(f"Dropped {dropped} long samples from train dataset")
if eval_dataset: if eval_dataset:
prior_len = len(eval_dataset) try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter( eval_dataset = eval_dataset.filter(
drop_long, drop_long,
num_proc=cfg.dataset_processes, **filter_map_kwargs,
load_from_cache_file=not cfg.is_preprocess, **drop_long_kwargs,
desc="Dropping Long Sequences",
) )
dropped = prior_len - len(eval_dataset) if prior_len:
if dropped: dropped = prior_len - len(eval_dataset)
LOG.warning(f"Dropped {dropped} long samples from eval dataset") if dropped:
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
# drop samples with where the number of elements with labels not equal to -100 is zero
def drop_no_trainable_tokens(sample): def drop_no_trainable_tokens(sample):
return np.sum(np.array(sample["labels"]) != -100) > 0 """
Drop samples if all labels are -100 (i.e., zero trainable tokens).
Works for both single-example or batched input.
"""
labels = sample["labels"]
if not labels:
# Edge case: if labels is empty, decide if you want to keep or drop
return True # or False
prior_len = len(train_dataset) # Check if single example or batch
# If first element is an int, we assume a single example
# If it's a list, we assume we're dealing with a batch
if isinstance(labels[0], int):
# Single example: return a single bool
return np.sum(np.array(labels) != -100) > 0
# Batched: 'labels' is a list of lists
# Return a list of booleans, one per sub-list
results = []
for row_labels in labels:
# Each row_labels is a list[int]
results.append(np.sum(np.array(row_labels) != -100) > 0)
return results
try:
prior_len = len(train_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
train_dataset = train_dataset.filter( train_dataset = train_dataset.filter(
drop_no_trainable_tokens, drop_no_trainable_tokens,
num_proc=cfg.dataset_processes, batched=True,
load_from_cache_file=not cfg.is_preprocess, **filter_map_kwargs,
desc="Drop Samples with Zero Trainable Tokens", **drop_long_kwargs,
) )
dropped = prior_len - len(train_dataset) if prior_len:
if dropped: dropped = prior_len - len(train_dataset)
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from train dataset"
)
if eval_dataset:
prior_len = len(eval_dataset)
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
num_proc=cfg.dataset_processes,
load_from_cache_file=not cfg.is_preprocess,
desc="Drop Samples with Zero Trainable Tokens",
)
dropped = prior_len - len(eval_dataset)
if dropped: if dropped:
LOG.warning( LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset" f"Dropped {dropped} samples with no trainable tokens from train dataset"
) )
if eval_dataset:
try:
prior_len = len(eval_dataset)
except TypeError:
# handle iterable datasets case
prior_len = None
eval_dataset = eval_dataset.filter(
drop_no_trainable_tokens,
**filter_map_kwargs,
**drop_long_kwargs,
)
if prior_len:
dropped = prior_len - len(eval_dataset)
if dropped:
LOG.warning(
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
)
if cfg.group_by_length: if cfg.group_by_length:
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_length, add_length,
@@ -291,19 +401,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
desc="Add position_id column (PoSE)", desc="Add position_id column (PoSE)",
) )
elif cfg.sample_packing: elif cfg.sample_packing:
drop_long_kwargs = {}
if filter_map_kwargs:
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
train_dataset = train_dataset.map( train_dataset = train_dataset.map(
add_position_ids, add_position_ids,
num_proc=cfg.dataset_processes, batched=True,
load_from_cache_file=not cfg.is_preprocess, **filter_map_kwargs,
desc="Add position_id column (Sample Packing)", **drop_long_kwargs,
) )
if cfg.eval_sample_packing is not False: if cfg.eval_sample_packing is not False:
if eval_dataset: if eval_dataset:
eval_dataset = eval_dataset.map( eval_dataset = eval_dataset.map(
add_position_ids, add_position_ids,
num_proc=cfg.dataset_processes, **filter_map_kwargs,
load_from_cache_file=not cfg.is_preprocess, **drop_long_kwargs,
desc="Add position_id column (Sample Packing)",
) )
return train_dataset, eval_dataset return train_dataset, eval_dataset
@@ -334,7 +446,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
and not cfg.reward_model and not cfg.reward_model
): ):
total_num_tokens = np.sum( total_num_tokens = np.sum(
train_dataset.data.column("input_ids") train_dataset.select_columns("input_ids")
.to_pandas() .to_pandas()
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda .apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
.values .values