do lm_eval in cloud too
This commit is contained in:
@@ -42,3 +42,15 @@ def do_cli_train(
|
|||||||
with open(config, "r", encoding="utf-8") as file:
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
config_yaml = file.read()
|
config_yaml = file.read()
|
||||||
cloud.train(config_yaml, accelerate=accelerate)
|
cloud.train(config_yaml, accelerate=accelerate)
|
||||||
|
|
||||||
|
|
||||||
|
def do_cli_lm_eval(
|
||||||
|
cloud_config: Union[Path, str],
|
||||||
|
config: Union[Path, str] = Path("examples/"),
|
||||||
|
) -> None:
|
||||||
|
print_axolotl_text_art()
|
||||||
|
cloud_cfg = load_cloud_cfg(cloud_config)
|
||||||
|
cloud = ModalCloud(cloud_cfg)
|
||||||
|
with open(config, "r", encoding="utf-8") as file:
|
||||||
|
config_yaml = file.read()
|
||||||
|
cloud.lm_eval(config_yaml)
|
||||||
|
|||||||
@@ -197,6 +197,15 @@ class ModalCloud(Cloud):
|
|||||||
volumes={k: v[0] for k, v in self.volumes.items()},
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def lm_eval(self, config_yaml: str):
|
||||||
|
modal_fn = self.get_train_env()(_lm_eval)
|
||||||
|
with modal.enable_output():
|
||||||
|
with self.app.run(detach=True):
|
||||||
|
modal_fn.remote(
|
||||||
|
config_yaml,
|
||||||
|
volumes={k: v[0] for k, v in self.volumes.items()},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _preprocess(config_yaml: str, volumes=None):
|
def _preprocess(config_yaml: str, volumes=None):
|
||||||
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
Path("/workspace/artifacts/axolotl").mkdir(parents=True, exist_ok=True)
|
||||||
@@ -227,3 +236,16 @@ def _train(config_yaml: str, accelerate: bool = True, volumes=None):
|
|||||||
run_folder,
|
run_folder,
|
||||||
volumes,
|
volumes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _lm_eval(config_yaml: str, volumes=None):
|
||||||
|
with open(
|
||||||
|
"/workspace/artifacts/axolotl/config.yaml", "w", encoding="utf-8"
|
||||||
|
) as f_out:
|
||||||
|
f_out.write(config_yaml)
|
||||||
|
run_folder = "/workspace/artifacts/axolotl"
|
||||||
|
run_cmd(
|
||||||
|
"axolotl lm_eval /workspace/artifacts/axolotl/config.yaml",
|
||||||
|
run_folder,
|
||||||
|
volumes,
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from axolotl.cli.utils import (
|
|||||||
fetch_from_github,
|
fetch_from_github,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
|
from axolotl.integrations.lm_eval.cli import lm_eval
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
from axolotl.utils.config.models.input.v0_4_1 import AxolotlInputConfig
|
||||||
|
|
||||||
@@ -268,6 +269,9 @@ def fetch(directory: str, dest: Optional[str]):
|
|||||||
fetch_from_github(f"{directory}/", dest)
|
fetch_from_github(f"{directory}/", dest)
|
||||||
|
|
||||||
|
|
||||||
|
cli.add_command(lm_eval)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
cli()
|
cli()
|
||||||
|
|
||||||
|
|||||||
@@ -2,9 +2,9 @@
|
|||||||
Module for the Plugin for LM Eval Harness
|
Module for the Plugin for LM Eval Harness
|
||||||
"""
|
"""
|
||||||
import subprocess # nosec
|
import subprocess # nosec
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from axolotl.integrations.base import BasePlugin
|
from axolotl.integrations.base import BasePlugin
|
||||||
|
from axolotl.integrations.lm_eval.cli import build_lm_eval_command
|
||||||
|
|
||||||
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
from .args import LMEvalArgs # pylint: disable=unused-import. # noqa: F401
|
||||||
|
|
||||||
@@ -18,25 +18,18 @@ class LMEvalPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
tasks = ",".join(cfg.lm_eval_tasks)
|
# pylint: disable=duplicate-code
|
||||||
fa2 = ",attn_implementation=flash_attention_2" if cfg.flash_attention else ""
|
for lm_eval_args in build_lm_eval_command(
|
||||||
dtype = ",dtype=bfloat16" if cfg.bf16 else ",dtype=float16"
|
cfg.lm_eval_tasks,
|
||||||
output_path = cfg.output_dir
|
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||||
output_path += "" if cfg.output_dir.endswith("/") else "/"
|
flash_attention=cfg.flash_attention,
|
||||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
output_dir=cfg.output_dir,
|
||||||
subprocess.run( # nosec
|
batch_size=cfg.lm_eval_batch_size,
|
||||||
[
|
wandb_project=cfg.wandb_project,
|
||||||
"lm_eval",
|
wandb_entity=cfg.wandb_entity,
|
||||||
"--model",
|
hub_model_id=cfg.hub_model_id,
|
||||||
"hf",
|
):
|
||||||
"--model_args",
|
subprocess.run( # nosec
|
||||||
f"pretrained={cfg.output_dir}{fa2}{dtype}",
|
lm_eval_args,
|
||||||
"--tasks",
|
check=True,
|
||||||
tasks,
|
)
|
||||||
"--batch_size",
|
|
||||||
str(cfg.lm_eval_batch_size),
|
|
||||||
"--output_path",
|
|
||||||
output_path,
|
|
||||||
],
|
|
||||||
check=True,
|
|
||||||
)
|
|
||||||
|
|||||||
99
src/axolotl/integrations/lm_eval/cli.py
Normal file
99
src/axolotl/integrations/lm_eval/cli.py
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
"""
|
||||||
|
axolotl CLI for running lm_eval tasks
|
||||||
|
"""
|
||||||
|
import subprocess # nosec
|
||||||
|
from collections import defaultdict
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import click
|
||||||
|
|
||||||
|
from axolotl.cli import load_cfg
|
||||||
|
|
||||||
|
|
||||||
|
def build_lm_eval_command(
|
||||||
|
tasks: list[str],
|
||||||
|
bfloat16=True,
|
||||||
|
flash_attention=False,
|
||||||
|
output_dir="./",
|
||||||
|
batch_size=8,
|
||||||
|
wandb_project=None,
|
||||||
|
wandb_entity=None,
|
||||||
|
hub_model_id=None,
|
||||||
|
):
|
||||||
|
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||||
|
for task in tasks:
|
||||||
|
num_fewshot = "-1"
|
||||||
|
task_parts = task.split(":")
|
||||||
|
task_name = task_parts[0]
|
||||||
|
if len(task_parts) == 2:
|
||||||
|
task_name, num_fewshot = task_parts
|
||||||
|
tasks_by_num_fewshot[str(num_fewshot)].append(task_name)
|
||||||
|
|
||||||
|
for num_fewshot, tasks_list in tasks_by_num_fewshot.items():
|
||||||
|
tasks_str = ",".join(tasks_list)
|
||||||
|
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
||||||
|
pretrained = "pretrained="
|
||||||
|
pretrained += hub_model_id if hub_model_id else output_dir
|
||||||
|
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||||
|
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||||
|
output_path = output_dir
|
||||||
|
output_path += "" if output_dir.endswith("/") else "/"
|
||||||
|
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
|
lm_eval_args = [
|
||||||
|
"lm_eval",
|
||||||
|
"--model",
|
||||||
|
"hf",
|
||||||
|
"--model_args",
|
||||||
|
f"{pretrained}{fa2}{dtype}",
|
||||||
|
"--tasks",
|
||||||
|
tasks_str,
|
||||||
|
"--batch_size",
|
||||||
|
str(batch_size),
|
||||||
|
"--output_path",
|
||||||
|
output_path,
|
||||||
|
]
|
||||||
|
wandb_args = []
|
||||||
|
if wandb_project:
|
||||||
|
wandb_args.append(f"project={wandb_project}")
|
||||||
|
if wandb_entity:
|
||||||
|
wandb_args.append(f"entity={wandb_entity}")
|
||||||
|
if wandb_args:
|
||||||
|
lm_eval_args.append("--wandb_args")
|
||||||
|
lm_eval_args.extend(",".join(wandb_args))
|
||||||
|
if num_fewshot_val:
|
||||||
|
lm_eval_args.append("--num_fewshot")
|
||||||
|
lm_eval_args.append(str(num_fewshot_val))
|
||||||
|
|
||||||
|
yield lm_eval_args
|
||||||
|
|
||||||
|
|
||||||
|
@click.command()
|
||||||
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
|
@click.option("--cloud", default=None, type=click.Path(exists=True, path_type=str))
|
||||||
|
def lm_eval(config: str, cloud: Optional[str] = None):
|
||||||
|
"""
|
||||||
|
use lm eval to evaluate a trained language model
|
||||||
|
"""
|
||||||
|
|
||||||
|
if cloud:
|
||||||
|
from axolotl.cli.cloud import do_cli_lm_eval
|
||||||
|
|
||||||
|
do_cli_lm_eval(cloud_config=cloud, config=config)
|
||||||
|
else:
|
||||||
|
cfg = load_cfg(config)
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
|
for lm_eval_args in build_lm_eval_command(
|
||||||
|
cfg.lm_eval_tasks,
|
||||||
|
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||||
|
flash_attention=cfg.flash_attention,
|
||||||
|
output_dir=cfg.output_dir,
|
||||||
|
batch_size=cfg.lm_eval_batch_size,
|
||||||
|
wandb_project=cfg.wandb_project,
|
||||||
|
wandb_entity=cfg.wandb_entity,
|
||||||
|
hub_model_id=cfg.hub_model_id,
|
||||||
|
):
|
||||||
|
subprocess.run( # nosec
|
||||||
|
lm_eval_args,
|
||||||
|
check=True,
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user