allow minimal yaml for lm eval
This commit is contained in:
@@ -28,7 +28,7 @@ class LMEvalPlugin(BasePlugin):
|
|||||||
batch_size=cfg.lm_eval_batch_size,
|
batch_size=cfg.lm_eval_batch_size,
|
||||||
wandb_project=cfg.wandb_project,
|
wandb_project=cfg.wandb_project,
|
||||||
wandb_entity=cfg.wandb_entity,
|
wandb_entity=cfg.wandb_entity,
|
||||||
hub_model_id=cfg.hub_model_id,
|
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||||
):
|
):
|
||||||
subprocess.run( # nosec
|
subprocess.run( # nosec
|
||||||
lm_eval_args,
|
lm_eval_args,
|
||||||
|
|||||||
@@ -14,3 +14,4 @@ class LMEvalArgs(BaseModel):
|
|||||||
lm_eval_tasks: List[str] = []
|
lm_eval_tasks: List[str] = []
|
||||||
lm_eval_batch_size: Optional[int] = 8
|
lm_eval_batch_size: Optional[int] = 8
|
||||||
lm_eval_post_train: Optional[bool] = True
|
lm_eval_post_train: Optional[bool] = True
|
||||||
|
lm_eval_model: Optional[str] = None
|
||||||
|
|||||||
@@ -7,8 +7,9 @@ from datetime import datetime
|
|||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import click
|
import click
|
||||||
|
import yaml
|
||||||
|
|
||||||
from axolotl.cli import load_cfg
|
from axolotl.utils.dict import DictDefault
|
||||||
|
|
||||||
|
|
||||||
def build_lm_eval_command(
|
def build_lm_eval_command(
|
||||||
@@ -19,7 +20,7 @@ def build_lm_eval_command(
|
|||||||
batch_size=8,
|
batch_size=8,
|
||||||
wandb_project=None,
|
wandb_project=None,
|
||||||
wandb_entity=None,
|
wandb_entity=None,
|
||||||
hub_model_id=None,
|
model=None,
|
||||||
):
|
):
|
||||||
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
@@ -34,7 +35,7 @@ def build_lm_eval_command(
|
|||||||
tasks_str = ",".join(tasks_list)
|
tasks_str = ",".join(tasks_list)
|
||||||
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
num_fewshot_val = num_fewshot if num_fewshot != "-1" else None
|
||||||
pretrained = "pretrained="
|
pretrained = "pretrained="
|
||||||
pretrained += hub_model_id if hub_model_id else output_dir
|
pretrained += model if model else output_dir
|
||||||
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||||
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||||
output_path = output_dir
|
output_path = output_dir
|
||||||
@@ -81,7 +82,9 @@ def lm_eval(config: str, cloud: Optional[str] = None):
|
|||||||
|
|
||||||
do_cli_lm_eval(cloud_config=cloud, config=config)
|
do_cli_lm_eval(cloud_config=cloud, config=config)
|
||||||
else:
|
else:
|
||||||
cfg = load_cfg(config)
|
with open(config, encoding="utf-8") as file:
|
||||||
|
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
for lm_eval_args in build_lm_eval_command(
|
for lm_eval_args in build_lm_eval_command(
|
||||||
cfg.lm_eval_tasks,
|
cfg.lm_eval_tasks,
|
||||||
@@ -91,7 +94,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
|
|||||||
batch_size=cfg.lm_eval_batch_size,
|
batch_size=cfg.lm_eval_batch_size,
|
||||||
wandb_project=cfg.wandb_project,
|
wandb_project=cfg.wandb_project,
|
||||||
wandb_entity=cfg.wandb_entity,
|
wandb_entity=cfg.wandb_entity,
|
||||||
hub_model_id=cfg.hub_model_id,
|
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||||
):
|
):
|
||||||
subprocess.run( # nosec
|
subprocess.run( # nosec
|
||||||
lm_eval_args,
|
lm_eval_args,
|
||||||
|
|||||||
Reference in New Issue
Block a user