allow minimal yaml for lm eval
This commit is contained in:
@@ -28,7 +28,7 @@ class LMEvalPlugin(BasePlugin):
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
hub_model_id=cfg.hub_model_id,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
|
||||
@@ -14,3 +14,4 @@ class LMEvalArgs(BaseModel):
|
||||
lm_eval_tasks: List[str] = []
|
||||
lm_eval_batch_size: Optional[int] = 8
|
||||
lm_eval_post_train: Optional[bool] = True
|
||||
lm_eval_model: Optional[str] = None
|
||||
|
||||
@@ -7,8 +7,9 @@ from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
import yaml
|
||||
|
||||
from axolotl.cli import load_cfg
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
def build_lm_eval_command(
|
||||
@@ -19,7 +20,7 @@ def build_lm_eval_command(
|
||||
batch_size=8,
|
||||
wandb_project=None,
|
||||
wandb_entity=None,
|
||||
hub_model_id=None,
|
||||
model=None,
|
||||
):
|
||||
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||
for task in tasks:
|
||||
@@ -34,7 +35,7 @@ def build_lm_eval_command(
|
||||
tasks_str = ",".join(tasks_list)
|
||||
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
||||
pretrained = "pretrained="
|
||||
pretrained += hub_model_id if hub_model_id else output_dir
|
||||
pretrained += model if model else output_dir
|
||||
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||
output_path = output_dir
|
||||
@@ -81,7 +82,9 @@ def lm_eval(config: str, cloud: Optional[str] = None):
|
||||
|
||||
do_cli_lm_eval(cloud_config=cloud, config=config)
|
||||
else:
|
||||
cfg = load_cfg(config)
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
for lm_eval_args in build_lm_eval_command(
|
||||
cfg.lm_eval_tasks,
|
||||
@@ -91,7 +94,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
|
||||
batch_size=cfg.lm_eval_batch_size,
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
hub_model_id=cfg.hub_model_id,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
|
||||
Reference in New Issue
Block a user