Compare commits
1 Commits
v0.4.0
...
sharegpt-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4d84d56d5 |
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
|
||||
disable=missing-function-docstring, line-too-long, import-error,
|
||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||
too-many-nested-blocks,
|
||||
|
||||
@@ -5,7 +5,7 @@ import os
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from datasets import Dataset, IterableDataset
|
||||
from datasets import Dataset, IterableDataset, Sequence, Value
|
||||
|
||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
@@ -42,11 +42,15 @@ class TokenizedPromptDataset(Dataset):
|
||||
if self.prompt_tokenizer.supports_batched:
|
||||
map_kwargs["batched"] = True
|
||||
map_kwargs["batch_size"] = 100
|
||||
return dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
return (
|
||||
dataset.map(
|
||||
self.prompt_tokenizer.tokenize_prompt,
|
||||
num_proc=num_proc,
|
||||
remove_columns=features,
|
||||
**map_kwargs,
|
||||
)
|
||||
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
|
||||
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
return SimpleShareGPTPromptTokenizingStrategy(
|
||||
strat = ShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
if ds_cfg and ds_cfg["skip"]:
|
||||
strat.skip_invalid = True
|
||||
return strat
|
||||
|
||||
|
||||
def load_role(tokenizer, cfg):
|
||||
@@ -54,13 +57,38 @@ def load_guanaco(tokenizer, cfg):
|
||||
)
|
||||
|
||||
|
||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||
return NousShareGPTPromptTokenizingStrategy(
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
role_key_model=field_model,
|
||||
role_key_human=field_human,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
|
||||
|
||||
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
basic sharegpt strategy used by nous/ldj for input/output keyed data
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
def get_conversation_thread(self):
|
||||
return "conversation"
|
||||
|
||||
def map_conversation_thread(self, conversation):
|
||||
turns = []
|
||||
for turn in conversation:
|
||||
turns.append({"from": "human", "value": turn["input"]})
|
||||
turns.append({"from": "gpt", "value": turn["output"]})
|
||||
return turns
|
||||
|
||||
|
||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
@@ -68,10 +96,11 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
|
||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
def map_conversation_thread(self, conversation):
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
||||
turns = [
|
||||
{"from": turn["role"], "value": turn["value"]} for turn in conversation
|
||||
]
|
||||
return turns
|
||||
|
||||
|
||||
@@ -80,11 +109,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
sharegpt strategy that remaps oasst data to sharegpt format
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
conversations = prompt["conversations"]
|
||||
def map_conversation_thread(self, conversation):
|
||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||
turns = [
|
||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
||||
{"from": role_map[turn["role"]], "value": turn["text"]}
|
||||
for turn in conversation
|
||||
]
|
||||
return turns
|
||||
|
||||
@@ -4,6 +4,7 @@ import abc
|
||||
import copy
|
||||
import functools
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Tuple, Union
|
||||
|
||||
from fastchat.conversation import Conversation
|
||||
@@ -351,76 +352,109 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
||||
Tokenizing strategy for ShareGPT prompts.
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return prompt["conversations"]
|
||||
_skip_invalid = False
|
||||
|
||||
@property
|
||||
def supports_batched(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def skip_invalid(self):
|
||||
return self._skip_invalid
|
||||
|
||||
@skip_invalid.setter
|
||||
def skip_invalid(self, value):
|
||||
self._skip_invalid = value
|
||||
|
||||
def get_conversation_thread(self):
|
||||
return "conversations"
|
||||
|
||||
def map_conversation_thread(self, conversation):
|
||||
return conversation
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
result, current_len = tokenize_prompt_default()
|
||||
user_token = self._get_user_token()
|
||||
assistant_token = self._get_assistant_token()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation # pylint: disable=protected-access
|
||||
)
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
if conversation.roles[0] in part[0]:
|
||||
turn = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"user turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif conversation.roles[1] in part[0]:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not part[1].strip():
|
||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
elif part[0] == "":
|
||||
turn = part[1]
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
LOG.warning(f"unhandled role: {part[0]}")
|
||||
continue
|
||||
tokenized_res = defaultdict(lambda: [])
|
||||
conv_field = self.get_conversation_thread()
|
||||
for prmpt in prompt[conv_field]:
|
||||
result, current_len = tokenize_prompt_default()
|
||||
user_token = self._get_user_token()
|
||||
assistant_token = self._get_assistant_token()
|
||||
conversation: Conversation = (
|
||||
self.prompter._conversation # pylint: disable=protected-access
|
||||
)
|
||||
try:
|
||||
for _, part in enumerate(
|
||||
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
|
||||
):
|
||||
if isinstance(part, tuple):
|
||||
if conversation.roles[0] in part[0]:
|
||||
turn = part[0] + part[1] if not user_token else part[1]
|
||||
# this is still the user query, we should
|
||||
if not part[1].strip():
|
||||
err_msg = f"user turn has empty text: {prmpt}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=False,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if user_token:
|
||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
elif conversation.roles[1] in part[0]:
|
||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||
# this should be the assistant response, should end with an eos token
|
||||
if not part[1].strip():
|
||||
err_msg = f"assistant turn has empty text: {prmpt}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
res = self._tokenize(
|
||||
turn,
|
||||
add_eos_token=True,
|
||||
strip_bos_token=True,
|
||||
)
|
||||
if assistant_token:
|
||||
res["input_ids"] = [
|
||||
assistant_token,
|
||||
*res["input_ids"],
|
||||
]
|
||||
# not masked out from labels
|
||||
labels = copy.deepcopy(res["input_ids"])
|
||||
elif part[0] == "":
|
||||
turn = part[1]
|
||||
# this is only ever the first part, should include the bos token and the user query
|
||||
res = self._tokenize(
|
||||
turn, add_eos_token=False, strip_bos_token=False
|
||||
)
|
||||
# everything from this is masked out from the labels
|
||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||
else:
|
||||
err_msg = f"unhandled role: {part[0]}"
|
||||
if self.skip_invalid:
|
||||
raise ValueError(err_msg)
|
||||
LOG.warning(err_msg)
|
||||
continue
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
return result
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
# pylint: disable=duplicate-code
|
||||
result, current_len = parse_tokenized_to_result(
|
||||
result,
|
||||
current_len,
|
||||
res,
|
||||
labels,
|
||||
pad_token_id=self.tokenizer.pad_token_id,
|
||||
)
|
||||
for key, val in sorted(result.items(), key=lambda x: x[0]):
|
||||
tokenized_res[key].append(val)
|
||||
except (KeyError, AssertionError, IndexError) as err:
|
||||
raise InvalidDataException(str(err)) from err
|
||||
except ValueError as err:
|
||||
LOG.warning("skipping prompt: %s", str(err))
|
||||
return tokenized_res
|
||||
|
||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||
if not prompt.strip():
|
||||
|
||||
@@ -58,7 +58,9 @@ def train(
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
if (
|
||||
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
|
||||
) or cfg.resume_from_checkpoint is True:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
@@ -71,7 +73,9 @@ def train(
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
resume_from_checkpoint = (
|
||||
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
|
||||
)
|
||||
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
|
||||
@@ -158,21 +158,23 @@ def load_tokenized_prepared_datasets(
|
||||
token=use_auth_token,
|
||||
)
|
||||
ds_from_hub = True
|
||||
except FileNotFoundError:
|
||||
except (FileNotFoundError, ValueError):
|
||||
pass
|
||||
|
||||
# prefer local dataset, even if hub exists
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
if not d.type:
|
||||
ds = load_from_disk(d.path)
|
||||
else:
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds_type = "json"
|
||||
if d.ds_type:
|
||||
|
||||
@@ -382,7 +382,7 @@ def load_model(
|
||||
if model_config.model_type == "btlm":
|
||||
# don't upcast lm_head for btlm
|
||||
continue
|
||||
if "lm_head" in name or "embed_tokens" in name:
|
||||
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
|
||||
if hasattr(module, "weight"):
|
||||
module.to(torch.float32)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user