improve iterable support
This commit is contained in:
@@ -13,6 +13,7 @@ class PreprocessCliArgs:
|
|||||||
debug_num_examples: int = field(default=1)
|
debug_num_examples: int = field(default=1)
|
||||||
prompter: Optional[str] = field(default=None)
|
prompter: Optional[str] = field(default=None)
|
||||||
download: Optional[bool] = field(default=True)
|
download: Optional[bool] = field(default=True)
|
||||||
|
iterable: Optional[bool] = field(default=None, metadata={"help": "Use IterableDataset for streaming processing of large datasets"})
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|||||||
@@ -39,6 +39,8 @@ def preprocess(config: str, **kwargs) -> None:
|
|||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
|
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
|
||||||
from axolotl.cli.preprocess import do_cli
|
from axolotl.cli.preprocess import do_cli
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
|
|||||||
@@ -3,7 +3,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import transformers
|
import transformers
|
||||||
@@ -75,7 +75,10 @@ def do_preprocess(cfg: DictDefault, cli_args: PreprocessCliArgs) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:
|
def do_cli(
|
||||||
|
config: Union[Path, str] = Path("examples/"),
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Parses `axolotl` config, CLI args, and calls `do_preprocess`.
|
Parses `axolotl` config, CLI args, and calls `do_preprocess`.
|
||||||
|
|
||||||
|
|||||||
@@ -63,11 +63,13 @@ def load_datasets(
|
|||||||
"""
|
"""
|
||||||
tokenizer = load_tokenizer(cfg)
|
tokenizer = load_tokenizer(cfg)
|
||||||
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
processor = load_processor(cfg, tokenizer=tokenizer) if cfg.processor_type else None
|
||||||
|
preprocess_iterable = hasattr(cli_args, "iterable") and cli_args.iterable is not None and cli_args.iterable
|
||||||
|
|
||||||
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
train_dataset, eval_dataset, total_num_steps, prompters = prepare_dataset(
|
||||||
cfg,
|
cfg,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -51,7 +51,7 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
map_kwargs = {}
|
map_kwargs = {}
|
||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
map_kwargs["batch_size"] = 100
|
map_kwargs["batch_size"] = 1_000
|
||||||
if self.prompt_tokenizer.filter_rows:
|
if self.prompt_tokenizer.filter_rows:
|
||||||
dataset = dataset.filter(
|
dataset = dataset.filter(
|
||||||
self.prompt_tokenizer.filter_rows,
|
self.prompt_tokenizer.filter_rows,
|
||||||
|
|||||||
@@ -132,9 +132,9 @@ class ChatTemplateStrategyWithKD(ChatTemplateStrategy):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def _tokenize_single_prompt(self, prompt):
|
||||||
logprobs = prompt.pop(self.logprobs_field)
|
logprobs = prompt.pop(self.logprobs_field)
|
||||||
tokenized_prompt = super().tokenize_prompt(prompt)
|
tokenized_prompt = super()._tokenize_single_prompt(prompt)
|
||||||
tokenized_prompt[self.logprobs_field] = logprobs
|
tokenized_prompt[self.logprobs_field] = logprobs
|
||||||
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
tokenized_prompt = self.transform_logprobs(tokenized_prompt)
|
||||||
|
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ HF Chat Templates prompt strategy
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
from transformers import ProcessorMixin
|
from transformers import ProcessorMixin
|
||||||
@@ -193,7 +194,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
prompter,
|
prompter: ChatTemplatePrompter,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
train_on_inputs,
|
train_on_inputs,
|
||||||
sequence_len,
|
sequence_len,
|
||||||
@@ -220,22 +221,50 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
def messages(self, messages):
|
def messages(self, messages):
|
||||||
self._messages = messages
|
self._messages = messages
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
@property
|
||||||
|
def supports_batched(self) -> bool:
|
||||||
|
# Let calling code know we can handle lists of examples
|
||||||
|
return True
|
||||||
|
|
||||||
|
def tokenize_prompt(self, prompt: dict[str, Any]) -> Dict[str, List[List[int]]]:
|
||||||
|
"""
|
||||||
|
Public method that can handle either a single prompt or a batch of prompts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
res = defaultdict(lambda: [])
|
||||||
|
feature_names = list(prompt.keys())
|
||||||
|
|
||||||
|
# Process each prompt individually
|
||||||
|
for row in zip(*prompt.values()):
|
||||||
|
tokenized_prompt = self._tokenize_single_prompt(
|
||||||
|
dict(zip(feature_names, row))
|
||||||
|
)
|
||||||
|
for key, val in tokenized_prompt.items():
|
||||||
|
for i in range(0, len(val), self.sequence_len):
|
||||||
|
res[key].append(val[i : i + self.sequence_len])
|
||||||
|
|
||||||
|
# If there are no examples left, return an empty dictionary
|
||||||
|
if not res:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return dict(res)
|
||||||
|
|
||||||
|
def _tokenize_single_prompt(self, prompt: dict) -> Dict[str, List[int]]:
|
||||||
# Old simple legacy behavior that works reliably.
|
# Old simple legacy behavior that works reliably.
|
||||||
if (
|
if (
|
||||||
not self.roles_to_train
|
not self.roles_to_train
|
||||||
and not self.train_on_eos
|
and not self.train_on_eos
|
||||||
and not self.prompter.message_field_training
|
and not self.prompter.message_field_training # type: ignore
|
||||||
and not self.prompter.message_field_training_detail
|
and not self.prompter.message_field_training_detail # type: ignore
|
||||||
):
|
):
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
images = self.get_images(prompt)
|
images = self.get_images(prompt)
|
||||||
prompt_ids = self.prompter.build_prompt(
|
prompt_ids = self.prompter.build_prompt( # type: ignore
|
||||||
turns[:-1],
|
turns[:-1],
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
images=images,
|
images=images,
|
||||||
)
|
)
|
||||||
tokenized_res = self.prompter.build_prompt(turns, images=images)
|
tokenized_res = self.prompter.build_prompt(turns, images=images) # type: ignore
|
||||||
tokenized_prompt = {}
|
tokenized_prompt = {}
|
||||||
if isinstance(tokenized_res, list):
|
if isinstance(tokenized_res, list):
|
||||||
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
|
input_ids = prompt_ids + tokenized_res[len(prompt_ids) :]
|
||||||
@@ -256,7 +285,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
return tokenized_prompt
|
return tokenized_prompt
|
||||||
|
|
||||||
turns = self.get_conversation_thread(prompt)
|
turns = self.get_conversation_thread(prompt)
|
||||||
input_ids = self.prompter.build_prompt(turns)
|
input_ids = self.prompter.build_prompt(turns) # type: ignore
|
||||||
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
labels = [IGNORE_TOKEN_ID] * len(input_ids)
|
||||||
|
|
||||||
last_eos_idx = -1
|
last_eos_idx = -1
|
||||||
@@ -286,7 +315,7 @@ class ChatTemplateStrategy(PromptTokenizingStrategy):
|
|||||||
|
|
||||||
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
if should_train and turn_start_idx != -1 and turn_end_idx != -1:
|
||||||
if train_detail:
|
if train_detail:
|
||||||
token_offsets = self.prompter.get_offsets_for_train_detail(
|
token_offsets = self.prompter.get_offsets_for_train_detail( # type: ignore
|
||||||
content, train_detail
|
content, train_detail
|
||||||
)
|
)
|
||||||
LOG.debug(f"Token offsets: {token_offsets}")
|
LOG.debug(f"Token offsets: {token_offsets}")
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ class SFTDataset(BaseModel):
|
|||||||
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
type: Optional[Union[str, UserDefinedPrompterType]] = None
|
||||||
input_transform: Optional[str] = None
|
input_transform: Optional[str] = None
|
||||||
shards: Optional[int] = None
|
shards: Optional[int] = None
|
||||||
|
preprocess_shards: Optional[int] = None
|
||||||
conversation: Optional[str] = None
|
conversation: Optional[str] = None
|
||||||
# Do not make this too strict or it will break the validator to choose different dataset class
|
# Do not make this too strict or it will break the validator to choose different dataset class
|
||||||
chat_template: Optional[
|
chat_template: Optional[
|
||||||
@@ -809,6 +810,7 @@ class AxolotlInputConfig(
|
|||||||
|
|
||||||
# INTERNALS - document for now, generally not set externally
|
# INTERNALS - document for now, generally not set externally
|
||||||
is_preprocess: Optional[bool] = None
|
is_preprocess: Optional[bool] = None
|
||||||
|
preprocess_iterable: Optional[bool] = None
|
||||||
|
|
||||||
total_num_tokens: Optional[int] = None
|
total_num_tokens: Optional[int] = None
|
||||||
total_supervised_tokens: Optional[int] = None
|
total_supervised_tokens: Optional[int] = None
|
||||||
|
|||||||
@@ -3,11 +3,12 @@
|
|||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Tuple, Union
|
from typing import List, Tuple, Union, Optional
|
||||||
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
|
IterableDataset,
|
||||||
concatenate_datasets,
|
concatenate_datasets,
|
||||||
load_dataset,
|
load_dataset,
|
||||||
load_from_disk,
|
load_from_disk,
|
||||||
@@ -57,7 +58,7 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
|
|
||||||
@retry_on_request_exceptions(max_retries=3, delay=5)
|
@retry_on_request_exceptions(max_retries=3, delay=5)
|
||||||
def prepare_dataset(cfg, tokenizer, processor=None):
|
def prepare_dataset(cfg, tokenizer, processor=None, preprocess_iterable=None):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_local_main_process()):
|
with zero_first(is_local_main_process()):
|
||||||
@@ -68,6 +69,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
DEFAULT_DATASET_PREPARED_PATH,
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
split="train",
|
split="train",
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
_, eval_dataset, _ = load_prepare_datasets(
|
_, eval_dataset, _ = load_prepare_datasets(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -75,6 +77,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
DEFAULT_DATASET_PREPARED_PATH,
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
split="test",
|
split="test",
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
train_dataset, eval_dataset, prompters = load_prepare_datasets(
|
||||||
@@ -82,6 +85,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
cfg,
|
cfg,
|
||||||
DEFAULT_DATASET_PREPARED_PATH,
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Load streaming dataset if pretraining_dataset is given
|
# Load streaming dataset if pretraining_dataset is given
|
||||||
@@ -137,6 +141,7 @@ def prepare_dataset(cfg, tokenizer, processor=None):
|
|||||||
DEFAULT_DATASET_PREPARED_PATH,
|
DEFAULT_DATASET_PREPARED_PATH,
|
||||||
split="test",
|
split="test",
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.dataset_exact_deduplication:
|
if cfg.dataset_exact_deduplication:
|
||||||
@@ -168,6 +173,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
split="train",
|
split="train",
|
||||||
processor=None,
|
processor=None,
|
||||||
|
preprocess_iterable: Optional[bool] = None,
|
||||||
) -> Tuple[DatasetDict, List[Prompter]]:
|
) -> Tuple[DatasetDict, List[Prompter]]:
|
||||||
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
cfg_datasets = cfg.test_datasets if split == "test" else cfg.datasets
|
||||||
tokenizer_name = cfg.tokenizer_config
|
tokenizer_name = cfg.tokenizer_config
|
||||||
@@ -261,13 +267,25 @@ def load_tokenized_prepared_datasets(
|
|||||||
# at the same time for a given dataset
|
# at the same time for a given dataset
|
||||||
for name in dataset.name:
|
for name in dataset.name:
|
||||||
yield DictDefault({**dataset, "name": name})
|
yield DictDefault({**dataset, "name": name})
|
||||||
|
elif dataset.preprocess_shards and not dataset.shards:
|
||||||
|
for shard in range(dataset.preprocess_shards):
|
||||||
|
yield DictDefault(
|
||||||
|
{
|
||||||
|
**dataset,
|
||||||
|
"shards": dataset.preprocess_shards,
|
||||||
|
"shards_idx": shard,
|
||||||
|
}
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
yield dataset
|
yield dataset
|
||||||
|
|
||||||
|
streaming_ds = False
|
||||||
|
if preprocess_iterable:
|
||||||
|
streaming_ds = True
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
for config_dataset in for_d_in_datasets(cfg_datasets):
|
for config_dataset in for_d_in_datasets(cfg_datasets):
|
||||||
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
ds: Union[Dataset, DatasetDict] = load_dataset_w_config(
|
||||||
config_dataset, use_auth_token
|
config_dataset, use_auth_token, streaming=streaming_ds
|
||||||
)
|
)
|
||||||
|
|
||||||
d_base_type = d_prompt_style = None
|
d_base_type = d_prompt_style = None
|
||||||
@@ -324,7 +342,21 @@ def load_tokenized_prepared_datasets(
|
|||||||
|
|
||||||
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
if cfg.local_rank == 0 and not cfg.skip_prepare_dataset:
|
||||||
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
LOG.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
|
||||||
dataset.save_to_disk(str(prepared_ds_path))
|
if isinstance(dataset, IterableDataset):
|
||||||
|
|
||||||
|
def gen_from_iter_ds(_ds, _=None):
|
||||||
|
yield from _ds
|
||||||
|
|
||||||
|
ds_from_iter = Dataset.from_generator(
|
||||||
|
functools.partial(gen_from_iter_ds, dataset),
|
||||||
|
features=dataset.features,
|
||||||
|
num_proc=cfg.dataset_processes,
|
||||||
|
split=split,
|
||||||
|
gen_kwargs={"_": list(range(cfg.dataset_processes))},
|
||||||
|
)
|
||||||
|
ds_from_iter.save_to_disk(str(prepared_ds_path))
|
||||||
|
else:
|
||||||
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
if cfg.push_dataset_to_hub:
|
if cfg.push_dataset_to_hub:
|
||||||
LOG.info(
|
LOG.info(
|
||||||
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
|
f"Pushing merged prepared dataset to Huggingface hub at {cfg.push_dataset_to_hub} (version {ds_hash})..."
|
||||||
@@ -344,6 +376,7 @@ def load_prepare_datasets(
|
|||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
split="train",
|
split="train",
|
||||||
processor=None,
|
processor=None,
|
||||||
|
preprocess_iterable: Optional[bool] = False,
|
||||||
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
) -> Tuple[Dataset, Dataset, List[Prompter]]:
|
||||||
dataset, prompters = load_tokenized_prepared_datasets(
|
dataset, prompters = load_tokenized_prepared_datasets(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@@ -351,6 +384,7 @@ def load_prepare_datasets(
|
|||||||
default_dataset_prepared_path,
|
default_dataset_prepared_path,
|
||||||
split=split,
|
split=split,
|
||||||
processor=processor,
|
processor=processor,
|
||||||
|
preprocess_iterable=preprocess_iterable,
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||||
|
|||||||
@@ -29,7 +29,9 @@ def get_ds_type(config_dataset: DictDefault):
|
|||||||
return ds_type
|
return ds_type
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_w_config(config_dataset, auth_token):
|
def load_dataset_w_config(
|
||||||
|
config_dataset, auth_token, streaming=False
|
||||||
|
) -> Union[Dataset, DatasetDict]:
|
||||||
# pylint: disable=invalid-name
|
# pylint: disable=invalid-name
|
||||||
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
ds: Optional[Union[Dataset, DatasetDict]] = None # pylint: disable=invalid-name
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
@@ -117,7 +119,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
ds_type,
|
ds_type,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
streaming=False,
|
streaming=streaming,
|
||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -153,7 +155,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
config_dataset.path,
|
config_dataset.path,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
streaming=False,
|
streaming=streaming,
|
||||||
data_files=config_dataset.data_files,
|
data_files=config_dataset.data_files,
|
||||||
token=auth_token,
|
token=auth_token,
|
||||||
revision=config_dataset.revision,
|
revision=config_dataset.revision,
|
||||||
@@ -172,7 +174,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
ds_type,
|
ds_type,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=config_dataset.path,
|
data_files=config_dataset.path,
|
||||||
streaming=False,
|
streaming=streaming,
|
||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
@@ -183,7 +185,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
ds_type,
|
ds_type,
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=config_dataset.path,
|
data_files=config_dataset.path,
|
||||||
streaming=False,
|
streaming=streaming,
|
||||||
split=None,
|
split=None,
|
||||||
storage_options=storage_options,
|
storage_options=storage_options,
|
||||||
trust_remote_code=config_dataset.trust_remote_code,
|
trust_remote_code=config_dataset.trust_remote_code,
|
||||||
@@ -213,7 +215,7 @@ def load_dataset_w_config(config_dataset, auth_token):
|
|||||||
"json",
|
"json",
|
||||||
name=config_dataset.name,
|
name=config_dataset.name,
|
||||||
data_files=fp,
|
data_files=fp,
|
||||||
streaming=False,
|
streaming=streaming,
|
||||||
split=None,
|
split=None,
|
||||||
)
|
)
|
||||||
if not ds:
|
if not ds:
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
from accelerate.logging import get_logger
|
from accelerate.logging import get_logger
|
||||||
from datasets import disable_caching, enable_caching
|
from datasets import IterableDataset, disable_caching, enable_caching
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
from transformers.utils import is_torch_bf16_gpu_available
|
from transformers.utils import is_torch_bf16_gpu_available
|
||||||
|
|
||||||
@@ -95,9 +95,46 @@ def disable_datasets_caching():
|
|||||||
|
|
||||||
|
|
||||||
def add_position_ids(sample):
|
def add_position_ids(sample):
|
||||||
sample_len = len(sample["input_ids"])
|
"""
|
||||||
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
Handle both single-example and batched data.
|
||||||
sample["length"] = sample_len
|
- single example: sample['input_ids'] is a list[int]
|
||||||
|
- batched data: sample['input_ids'] is a list[list[int]]
|
||||||
|
"""
|
||||||
|
if "input_ids" not in sample:
|
||||||
|
# If there's no "input_ids", just return sample unchanged
|
||||||
|
return sample
|
||||||
|
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
|
||||||
|
# Detect if it's a single example or a batch
|
||||||
|
if not input_ids:
|
||||||
|
# Edge case: empty
|
||||||
|
return sample
|
||||||
|
|
||||||
|
# If first element is an int, it’s a single example
|
||||||
|
# If first element is a list, it’s a batch
|
||||||
|
if isinstance(input_ids[0], int):
|
||||||
|
# ---- SINGLE EXAMPLE ----
|
||||||
|
seq_len = len(input_ids)
|
||||||
|
# Position IDs for a single example
|
||||||
|
# As a list
|
||||||
|
sample["position_ids"] = list(range(seq_len))
|
||||||
|
sample["length"] = seq_len
|
||||||
|
|
||||||
|
else:
|
||||||
|
# ---- BATCHED EXAMPLES ----
|
||||||
|
# input_ids is a list of lists
|
||||||
|
position_ids_batch = []
|
||||||
|
lengths_batch = []
|
||||||
|
for seq in input_ids:
|
||||||
|
seq_len = len(seq)
|
||||||
|
position_ids_batch.append(list(range(seq_len)))
|
||||||
|
lengths_batch.append(seq_len)
|
||||||
|
|
||||||
|
# Now store them back
|
||||||
|
sample["position_ids"] = position_ids_batch
|
||||||
|
sample["length"] = lengths_batch
|
||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
|
||||||
@@ -172,10 +209,31 @@ def add_length(sample):
|
|||||||
|
|
||||||
|
|
||||||
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
def drop_long_seq(sample, sequence_len=2048, min_sequence_len=2):
|
||||||
return (
|
"""
|
||||||
len(sample["input_ids"]) <= sequence_len
|
Drop samples whose sequence length is either too long (> sequence_len)
|
||||||
and len(sample["input_ids"]) >= min_sequence_len
|
or too short (< min_sequence_len).
|
||||||
)
|
|
||||||
|
Works for both single-example (list[int]) or batched (list[list[int]]).
|
||||||
|
"""
|
||||||
|
input_ids = sample["input_ids"]
|
||||||
|
|
||||||
|
# Edge case: if input_ids is empty
|
||||||
|
if not input_ids:
|
||||||
|
# Decide if you want to drop or keep empty. Let's drop.
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Check if single example or batched by looking at the first element
|
||||||
|
if isinstance(input_ids[0], int):
|
||||||
|
# Single example (input_ids is a list of int)
|
||||||
|
length = len(input_ids)
|
||||||
|
return min_sequence_len <= length <= sequence_len
|
||||||
|
|
||||||
|
# Batched (input_ids is a list of lists)
|
||||||
|
results = []
|
||||||
|
for seq in input_ids:
|
||||||
|
length = len(seq)
|
||||||
|
results.append(min_sequence_len <= length <= sequence_len)
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
@@ -185,10 +243,13 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
min_sequence_len=cfg.min_sample_len or 2,
|
min_sequence_len=cfg.min_sample_len or 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
try:
|
||||||
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
min_input_len = np.min(get_dataset_lengths(train_dataset))
|
||||||
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
LOG.debug(f"min_input_len: {min_input_len}", main_process_only=True)
|
||||||
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
max_input_len = np.max(get_dataset_lengths(train_dataset))
|
||||||
|
LOG.debug(f"max_input_len: {max_input_len}", main_process_only=True)
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
if cfg.model_config_type == "mamba":
|
if cfg.model_config_type == "mamba":
|
||||||
LOG.info("dropping attention_mask column")
|
LOG.info("dropping attention_mask column")
|
||||||
@@ -203,60 +264,109 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
if eval_dataset and "token_type_ids" in eval_dataset.column_names:
|
||||||
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
eval_dataset = eval_dataset.remove_columns("token_type_ids")
|
||||||
|
|
||||||
prior_len = len(train_dataset)
|
filter_map_kwargs = {}
|
||||||
|
if not isinstance(train_dataset, IterableDataset):
|
||||||
|
filter_map_kwargs["num_proc"] = cfg.dataset_processes
|
||||||
|
filter_map_kwargs["load_from_cache_file"] = not cfg.is_preprocess
|
||||||
|
|
||||||
|
try:
|
||||||
|
prior_len = len(train_dataset)
|
||||||
|
except TypeError:
|
||||||
|
# handle iterable datasets case
|
||||||
|
prior_len = None
|
||||||
|
drop_long_kwargs = {}
|
||||||
|
if filter_map_kwargs:
|
||||||
|
drop_long_kwargs["desc"] = "Dropping Long Sequences"
|
||||||
train_dataset = train_dataset.filter(
|
train_dataset = train_dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
num_proc=cfg.dataset_processes,
|
**filter_map_kwargs,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
**drop_long_kwargs,
|
||||||
desc="Dropping Long Sequences",
|
|
||||||
)
|
)
|
||||||
dropped = prior_len - len(train_dataset)
|
if prior_len:
|
||||||
if dropped:
|
dropped = prior_len - len(train_dataset)
|
||||||
LOG.warning(f"Dropped {dropped} long samples from train dataset")
|
if dropped:
|
||||||
|
LOG.warning(f"Dropped {dropped} long samples from train dataset")
|
||||||
|
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
prior_len = len(eval_dataset)
|
try:
|
||||||
|
prior_len = len(eval_dataset)
|
||||||
|
except TypeError:
|
||||||
|
# handle iterable datasets case
|
||||||
|
prior_len = None
|
||||||
eval_dataset = eval_dataset.filter(
|
eval_dataset = eval_dataset.filter(
|
||||||
drop_long,
|
drop_long,
|
||||||
num_proc=cfg.dataset_processes,
|
**filter_map_kwargs,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
**drop_long_kwargs,
|
||||||
desc="Dropping Long Sequences",
|
|
||||||
)
|
)
|
||||||
dropped = prior_len - len(eval_dataset)
|
if prior_len:
|
||||||
if dropped:
|
dropped = prior_len - len(eval_dataset)
|
||||||
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
|
if dropped:
|
||||||
|
LOG.warning(f"Dropped {dropped} long samples from eval dataset")
|
||||||
|
|
||||||
# drop samples with where the number of elements with labels not equal to -100 is zero
|
|
||||||
def drop_no_trainable_tokens(sample):
|
def drop_no_trainable_tokens(sample):
|
||||||
return np.sum(np.array(sample["labels"]) != -100) > 0
|
"""
|
||||||
|
Drop samples if all labels are -100 (i.e., zero trainable tokens).
|
||||||
|
Works for both single-example or batched input.
|
||||||
|
"""
|
||||||
|
labels = sample["labels"]
|
||||||
|
if not labels:
|
||||||
|
# Edge case: if labels is empty, decide if you want to keep or drop
|
||||||
|
return True # or False
|
||||||
|
|
||||||
prior_len = len(train_dataset)
|
# Check if single example or batch
|
||||||
|
# If first element is an int, we assume a single example
|
||||||
|
# If it's a list, we assume we're dealing with a batch
|
||||||
|
if isinstance(labels[0], int):
|
||||||
|
# Single example: return a single bool
|
||||||
|
return np.sum(np.array(labels) != -100) > 0
|
||||||
|
|
||||||
|
# Batched: 'labels' is a list of lists
|
||||||
|
# Return a list of booleans, one per sub-list
|
||||||
|
results = []
|
||||||
|
for row_labels in labels:
|
||||||
|
# Each row_labels is a list[int]
|
||||||
|
results.append(np.sum(np.array(row_labels) != -100) > 0)
|
||||||
|
return results
|
||||||
|
|
||||||
|
try:
|
||||||
|
prior_len = len(train_dataset)
|
||||||
|
except TypeError:
|
||||||
|
# handle iterable datasets case
|
||||||
|
prior_len = None
|
||||||
|
drop_long_kwargs = {}
|
||||||
|
if filter_map_kwargs:
|
||||||
|
drop_long_kwargs["desc"] = "Drop Samples with Zero Trainable Tokens"
|
||||||
train_dataset = train_dataset.filter(
|
train_dataset = train_dataset.filter(
|
||||||
drop_no_trainable_tokens,
|
drop_no_trainable_tokens,
|
||||||
num_proc=cfg.dataset_processes,
|
batched=True,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
**filter_map_kwargs,
|
||||||
desc="Drop Samples with Zero Trainable Tokens",
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
dropped = prior_len - len(train_dataset)
|
if prior_len:
|
||||||
if dropped:
|
dropped = prior_len - len(train_dataset)
|
||||||
LOG.warning(
|
|
||||||
f"Dropped {dropped} samples with no trainable tokens from train dataset"
|
|
||||||
)
|
|
||||||
|
|
||||||
if eval_dataset:
|
|
||||||
prior_len = len(eval_dataset)
|
|
||||||
eval_dataset = eval_dataset.filter(
|
|
||||||
drop_no_trainable_tokens,
|
|
||||||
num_proc=cfg.dataset_processes,
|
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
|
||||||
desc="Drop Samples with Zero Trainable Tokens",
|
|
||||||
)
|
|
||||||
dropped = prior_len - len(eval_dataset)
|
|
||||||
if dropped:
|
if dropped:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
|
f"Dropped {dropped} samples with no trainable tokens from train dataset"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if eval_dataset:
|
||||||
|
try:
|
||||||
|
prior_len = len(eval_dataset)
|
||||||
|
except TypeError:
|
||||||
|
# handle iterable datasets case
|
||||||
|
prior_len = None
|
||||||
|
eval_dataset = eval_dataset.filter(
|
||||||
|
drop_no_trainable_tokens,
|
||||||
|
**filter_map_kwargs,
|
||||||
|
**drop_long_kwargs,
|
||||||
|
)
|
||||||
|
if prior_len:
|
||||||
|
dropped = prior_len - len(eval_dataset)
|
||||||
|
if dropped:
|
||||||
|
LOG.warning(
|
||||||
|
f"Dropped {dropped} samples with no trainable tokens from eval dataset"
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_length,
|
add_length,
|
||||||
@@ -291,19 +401,21 @@ def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
|||||||
desc="Add position_id column (PoSE)",
|
desc="Add position_id column (PoSE)",
|
||||||
)
|
)
|
||||||
elif cfg.sample_packing:
|
elif cfg.sample_packing:
|
||||||
|
drop_long_kwargs = {}
|
||||||
|
if filter_map_kwargs:
|
||||||
|
drop_long_kwargs["desc"] = "Add position_id column (Sample Packing)"
|
||||||
train_dataset = train_dataset.map(
|
train_dataset = train_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
batched=True,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
**filter_map_kwargs,
|
||||||
desc="Add position_id column (Sample Packing)",
|
**drop_long_kwargs,
|
||||||
)
|
)
|
||||||
if cfg.eval_sample_packing is not False:
|
if cfg.eval_sample_packing is not False:
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(
|
eval_dataset = eval_dataset.map(
|
||||||
add_position_ids,
|
add_position_ids,
|
||||||
num_proc=cfg.dataset_processes,
|
**filter_map_kwargs,
|
||||||
load_from_cache_file=not cfg.is_preprocess,
|
**drop_long_kwargs,
|
||||||
desc="Add position_id column (Sample Packing)",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
@@ -334,7 +446,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
|||||||
and not cfg.reward_model
|
and not cfg.reward_model
|
||||||
):
|
):
|
||||||
total_num_tokens = np.sum(
|
total_num_tokens = np.sum(
|
||||||
train_dataset.data.column("input_ids")
|
train_dataset.select_columns("input_ids")
|
||||||
.to_pandas()
|
.to_pandas()
|
||||||
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
.apply(lambda x: len(x)) # pylint: disable=unnecessary-lambda
|
||||||
.values
|
.values
|
||||||
|
|||||||
Reference in New Issue
Block a user