From 27a88f37cd1d8ffedb10d63f9072c446b80f3897 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 6 Jan 2025 11:17:14 -0500 Subject: [PATCH] do lm_eval in cloud too --- src/axolotl/cli/cloud/__init__.py | 12 +++ src/axolotl/cli/cloud/modal_.py | 22 +++++ src/axolotl/cli/main.py | 4 + src/axolotl/integrations/lm_eval/__init__.py | 39 ++++---- src/axolotl/integrations/lm_eval/cli.py | 99 ++++++++++++++++++++ 5 files changed, 153 insertions(+), 23 deletions(-) create mode 100644 src/axolotl/integrations/lm_eval/cli.py diff --git a/src/axolotl/cli/cloud/__init__.py b/src/axolotl/cli/cloud/__init__.py index ef3cbc5c9..72ce68861 100644 --- a/src/axolotl/cli/cloud/__init__.py +++ b/src/axolotl/cli/cloud/__init__.py @@ -42,3 +42,15 @@ def do_cli_train( with open(config, "r", encoding="utf-8") as file: config_yaml = file.read() cloud.train(config_yaml, accelerate=accelerate) + + +def do_cli_lm_eval( + cloud_config: Union[Path, str], + config: Union[Path, str] = Path("examples/"), +) -> None: + print_axolotl_text_art() + cloud_cfg = load_cloud_cfg(cloud_config) + cloud = ModalCloud(cloud_cfg) + with open(config, "r", encoding="utf-8") as file: + config_yaml = file.read() + cloud.lm_eval(config_yaml) diff --git a/src/axolotl/cli/cloud/modal_.py b/src/axolotl/cli/cloud/modal_.py index 724e8cb20..82d443d5b 100644 --- a/src/axolotl/cli/cloud/modal_.py +++ b/src/axolotl/cli/cloud/modal_.py @@ -197,6 +197,15 @@ class ModalCloud(Cloud): volumes={k: v[0] for k, v in self.volumes.items()}, ) + def lm_eval(self, config_yaml: str): + modal_fn = self.get_train_env()(_lm_eval) + with modal.enable_output(): + with self.app.run(detach=True): + modal_fn.remote( + config_yaml, + volumes={k: v[0] for k, v in self.volumes.items()}, + ) + def _preprocess(config_yaml: str, volumes=None): Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True) @@ -227,3 +236,16 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None): run_folder, volumes, ) + + +def _lm_eval(config_yaml: str, volumes=None): + with open( + "/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8" + ) as f_out: + f_out.write(config_yaml) + run_folder = "/workspace/artifacts/axolotl" + run_cmd( + "axolotl lm_eval /workspace/artifacts/axolotl/config.yaml", + run_folder, + volumes, + ) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index f3950f4b2..2e23aa440 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -13,6 +13,7 @@ from axolotl.cli.utils import ( fetch_from_github, ) from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs +from axolotl.integrations.lm_eval.cli import lm_eval from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig @@ -268,6 +269,9 @@ def fetch(directory: str, dest: Optional[str]): fetch_from_github(f"{directory}/", dest) +cli.add_command(lm_eval) + + def main(): cli() diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index f1daa2000..87d2fc3a9 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -2,9 +2,9 @@ Module for the Plugin for LM Eval Harness """ import subprocess # nosec -from datetime import datetime from axolotl.integrations.base import BasePlugin +from axolotl.integrations.lm_eval.cli import build_lm_eval_command from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401 @@ -18,25 +18,18 @@ class LMEvalPlugin(BasePlugin): return "axolotl.integrations.lm_eval.LMEvalArgs" def post_train_unload(self, cfg): - tasks = ",".join(cfg.lm_eval_tasks) - fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else "" - dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16" - output_path = cfg.output_dir - output_path += "" if cfg.output_dir.endswith("/") else "/" - output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") - subprocess.run( # nosec - [ - "lm_eval", - "--model", - "hf", - "--model_args", - f"pretrained={cfg.output_dir}{fa2}{dtype}", - "--tasks", - tasks, - "--batch_size", - str(cfg.lm_eval_batch_size), - "--output_path", - output_path, - ], - check=True, - ) + # pylint: disable=duplicate-code + for lm_eval_args in build_lm_eval_command( + cfg.lm_eval_tasks, + bfloat16=cfg.bfloat16 or cfg.bf16, + flash_attention=cfg.flash_attention, + output_dir=cfg.output_dir, + batch_size=cfg.lm_eval_batch_size, + wandb_project=cfg.wandb_project, + wandb_entity=cfg.wandb_entity, + hub_model_id=cfg.hub_model_id, + ): + subprocess.run( # nosec + lm_eval_args, + check=True, + ) diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py new file mode 100644 index 000000000..352143a37 --- /dev/null +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -0,0 +1,99 @@ +""" +axolotl CLI for running lm_eval tasks +""" +import subprocess # nosec +from collections import defaultdict +from datetime import datetime +from typing import Optional + +import click + +from axolotl.cli import load_cfg + + +def build_lm_eval_command( + tasks: list[str], + bfloat16=True, + flash_attention=False, + output_dir="./", + batch_size=8, + wandb_project=None, + wandb_entity=None, + hub_model_id=None, +): + tasks_by_num_fewshot: dict[str, list] = defaultdict(list) + for task in tasks: + num_fewshot = "-1" + task_parts = task.split(":") + task_name = task_parts[0] + if len(task_parts) == 2: + task_name, num_fewshot = task_parts + tasks_by_num_fewshot[str(num_fewshot)].append(task_name) + + for num_fewshot, tasks_list in tasks_by_num_fewshot.items(): + tasks_str = ",".join(tasks_list) + num_fewshot_val = num_fewshot if num_fewshot != "-1" else None + pretrained = "pretrained=" + pretrained += hub_model_id if hub_model_id else output_dir + fa2 = ",attn_implementation=flash_attention_2" if flash_attention else "" + dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16" + output_path = output_dir + output_path += "" if output_dir.endswith("/") else "/" + output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") + lm_eval_args = [ + "lm_eval", + "--model", + "hf", + "--model_args", + f"{pretrained}{fa2}{dtype}", + "--tasks", + tasks_str, + "--batch_size", + str(batch_size), + "--output_path", + output_path, + ] + wandb_args = [] + if wandb_project: + wandb_args.append(f"project={wandb_project}") + if wandb_entity: + wandb_args.append(f"entity={wandb_entity}") + if wandb_args: + lm_eval_args.append("--wandb_args") + lm_eval_args.extend(",".join(wandb_args)) + if num_fewshot_val: + lm_eval_args.append("--num_fewshot") + lm_eval_args.append(str(num_fewshot_val)) + + yield lm_eval_args + + +@click.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str)) +def lm_eval(config: str, cloud: Optional[str] = None): + """ + use lm eval to evaluate a trained language model + """ + + if cloud: + from axolotl.cli.cloud import do_cli_lm_eval + + do_cli_lm_eval(cloud_config=cloud, config=config) + else: + cfg = load_cfg(config) + # pylint: disable=duplicate-code + for lm_eval_args in build_lm_eval_command( + cfg.lm_eval_tasks, + bfloat16=cfg.bfloat16 or cfg.bf16, + flash_attention=cfg.flash_attention, + output_dir=cfg.output_dir, + batch_size=cfg.lm_eval_batch_size, + wandb_project=cfg.wandb_project, + wandb_entity=cfg.wandb_entity, + hub_model_id=cfg.hub_model_id, + ): + subprocess.run( # nosec + lm_eval_args, + check=True, + )