diff --git a/requirements.txt b/requirements.txt index 41bfdfbeb..4323c76ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,3 +46,9 @@ gcsfs>=2024.5.0 trl==0.9.6 zstandard==0.22.0 fastcore + +# lm eval harness +lm_eval==0.4.4 +langdetect==1.0.9 +immutabledict==4.2.0 +antlr4-python3-runtime==4.13.2 diff --git a/src/axolotl/cli/__init__.py b/src/axolotl/cli/__init__.py index c757eca42..fe6df8694 100644 --- a/src/axolotl/cli/__init__.py +++ b/src/axolotl/cli/__init__.py @@ -55,8 +55,22 @@ LOG = logging.getLogger("axolotl.scripts") os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1" +AXOLOTL_LOGO = """ + #@@ #@@ @@# @@# + @@ @@ @@ @@ =@@# @@ #@ =@@#. + @@ #@@@@@@@@@ @@ #@#@= @@ #@ .=@@ + #@@@@@@@@@@@@@@@@@ =@# @# ##= ## =####=+ @@ =#####+ =#@@###. @@ + @@@@@@@@@@/ +@@/ +@@ #@ =@= #@= @@ =@#+ +#@# @@ =@#+ +#@# #@. @@ + @@@@@@@@@@ ##@@ ##@@ =@# @# =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@@@@@ #@=+++#@= =@@# @@ @@ @@ @@ #@ #@ @@ + =@#=====@@ =@# @# @@ @@ @@ @@ #@ #@ @@ + @@@@@@@@@@@@@@@@ @@@@ #@ #@= #@= +@@ #@# =@# @@. =@# =@# #@. @@ + =@# @# #@= #@ =#@@@@#= +#@@= +#@@@@#= .##@@+ @@ + @@@@ @@@@@@@@@@@@@@@@ +""" -def print_axolotl_text_art(suffix=None): + +def print_legacy_axolotl_text_art(suffix=None): font = "nancyj" ascii_text = " axolotl" if suffix: @@ -69,6 +83,13 @@ def print_axolotl_text_art(suffix=None): print_dep_versions() +def print_axolotl_text_art( + **kwargs, # pylint: disable=unused-argument +): + if is_main_process(): + print(AXOLOTL_LOGO) + + def print_dep_versions(): packages = ["accelerate", "peft", "transformers", "trl", "torch", "bitsandbytes"] max_len = max(len(pkg) for pkg in packages) diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 050f18a05..16d66a82f 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -3,13 +3,11 @@ CLI to run training on a model """ import logging from pathlib import Path -from typing import Tuple, Union +from typing import Union import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer from axolotl.cli import ( check_accelerate_default_config, @@ -20,6 +18,7 @@ from axolotl.cli import ( print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs +from axolotl.integrations.base import PluginManager from axolotl.prompt_strategies.sharegpt import ( register_chatml_template, register_llama3_template, @@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): return do_train(parsed_cfg, parsed_cli_args) -def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: +def do_train(cfg, cli_args) -> None: print_axolotl_text_art() check_accelerate_default_config() check_user_token() @@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + plugin_manager = PluginManager.get_instance() + + del model + del tokenizer + + plugin_manager.post_train_unload(cfg) if __name__ == "__main__": diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index d26eed90f..e2bd79bc4 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -159,6 +159,29 @@ class BasePlugin: List[callable]: A list of callback functions to be added to the TrainingArgs """ + def post_train(self, cfg, model): + """ + Performs actions after training is complete. + + Parameters: + cfg (dict): The axolotl configuration + model (object): The loaded model. + + Returns: + None + """ + + def post_train_unload(self, cfg): + """ + Performs actions after training is complete and the model is unloaded. + + Parameters: + cfg (dict): The configuration for the plugin. + + Returns: + None + """ + def load_plugin(plugin_name: str) -> BasePlugin: """ @@ -381,3 +404,17 @@ class PluginManager: for plugin in self.plugins: callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) return callbacks + + def post_train_unload(self, cfg): + """ + Calls the post_train_unload method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.post_train_unload(cfg) diff --git a/src/axolotl/integrations/lm_eval/README.md b/src/axolotl/integrations/lm_eval/README.md new file mode 100644 index 000000000..3724c49cc --- /dev/null +++ b/src/axolotl/integrations/lm_eval/README.md @@ -0,0 +1,13 @@ +# LM Eval Harness + +### Usage + +```yaml +plugins: + - axolotl.integrations.lm_eval.LMEvalPlugin + +lm_eval_tasks: + - gsm8k + - hellaswag + - arc_easy +``` diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py new file mode 100644 index 000000000..f1daa2000 --- /dev/null +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -0,0 +1,42 @@ +""" +Module for the Plugin for LM Eval Harness +""" +import subprocess # nosec +from datetime import datetime + +from axolotl.integrations.base import BasePlugin + +from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401 + + +class LMEvalPlugin(BasePlugin): + """ + Plugin for LM Evaluation Harness integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.lm_eval.LMEvalArgs" + + def post_train_unload(self, cfg): + tasks = ",".join(cfg.lm_eval_tasks) + fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else "" + dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16" + output_path = cfg.output_dir + output_path += "" if cfg.output_dir.endswith("/") else "/" + output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") + subprocess.run( # nosec + [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"pretrained={cfg.output_dir}{fa2}{dtype}", + "--tasks", + tasks, + "--batch_size", + str(cfg.lm_eval_batch_size), + "--output_path", + output_path, + ], + check=True, + ) diff --git a/src/axolotl/integrations/lm_eval/args.py b/src/axolotl/integrations/lm_eval/args.py new file mode 100644 index 000000000..f58e6a6e3 --- /dev/null +++ b/src/axolotl/integrations/lm_eval/args.py @@ -0,0 +1,15 @@ +""" +Module for handling lm eval harness input arguments. +""" +from typing import List, Optional + +from pydantic import BaseModel + + +class LMEvalArgs(BaseModel): + """ + Input args for lm eval harness + """ + + lm_eval_tasks: List[str] = [] + lm_eval_batch_size: Optional[int] = 8 diff --git a/src/axolotl/utils/chat_templates.py b/src/axolotl/utils/chat_templates.py index 620098ae0..e3b92ef25 100644 --- a/src/axolotl/utils/chat_templates.py +++ b/src/axolotl/utils/chat_templates.py @@ -30,7 +30,9 @@ _CHAT_TEMPLATES = { CHAT_TEMPLATES = { "alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}", - "inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral. + "mistral_v1": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ ' [INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # Mistral 7B V1, Mistral 7B V2, Mixtral 8x7B V1... + "mistral_v2v3": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ ' ' + message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3: Mistral 7B V3, Small, Large... + "mistral_v3_tekken": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST]' + message['content'] + '[/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # V3-Tekken: Nemo, Pixtral... "chatml": "{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}", "gemma": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{{ raise_exception('System role not supported') }}{% endif %}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if (message['role'] == 'assistant') %}{% set role = 'model' %}{% else %}{% set role = message['role'] %}{% endif %}{{ '' + role + '\n' + message['content'] | trim + '\n' }}{% endfor %}{% if add_generation_prompt %}{{'model\n'}}{% endif %}", "cohere": "{{ bos_token }}{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% elif false == true %}{% set loop_messages = messages %}{% set system_message = 'You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by providing thorough responses. You are trained by Cohere.' %}{% else %}{% set loop_messages = messages %}{% set system_message = false %}{% endif %}{% if system_message != false %}{{ '<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>' + system_message + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% for message in loop_messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% set content = message['content'] %}{% if message['role'] == 'user' %}{{ '<|START_OF_TURN_TOKEN|><|USER_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% elif message['role'] == 'assistant' %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' + content.strip() + '<|END_OF_TURN_TOKEN|>' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ '<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>' }}{% endif %}", diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 7290b948d..1a8154159 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1051,6 +1051,26 @@ class AxolotlInputConfig( "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." ) + if data.get("do_bench_eval") and not ( + data.get("evals_per_epoch") or data.get("eval_steps") + ): + raise ValueError( + "do_bench_eval requires evals_per_epoch or eval_steps to be set." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_test_datasets_bench(cls, data): + if ( + data.get("do_bench_eval") + and not data.get("test_datasets") + and not data.get("val_set_size") + ): + LOG.warning( + "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." + ) + data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] return data @model_validator(mode="before")