Compare commits
1 Commits
diff-trans
...
sharegpt-b
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b4d84d56d5 |
@@ -12,3 +12,4 @@ generated-members=numpy.*, torch.*
|
|||||||
disable=missing-function-docstring, line-too-long, import-error,
|
disable=missing-function-docstring, line-too-long, import-error,
|
||||||
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods,
|
||||||
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation,
|
||||||
|
too-many-nested-blocks,
|
||||||
|
|||||||
@@ -5,7 +5,7 @@ import os
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from datasets import Dataset, IterableDataset
|
from datasets import Dataset, IterableDataset, Sequence, Value
|
||||||
|
|
||||||
from .prompt_tokenizers import PromptTokenizingStrategy
|
from .prompt_tokenizers import PromptTokenizingStrategy
|
||||||
|
|
||||||
@@ -42,11 +42,15 @@ class TokenizedPromptDataset(Dataset):
|
|||||||
if self.prompt_tokenizer.supports_batched:
|
if self.prompt_tokenizer.supports_batched:
|
||||||
map_kwargs["batched"] = True
|
map_kwargs["batched"] = True
|
||||||
map_kwargs["batch_size"] = 100
|
map_kwargs["batch_size"] = 100
|
||||||
return dataset.map(
|
return (
|
||||||
self.prompt_tokenizer.tokenize_prompt,
|
dataset.map(
|
||||||
num_proc=num_proc,
|
self.prompt_tokenizer.tokenize_prompt,
|
||||||
remove_columns=features,
|
num_proc=num_proc,
|
||||||
**map_kwargs,
|
remove_columns=features,
|
||||||
|
**map_kwargs,
|
||||||
|
)
|
||||||
|
.cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None)))
|
||||||
|
.cast_column("labels", Sequence(feature=Value(dtype="int32", id=None)))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
)
|
)
|
||||||
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
return SimpleShareGPTPromptTokenizingStrategy(
|
strat = ShareGPTPromptTokenizingStrategy(
|
||||||
ShareGPTPrompterV2(
|
ShareGPTPrompterV2(
|
||||||
conversation=conversation,
|
conversation=conversation,
|
||||||
role_key_model=field_model,
|
role_key_model=field_model,
|
||||||
@@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
if ds_cfg and ds_cfg["skip"]:
|
||||||
|
strat.skip_invalid = True
|
||||||
|
return strat
|
||||||
|
|
||||||
|
|
||||||
def load_role(tokenizer, cfg):
|
def load_role(tokenizer, cfg):
|
||||||
@@ -54,13 +57,38 @@ def load_guanaco(tokenizer, cfg):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||||
|
conversation = (
|
||||||
|
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||||
|
)
|
||||||
|
field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None
|
||||||
|
field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None
|
||||||
|
return NousShareGPTPromptTokenizingStrategy(
|
||||||
|
ShareGPTPrompterV2(
|
||||||
|
conversation=conversation,
|
||||||
|
role_key_model=field_model,
|
||||||
|
role_key_human=field_human,
|
||||||
|
),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
"""
|
"""
|
||||||
basic sharegpt strategy to grab conversations from the sample row
|
basic sharegpt strategy used by nous/ldj for input/output keyed data
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def get_conversation_thread(self):
|
||||||
return prompt["conversations"]
|
return "conversation"
|
||||||
|
|
||||||
|
def map_conversation_thread(self, conversation):
|
||||||
|
turns = []
|
||||||
|
for turn in conversation:
|
||||||
|
turns.append({"from": "human", "value": turn["input"]})
|
||||||
|
turns.append({"from": "gpt", "value": turn["output"]})
|
||||||
|
return turns
|
||||||
|
|
||||||
|
|
||||||
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||||
@@ -68,10 +96,11 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg
|
|||||||
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def map_conversation_thread(self, conversation):
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||||
turns = [{"from": t["role"], "value": t["value"]} for t in conversations]
|
turns = [
|
||||||
|
{"from": turn["role"], "value": turn["value"]} for turn in conversation
|
||||||
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|
||||||
|
|
||||||
@@ -80,11 +109,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
|||||||
sharegpt strategy that remaps oasst data to sharegpt format
|
sharegpt strategy that remaps oasst data to sharegpt format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
def map_conversation_thread(self, conversation):
|
||||||
conversations = prompt["conversations"]
|
|
||||||
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
# remap role: prompter/assistant, text: ... => from: human/gpt, value: ...
|
||||||
role_map = {"prompter": "human", "assistant": "gpt"}
|
role_map = {"prompter": "human", "assistant": "gpt"}
|
||||||
turns = [
|
turns = [
|
||||||
{"from": role_map[t["role"]], "value": t["text"]} for t in conversations
|
{"from": role_map[turn["role"]], "value": turn["text"]}
|
||||||
|
for turn in conversation
|
||||||
]
|
]
|
||||||
return turns
|
return turns
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import abc
|
|||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Tuple, Union
|
from typing import Dict, List, Tuple, Union
|
||||||
|
|
||||||
from fastchat.conversation import Conversation
|
from fastchat.conversation import Conversation
|
||||||
@@ -351,76 +352,109 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy):
|
|||||||
Tokenizing strategy for ShareGPT prompts.
|
Tokenizing strategy for ShareGPT prompts.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def get_conversation_thread(self, prompt):
|
_skip_invalid = False
|
||||||
return prompt["conversations"]
|
|
||||||
|
@property
|
||||||
|
def supports_batched(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skip_invalid(self):
|
||||||
|
return self._skip_invalid
|
||||||
|
|
||||||
|
@skip_invalid.setter
|
||||||
|
def skip_invalid(self, value):
|
||||||
|
self._skip_invalid = value
|
||||||
|
|
||||||
|
def get_conversation_thread(self):
|
||||||
|
return "conversations"
|
||||||
|
|
||||||
|
def map_conversation_thread(self, conversation):
|
||||||
|
return conversation
|
||||||
|
|
||||||
def tokenize_prompt(self, prompt):
|
def tokenize_prompt(self, prompt):
|
||||||
result, current_len = tokenize_prompt_default()
|
tokenized_res = defaultdict(lambda: [])
|
||||||
user_token = self._get_user_token()
|
conv_field = self.get_conversation_thread()
|
||||||
assistant_token = self._get_assistant_token()
|
for prmpt in prompt[conv_field]:
|
||||||
conversation: Conversation = (
|
result, current_len = tokenize_prompt_default()
|
||||||
self.prompter._conversation # pylint: disable=protected-access
|
user_token = self._get_user_token()
|
||||||
)
|
assistant_token = self._get_assistant_token()
|
||||||
try:
|
conversation: Conversation = (
|
||||||
for _, part in enumerate(
|
self.prompter._conversation # pylint: disable=protected-access
|
||||||
self.prompter.build_prompt(self.get_conversation_thread(prompt))
|
)
|
||||||
):
|
try:
|
||||||
if isinstance(part, tuple):
|
for _, part in enumerate(
|
||||||
if conversation.roles[0] in part[0]:
|
self.prompter.build_prompt(self.map_conversation_thread(prmpt))
|
||||||
turn = part[0] + part[1] if not user_token else part[1]
|
):
|
||||||
# this is still the user query, we should
|
if isinstance(part, tuple):
|
||||||
if not part[1].strip():
|
if conversation.roles[0] in part[0]:
|
||||||
LOG.warning(f"user turn has empty text: {prompt}")
|
turn = part[0] + part[1] if not user_token else part[1]
|
||||||
res = self._tokenize(
|
# this is still the user query, we should
|
||||||
turn,
|
if not part[1].strip():
|
||||||
add_eos_token=False,
|
err_msg = f"user turn has empty text: {prmpt}"
|
||||||
strip_bos_token=True,
|
if self.skip_invalid:
|
||||||
)
|
raise ValueError(err_msg)
|
||||||
if user_token:
|
LOG.warning(err_msg)
|
||||||
res["input_ids"] = [user_token, *res["input_ids"]]
|
res = self._tokenize(
|
||||||
# everything from this is masked out from the labels
|
turn,
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
add_eos_token=False,
|
||||||
elif conversation.roles[1] in part[0]:
|
strip_bos_token=True,
|
||||||
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
)
|
||||||
turn = part[0] + part[1] if not assistant_token else part[1]
|
if user_token:
|
||||||
# this should be the assistant response, should end with an eos token
|
res["input_ids"] = [user_token, *res["input_ids"]]
|
||||||
if not part[1].strip():
|
# everything from this is masked out from the labels
|
||||||
LOG.warning(f"assistant turn has empty text: {prompt}")
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
res = self._tokenize(
|
elif conversation.roles[1] in part[0]:
|
||||||
turn,
|
# TODO label assistant token/tokens w/ IGNORE_TOKEN_ID
|
||||||
add_eos_token=True,
|
turn = part[0] + part[1] if not assistant_token else part[1]
|
||||||
strip_bos_token=True,
|
# this should be the assistant response, should end with an eos token
|
||||||
)
|
if not part[1].strip():
|
||||||
if assistant_token:
|
err_msg = f"assistant turn has empty text: {prmpt}"
|
||||||
res["input_ids"] = [
|
if self.skip_invalid:
|
||||||
assistant_token,
|
raise ValueError(err_msg)
|
||||||
*res["input_ids"],
|
LOG.warning(err_msg)
|
||||||
]
|
res = self._tokenize(
|
||||||
# not masked out from labels
|
turn,
|
||||||
labels = copy.deepcopy(res["input_ids"])
|
add_eos_token=True,
|
||||||
elif part[0] == "":
|
strip_bos_token=True,
|
||||||
turn = part[1]
|
)
|
||||||
# this is only ever the first part, should include the bos token and the user query
|
if assistant_token:
|
||||||
res = self._tokenize(
|
res["input_ids"] = [
|
||||||
turn, add_eos_token=False, strip_bos_token=False
|
assistant_token,
|
||||||
)
|
*res["input_ids"],
|
||||||
# everything from this is masked out from the labels
|
]
|
||||||
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
# not masked out from labels
|
||||||
else:
|
labels = copy.deepcopy(res["input_ids"])
|
||||||
LOG.warning(f"unhandled role: {part[0]}")
|
elif part[0] == "":
|
||||||
continue
|
turn = part[1]
|
||||||
|
# this is only ever the first part, should include the bos token and the user query
|
||||||
|
res = self._tokenize(
|
||||||
|
turn, add_eos_token=False, strip_bos_token=False
|
||||||
|
)
|
||||||
|
# everything from this is masked out from the labels
|
||||||
|
labels = [IGNORE_TOKEN_ID] * len(res["input_ids"])
|
||||||
|
else:
|
||||||
|
err_msg = f"unhandled role: {part[0]}"
|
||||||
|
if self.skip_invalid:
|
||||||
|
raise ValueError(err_msg)
|
||||||
|
LOG.warning(err_msg)
|
||||||
|
continue
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
result, current_len = parse_tokenized_to_result(
|
result, current_len = parse_tokenized_to_result(
|
||||||
result,
|
result,
|
||||||
current_len,
|
current_len,
|
||||||
res,
|
res,
|
||||||
labels,
|
labels,
|
||||||
pad_token_id=self.tokenizer.pad_token_id,
|
pad_token_id=self.tokenizer.pad_token_id,
|
||||||
)
|
)
|
||||||
return result
|
for key, val in sorted(result.items(), key=lambda x: x[0]):
|
||||||
except (KeyError, AssertionError, IndexError) as err:
|
tokenized_res[key].append(val)
|
||||||
raise InvalidDataException(str(err)) from err
|
except (KeyError, AssertionError, IndexError) as err:
|
||||||
|
raise InvalidDataException(str(err)) from err
|
||||||
|
except ValueError as err:
|
||||||
|
LOG.warning("skipping prompt: %s", str(err))
|
||||||
|
return tokenized_res
|
||||||
|
|
||||||
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False):
|
||||||
if not prompt.strip():
|
if not prompt.strip():
|
||||||
|
|||||||
@@ -58,7 +58,9 @@ def train(
|
|||||||
|
|
||||||
safe_serialization = cfg.save_safetensors is True
|
safe_serialization = cfg.save_safetensors is True
|
||||||
|
|
||||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
if (
|
||||||
|
cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints
|
||||||
|
) or cfg.resume_from_checkpoint is True:
|
||||||
possible_checkpoints = [
|
possible_checkpoints = [
|
||||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||||
]
|
]
|
||||||
@@ -71,7 +73,9 @@ def train(
|
|||||||
LOG.info(
|
LOG.info(
|
||||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||||
)
|
)
|
||||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
resume_from_checkpoint = (
|
||||||
|
cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None
|
||||||
|
)
|
||||||
|
|
||||||
trainer = setup_trainer(
|
trainer = setup_trainer(
|
||||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
|
|||||||
@@ -158,21 +158,23 @@ def load_tokenized_prepared_datasets(
|
|||||||
token=use_auth_token,
|
token=use_auth_token,
|
||||||
)
|
)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except FileNotFoundError:
|
except (FileNotFoundError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# prefer local dataset, even if hub exists
|
# prefer local dataset, even if hub exists
|
||||||
local_path = Path(d.path)
|
local_path = Path(d.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
if not d.type:
|
||||||
ds = load_dataset(
|
ds = load_from_disk(d.path)
|
||||||
d.path,
|
else:
|
||||||
name=d.name,
|
ds = load_dataset(
|
||||||
data_files=d.data_files,
|
d.path,
|
||||||
streaming=False,
|
name=d.name,
|
||||||
split=None,
|
data_files=d.data_files,
|
||||||
)
|
streaming=False,
|
||||||
|
split=None,
|
||||||
|
)
|
||||||
elif local_path.is_file():
|
elif local_path.is_file():
|
||||||
ds_type = "json"
|
ds_type = "json"
|
||||||
if d.ds_type:
|
if d.ds_type:
|
||||||
|
|||||||
@@ -382,7 +382,7 @@ def load_model(
|
|||||||
if model_config.model_type == "btlm":
|
if model_config.model_type == "btlm":
|
||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if "lm_head" in name or "embed_tokens" in name:
|
if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]):
|
||||||
if hasattr(module, "weight"):
|
if hasattr(module, "weight"):
|
||||||
module.to(torch.float32)
|
module.to(torch.float32)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user