diff --git a/requirements.txt b/requirements.txt index 41bfdfbeb..4323c76ce 100644 --- a/requirements.txt +++ b/requirements.txt @@ -46,3 +46,9 @@ gcsfs>=2024.5.0 trl==0.9.6 zstandard==0.22.0 fastcore + +# lm eval harness +lm_eval==0.4.4 +langdetect==1.0.9 +immutabledict==4.2.0 +antlr4-python3-runtime==4.13.2 diff --git a/src/axolotl/cli/train.py b/src/axolotl/cli/train.py index 050f18a05..16d66a82f 100644 --- a/src/axolotl/cli/train.py +++ b/src/axolotl/cli/train.py @@ -3,13 +3,11 @@ CLI to run training on a model """ import logging from pathlib import Path -from typing import Tuple, Union +from typing import Union import fire from dotenv import load_dotenv from transformers.hf_argparser import HfArgumentParser -from transformers.modeling_utils import PreTrainedModel -from transformers.tokenization_utils import PreTrainedTokenizer from axolotl.cli import ( check_accelerate_default_config, @@ -20,6 +18,7 @@ from axolotl.cli import ( print_axolotl_text_art, ) from axolotl.common.cli import TrainerCliArgs +from axolotl.integrations.base import PluginManager from axolotl.prompt_strategies.sharegpt import ( register_chatml_template, register_llama3_template, @@ -39,7 +38,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): return do_train(parsed_cfg, parsed_cli_args) -def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: +def do_train(cfg, cli_args) -> None: print_axolotl_text_art() check_accelerate_default_config() check_user_token() @@ -64,7 +63,13 @@ def do_train(cfg, cli_args) -> Tuple[PreTrainedModel, PreTrainedTokenizer]: else: dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) - return train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + model, tokenizer = train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) + plugin_manager = PluginManager.get_instance() + + del model + del tokenizer + + plugin_manager.post_train_unload(cfg) if __name__ == "__main__": diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index d26eed90f..e2bd79bc4 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -159,6 +159,29 @@ class BasePlugin: List[callable]: A list of callback functions to be added to the TrainingArgs """ + def post_train(self, cfg, model): + """ + Performs actions after training is complete. + + Parameters: + cfg (dict): The axolotl configuration + model (object): The loaded model. + + Returns: + None + """ + + def post_train_unload(self, cfg): + """ + Performs actions after training is complete and the model is unloaded. + + Parameters: + cfg (dict): The configuration for the plugin. + + Returns: + None + """ + def load_plugin(plugin_name: str) -> BasePlugin: """ @@ -381,3 +404,17 @@ class PluginManager: for plugin in self.plugins: callbacks.extend(plugin.add_callbacks_post_trainer(cfg, trainer)) return callbacks + + def post_train_unload(self, cfg): + """ + Calls the post_train_unload method of all registered plugins. + + Parameters: + cfg (dict): The configuration for the plugins. + model (object): The loaded model. + + Returns: + None + """ + for plugin in self.plugins: + plugin.post_train_unload(cfg) diff --git a/src/axolotl/integrations/lm_eval/README.md b/src/axolotl/integrations/lm_eval/README.md new file mode 100644 index 000000000..3724c49cc --- /dev/null +++ b/src/axolotl/integrations/lm_eval/README.md @@ -0,0 +1,13 @@ +# LM Eval Harness + +### Usage + +```yaml +plugins: + - axolotl.integrations.lm_eval.LMEvalPlugin + +lm_eval_tasks: + - gsm8k + - hellaswag + - arc_easy +``` diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py new file mode 100644 index 000000000..f1daa2000 --- /dev/null +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -0,0 +1,42 @@ +""" +Module for the Plugin for LM Eval Harness +""" +import subprocess # nosec +from datetime import datetime + +from axolotl.integrations.base import BasePlugin + +from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401 + + +class LMEvalPlugin(BasePlugin): + """ + Plugin for LM Evaluation Harness integraton with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.lm_eval.LMEvalArgs" + + def post_train_unload(self, cfg): + tasks = ",".join(cfg.lm_eval_tasks) + fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else "" + dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16" + output_path = cfg.output_dir + output_path += "" if cfg.output_dir.endswith("/") else "/" + output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") + subprocess.run( # nosec + [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"pretrained={cfg.output_dir}{fa2}{dtype}", + "--tasks", + tasks, + "--batch_size", + str(cfg.lm_eval_batch_size), + "--output_path", + output_path, + ], + check=True, + ) diff --git a/src/axolotl/integrations/lm_eval/args.py b/src/axolotl/integrations/lm_eval/args.py new file mode 100644 index 000000000..f58e6a6e3 --- /dev/null +++ b/src/axolotl/integrations/lm_eval/args.py @@ -0,0 +1,15 @@ +""" +Module for handling lm eval harness input arguments. +""" +from typing import List, Optional + +from pydantic import BaseModel + + +class LMEvalArgs(BaseModel): + """ + Input args for lm eval harness + """ + + lm_eval_tasks: List[str] = [] + lm_eval_batch_size: Optional[int] = 8 diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 76748191b..47796add6 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -980,6 +980,26 @@ class AxolotlInputConfig( "evaluation_strategy must be empty or set to `steps` when used with evals_per_epoch." ) + if data.get("do_bench_eval") and not ( + data.get("evals_per_epoch") or data.get("eval_steps") + ): + raise ValueError( + "do_bench_eval requires evals_per_epoch or eval_steps to be set." + ) + return data + + @model_validator(mode="before") + @classmethod + def check_test_datasets_bench(cls, data): + if ( + data.get("do_bench_eval") + and not data.get("test_datasets") + and not data.get("val_set_size") + ): + LOG.warning( + "`do_bench_eval` needs a test dataset to run evals, adding an empty test_dataset." + ) + data["test_datasets"] = [{"path": "axolotl-ai-co/empty-test-ds"}] return data @model_validator(mode="before")