lm_eval option to not post eval, and append not extend

This commit is contained in:
Wing Lian
2025-01-06 11:52:07 -05:00
parent 2741d8de23
commit 0390bce7aa
3 changed files with 18 additions and 16 deletions

View File

@@ -18,6 +18,7 @@ class LMEvalPlugin(BasePlugin):
return "axolotl.integrations.lm_eval.LMEvalArgs" return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg): def post_train_unload(self, cfg):
if cfg.lm_eval_post_train:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command( for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks, cfg.lm_eval_tasks,

View File

@@ -13,3 +13,4 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = [] lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8 lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True

View File

@@ -60,7 +60,7 @@ def build_lm_eval_command(
wandb_args.append(f"entity={wandb_entity}") wandb_args.append(f"entity={wandb_entity}")
if wandb_args: if wandb_args:
lm_eval_args.append("--wandb_args") lm_eval_args.append("--wandb_args")
lm_eval_args.extend(",".join(wandb_args)) lm_eval_args.append(",".join(wandb_args))
if num_fewshot_val: if num_fewshot_val:
lm_eval_args.append("--num_fewshot") lm_eval_args.append("--num_fewshot")
lm_eval_args.append(str(num_fewshot_val)) lm_eval_args.append(str(num_fewshot_val))