diff --git a/src/axolotl/integrations/lm_eval/__init__.py b/src/axolotl/integrations/lm_eval/__init__.py index 87d2fc3a9..23a66ddca 100644 --- a/src/axolotl/integrations/lm_eval/__init__.py +++ b/src/axolotl/integrations/lm_eval/__init__.py @@ -18,18 +18,19 @@ class LMEvalPlugin(BasePlugin): return "axolotl.integrations.lm_eval.LMEvalArgs" def post_train_unload(self, cfg): - # pylint: disable=duplicate-code - for lm_eval_args in build_lm_eval_command( - cfg.lm_eval_tasks, - bfloat16=cfg.bfloat16 or cfg.bf16, - flash_attention=cfg.flash_attention, - output_dir=cfg.output_dir, - batch_size=cfg.lm_eval_batch_size, - wandb_project=cfg.wandb_project, - wandb_entity=cfg.wandb_entity, - hub_model_id=cfg.hub_model_id, - ): - subprocess.run( # nosec - lm_eval_args, - check=True, - ) + if cfg.lm_eval_post_train: + # pylint: disable=duplicate-code + for lm_eval_args in build_lm_eval_command( + cfg.lm_eval_tasks, + bfloat16=cfg.bfloat16 or cfg.bf16, + flash_attention=cfg.flash_attention, + output_dir=cfg.output_dir, + batch_size=cfg.lm_eval_batch_size, + wandb_project=cfg.wandb_project, + wandb_entity=cfg.wandb_entity, + hub_model_id=cfg.hub_model_id, + ): + subprocess.run( # nosec + lm_eval_args, + check=True, + ) diff --git a/src/axolotl/integrations/lm_eval/args.py b/src/axolotl/integrations/lm_eval/args.py index f58e6a6e3..3bce5c238 100644 --- a/src/axolotl/integrations/lm_eval/args.py +++ b/src/axolotl/integrations/lm_eval/args.py @@ -13,3 +13,4 @@ class LMEvalArgs(BaseModel): lm_eval_tasks: List[str] = [] lm_eval_batch_size: Optional[int] = 8 + lm_eval_post_train: Optional[bool] = True diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index 352143a37..f979e683e 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -60,7 +60,7 @@ def build_lm_eval_command( wandb_args.append(f"entity={wandb_entity}") if wandb_args: lm_eval_args.append("--wandb_args") - lm_eval_args.extend(",".join(wandb_args)) + lm_eval_args.append(",".join(wandb_args)) if num_fewshot_val: lm_eval_args.append("--num_fewshot") lm_eval_args.append(str(num_fewshot_val))