* wip add new proposed message structure * tokenization * wip * wip transform builder * wip make the chat dataset loadable * wip chatml + llama 3 new chat objects * chore: lint * chore: lint * fix tokenization * remove dacite dependency since we're using pydantic now * fix handling when already correctly split in messages * make sure to remove chat features from tokenized ds * move chat to be a input transform for messages * make sure llama3 has the bos token * remove non-working special token code * fix messages strat loader
56 lines
1.5 KiB
Python
56 lines
1.5 KiB
Python
"""
|
|
chat dataset module
|
|
"""
|
|
import os
|
|
from typing import Callable, Optional, Union
|
|
|
|
from datasets import Dataset
|
|
from transformers import PreTrainedTokenizer
|
|
|
|
from axolotl.core.chat.messages import ChatFormattedChats
|
|
|
|
|
|
class TokenizedChatDataset(Dataset):
|
|
"""
|
|
Tokenized chat dataset
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data: Dataset,
|
|
model_transform: Union[PreTrainedTokenizer, Callable],
|
|
*args,
|
|
message_transform: Optional[Callable] = None,
|
|
formatter=None,
|
|
process_count: Optional[int] = None,
|
|
keep_in_memory: Optional[bool] = False,
|
|
**kwargs,
|
|
):
|
|
def map_fn(ex):
|
|
if message_transform is not None:
|
|
ex = message_transform(ex)
|
|
if formatter is not None:
|
|
ex = ChatFormattedChats(
|
|
formatter=formatter,
|
|
**ex,
|
|
)
|
|
else:
|
|
ex = ChatFormattedChats(
|
|
**ex,
|
|
)
|
|
return ex.tokenized(model_transform)
|
|
|
|
process_or_cpu_count: int = (
|
|
process_count or os.cpu_count() # type: ignore[assignment]
|
|
)
|
|
num_proc = min(64, process_or_cpu_count)
|
|
features = data.features.keys()
|
|
tokenized_data = data.map(
|
|
map_fn,
|
|
num_proc=num_proc,
|
|
keep_in_memory=keep_in_memory,
|
|
remove_columns=features,
|
|
desc="Tokenizing Chats",
|
|
)
|
|
super().__init__(tokenized_data.data, *args, **kwargs)
|