From 981ad965d0dac6aaa17fdacec870aed2995e07db Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Mon, 6 Jan 2025 17:41:10 -0500 Subject: [PATCH] allow minimal yaml for lm eval --- src/axolotl/integrations/lm_eval/__init__.py | 2 +- src/axolotl/integrations/lm_eval/args.py | 1 + src/axolotl/integrations/lm_eval/cli.py | 13 ++++++++----- 3 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index 23a66ddca..f039cd727 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -28,7 +28,7 @@ class LMEvalPlugin(BasePlugin): batch_size=cfg.lm_eval_batch_size, wandb_project=cfg.wandb_project, wandb_entity=cfg.wandb_entity, - hub_model_id=cfg.hub_model_id, + model=cfg.lm_eval_model or cfg.hub_model_id, ): subprocess.run( # nosec lm_eval_args, diff --git a/src/axolotl/integrations/lm_eval/args.py b/src/axolotl/integrations/lm_eval/args.py index 3bce5c238..721f560e3 100644 --- a/src/axolotl/integrations/lm_eval/args.py +++ b/src/axolotl/integrations/lm_eval/args.py @@ -14,3 +14,4 @@ class LMEvalArgs(BaseModel): lm_eval_tasks: List[str] = [] lm_eval_batch_size: Optional[int] = 8 lm_eval_post_train: Optional[bool] = True + lm_eval_model: Optional[str] = None diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index f979e683e..5e7821a68 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -7,8 +7,9 @@ from datetime import datetime from typing import Optional import click +import yaml -from axolotl.cli import load_cfg +from axolotl.utils.dict import DictDefault def build_lm_eval_command( @@ -19,7 +20,7 @@ def build_lm_eval_command( batch_size=8, wandb_project=None, wandb_entity=None, - hub_model_id=None, + model=None, ): tasks_by_num_fewshot: dict[str, list] = defaultdict(list) for task in tasks: @@ -34,7 +35,7 @@ def build_lm_eval_command( tasks_str = ",".join(tasks_list) num_fewshot_val = num_fewshot if num_fewshot != "-1" else None pretrained = "pretrained=" - pretrained += hub_model_id if hub_model_id else output_dir + pretrained += model if model else output_dir fa2 = ",attn_implementation=flash_attention_2" if flash_attention else "" dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16" output_path = output_dir @@ -81,7 +82,9 @@ def lm_eval(config: str, cloud: Optional[str] = None): do_cli_lm_eval(cloud_config=cloud, config=config) else: - cfg = load_cfg(config) + with open(config, encoding="utf-8") as file: + cfg: DictDefault = DictDefault(yaml.safe_load(file)) + # pylint: disable=duplicate-code for lm_eval_args in build_lm_eval_command( cfg.lm_eval_tasks, @@ -91,7 +94,7 @@ def lm_eval(config: str, cloud: Optional[str] = None): batch_size=cfg.lm_eval_batch_size, wandb_project=cfg.wandb_project, wandb_entity=cfg.wandb_entity, - hub_model_id=cfg.hub_model_id, + model=cfg.lm_eval_model or cfg.hub_model_id, ): subprocess.run( # nosec lm_eval_args,