Compare commits

...

1 Commits

Author SHA1 Message Date
Wing Lian
b4d84d56d5 support for batched sharegpt tokenization to skip bad data 2023-10-06 15:03:07 -04:00
7 changed files with 170 additions and 96 deletions

View File

@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
disable=missing-function-docstring, line-too-long, import-error, disable=missing-function-docstring, line-too-long, import-error,
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
too-many-nested-blocks,

View File

@@ -5,7 +5,7 @@ import os
from typing import List from typing import List
import torch import torch
from datasets import Dataset, IterableDataset from datasets import Dataset, IterableDataset, Sequence, Value
from .prompt_tokenizers import PromptTokenizingStrategy from .prompt_tokenizers import PromptTokenizingStrategy
@@ -42,12 +42,16 @@ class TokenizedPromptDataset(Dataset):
if self.prompt_tokenizer.supports_batched: if self.prompt_tokenizer.supports_batched:
map_kwargs["batched"] = True map_kwargs["batched"] = True
map_kwargs["batch_size"] = 100 map_kwargs["batch_size"] = 100
return dataset.map( return (
dataset.map(
self.prompt_tokenizer.tokenize_prompt, self.prompt_tokenizer.tokenize_prompt,
num_proc=num_proc, num_proc=num_proc,
remove_columns=features, remove_columns=features,
**map_kwargs, **map_kwargs,
) )
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
)
# TODO this isn't the best since it can't interleave datasets # TODO this isn't the best since it can't interleave datasets

View File

