allow minimal yaml for lm eval

This commit is contained in:
Wing Lian
2025-01-06 17:41:10 -05:00
parent 7ba701a355
commit 981ad965d0
3 changed files with 10 additions and 6 deletions

View File

@@ -28,7 +28,7 @@ class LMEvalPlugin(BasePlugin):
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
hub_model_id=cfg.hub_model_id,
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,

View File

@@ -14,3 +14,4 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True
lm_eval_model: Optional[str] = None

View File

@@ -7,8 +7,9 @@ from datetime import datetime
from typing import Optional
import click
import yaml
from axolotl.cli import load_cfg
from axolotl.utils.dict import DictDefault
def build_lm_eval_command(
@@ -19,7 +20,7 @@ def build_lm_eval_command(
batch_size=8,
wandb_project=None,
wandb_entity=None,
hub_model_id=None,
model=None,
):
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
for task in tasks:
@@ -34,7 +35,7 @@ def build_lm_eval_command(
tasks_str = ",".join(tasks_list)
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
pretrained = "pretrained="
pretrained += hub_model_id if hub_model_id else output_dir
pretrained += model if model else output_dir
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
output_path = output_dir
@@ -81,7 +82,9 @@ def lm_eval(config: str, cloud: Optional[str] = None):
do_cli_lm_eval(cloud_config=cloud, config=config)
else:
cfg = load_cfg(config)
with open(config, encoding="utf-8") as file:
cfg: DictDefault = DictDefault(yaml.safe_load(file))
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
@@ -91,7 +94,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
hub_model_id=cfg.hub_model_id,
model=cfg.lm_eval_model or cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,