refactor utils.data module for line count linter (#1476)
This commit is contained in:
15
src/axolotl/utils/data/__init__.py
Normal file
15
src/axolotl/utils/data/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
"""
|
||||||
|
Data processing modules
|
||||||
|
"""
|
||||||
|
from axolotl.utils.data.dpo import load_prepare_dpo_datasets # noqa: F401
|
||||||
|
from axolotl.utils.data.pretraining import ( # noqa: F401
|
||||||
|
encode_pretraining,
|
||||||
|
wrap_pretraining_dataset,
|
||||||
|
)
|
||||||
|
from axolotl.utils.data.sft import ( # noqa: F401
|
||||||
|
get_dataset_wrapper,
|
||||||
|
load_prepare_datasets,
|
||||||
|
load_tokenized_prepared_datasets,
|
||||||
|
prepare_dataset,
|
||||||
|
)
|
||||||
|
from axolotl.utils.data.utils import md5 # noqa: F401
|
||||||
114
src/axolotl/utils/data/dpo.py
Normal file
114
src/axolotl/utils/data/dpo.py
Normal file
@@ -0,0 +1,114 @@
|
|||||||
|
"""data handling specific to DPO"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, List
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from datasets import concatenate_datasets, load_dataset, load_from_disk
|
||||||
|
|
||||||
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
|
from axolotl.prompt_strategies.dpo import load as load_dpo
|
||||||
|
from axolotl.utils.data.utils import md5
|
||||||
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_path(ds_hash, cfg):
|
||||||
|
prepared_ds_path = (
|
||||||
|
Path(cfg.dataset_prepared_path) / ds_hash
|
||||||
|
if cfg.dataset_prepared_path
|
||||||
|
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
||||||
|
)
|
||||||
|
|
||||||
|
return prepared_ds_path
|
||||||
|
|
||||||
|
|
||||||
|
def _load_preprocessed_ds(cfg, sub_cfg):
|
||||||
|
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
||||||
|
prepared_ds_path = _get_path(ds_hash, cfg)
|
||||||
|
dataset = None
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
if (
|
||||||
|
cfg.dataset_prepared_path
|
||||||
|
and any(prepared_ds_path.glob("*"))
|
||||||
|
and not cfg.is_preprocess
|
||||||
|
):
|
||||||
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||||
|
dataset = load_from_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
||||||
|
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
||||||
|
prepared_ds_path = _get_path(ds_hash, cfg)
|
||||||
|
|
||||||
|
if cfg.is_preprocess and is_main_process():
|
||||||
|
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||||
|
dataset.save_to_disk(str(prepared_ds_path))
|
||||||
|
|
||||||
|
|
||||||
|
def load_prepare_dpo_datasets(cfg):
|
||||||
|
def load_split(dataset_cfgs, _cfg):
|
||||||
|
split_datasets: List[Any] = []
|
||||||
|
for i, ds_cfg in enumerate(dataset_cfgs):
|
||||||
|
if ds_cfg["ds_type"] == "json":
|
||||||
|
for data_file in ds_cfg["data_files"]:
|
||||||
|
data_files = {ds_cfg["split"]: data_file}
|
||||||
|
ds = load_dataset( # pylint: disable=invalid-name
|
||||||
|
"json",
|
||||||
|
data_files=data_files,
|
||||||
|
split=ds_cfg["split"],
|
||||||
|
)
|
||||||
|
split_datasets.insert(i, ds)
|
||||||
|
else:
|
||||||
|
ds = load_dataset( # pylint: disable=invalid-name
|
||||||
|
ds_cfg["path"],
|
||||||
|
split=ds_cfg["split"],
|
||||||
|
)
|
||||||
|
split_datasets.insert(i, ds)
|
||||||
|
|
||||||
|
for i, data_set in enumerate(split_datasets):
|
||||||
|
_type = dataset_cfgs[i]["type"]
|
||||||
|
if _type:
|
||||||
|
if isinstance(_type, DictDefault):
|
||||||
|
_type = "user_defined.default"
|
||||||
|
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
||||||
|
split_datasets[i] = data_set.map(
|
||||||
|
ds_transform_fn,
|
||||||
|
desc="Mapping RL Dataset",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# If no `type` is provided, assume the dataset is already in the expected format with
|
||||||
|
# "prompt", "chosen" and "rejected" already preprocessed
|
||||||
|
split_datasets[i] = data_set
|
||||||
|
|
||||||
|
return concatenate_datasets(split_datasets)
|
||||||
|
|
||||||
|
with zero_first(is_main_process()):
|
||||||
|
train_is_preprocessed = False
|
||||||
|
eval_is_preprocessed = False
|
||||||
|
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
||||||
|
train_is_preprocessed = True
|
||||||
|
else:
|
||||||
|
train_dataset = load_split(cfg.datasets, cfg)
|
||||||
|
|
||||||
|
eval_dataset = None
|
||||||
|
if cfg.test_datasets:
|
||||||
|
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
||||||
|
eval_is_preprocessed = True
|
||||||
|
else:
|
||||||
|
eval_dataset = load_split(cfg.test_datasets, cfg)
|
||||||
|
if not eval_dataset:
|
||||||
|
eval_dataset = None
|
||||||
|
|
||||||
|
if not train_is_preprocessed:
|
||||||
|
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
||||||
|
if eval_dataset and not eval_is_preprocessed:
|
||||||
|
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
||||||
|
|
||||||
|
return train_dataset, eval_dataset
|
||||||
232
src/axolotl/utils/data/pretraining.py
Normal file
232
src/axolotl/utils/data/pretraining.py
Normal file
@@ -0,0 +1,232 @@
|
|||||||
|
"""data handling specific to pretraining"""
|
||||||
|
|
||||||
|
import functools
|
||||||
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Callable, Dict, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from datasets import Dataset
|
||||||
|
from torch.utils.data import RandomSampler
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
|
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||||
|
from axolotl.utils.trainer import process_pretraining_datasets_for_packing
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
def encode_pretraining(
|
||||||
|
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
||||||
|
) -> Dict[str, List]:
|
||||||
|
res = tokenizer(
|
||||||
|
examples,
|
||||||
|
truncation=True,
|
||||||
|
max_length=max_tokens - 2,
|
||||||
|
add_special_tokens=True,
|
||||||
|
)
|
||||||
|
# Convert to PyTorch tensors
|
||||||
|
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
||||||
|
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
||||||
|
new_input_ids = []
|
||||||
|
new_attention_mask = []
|
||||||
|
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
||||||
|
for i, _ in enumerate(input_ids):
|
||||||
|
input_ids[i] = torch.cat(
|
||||||
|
(
|
||||||
|
input_ids[i],
|
||||||
|
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
||||||
|
|
||||||
|
# Concatenate tokens so that their lengths are less than max_tokens
|
||||||
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
|
for ids, mask in zip(input_ids, attention_mask):
|
||||||
|
if buffer_input_ids.numel() == max_tokens:
|
||||||
|
new_input_ids.append(buffer_input_ids)
|
||||||
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
|
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
||||||
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
|
else:
|
||||||
|
buffer_input_ids = torch.cat(
|
||||||
|
(
|
||||||
|
buffer_input_ids,
|
||||||
|
torch.full(
|
||||||
|
(max_tokens - buffer_input_ids.numel(),),
|
||||||
|
tokenizer.pad_token_id,
|
||||||
|
dtype=torch.long,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
buffer_attention_mask = torch.cat(
|
||||||
|
(
|
||||||
|
buffer_attention_mask,
|
||||||
|
torch.full(
|
||||||
|
(max_tokens - buffer_attention_mask.numel(),),
|
||||||
|
0,
|
||||||
|
dtype=torch.long,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
new_input_ids.append(buffer_input_ids)
|
||||||
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
|
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
||||||
|
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
||||||
|
|
||||||
|
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
||||||
|
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
||||||
|
|
||||||
|
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
||||||
|
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
||||||
|
buffer_input_ids = torch.cat(
|
||||||
|
(
|
||||||
|
buffer_input_ids,
|
||||||
|
torch.full(
|
||||||
|
(max_tokens - buffer_input_ids.numel(),),
|
||||||
|
tokenizer.pad_token_id,
|
||||||
|
dtype=torch.long,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
buffer_attention_mask = torch.cat(
|
||||||
|
(
|
||||||
|
buffer_attention_mask,
|
||||||
|
torch.full(
|
||||||
|
(max_tokens - buffer_attention_mask.numel(),),
|
||||||
|
0,
|
||||||
|
dtype=torch.long,
|
||||||
|
),
|
||||||
|
),
|
||||||
|
dim=0,
|
||||||
|
)
|
||||||
|
new_input_ids.append(buffer_input_ids)
|
||||||
|
new_attention_mask.append(buffer_attention_mask)
|
||||||
|
|
||||||
|
ret = {
|
||||||
|
"input_ids": [seq.tolist() for seq in new_input_ids],
|
||||||
|
"labels": [seq.tolist() for seq in new_input_ids],
|
||||||
|
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
||||||
|
}
|
||||||
|
|
||||||
|
LOG.debug(len(ret["input_ids"]))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
def wrap_pretraining_dataset(
|
||||||
|
dataset,
|
||||||
|
tokenizer,
|
||||||
|
cfg,
|
||||||
|
ds_wrapper_fn,
|
||||||
|
max_tokens=2048,
|
||||||
|
batch_size=1,
|
||||||
|
seed=42,
|
||||||
|
buffer_size=10_000,
|
||||||
|
):
|
||||||
|
if cfg.sample_packing:
|
||||||
|
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding=True,
|
||||||
|
pad_to_multiple_of=max_tokens * batch_size,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
|
)
|
||||||
|
encode = functools.partial(
|
||||||
|
encode_packed_pretraining,
|
||||||
|
collate_fn,
|
||||||
|
ds_wrapper_fn,
|
||||||
|
max_seq_length=max_tokens,
|
||||||
|
batch_size=batch_size,
|
||||||
|
multipack_attn=cfg.pretrain_multipack_attn,
|
||||||
|
)
|
||||||
|
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||||
|
cfg.micro_batch_size = 1
|
||||||
|
else:
|
||||||
|
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||||
|
|
||||||
|
if cfg.shuffle_merged_datasets:
|
||||||
|
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||||
|
else:
|
||||||
|
LOG.debug("NOT shuffling merged pretraining datasets")
|
||||||
|
|
||||||
|
# remove all the existing columns after mapping since they end up having
|
||||||
|
# a different length than the encoded/tokenized column
|
||||||
|
# this is empty during streaming/pretraining
|
||||||
|
remove_columns = []
|
||||||
|
if dataset.features is None:
|
||||||
|
for first_row in dataset:
|
||||||
|
remove_columns = first_row.keys()
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
remove_columns = dataset.features.keys()
|
||||||
|
|
||||||
|
dataset = dataset.map(
|
||||||
|
encode,
|
||||||
|
batched=True,
|
||||||
|
batch_size=buffer_size,
|
||||||
|
# input_columns="text",
|
||||||
|
remove_columns=remove_columns,
|
||||||
|
)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
def encode_packed_pretraining(
|
||||||
|
collate_fn,
|
||||||
|
ds_wrapper: Callable,
|
||||||
|
examples: Dict[str, List],
|
||||||
|
max_seq_length: int = 2048,
|
||||||
|
batch_size: int = 4,
|
||||||
|
multipack_attn: Optional[bool] = False,
|
||||||
|
) -> Dict[str, List]:
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
# tokenize all the examples
|
||||||
|
# rows get split with stride (overlap)
|
||||||
|
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
||||||
|
|
||||||
|
train_dataset = process_pretraining_datasets_for_packing(
|
||||||
|
train_dataset,
|
||||||
|
max_seq_length,
|
||||||
|
skip_position_ids=not multipack_attn,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = MultipackBatchSampler(
|
||||||
|
RandomSampler(train_dataset),
|
||||||
|
batch_size=1,
|
||||||
|
drop_last=True,
|
||||||
|
batch_max_len=batch_size * max_seq_length,
|
||||||
|
lengths=get_dataset_lengths(train_dataset),
|
||||||
|
)
|
||||||
|
|
||||||
|
chunked_data = defaultdict(list)
|
||||||
|
|
||||||
|
for batch in sampler:
|
||||||
|
for data in batch:
|
||||||
|
features = train_dataset[data]
|
||||||
|
if "num_truncated_tokens" in features:
|
||||||
|
del features["num_truncated_tokens"]
|
||||||
|
if "num_truncated_tokens" in features:
|
||||||
|
del features["num_truncated_tokens"]
|
||||||
|
if "overflow_to_sample_mapping" in features:
|
||||||
|
del features["overflow_to_sample_mapping"]
|
||||||
|
if "labels" not in features:
|
||||||
|
features["labels"] = features["input_ids"].copy()
|
||||||
|
collated_features = collate_fn(features)
|
||||||
|
|
||||||
|
for feature in features.keys():
|
||||||
|
if feature == "length":
|
||||||
|
continue
|
||||||
|
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||||
|
|
||||||
|
return chunked_data
|
||||||
@@ -1,14 +1,10 @@
|
|||||||
"""Module containing data utilities"""
|
"""data handling specific to SFT"""
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import hashlib
|
|
||||||
import logging
|
import logging
|
||||||
from collections import defaultdict
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
|
||||||
import yaml
|
|
||||||
from datasets import (
|
from datasets import (
|
||||||
Dataset,
|
Dataset,
|
||||||
DatasetDict,
|
DatasetDict,
|
||||||
@@ -18,13 +14,11 @@ from datasets import (
|
|||||||
)
|
)
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from huggingface_hub.utils import HFValidationError
|
from huggingface_hub.utils import HFValidationError
|
||||||
from torch.utils.data import RandomSampler
|
|
||||||
from transformers import PreTrainedTokenizerBase
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
|
||||||
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
from axolotl.common.const import DEFAULT_DATASET_PREPARED_PATH
|
||||||
from axolotl.datasets import TokenizedPromptDataset
|
from axolotl.datasets import TokenizedPromptDataset
|
||||||
from axolotl.prompt_strategies import load
|
from axolotl.prompt_strategies import load
|
||||||
from axolotl.prompt_strategies.dpo import load as load_dpo
|
|
||||||
from axolotl.prompt_tokenizers import (
|
from axolotl.prompt_tokenizers import (
|
||||||
AlpacaMultipleChoicePromptTokenizingStrategy,
|
AlpacaMultipleChoicePromptTokenizingStrategy,
|
||||||
AlpacaPromptTokenizingStrategy,
|
AlpacaPromptTokenizingStrategy,
|
||||||
@@ -45,26 +39,18 @@ from axolotl.prompters import (
|
|||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
UnsupportedPrompter,
|
UnsupportedPrompter,
|
||||||
)
|
)
|
||||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
from axolotl.utils.data.pretraining import wrap_pretraining_dataset
|
||||||
|
from axolotl.utils.data.utils import md5
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
from axolotl.utils.distributed import is_main_process, zero_first
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|
||||||
from axolotl.utils.trainer import (
|
from axolotl.utils.trainer import (
|
||||||
calculate_total_num_steps,
|
calculate_total_num_steps,
|
||||||
process_datasets_for_packing,
|
process_datasets_for_packing,
|
||||||
process_pretraining_datasets_for_packing,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
|
||||||
try:
|
|
||||||
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
|
||||||
except TypeError:
|
|
||||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
prompters = []
|
prompters = []
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
@@ -182,6 +168,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
except Exception: # pylint: disable=broad-except # nosec
|
except Exception: # pylint: disable=broad-except # nosec
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if dataset:
|
if dataset:
|
||||||
...
|
...
|
||||||
elif (
|
elif (
|
||||||
@@ -691,315 +678,3 @@ def get_dataset_wrapper(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return dataset_wrapper, dataset_prompter
|
return dataset_wrapper, dataset_prompter
|
||||||
|
|
||||||
|
|
||||||
def encode_pretraining(
|
|
||||||
tokenizer: PreTrainedTokenizerBase, max_tokens: int, examples: List[str]
|
|
||||||
) -> Dict[str, List]:
|
|
||||||
res = tokenizer(
|
|
||||||
examples,
|
|
||||||
truncation=True,
|
|
||||||
max_length=max_tokens - 2,
|
|
||||||
add_special_tokens=True,
|
|
||||||
)
|
|
||||||
# Convert to PyTorch tensors
|
|
||||||
input_ids = [torch.tensor(seq) for seq in res["input_ids"]]
|
|
||||||
attention_mask = [torch.tensor(seq) for seq in res["attention_mask"]]
|
|
||||||
new_input_ids = []
|
|
||||||
new_attention_mask = []
|
|
||||||
# Append EOS and PAD tokens to input_ids, and correct attention_mask
|
|
||||||
for i, _ in enumerate(input_ids):
|
|
||||||
input_ids[i] = torch.cat(
|
|
||||||
(
|
|
||||||
input_ids[i],
|
|
||||||
torch.tensor([tokenizer.eos_token_id, tokenizer.pad_token_id]),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
attention_mask[i] = torch.cat((attention_mask[i], torch.tensor([1, 0])), dim=0)
|
|
||||||
|
|
||||||
# Concatenate tokens so that their lengths are less than max_tokens
|
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
|
||||||
|
|
||||||
for ids, mask in zip(input_ids, attention_mask):
|
|
||||||
if buffer_input_ids.numel() == max_tokens:
|
|
||||||
new_input_ids.append(buffer_input_ids)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
|
||||||
elif buffer_input_ids.numel() + ids.numel() <= max_tokens:
|
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
|
||||||
else:
|
|
||||||
buffer_input_ids = torch.cat(
|
|
||||||
(
|
|
||||||
buffer_input_ids,
|
|
||||||
torch.full(
|
|
||||||
(max_tokens - buffer_input_ids.numel(),),
|
|
||||||
tokenizer.pad_token_id,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
buffer_attention_mask = torch.cat(
|
|
||||||
(
|
|
||||||
buffer_attention_mask,
|
|
||||||
torch.full(
|
|
||||||
(max_tokens - buffer_attention_mask.numel(),),
|
|
||||||
0,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
new_input_ids.append(buffer_input_ids)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
|
||||||
buffer_input_ids = torch.tensor([], dtype=torch.long)
|
|
||||||
buffer_attention_mask = torch.tensor([], dtype=torch.long)
|
|
||||||
|
|
||||||
buffer_input_ids = torch.cat((buffer_input_ids, ids), dim=0)
|
|
||||||
buffer_attention_mask = torch.cat((buffer_attention_mask, mask), dim=0)
|
|
||||||
|
|
||||||
if buffer_input_ids.numel() > 0: # for any leftover tokens
|
|
||||||
while buffer_input_ids.numel() < max_tokens: # make all sequences equal in size
|
|
||||||
buffer_input_ids = torch.cat(
|
|
||||||
(
|
|
||||||
buffer_input_ids,
|
|
||||||
torch.full(
|
|
||||||
(max_tokens - buffer_input_ids.numel(),),
|
|
||||||
tokenizer.pad_token_id,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
buffer_attention_mask = torch.cat(
|
|
||||||
(
|
|
||||||
buffer_attention_mask,
|
|
||||||
torch.full(
|
|
||||||
(max_tokens - buffer_attention_mask.numel(),),
|
|
||||||
0,
|
|
||||||
dtype=torch.long,
|
|
||||||
),
|
|
||||||
),
|
|
||||||
dim=0,
|
|
||||||
)
|
|
||||||
new_input_ids.append(buffer_input_ids)
|
|
||||||
new_attention_mask.append(buffer_attention_mask)
|
|
||||||
|
|
||||||
ret = {
|
|
||||||
"input_ids": [seq.tolist() for seq in new_input_ids],
|
|
||||||
"labels": [seq.tolist() for seq in new_input_ids],
|
|
||||||
"attention_mask": [seq.tolist() for seq in new_attention_mask],
|
|
||||||
}
|
|
||||||
|
|
||||||
LOG.debug(len(ret["input_ids"]))
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def wrap_pretraining_dataset(
|
|
||||||
dataset,
|
|
||||||
tokenizer,
|
|
||||||
cfg,
|
|
||||||
ds_wrapper_fn,
|
|
||||||
max_tokens=2048,
|
|
||||||
batch_size=1,
|
|
||||||
seed=42,
|
|
||||||
buffer_size=10_000,
|
|
||||||
):
|
|
||||||
if cfg.sample_packing:
|
|
||||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
|
||||||
tokenizer,
|
|
||||||
return_tensors="pt",
|
|
||||||
padding=True,
|
|
||||||
pad_to_multiple_of=max_tokens * batch_size,
|
|
||||||
multipack_attn=cfg.pretrain_multipack_attn,
|
|
||||||
)
|
|
||||||
encode = functools.partial(
|
|
||||||
encode_packed_pretraining,
|
|
||||||
collate_fn,
|
|
||||||
ds_wrapper_fn,
|
|
||||||
max_seq_length=max_tokens,
|
|
||||||
batch_size=batch_size,
|
|
||||||
multipack_attn=cfg.pretrain_multipack_attn,
|
|
||||||
)
|
|
||||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
|
||||||
cfg.micro_batch_size = 1
|
|
||||||
else:
|
|
||||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
|
||||||
|
|
||||||
if cfg.shuffle_merged_datasets:
|
|
||||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
|
||||||
else:
|
|
||||||
LOG.debug("NOT shuffling merged pretraining datasets")
|
|
||||||
|
|
||||||
# remove all the existing columns after mapping since they end up having
|
|
||||||
# a different length than the encoded/tokenized column
|
|
||||||
# this is empty during streaming/pretraining
|
|
||||||
remove_columns = []
|
|
||||||
if dataset.features is None:
|
|
||||||
for first_row in dataset:
|
|
||||||
remove_columns = first_row.keys()
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
remove_columns = dataset.features.keys()
|
|
||||||
|
|
||||||
dataset = dataset.map(
|
|
||||||
encode,
|
|
||||||
batched=True,
|
|
||||||
batch_size=buffer_size,
|
|
||||||
# input_columns="text",
|
|
||||||
remove_columns=remove_columns,
|
|
||||||
)
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def encode_packed_pretraining(
|
|
||||||
collate_fn,
|
|
||||||
ds_wrapper: Callable,
|
|
||||||
examples: Dict[str, List],
|
|
||||||
max_seq_length: int = 2048,
|
|
||||||
batch_size: int = 4,
|
|
||||||
multipack_attn: Optional[bool] = False,
|
|
||||||
) -> Dict[str, List]:
|
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
# tokenize all the examples
|
|
||||||
# rows get split with stride (overlap)
|
|
||||||
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
|
||||||
|
|
||||||
train_dataset = process_pretraining_datasets_for_packing(
|
|
||||||
train_dataset,
|
|
||||||
max_seq_length,
|
|
||||||
skip_position_ids=not multipack_attn,
|
|
||||||
)
|
|
||||||
|
|
||||||
sampler = MultipackBatchSampler(
|
|
||||||
RandomSampler(train_dataset),
|
|
||||||
batch_size=1,
|
|
||||||
drop_last=True,
|
|
||||||
batch_max_len=batch_size * max_seq_length,
|
|
||||||
lengths=get_dataset_lengths(train_dataset),
|
|
||||||
)
|
|
||||||
|
|
||||||
chunked_data = defaultdict(list)
|
|
||||||
|
|
||||||
for batch in sampler:
|
|
||||||
for data in batch:
|
|
||||||
features = train_dataset[data]
|
|
||||||
if "num_truncated_tokens" in features:
|
|
||||||
del features["num_truncated_tokens"]
|
|
||||||
if "num_truncated_tokens" in features:
|
|
||||||
del features["num_truncated_tokens"]
|
|
||||||
if "overflow_to_sample_mapping" in features:
|
|
||||||
del features["overflow_to_sample_mapping"]
|
|
||||||
if "labels" not in features:
|
|
||||||
features["labels"] = features["input_ids"].copy()
|
|
||||||
collated_features = collate_fn(features)
|
|
||||||
|
|
||||||
for feature in features.keys():
|
|
||||||
if feature == "length":
|
|
||||||
continue
|
|
||||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
|
||||||
|
|
||||||
return chunked_data
|
|
||||||
|
|
||||||
|
|
||||||
def _get_path(ds_hash, cfg):
|
|
||||||
prepared_ds_path = (
|
|
||||||
Path(cfg.dataset_prepared_path) / ds_hash
|
|
||||||
if cfg.dataset_prepared_path
|
|
||||||
else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
|
||||||
)
|
|
||||||
|
|
||||||
return prepared_ds_path
|
|
||||||
|
|
||||||
|
|
||||||
def _load_preprocessed_ds(cfg, sub_cfg):
|
|
||||||
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
|
||||||
prepared_ds_path = _get_path(ds_hash, cfg)
|
|
||||||
dataset = None
|
|
||||||
|
|
||||||
if (
|
|
||||||
cfg.dataset_prepared_path
|
|
||||||
and any(prepared_ds_path.glob("*"))
|
|
||||||
and not cfg.is_preprocess
|
|
||||||
):
|
|
||||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
|
||||||
dataset = load_from_disk(str(prepared_ds_path))
|
|
||||||
|
|
||||||
return dataset
|
|
||||||
|
|
||||||
|
|
||||||
def _save_preprocessed_ds(cfg, sub_cfg, dataset):
|
|
||||||
ds_hash = md5(yaml.dump(sub_cfg, Dumper=yaml.Dumper))
|
|
||||||
prepared_ds_path = _get_path(ds_hash, cfg)
|
|
||||||
|
|
||||||
if cfg.is_preprocess and is_main_process():
|
|
||||||
LOG.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
|
||||||
dataset.save_to_disk(str(prepared_ds_path))
|
|
||||||
|
|
||||||
|
|
||||||
def load_prepare_dpo_datasets(cfg):
|
|
||||||
def load_split(dataset_cfgs, _cfg):
|
|
||||||
split_datasets: List[Any] = []
|
|
||||||
for i, ds_cfg in enumerate(dataset_cfgs):
|
|
||||||
if ds_cfg["ds_type"] == "json":
|
|
||||||
for data_file in ds_cfg["data_files"]:
|
|
||||||
data_files = {ds_cfg["split"]: data_file}
|
|
||||||
ds = load_dataset( # pylint: disable=invalid-name
|
|
||||||
"json",
|
|
||||||
data_files=data_files,
|
|
||||||
split=ds_cfg["split"],
|
|
||||||
)
|
|
||||||
split_datasets.insert(i, ds)
|
|
||||||
else:
|
|
||||||
ds = load_dataset( # pylint: disable=invalid-name
|
|
||||||
ds_cfg["path"],
|
|
||||||
split=ds_cfg["split"],
|
|
||||||
)
|
|
||||||
split_datasets.insert(i, ds)
|
|
||||||
|
|
||||||
for i, data_set in enumerate(split_datasets):
|
|
||||||
_type = dataset_cfgs[i]["type"]
|
|
||||||
if _type:
|
|
||||||
if isinstance(_type, DictDefault):
|
|
||||||
_type = "user_defined.default"
|
|
||||||
ds_transform_fn = load_dpo(_type, _cfg, dataset_idx=i)
|
|
||||||
split_datasets[i] = data_set.map(
|
|
||||||
ds_transform_fn,
|
|
||||||
desc="Mapping RL Dataset",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# If no `type` is provided, assume the dataset is already in the expected format with
|
|
||||||
# "prompt", "chosen" and "rejected" already preprocessed
|
|
||||||
split_datasets[i] = data_set
|
|
||||||
|
|
||||||
return concatenate_datasets(split_datasets)
|
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
|
||||||
train_is_preprocessed = False
|
|
||||||
eval_is_preprocessed = False
|
|
||||||
if train_dataset := _load_preprocessed_ds(cfg, cfg.datasets):
|
|
||||||
train_is_preprocessed = True
|
|
||||||
else:
|
|
||||||
train_dataset = load_split(cfg.datasets, cfg)
|
|
||||||
|
|
||||||
eval_dataset = None
|
|
||||||
if cfg.test_datasets:
|
|
||||||
if eval_dataset := _load_preprocessed_ds(cfg, cfg.test_datasets):
|
|
||||||
eval_is_preprocessed = True
|
|
||||||
else:
|
|
||||||
eval_dataset = load_split(cfg.test_datasets, cfg)
|
|
||||||
if not eval_dataset:
|
|
||||||
eval_dataset = None
|
|
||||||
|
|
||||||
if not train_is_preprocessed:
|
|
||||||
_save_preprocessed_ds(cfg, cfg.datasets, train_dataset)
|
|
||||||
if eval_dataset and not eval_is_preprocessed:
|
|
||||||
_save_preprocessed_ds(cfg, cfg.test_datasets, eval_dataset)
|
|
||||||
|
|
||||||
return train_dataset, eval_dataset
|
|
||||||
10
src/axolotl/utils/data/utils.py
Normal file
10
src/axolotl/utils/data/utils.py
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
"""data handling helpers"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||||
|
try:
|
||||||
|
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||||
|
except TypeError:
|
||||||
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||||
Reference in New Issue
Block a user