@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
) )
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return SimpleShareGPTPromptTokenizingStrategy( strat = ShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2( ShareGPTPrompterV2(
conversation=conversation, conversation=conversation,
role_key_model=field_model, role_key_model=field_model,
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
cfg.train_on_inputs, cfg.train_on_inputs,
cfg.sequence_len, cfg.sequence_len,
) )
if ds_cfg and ds_cfg["skip"]:
strat.skip_invalid = True
return strat
def load_role(tokenizer, cfg): def load_role(tokenizer, cfg):
@@ -54,13 +57,38 @@ def load_guanaco(tokenizer, cfg):
) )
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
return NousShareGPTPromptTokenizingStrategy(
ShareGPTPrompterV2(
conversation=conversation,
role_key_model=field_model,
role_key_human=field_human,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
""" """
basic sharegpt strategy to grab conversations from the sample row basic sharegpt strategy used by nous/ldj for input/output keyed data
""" """
def get_conversation_thread(self, prompt): def get_conversation_thread(self):
return prompt["conversations"] return "conversation"
def map_conversation_thread(self, conversation):
turns = []
for turn in conversation:
turns.append({"from": "human", "value": turn["input"]})
turns.append({"from": "gpt", "value": turn["output"]})
return turns
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
@@ -68,10 +96,11 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
""" """
def get_conversation_thread(self, prompt): def map_conversation_thread(self, conversation):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ... # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
turns = [{"from": t["role"], "value": t["value"]} for t in conversations] turns = [
{"from": turn["role"], "value": turn["value"]} for turn in conversation
]
return turns return turns
@@ -80,11 +109,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
sharegpt strategy that remaps oasst data to sharegpt format sharegpt strategy that remaps oasst data to sharegpt format
""" """
def get_conversation_thread(self, prompt): def map_conversation_thread(self, conversation):
conversations = prompt["conversations"]
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ... # remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
role_map = {"prompter": "human", "assistant": "gpt"} role_map = {"prompter": "human", "assistant": "gpt"}
turns = [ turns = [
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations {"from": role_map[turn["role"]], "value": turn["text"]}
for turn in conversation
] ]
return turns return turns

View File

@@ -4,6 +4,7 @@ import abc
import copy import copy
import functools import functools
import logging import logging
from collections import defaultdict
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from fastchat.conversation import Conversation from fastchat.conversation import Conversation
@@ -351,10 +352,30 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
Tokenizing strategy for ShareGPT prompts. Tokenizing strategy for ShareGPT prompts.
""" """
def get_conversation_thread(self, prompt): _skip_invalid = False
return prompt["conversations"]
@property
def supports_batched(self):
return True
@property
def skip_invalid(self):
return self._skip_invalid
@skip_invalid.setter
def skip_invalid(self, value):
self._skip_invalid = value
def get_conversation_thread(self):
return "conversations"
def map_conversation_thread(self, conversation):
return conversation
def tokenize_prompt(self, prompt): def tokenize_prompt(self, prompt):
tokenized_res = defaultdict(lambda: [])
conv_field = self.get_conversation_thread()
for prmpt in prompt[conv_field]:
result, current_len = tokenize_prompt_default() result, current_len = tokenize_prompt_default()
user_token = self._get_user_token() user_token = self._get_user_token()
assistant_token = self._get_assistant_token() assistant_token = self._get_assistant_token()
@@ -363,14 +384,17 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
) )
try: try:
for _, part in enumerate( for _, part in enumerate(
self.prompter.build_prompt(self.get_conversation_thread(prompt)) self.prompter.build_prompt(self.map_conversation_thread(prmpt))
): ):
if isinstance(part, tuple): if isinstance(part, tuple):
if conversation.roles[0] in part[0]: if conversation.roles[0] in part[0]:
turn = part[0] + part[1] if not user_token else part[1] turn = part[0] + part[1] if not user_token else part[1]
# this is still the user query, we should # this is still the user query, we should
if not part[1].strip(): if not part[1].strip():
LOG.warning(f"user turn has empty text: {prompt}") err_msg = f"user turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize( res = self._tokenize(
turn, turn,
add_eos_token=False, add_eos_token=False,
@@ -385,7 +409,10 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
turn = part[0] + part[1] if not assistant_token else part[1] turn = part[0] + part[1] if not assistant_token else part[1]
# this should be the assistant response, should end with an eos token # this should be the assistant response, should end with an eos token
if not part[1].strip(): if not part[1].strip():
LOG.warning(f"assistant turn has empty text: {prompt}") err_msg = f"assistant turn has empty text: {prmpt}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
res = self._tokenize( res = self._tokenize(
turn, turn,
add_eos_token=True, add_eos_token=True,
@@ -407,7 +434,10 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
# everything from this is masked out from the labels # everything from this is masked out from the labels
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
else: else:
LOG.warning(f"unhandled role: {part[0]}") err_msg = f"unhandled role: {part[0]}"
if self.skip_invalid:
raise ValueError(err_msg)
LOG.warning(err_msg)
continue continue
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
@@ -418,9 +448,13 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
labels, labels,
pad_token_id=self.tokenizer.pad_token_id, pad_token_id=self.tokenizer.pad_token_id,
) )
return result for key, val in sorted(result.items(), key=lambda x: x[0]):
tokenized_res[key].append(val)
except (KeyError, AssertionError, IndexError) as err: except (KeyError, AssertionError, IndexError) as err:
raise InvalidDataException(str(err)) from err raise InvalidDataException(str(err)) from err
except ValueError as err:
LOG.warning("skipping prompt: %s", str(err))
return tokenized_res
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
if not prompt.strip(): if not prompt.strip():

View File

@@ -58,7 +58,9 @@ def train(
safe_serialization = cfg.save_safetensors is True safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: if (
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
) or cfg.resume_from_checkpoint is True:
possible_checkpoints = [ possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
] ]
@@ -71,7 +73,9 @@ def train(
LOG.info( LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
) )
resume_from_checkpoint = cfg.resume_from_checkpoint resume_from_checkpoint = (
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
)
trainer = setup_trainer( trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps

View File

@@ -158,14 +158,16 @@ def load_tokenized_prepared_datasets(
token=use_auth_token, token=use_auth_token,
) )
ds_from_hub = True ds_from_hub = True
except FileNotFoundError: except (FileNotFoundError, ValueError):
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(d.path) local_path = Path(d.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk` if not d.type:
ds = load_from_disk(d.path)
else:
ds = load_dataset( ds = load_dataset(
d.path, d.path,
name=d.name, name=d.name,

View File

@@ -382,7 +382,7 @@ def load_model(
if model_config.model_type == "btlm": if model_config.model_type == "btlm":
# don't upcast lm_head for btlm # don't upcast lm_head for btlm
continue continue
if "lm_head" in name or "embed_tokens" in name: if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
if hasattr(module, "weight"): if hasattr(module, "weight"):
module.to(torch.float32) module.to(torch.float32)