lm_eval option to not post eval, and append not extend

This commit is contained in:
Wing Lian
2025-01-06 11:52:07 -05:00
parent 2741d8de23
commit 0390bce7aa
3 changed files with 18 additions and 16 deletions

View File

@@ -18,18 +18,19 @@ class LMEvalPlugin(BasePlugin):
return "axolotl.integrations.lm_eval.LMEvalArgs"
def post_train_unload(self, cfg):
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
hub_model_id=cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)
if cfg.lm_eval_post_train:
# pylint: disable=duplicate-code
for lm_eval_args in build_lm_eval_command(
cfg.lm_eval_tasks,
bfloat16=cfg.bfloat16 or cfg.bf16,
flash_attention=cfg.flash_attention,
output_dir=cfg.output_dir,
batch_size=cfg.lm_eval_batch_size,
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
hub_model_id=cfg.hub_model_id,
):
subprocess.run( # nosec
lm_eval_args,
check=True,
)

View File

@@ -13,3 +13,4 @@ class LMEvalArgs(BaseModel):
lm_eval_tasks: List[str] = []
lm_eval_batch_size: Optional[int] = 8
lm_eval_post_train: Optional[bool] = True

View File

@@ -60,7 +60,7 @@ def build_lm_eval_command(
wandb_args.append(f"entity={wandb_entity}")
if wandb_args:
lm_eval_args.append("--wandb_args")
lm_eval_args.extend(",".join(wandb_args))
lm_eval_args.append(",".join(wandb_args))
if num_fewshot_val:
lm_eval_args.append("--num_fewshot")
lm_eval_args.append(str(num_fewshot_val))