lm_eval option to not post eval, and append not extend

This commit is contained in:
Wing Lian
2025-01-06 11:52:07 -05:00
parent 2741d8de23
commit 0390bce7aa
3 changed files with 18 additions and 16 deletions

View File

@@ -18,18 +18,19 @@ class LMEvalPlugin(BasePlugin):
return "axolotl.integrations.lm_eval.LMEvalArgs" return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg): def post_train_unload(self, cfg):
# pylint: disable=duplicate-code if cfg.lm_eval_post_train:
for lm_eval_args in build_lm_eval_command( # pylint: disable=duplicate-code
cfg.lm_eval_tasks, for lm_eval_args in build_lm_eval_command(
bfloat16=cfg.bfloat16 or cfg.bf16, cfg.lm_eval_tasks,
flash_attention=cfg.flash_attention, bfloat16=cfg.bfloat16 or cfg.bf16,
output_dir=cfg.output_dir, flash_attention=cfg.flash_attention,
batch_size=cfg.lm_eval_batch_size, output_dir=cfg.output_dir,
wandb_project=cfg.wandb_project, batch_size=cfg.lm_eval_batch_size,
wandb_entity=cfg.wandb_entity, wandb_project=cfg.wandb_project,
hub_model_id=cfg.hub_model_id, wandb_entity=cfg.wandb_entity,
): hub_model_id=cfg.hub_model_id,
subprocess.run( # nosec ):
lm_eval_args, subprocess.run( # nosec
check=True, lm_eval_args,
) check=True,
)

View File

@@ -13,3 +13,4 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = [] lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8 lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True

View File

@@ -60,7 +60,7 @@ def build_lm_eval_command(
wandb_args.append(f"entity={wandb_entity}") wandb_args.append(f"entity={wandb_entity}")
if wandb_args: if wandb_args:
lm_eval_args.append("--wandb_args") lm_eval_args.append("--wandb_args")
lm_eval_args.extend(",".join(wandb_args)) lm_eval_args.append(",".join(wandb_args))
if num_fewshot_val: if num_fewshot_val:
lm_eval_args.append("--num_fewshot") lm_eval_args.append("--num_fewshot")
lm_eval_args.append(str(num_fewshot_val)) lm_eval_args.append(str(num_fewshot_val))