diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index 134ffb7d5..449adec35 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -11,6 +11,15 @@ jobs: if: github.repository_owner == 'OpenAccess-AI-Collective' # this job needs to be run on self-hosted GPU runners... runs-on: self-hosted + strategy: + matrix: + include: + - cuda: cu118 + cuda_version: 11.8.0 + pytorch: 2.0.0 + - cuda: cu117 + cuda_version: 11.7.0 + pytorch: 1.13.1 steps: - name: Checkout uses: actions/checkout@v3 @@ -32,7 +41,11 @@ jobs: context: . file: ./docker/Dockerfile-base push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.metadata.outputs.tags }} + tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }} labels: ${{ steps.metadata.outputs.labels }} cache-from: type=gha cache-to: type=gha,mode=max + build-args: | + CUDA_VERSION=${{ matrix.cuda_version }} + CUDA=${{ matrix.cuda }} + PYTORCH_VERSION=${{ matrix.pytorch }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 8919a8825..6e51fef3c 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -10,6 +10,15 @@ jobs: build-axolotl: if: github.repository_owner == 'OpenAccess-AI-Collective' # this job needs to be run on self-hosted GPU runners... + strategy: + matrix: + include: + - cuda: cu118 + cuda_version: 11.8.0 + pytorch: 2.0.0 + - cuda: cu117 + cuda_version: 11.7.0 + pytorch: 1.13.1 runs-on: self-hosted steps: - name: Checkout @@ -31,10 +40,10 @@ jobs: with: context: . build-args: | - BASE_TAG=${{ github.ref_name }}-base + BASE_TAG=${{ github.ref_name }}-base-${{ matrix.cuda }}-${{ matrix.pytorch }} file: ./docker/Dockerfile push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.metadata.outputs.tags }} + tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }} labels: ${{ steps.metadata.outputs.labels }} cache-from: type=gha cache-to: type=gha,mode=max @@ -42,6 +51,15 @@ jobs: needs: build-axolotl if: github.repository_owner == 'OpenAccess-AI-Collective' # this job needs to be run on self-hosted GPU runners... + strategy: + matrix: + include: + - cuda: cu118 + cuda_version: 11.8.0 + pytorch: 2.0.0 + - cuda: cu117 + cuda_version: 11.7.0 + pytorch: 1.13.1 runs-on: self-hosted steps: - name: Checkout @@ -63,10 +81,10 @@ jobs: with: context: . build-args: | - BASE_TAG=${{ github.ref_name }} + BASE_TAG=${{ github.ref_name }}-${{ matrix.cuda }}-${{ matrix.pytorch }} file: ./docker/Dockerfile-runpod push: ${{ github.event_name != 'pull_request' }} - tags: ${{ steps.metadata.outputs.tags }} + tags: ${{ steps.metadata.outputs.tags }}-${{ matrix.cuda }}-${{ matrix.pytorch }} labels: ${{ steps.metadata.outputs.labels }} cache-from: type=gha cache-to: type=gha,mode=max diff --git a/docker/Dockerfile-base b/docker/Dockerfile-base index 54738ddb8..943bae3b0 100644 --- a/docker/Dockerfile-base +++ b/docker/Dockerfile-base @@ -1,6 +1,7 @@ ARG CUDA_VERSION="11.8.0" ARG CUDNN_VERSION="8" ARG UBUNTU_VERSION="22.04" +ARG MAX_JOBS=4 FROM nvidia/cuda:$CUDA_VERSION-cudnn$CUDNN_VERSION-devel-ubuntu$UBUNTU_VERSION as base-builder @@ -39,6 +40,14 @@ ARG TORCH_CUDA_ARCH_LIST="7.0 7.5 8.0 8.6+PTX" RUN git clone https://github.com/HazyResearch/flash-attention.git && \ cd flash-attention && \ + python3 setup.py bdist_wheel && \ + cd csrc/fused_dense_lib && \ + python3 setup.py bdist_wheel && \ + cd csrc/xentropy && \ + python3 setup.py bdist_wheel && \ + cd csrc/rotary && \ + python3 setup.py bdist_wheel && \ + cd csrc/layer_norm && \ python3 setup.py bdist_wheel FROM base-builder AS deepspeed-builder @@ -60,8 +69,12 @@ RUN cd apex && MAX_JOBS=1 python3 -m pip install --global-option="--cpp_ext" --g RUN mkdir /workspace/wheels COPY --from=deepspeed-builder /workspace/DeepSpeed/dist/deepspeed-*.whl wheels COPY --from=flash-attn-builder /workspace/flash-attention/dist/flash_attn-*.whl wheels +COPY --from=flash-attn-builder /workspace/flash-attention/csrc/fused_dense_lib/dist/fused_dense_lib-*.whl wheels +COPY --from=flash-attn-builder /workspace/flash-attention/csrc/xentropy/dist/xentropy-*.whl wheels +COPY --from=flash-attn-builder /workspace/flash-attention/csrc/rotary/dist/rotary-*.whl wheels +COPY --from=flash-attn-builder /workspace/flash-attention/csrc/layer_norm/dist/dropout_layer_norm-*.whl wheels -RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl +RUN pip3 install wheels/deepspeed-*.whl wheels/flash_attn-*.whl wheels/fused_dense_lib-*.whl wheels/xeontropy-*.whl wheels/rotary-*.whl wheels/dropout_layer_norm-*.whl RUN git lfs install --skip-repo RUN pip3 install "peft @ git+https://github.com/huggingface/peft.git@main" \ "accelerate @ git+https://github.com/huggingface/accelerate.git@main" \ diff --git a/scripts/finetune.py b/scripts/finetune.py index 5fb38b6f6..cb9d7e94e 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -31,7 +31,7 @@ DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared" def choose_device(cfg): def get_device(): if torch.cuda.is_available(): - return "cuda" + return f"cuda:{cfg.local_rank}" else: try: if torch.backends.mps.is_available(): @@ -131,7 +131,8 @@ def train( # then overwrite the value cfg_keys = dict(cfg).keys() for k in kwargs: - if k in cfg_keys: + # if not strict, allow writing to cfg even if it's not in the yml already + if k in cfg_keys or cfg.strict is False: # handle booleans if isinstance(cfg[k], bool): cfg[k] = bool(kwargs[k]) @@ -169,6 +170,15 @@ def train( inference=("inference" in kwargs), ) + if "merge_lora" in kwargs and cfg.adapter is not None: + logging.info("running merge of LoRA with base model") + model = model.merge_and_unload() + + if cfg.local_rank == 0: + logging.info("saving merged model") + model.save_pretrained(str(Path(cfg.output_dir) / "merged")) + return + if "inference" in kwargs: logging.info("calling do_inference function") do_inference(cfg, model, tokenizer) @@ -216,6 +226,8 @@ def train( ) logging.info("Starting trainer...") + if cfg.group_by_length: + logging.info("hang tight... sorting dataset for group_by_length") resume_from_checkpoint = cfg.resume_from_checkpoint if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints: possible_checkpoints = [ diff --git a/src/axolotl/datasets.py b/src/axolotl/datasets.py index d9acf5715..0e166f6f0 100644 --- a/src/axolotl/datasets.py +++ b/src/axolotl/datasets.py @@ -106,7 +106,7 @@ class ConstantLengthDataset(IterableDataset): } else: logging.warning( - "dropping batch due to tensor size mismatch" + f"dropping batch due to tensor size mismatch input_ids: {input_ids.size()}, labels: {labels.size()}, attention_mask: {attention_mask.size()}" ) buffer = {"input_ids": [], "attention_mask": [], "labels": []} buffer_len = 0 diff --git a/src/axolotl/prompt_strategies/__init__.py b/src/axolotl/prompt_strategies/__init__.py new file mode 100644 index 000000000..dcdc4315f --- /dev/null +++ b/src/axolotl/prompt_strategies/__init__.py @@ -0,0 +1,13 @@ +import importlib + +def load(strategy, tokenizer, cfg): + try: + load_fn = "load" + if strategy.split(".")[-1].startswith("load_"): + load_fn = strategy.split(".")[-1] + strategy = ".".join(strategy.split(".")[:-1]) + m = importlib.import_module(f".{strategy}", "axolotl.prompt_strategies") + fn = getattr(m, load_fn) + return fn(tokenizer, cfg) + except: + pass diff --git a/src/axolotl/prompt_strategies/alpaca_chat.py b/src/axolotl/prompt_strategies/alpaca_chat.py new file mode 100644 index 000000000..1cd99bd9f --- /dev/null +++ b/src/axolotl/prompt_strategies/alpaca_chat.py @@ -0,0 +1,8 @@ +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +def load(tokenizer, cfg): + return AlpacaPromptTokenizingStrategy( + AlpacaPrompter(PromptStyle.chat), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/prompt_strategies/alpaca_instruct.py b/src/axolotl/prompt_strategies/alpaca_instruct.py new file mode 100644 index 000000000..8f09407ad --- /dev/null +++ b/src/axolotl/prompt_strategies/alpaca_instruct.py @@ -0,0 +1,8 @@ +from axolotl.prompt_tokenizers import AlpacaPromptTokenizingStrategy +from axolotl.prompters import AlpacaPrompter, PromptStyle + + +def load(tokenizer, cfg): + return AlpacaPromptTokenizingStrategy( + AlpacaPrompter(PromptStyle.instruct), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/prompt_strategies/creative_acr.py b/src/axolotl/prompt_strategies/creative_acr.py new file mode 100644 index 000000000..e68e206de --- /dev/null +++ b/src/axolotl/prompt_strategies/creative_acr.py @@ -0,0 +1,137 @@ +from typing import Union, Generator + +import yaml +from axolotl.prompt_tokenizers import InstructionPromptTokenizingStrategy + + +class CreativeAnsweringPromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + def parse_instruction_fields(self, prompt) -> (str, str, str): + question = prompt["instruction"] + answer = prompt["revision"] # don't use prompt[answer], that's data we don't want in the dataset + return ( + question, + "", + answer, + ) + + +class CreativeCritiquePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + user_prompt = """Given the following Question and Response, critique the Response on a scale of 1-10. You should critique the answer in the following criteria: +refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. +prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means there is prescriptive bias. +creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity +comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth +Then for each criteria, give a detailed explanation of why the response doesn't meet the criteria (don't reference the score), what it did right, and how it could be improved for that criteria. + +Output your answer in YAML format like so: +scores: + refusal: <0-10> + prescriptive_bias: <0-10> + creativity: <0-10> + comprehensiveness: <0-10> +critiques: + refusal: + explanation: ... + improvements: ... + prescriptive_bias: + explanation: ... + improvements: ... + creativity: + explanation: ... + improvements: ... + comprehensiveness: + explanation: ... + improvements: ... + +Question: {question} +Answer: {answer} +""" + + def parse_instruction_fields(self, prompt) -> (str, str, str): + scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper) + critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper) + evaluation = scores + critiques + question = prompt["instruction"] + answer = prompt["answer"] + return ( + self.user_prompt.format(question=question, answer=answer), + "", + evaluation, + ) + + +class CreativeRevisePromptTokenizingStrategy(InstructionPromptTokenizingStrategy): + user_prompt = """Definitions: +refusal: whether the Response refuses to answer the Question. Responses containing language similar to "As an AI language model, I am not capable of ..." is a form of refusal and should have a low refusal score. 10 is no refusal, 0 is refuses to answer the question. +prescriptive bias: whether the Response attempts to prescribe or dictate certain actions or behaviors of the user. 10 no prescriptive bias present, 0 means their is prescriptive bias. +creativity: score how creative the Response is. 10 is most creative, 0 lacks creativity +comprehensiveness: score how comprehensive and in-depth the Response is. 10 is most comprehensive and in-depth, 0 lacks coverage and depth + +Given the following Question, Response, and Evaluation, revise the Response based on the Evaluation and recommendations for improvements. Reply only with the revised response. + +Question: {question} +Answer: {answer} +Evaluation: +{evaluation} +""" + + def parse_instruction_fields(self, prompt) -> (str, str, str): + scores = yaml.dump(prompt["scores"], default_flow_style=False, Dumper=yaml.Dumper) + critiques = yaml.dump(prompt["critiques"], default_flow_style=False, Dumper=yaml.Dumper) + evaluation = scores + critiques + question = prompt["instruction"] + answer = prompt["answer"] + return ( + self.user_prompt.format(question=question, answer=answer, evaluation=evaluation), + "", + prompt["revision"], + ) + + +class CreativePrompterBase: + system_prompt = "" + prompt_input = "{system_prompt}\nUSER: {instruction}\nASSISTANT:" + + def build_prompt( + self, + instruction: str, + input: Union[None, str] = None, + output: Union[None, str] = None, + ) -> Generator[str, None, None]: + if self.system_prompt: + res = f"{self.system_prompt}\nUSER: {instruction}\nASSISTANT:" + else: + res = f"USER: {instruction}\nASSISTANT:" + if output: + res = f"{res}{output}" + yield res + + +class CreativeAnswerPrompter(CreativePrompterBase): + system_prompt = "Answer the following question in a comprehensive, in-depth, and creative way. Additionally your response should be relevant, accurate, and free of any ambiguity." + + +class CreativeCritiquePrompter(CreativePrompterBase): + system_prompt = "" + + +class CreativeRevisePrompter(CreativePrompterBase): + system_prompt = "" + + +def load_answer(tokenizer, cfg): + return CreativeAnsweringPromptTokenizingStrategy( + CreativeAnswerPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + + +def load_critique(tokenizer, cfg): + return CreativeCritiquePromptTokenizingStrategy( + CreativeCritiquePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) + + +def load_revise(tokenizer, cfg): + return CreativeRevisePromptTokenizingStrategy( + CreativeRevisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/prompt_strategies/pygmalion.py b/src/axolotl/prompt_strategies/pygmalion.py new file mode 100644 index 000000000..bd70c73d5 --- /dev/null +++ b/src/axolotl/prompt_strategies/pygmalion.py @@ -0,0 +1,100 @@ +import copy +import logging +from collections import defaultdict +from typing import Generator + +from axolotl.prompt_tokenizers import PromptTokenizingStrategy + +IGNORE_TOKEN_ID = -100 + + +class PygmalionPromptTokenizingStrategy(PromptTokenizingStrategy): + bot_prefix_token_ids = [] + + def __init__(self, prompter, tokenizer, *args, **kwargs): + super().__init__(prompter, tokenizer) + res = self._tokenize("<|model|>", add_eos_token=False, strip_bos_token=True) + self.bot_prefix_token_ids = res["input_ids"] + + def tokenize_prompt(self, prompt): + result = { + "input_ids": [], + "attention_mask": [], + "labels": [], + } + current_len = 0 + for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): + role, message = part + if role == "system": + prefix = "<|system|>" + # this should include a bos token, no eos token, strip trailing "\n" + if message.endswith("\n"): + message = message[:-8] + res = self._tokenize(prefix + "Persona: " + message.strip(), add_eos_token=False, strip_bos_token=False) + # everything from this is masked out from the labels + labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) + elif role == "human": + prefix = "<|user|>" + res = self._tokenize(prefix + " " + message.strip(), add_eos_token=False, strip_bos_token=True) + # everything from this is masked out from the labels + labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) + elif role == "bot": + prefix = "<|model|>" + res = self._tokenize(prefix + " " + message.strip(), add_eos_token=True, strip_bos_token=True) + # mask out the prefix token, rest is not masked out from labels + # make sure we create the labels first, otherwise we get incorrect lengths + labels = [ IGNORE_TOKEN_ID ] * len(self.bot_prefix_token_ids) + [*copy.deepcopy(res["input_ids"])][len(self.bot_prefix_token_ids):] + else: + logging.warning(f"unknown role in conversation: {role}") + res = defaultdict(lambda: []) + input_ids = res["input_ids"] + input_len = len(input_ids) + result["input_ids"][current_len : current_len + input_len] = input_ids + result["attention_mask"][current_len : current_len + input_len] = [ + 1 if x != self.tokenizer.pad_token_id else 0 + for x in input_ids + ] + result["labels"][current_len : current_len + input_len] = labels + current_len += input_len + return result + + def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): + result = self.tokenizer( + prompt, + truncation=True, + max_length=self.sequence_len, + padding=False, + return_tensors=None, + ) + if ( + result["input_ids"][-1] != self.tokenizer.eos_token_id + and len(result["input_ids"]) < self.sequence_len + and add_eos_token + ): + result["input_ids"].append(self.tokenizer.eos_token_id) + result["attention_mask"].append(1) + + if ( + result["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + result["input_ids"] = result["input_ids"][1:] + result["attention_mask"] = result["attention_mask"][1:] + + result["labels"] = result["input_ids"].copy() + return result + + +class PygmalionPrompter: + def __init__(self, *args, **kwargs): + pass + + def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]: + for msg in source: + yield msg["role"], msg["value"] + + +def load(tokenizer, cfg): + return PygmalionPromptTokenizingStrategy( + PygmalionPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ) diff --git a/src/axolotl/prompt_tokenizers.py b/src/axolotl/prompt_tokenizers.py index 7f79ef192..6c20e7729 100644 --- a/src/axolotl/prompt_tokenizers.py +++ b/src/axolotl/prompt_tokenizers.py @@ -1,5 +1,7 @@ import abc import copy +import functools +import logging from transformers import PreTrainedTokenizer @@ -33,6 +35,20 @@ class PromptTokenizingStrategy(abc.ABC): def tokenize_prompt(self, prompt): pass + @functools.cache + def _get_user_token(self): + id_or_ids = self.tokenizer.convert_tokens_to_ids("<|USER|>") + if isinstance(id_or_ids, (int,)): + return id_or_ids + return False + + @functools.cache + def _get_assistant_token(self): + id_or_ids = self.tokenizer.convert_tokens_to_ids("<|ASSISTANT|>") + if isinstance(id_or_ids, (int,)): + return id_or_ids + return False + class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): def parse_instruction_fields(self, prompt) -> (str, str, str): @@ -63,7 +79,7 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): response, ))) - def _tokenize(self, prompt, add_eos_token=True): + def _tokenize(self, prompt, add_eos_token=True, strip_bos_token=False): result = self.tokenizer( prompt, truncation=True, @@ -79,6 +95,13 @@ class InstructionPromptTokenizingStrategy(PromptTokenizingStrategy): result["input_ids"].append(self.tokenizer.eos_token_id) result["attention_mask"].append(1) + if ( + result["input_ids"][0] == self.tokenizer.bos_token_id + and strip_bos_token + ): + result["input_ids"] = result["input_ids"][1:] + result["attention_mask"] = result["attention_mask"][1:] + result["labels"] = result["input_ids"].copy() return result @@ -239,23 +262,35 @@ class ShareGPTPromptTokenizingStrategy(PromptTokenizingStrategy): "labels": [], } current_len = 0 + user_token = self._get_user_token() + assistant_token = self._get_assistant_token() try: - for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"], self.tokenizer)): - if i == 0: + for i, part in enumerate(self.prompter.build_prompt(prompt["conversations"])): + if isinstance(part, tuple): + if part[0] == "USER:": + part = part[0] + part[1] if not user_token else part[1] + # this is still the user query, we should + res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True) + if user_token: + res["input_ids"] = [user_token, *res["input_ids"]] + # everything from this is masked out from the labels + labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) + elif part[0] == "ASSISTANT:": + # TODO label assistant token/tokens w/ IGNORE_TOKEN_ID + part = part[0] + part[1] if not assistant_token else part[1] + # this should be the assistent response, should end with an eos token + res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True) + if assistant_token: + res["input_ids"] = [assistant_token, *res["input_ids"]] + # not masked out from labels + labels = copy.deepcopy(res["input_ids"]) + else: + logging.warning("unhandled role: " + part[0]) + else: # this is only ever the first part, should include the bos token and the user query res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=False) # everything from this is masked out from the labels labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) - elif i % 2 == 0: - # this is still the user query, we should - res = self._tokenize(part.strip(), add_eos_token=False, strip_bos_token=True) - # everything from this is masked out from the labels - labels = [ IGNORE_TOKEN_ID ] * len(res["input_ids"]) - else: - # this should be the assistent response, should end with an eos token - res = self._tokenize(part.strip(), add_eos_token=True, strip_bos_token=True) - # not masked out from labels - labels = copy.deepcopy(res["input_ids"]) input_ids = res["input_ids"] input_len = len(input_ids) result["input_ids"][current_len : current_len + input_len] = input_ids diff --git a/src/axolotl/prompters.py b/src/axolotl/prompters.py index 8a8cfa247..3ae0a0bd4 100644 --- a/src/axolotl/prompters.py +++ b/src/axolotl/prompters.py @@ -1,15 +1,34 @@ import copy import dataclasses +import logging from enum import auto, Enum from typing import List, Tuple, Any, Union, Generator IGNORE_TOKEN_ID = -100 +class PromptStyle(Enum): + instruct = "instruct" + chat = "chat" + class AlpacaPrompter: - prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - prompt_no_input = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n### Instruction:\n{instruction}\n\n### Response:\n" - response_split = "### Response:" + system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n" + system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n" + prompt_style = None + + def __init__(self, prompt_style="instruct"): + self.prompt_style = prompt_style + self.match_prompt_style() + + def match_prompt_style(self): + if self.prompt_style == PromptStyle.instruct.value: + self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n" + self.response_split = "### Response:" + if self.prompt_style == PromptStyle.chat.value: + self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" + self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" + self.response_split = "ASSISTANT:" def build_prompt( self, @@ -36,7 +55,7 @@ class JeopardyPrompter(AlpacaPrompter): class MultipleChoiceExplainPrompter(AlpacaPrompter): - prompt_input = "Choose the answer that best answers the question. Explain your reasoning.\n\n### Question:\n{instruction}\n\n### Choices:\n{input}\n\n### Response:\n" + system_prompt = "Choose the answer that best answers the question. Explain your reasoning." class MultipleChoiceConcisePrompter(AlpacaPrompter): @@ -64,11 +83,30 @@ class NomicGPT4AllPrompter(AlpacaPrompter): class ReflectAlpacaPrompter: - prompt_input = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" - prompt_no_input = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n### Instruction:\n{instruction}\n\n### Response:\n" - agent_label = "{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" + system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n" + system_no_input_prompt = "Below is an instruction that describes a task. You, the Assistant, should generate a response as if it were an abstract for an academic or technical paper on the query along with a methodology. Then generate an Agent Reflection where you create a long form response as if from subject matter expert, be verbose, diligent, and creative in your application of knowledge, apply it through the lens of the response generated by the assistant. Look for flawed reasoning, faulty logic, or other mistakes in the method. Finally, generate a final response and method for the user with the Assistant abstract and Reflection analysis as augmentations to the generation\n\n" + + prompt_input = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + prompt_no_input = "### Instruction:\n{instruction}\n\n### Response:\n" + agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" response_split = "### Response:" + def __init__(self, prompt_style="instruct"): + self.prompt_style = prompt_style + self.match_prompt_style() + + def match_prompt_style(self): + if self.prompt_style == PromptStyle.instruct.value: + self.prompt_input = self.system_prompt + "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n" + self.prompt_no_input = self.system_no_input_prompt + "### Instruction:\n{instruction}\n\n### Response:\n" + self.agent_label = "### Thought:\n{output}\n\n### Agent Reflection:\n{reflection}\n\n### Final Response:\n{corrected}" + self.response_split = "### Final Response:" + if self.prompt_style == PromptStyle.chat.value: + self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" + self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" + self.agent_label = "\nTHOUGHT: {output}\nASSISTANT REFLECTION: {reflection}\nASSISTANT:" + self.response_split = "ASSISTANT:" + def build_prompt( self, instruction: str, @@ -118,13 +156,13 @@ class Conversation: def get_prompt(self) -> Generator[str, None, None]: seps = [self.sep, self.sep2] preamble = self.system + seps[0] + yield preamble for i, (role, message) in enumerate(self.messages): if message: - yield preamble + role + ": " + message + seps[i % 2] + yield (role + ":", " " + message) else: - yield role + ":" - if i == 0: - preamble = "" + logging.warning("role with empty message: " + role) + yield (role + ":", ) def copy(self): return Conversation( @@ -154,7 +192,17 @@ conv_vicuna_v1_1 = Conversation( class ShareGPTPrompter: - def build_prompt(self, source, tokenizer, sequence_len=2048) -> Generator[str, None, None]: + def __init__(self, prompt_style=None): + if prompt_style != PromptStyle.chat.value: + raise Exception(f"unsupported prompt_style for ShareGPTPrompter({prompt_style})") + + # def match_prompt_style(self): + # if self.prompt_style == PromptStyle.chat.value: + # self.prompt_input = self.system_prompt + "USER: {instruction}\n{input}\nASSISTANT:" + # self.prompt_no_input = self.system_no_input_prompt + "USER: {instruction}\nASSISTANT:" + # self.response_split = "ASSISTANT:" + + def build_prompt(self, source, *args, **kwargs) -> Generator[str, None, None]: # ignore the system prompt if provided if source[0]["from"] == "system": source.pop(0) diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 28b6ee072..2ceaa4d99 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -7,11 +7,13 @@ from datasets import ( load_dataset, IterableDataset, Dataset, - concatenate_datasets, + concatenate_datasets, DatasetDict, ) from huggingface_hub import hf_hub_download +from transformers import PreTrainedTokenizerBase from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset +from axolotl.prompt_strategies import load from axolotl.prompt_tokenizers import ( AlpacaPromptTokenizingStrategy, GPTeacherPromptTokenizingStrategy, @@ -35,13 +37,15 @@ from axolotl.prompters import ( ) -def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path): +def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path) -> DatasetDict: + tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( md5( ( str(cfg.sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + + "|" + tokenizer_name ).encode("utf-8") ).hexdigest() ) @@ -50,8 +54,17 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa if cfg.dataset_prepared_path else Path(default_dataset_prepared_path) / ds_hash ) + dataset = None + try: + if cfg.push_dataset_to_hub: + dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True) + dataset = dataset["train"] + except: + pass - if any(prepared_ds_path.glob("*")): + if dataset: + ... + elif any(prepared_ds_path.glob("*")): logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) logging.info("Prepared dataset loaded from disk...") @@ -63,7 +76,7 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa ds = None ds_from_hub = False try: - load_dataset(d.path, streaming=True) + load_dataset(d.path, streaming=True, use_auth_token=True) ds_from_hub = True except FileNotFoundError: pass @@ -71,82 +84,88 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa # prefer local dataset, even if hub exists if Path(d.path).exists(): ds: IterableDataset = load_dataset( - "json", data_files=d.path, streaming=True, split=None + "json", data_files=d.path, streaming=False, split=None ) elif ds_from_hub: if d.data_files: - ds = load_dataset(d.path, streaming=True, data_files=d.data_files) + ds = load_dataset(d.path, streaming=False, data_files=d.data_files, use_auth_token=True) else: - ds = load_dataset(d.path, streaming=True) + ds = load_dataset(d.path, streaming=False, use_auth_token=True) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds = load_dataset("json", data_files=fp, streaming=True, split=None) + ds = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise Exception("unhandled dataset load") - - if d.type == "alpaca": + d_type = d.type + d_type_split = d_type.split(":") + d_base_type = d_type_split[0] + d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None + if (ds_strategy := load(d.type, tokenizer, cfg)): + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + datasets.append(ds_wrapper) + elif d_base_type == "alpaca": ds_strategy = AlpacaPromptTokenizingStrategy( - AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "explainchoice": + elif d_base_type == "explainchoice": ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - MultipleChoiceExplainPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + MultipleChoiceExplainPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "concisechoice": + elif d_base_type == "concisechoice": ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( - MultipleChoiceConcisePrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + MultipleChoiceConcisePrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "summarizetldr": + elif d_base_type == "summarizetldr": ds_strategy = SummarizeTLDRPromptTokenizingStrategy( - SummarizeTLDRPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + SummarizeTLDRPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "jeopardy": + elif d_base_type == "jeopardy": ds_strategy = JeopardyPromptTokenizingStrategy( - JeopardyPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + JeopardyPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "oasst": + elif d_base_type == "oasst": ds_strategy = OpenAssistantPromptTokenizingStrategy( - AlpacaPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + AlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "gpteacher": + elif d_base_type == "gpteacher": ds_strategy = GPTeacherPromptTokenizingStrategy( - GPTeacherPrompter(), + GPTeacherPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "reflection": + elif d_base_type == "reflection": ds_strategy = AlpacaReflectionPTStrategy( - ReflectAlpacaPrompter(), + ReflectAlpacaPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len, ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "sharegpt": + elif d_base_type == "sharegpt": ds_strategy = ShareGPTPromptTokenizingStrategy( - ShareGPTPrompter(), tokenizer, cfg.train_on_inputs, cfg.sequence_len + ShareGPTPrompter(d_prompt_style), tokenizer, cfg.train_on_inputs, cfg.sequence_len ) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) datasets.append(ds_wrapper) - elif d.type == "completion": + elif d_base_type == "completion": ds_strategy = CompletionPromptTokenizingStrategy( CompletionPrompter(), tokenizer, @@ -168,11 +187,16 @@ def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_pa f"Saving merged prepared dataset to disk... {prepared_ds_path}" ) dataset.save_to_disk(prepared_ds_path) + if cfg.push_dataset_to_hub: + logging.info( + f"Saving merged prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" + ) + dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True) return dataset -def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): +def load_prepare_datasets(tokenizer: PreTrainedTokenizerBase, cfg, default_dataset_prepared_path) -> (Dataset, Dataset): max_packed_sequence_len = ( cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len ) @@ -180,16 +204,19 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): max_packed_sequence_len, cfg.sequence_len ) # make sure we don't accidentally set it larger than sequence_len + tokenizer_name = tokenizer.__class__.__name__ if cfg.max_packed_sequence_len is not None: # see if we can go ahead and load the stacked dataset - + seed = f"@{str(cfg.seed)}" if cfg.seed else "" ds_hash = str( md5( ( str(cfg.sequence_len) + "@" + str(max_packed_sequence_len) + + seed + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + + "|" + tokenizer_name ).encode("utf-8") ).hexdigest() ) @@ -199,17 +226,38 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): else Path(default_dataset_prepared_path) / ds_hash ) - if any(prepared_ds_path.glob("*")): + dataset = None + try: + if cfg.push_dataset_to_hub: + logging.info( + f"Checking for packed prepared dataset from hub... {cfg.push_dataset_to_hub}/{ds_hash}" + ) + dataset = load_dataset(f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=True) + dataset = dataset["train"] + except: + pass + + if dataset: + ... + elif any(prepared_ds_path.glob("*")): logging.info( f"Loading prepared packed dataset from disk at {prepared_ds_path}..." ) dataset = load_from_disk(str(prepared_ds_path)) logging.info("Prepared packed dataset loaded from disk...") + if cfg.push_dataset_to_hub: + logging.info( + f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" + ) + dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True) else: dataset = load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path ) + if cfg.seed: + dataset = dataset.shuffle(seed=cfg.seed) + constant_len_dataset = ConstantLengthDataset( tokenizer, [dataset], @@ -237,6 +285,11 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): f"Saving packed prepared dataset to disk... {prepared_ds_path}" ) dataset.save_to_disk(prepared_ds_path) + if cfg.push_dataset_to_hub: + logging.info( + f"Saving packed prepared dataset with push_to_hub... {cfg.push_dataset_to_hub}/{ds_hash}" + ) + dataset.push_to_hub(f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True) else: dataset = load_tokenized_prepared_datasets( tokenizer, cfg, default_dataset_prepared_path diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d93d859b7..934f2f74c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -126,6 +126,32 @@ def load_model( torch_dtype=torch_dtype, device_map=cfg.device_map, ) + # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: + # This is a WIP, still an issue with the backward pass + # RuntimeError: grad can be implicitly created only for scalar outputs + # TODO: try config.sequence_parallel = False + # # https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/tests/models/test_gpt_neox.py#L12 + # # https://github.com/HazyResearch/flash-attention/tree/main/training#model-components + # # add `**kwargs` to https://github.com/HazyResearch/flash-attention/blob/40a25c8ee7465cf547b929cfa2937034e37bfce9/flash_attn/models/gpt.py#L442 + # from flash_attn.utils.pretrained import state_dict_from_pretrained + # from flash_attn.models.gpt import GPTLMHeadModel + # from flash_attn.models.gpt_neox import remap_state_dict_hf_gpt_neox, gpt_neox_config_to_gpt2_config + # from transformers import GPTNeoXConfig + # config = gpt_neox_config_to_gpt2_config(GPTNeoXConfig.from_pretrained(base_model)) + # config.use_flash_attn = True + # config.fused_bias_fc = True + # config.fused_mlp = True # GPT-NeoX-20B uses "gelu_fast" + # config.activation_function = "gelu_fast" + # config.fused_dropout_add_ln = True + # # config.residual_in_fp32 = True + # + # model: GPTLMHeadModel = GPTLMHeadModel.from_pretrained( + # base_model, + # config, + # dtype=torch_dtype, + # device=cfg.device, + # ) + # model.train() # sets to train instead of eval mode elif model_type: model = getattr(transformers, model_type).from_pretrained( base_model, @@ -194,7 +220,7 @@ def load_model( for k, v in cfg.special_tokens.items(): tokenizer.add_special_tokens({k: v}) if cfg.tokens: - tokenizer.add_tokens(cfg.tokens) + tokenizer.add_tokens(list(cfg.tokens)) embeddings_len = math.ceil(len(tokenizer) / 32) * 32 model.resize_token_embeddings(embeddings_len) @@ -266,7 +292,8 @@ def load_llama_adapter(model, cfg): task_type="CAUSAL_LM", ) - if cfg.peft_model_dir: + if cfg.lora_model_dir: + logging.info("Loading pretained LORA") model = PeftModel.from_pretrained( model, cfg.lora_model_dir, @@ -307,7 +334,7 @@ def load_lora(model, cfg): model, cfg.lora_model_dir, device_map=cfg.device_map, - torch_dtype=torch.float16, + # torch_dtype=torch.float16, ) else: model = get_peft_model(model, lora_config) diff --git a/src/axolotl/utils/trainer.py b/src/axolotl/utils/trainer.py index cd9f94229..4336f740c 100644 --- a/src/axolotl/utils/trainer.py +++ b/src/axolotl/utils/trainer.py @@ -9,13 +9,31 @@ import torch.cuda import transformers from torch import nn from torch.optim.lr_scheduler import OneCycleLR -from transformers import EarlyStoppingCallback +from transformers import EarlyStoppingCallback, Trainer from transformers.trainer_pt_utils import get_parameter_names from axolotl.utils.schedulers import InterpolatingLogScheduler from axolotl.utils.callbacks import SavePeftModelCallback +class OneCycleLRSchedulerTrainer(Trainer): + def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optimizer = None): + optimizer=self.optimizer if optimizer is None else optimizer + num_warmup_steps=self.args.get_warmup_steps(num_training_steps) + num_training_steps=num_training_steps + pct_start = num_warmup_steps / num_training_steps + + lr_scheduler = OneCycleLR( + optimizer, + max_lr=self.args.learning_rate, + total_steps=num_training_steps, + pct_start=pct_start, + div_factor=6, + ) + + return lr_scheduler + + def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): total_num_steps = int( math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) @@ -119,6 +137,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): cfg.optimizer == "adamw_bnb_8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs + and not cfg.fsdp ): decay_parameters = get_parameter_names(model, [nn.LayerNorm]) decay_parameters = [name for name in decay_parameters if "bias" not in name] @@ -157,7 +176,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): cfg.learning_rate, total_steps=total_num_steps, epochs=cfg.num_epochs, - div_factor=10, + div_factor=cfg.lr_div_factor if cfg.lr_div_factor else 6, **lr_scheduler_kwargs, ) elif cfg.lr_scheduler == "log_sweep": @@ -182,7 +201,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): cfg.early_stopping_patience, ) callbacks.append(early_stop_cb) - + if cfg.local_rank == 0 and cfg.adapter == 'lora': # only save in rank 0 callbacks.append(SavePeftModelCallback) @@ -194,7 +213,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): else: data_collator_kwargs["pad_to_multiple_of"] = 8 - trainer = transformers.Trainer( + trainer_cls = OneCycleLRSchedulerTrainer if cfg.lr_scheduler == "one_cycle" and cfg.fsdp else transformers.Trainer + trainer = trainer_cls( model=model, train_dataset=train_dataset, eval_dataset=eval_dataset,