diff --git a/.pylintrc b/.pylintrc index ed973d285..a4f7bbb5a 100644 --- a/.pylintrc +++ b/.pylintrc @@ -12,3 +12,4 @@ generated-members=numpy.*, torch.* disable=missing-function-docstring, line-too-long, import-error, too-many-arguments, too-many-locals, too-many-statements, too-many-branches, too-few-public-methods, too-many-instance-attributes, fixme, import-outside-toplevel, logging-fstring-interpolation, + too-many-nested-blocks, diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index 32b2e0cc2..580f94267 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -5,7 +5,7 @@ import os from typing import List import torch -from datasets import Dataset, IterableDataset +from datasets import Dataset, IterableDataset, Sequence, Value from .prompt_tokenizers import PromptTokenizingStrategy @@ -42,11 +42,15 @@ class TokenizedPromptDataset(Dataset): if self.prompt_tokenizer.supports_batched: map_kwargs["batched"] = True map_kwargs["batch_size"] = 100 - return dataset.map( - self.prompt_tokenizer.tokenize_prompt, - num_proc=num_proc, - remove_columns=features, - **map_kwargs, + return ( + dataset.map( + self.prompt_tokenizer.tokenize_prompt, + num_proc=num_proc, + remove_columns=features, + **map_kwargs, + ) + .cast_column("input_ids", Sequence(feature=Value(dtype="int32", id=None))) + .cast_column("labels", Sequence(feature=Value(dtype="int32", id=None))) ) diff --git a/src/axolotl/prompt_strategies/sharegpt.py b/src/axolotl/prompt_strategies/sharegpt.py index da36e778e..df30a72a0 100644 --- a/src/axolotl/prompt_strategies/sharegpt.py +++ b/src/axolotl/prompt_strategies/sharegpt.py @@ -24,7 +24,7 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): ) field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None - return SimpleShareGPTPromptTokenizingStrategy( + strat = ShareGPTPromptTokenizingStrategy( ShareGPTPrompterV2( conversation=conversation, role_key_model=field_model, @@ -34,6 +34,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): cfg.train_on_inputs, cfg.sequence_len, ) + if ds_cfg and ds_cfg["skip"]: + strat.skip_invalid = True + return strat def load_role(tokenizer, cfg): @@ -54,13 +57,38 @@ def load_guanaco(tokenizer, cfg): ) -class SimpleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): +def load_nous(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None): + conversation = ( + ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None + ) + field_human = ds_cfg["field_human"] if ds_cfg and "field_human" in ds_cfg else None + field_model = ds_cfg["field_model"] if ds_cfg and "field_model" in ds_cfg else None + return NousShareGPTPromptTokenizingStrategy( + ShareGPTPrompterV2( + conversation=conversation, + role_key_model=field_model, + role_key_human=field_human, + ), + tokenizer, + cfg.train_on_inputs, + cfg.sequence_len, + ) + + +class NousShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): """ - basic sharegpt strategy to grab conversations from the sample row + basic sharegpt strategy used by nous/ldj for input/output keyed data """ - def get_conversation_thread(self, prompt): - return prompt["conversations"] + def get_conversation_thread(self): + return "conversation" + + def map_conversation_thread(self, conversation): + turns = [] + for turn in conversation: + turns.append({"from": "human", "value": turn["input"]}) + turns.append({"from": "gpt", "value": turn["output"]}) + return turns class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): @@ -68,10 +96,11 @@ class SimpleRoleShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrateg basic sharegpt strategy to grab conversations from the sample row, but uses role instead of from """ - def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] + def map_conversation_thread(self, conversation): # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... - turns = [{"from": t["role"], "value": t["value"]} for t in conversations] + turns = [ + {"from": turn["role"], "value": turn["value"]} for turn in conversation + ] return turns @@ -80,11 +109,11 @@ class GuanacoShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy): sharegpt strategy that remaps oasst data to sharegpt format """ - def get_conversation_thread(self, prompt): - conversations = prompt["conversations"] + def map_conversation_thread(self, conversation): # remap role: prompter/assistant, text: ... => from: human/gpt, value: ... role_map = {"prompter": "human", "assistant": "gpt"} turns = [ - {"from": role_map[t["role"]], "value": t["text"]} for t in conversations + {"from": role_map[turn["role"]], "value": turn["text"]} + for turn in conversation ] return turns diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 1b3933664..69be898d6 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -4,6 +4,7 @@ import abc import copy import functools import logging +from collections import defaultdict from typing import Dict, List, Tuple, Union from fastchat.conversation import Conversation @@ -351,76 +352,109 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): Tokenizing strategy for ShareGPT prompts. """ - def get_conversation_thread(self, prompt): - return prompt["conversations"] + _skip_invalid = False + + @property + def supports_batched(self): + return True + + @property + def skip_invalid(self): + return self._skip_invalid + + @skip_invalid.setter + def skip_invalid(self, value): + self._skip_invalid = value + + def get_conversation_thread(self): + return "conversations" + + def map_conversation_thread(self, conversation): + return conversation def tokenize_prompt(self, prompt): - result, current_len = tokenize_prompt_default() - user_token = self._get_user_token() - assistant_token = self._get_assistant_token() - conversation: Conversation = ( - self.prompter._conversation # pylint: disable=protected-access - ) - try: - for _, part in enumerate( - self.prompter.build_prompt(self.get_conversation_thread(prompt)) - ): - if isinstance(part, tuple): - if conversation.roles[0] in part[0]: - turn = part[0] + part[1] if not user_token else part[1] - # this is still the user query, we should - if not part[1].strip(): - LOG.warning(f"user turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=False, - strip_bos_token=True, - ) - if user_token: - res["input_ids"] = [user_token, *res["input_ids"]] - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - elif conversation.roles[1] in part[0]: - # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID - turn = part[0] + part[1] if not assistant_token else part[1] - # this should be the assistant response, should end with an eos token - if not part[1].strip(): - LOG.warning(f"assistant turn has empty text: {prompt}") - res = self._tokenize( - turn, - add_eos_token=True, - strip_bos_token=True, - ) - if assistant_token: - res["input_ids"] = [ - assistant_token, - *res["input_ids"], - ] - # not masked out from labels - labels = copy.deepcopy(res["input_ids"]) - elif part[0] == "": - turn = part[1] - # this is only ever the first part, should include the bos token and the user query - res = self._tokenize( - turn, add_eos_token=False, strip_bos_token=False - ) - # everything from this is masked out from the labels - labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) - else: - LOG.warning(f"unhandled role: {part[0]}") - continue + tokenized_res = defaultdict(lambda: []) + conv_field = self.get_conversation_thread() + for prmpt in prompt[conv_field]: + result, current_len = tokenize_prompt_default() + user_token = self._get_user_token() + assistant_token = self._get_assistant_token() + conversation: Conversation = ( + self.prompter._conversation # pylint: disable=protected-access + ) + try: + for _, part in enumerate( + self.prompter.build_prompt(self.map_conversation_thread(prmpt)) + ): + if isinstance(part, tuple): + if conversation.roles[0] in part[0]: + turn = part[0] + part[1] if not user_token else part[1] + # this is still the user query, we should + if not part[1].strip(): + err_msg = f"user turn has empty text: {prmpt}" + if self.skip_invalid: + raise ValueError(err_msg) + LOG.warning(err_msg) + res = self._tokenize( + turn, + add_eos_token=False, + strip_bos_token=True, + ) + if user_token: + res["input_ids"] = [user_token, *res["input_ids"]] + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + elif conversation.roles[1] in part[0]: + # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID + turn = part[0] + part[1] if not assistant_token else part[1] + # this should be the assistant response, should end with an eos token + if not part[1].strip(): + err_msg = f"assistant turn has empty text: {prmpt}" + if self.skip_invalid: + raise ValueError(err_msg) + LOG.warning(err_msg) + res = self._tokenize( + turn, + add_eos_token=True, + strip_bos_token=True, + ) + if assistant_token: + res["input_ids"] = [ + assistant_token, + *res["input_ids"], + ] + # not masked out from labels + labels = copy.deepcopy(res["input_ids"]) + elif part[0] == "": + turn = part[1] + # this is only ever the first part, should include the bos token and the user query + res = self._tokenize( + turn, add_eos_token=False, strip_bos_token=False + ) + # everything from this is masked out from the labels + labels = [IGNORE_TOKEN_ID] * len(res["input_ids"]) + else: + err_msg = f"unhandled role: {part[0]}" + if self.skip_invalid: + raise ValueError(err_msg) + LOG.warning(err_msg) + continue - # pylint: disable=duplicate-code - result, current_len = parse_tokenized_to_result( - result, - current_len, - res, - labels, - pad_token_id=self.tokenizer.pad_token_id, - ) - return result - except (KeyError, AssertionError, IndexError) as err: - raise InvalidDataException(str(err)) from err + # pylint: disable=duplicate-code + result, current_len = parse_tokenized_to_result( + result, + current_len, + res, + labels, + pad_token_id=self.tokenizer.pad_token_id, + ) + for key, val in sorted(result.items(), key=lambda x: x[0]): + tokenized_res[key].append(val) + except (KeyError, AssertionError, IndexError) as err: + raise InvalidDataException(str(err)) from err + except ValueError as err: + LOG.warning("skipping prompt: %s", str(err)) + return tokenized_res def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): if not prompt.strip(): diff --git a/src/axolotl/train.py b/src/axolotl/train.py index da98600a4..b1ac50767 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -58,7 +58,9 @@ def train( safe_serialization = cfg.save_safetensors is True - if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: + if ( + cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints + ) or cfg.resume_from_checkpoint is True: possible_checkpoints = [ str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*") ] @@ -71,7 +73,9 @@ def train( LOG.info( f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}" ) - resume_from_checkpoint = cfg.resume_from_checkpoint + resume_from_checkpoint = ( + cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is not True else None + ) trainer = setup_trainer( cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index bac7d96c9..019f5462d 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -158,21 +158,23 @@ def load_tokenized_prepared_datasets( token=use_auth_token, ) ds_from_hub = True - except FileNotFoundError: + except (FileNotFoundError, ValueError): pass # prefer local dataset, even if hub exists local_path = Path(d.path) if local_path.exists(): if local_path.is_dir(): - # TODO dirs with arrow or parquet files could be loaded with `load_from_disk` - ds = load_dataset( - d.path, - name=d.name, - data_files=d.data_files, - streaming=False, - split=None, - ) + if not d.type: + ds = load_from_disk(d.path) + else: + ds = load_dataset( + d.path, + name=d.name, + data_files=d.data_files, + streaming=False, + split=None, + ) elif local_path.is_file(): ds_type = "json" if d.ds_type: diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2c60f00c2..db6a6e3cb 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -382,7 +382,7 @@ def load_model( if model_config.model_type == "btlm": # don't upcast lm_head for btlm continue - if "lm_head" in name or "embed_tokens" in name: + if any(x in name for x in ["lm_head", "embed_tokens", "wte", "wpe"]): if hasattr(module, "weight"): module.to(torch.float32)