lm_eval option to not post eval, and append not extend
This commit is contained in:
@@ -18,18 +18,19 @@ class LMEvalPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
return "axolotl.integrations.lm_eval.LMEvalArgs"
|
||||||
|
|
||||||
def post_train_unload(self, cfg):
|
def post_train_unload(self, cfg):
|
||||||
# pylint: disable=duplicate-code
|
if cfg.lm_eval_post_train:
|
||||||
for lm_eval_args in build_lm_eval_command(
|
# pylint: disable=duplicate-code
|
||||||
cfg.lm_eval_tasks,
|
for lm_eval_args in build_lm_eval_command(
|
||||||
bfloat16=cfg.bfloat16 or cfg.bf16,
|
cfg.lm_eval_tasks,
|
||||||
flash_attention=cfg.flash_attention,
|
bfloat16=cfg.bfloat16 or cfg.bf16,
|
||||||
output_dir=cfg.output_dir,
|
flash_attention=cfg.flash_attention,
|
||||||
batch_size=cfg.lm_eval_batch_size,
|
output_dir=cfg.output_dir,
|
||||||
wandb_project=cfg.wandb_project,
|
batch_size=cfg.lm_eval_batch_size,
|
||||||
wandb_entity=cfg.wandb_entity,
|
wandb_project=cfg.wandb_project,
|
||||||
hub_model_id=cfg.hub_model_id,
|
wandb_entity=cfg.wandb_entity,
|
||||||
):
|
hub_model_id=cfg.hub_model_id,
|
||||||
subprocess.run( # nosec
|
):
|
||||||
lm_eval_args,
|
subprocess.run( # nosec
|
||||||
check=True,
|
lm_eval_args,
|
||||||
)
|
check=True,
|
||||||
|
)
|
||||||
|
|||||||
@@ -13,3 +13,4 @@ class LMEvalArgs(BaseModel):
|
|||||||
|
|
||||||
lm_eval_tasks: List[str] = []
|
lm_eval_tasks: List[str] = []
|
||||||
lm_eval_batch_size: Optional[int] = 8
|
lm_eval_batch_size: Optional[int] = 8
|
||||||
|
lm_eval_post_train: Optional[bool] = True
|
||||||
|
|||||||
@@ -60,7 +60,7 @@ def build_lm_eval_command(
|
|||||||
wandb_args.append(f"entity={wandb_entity}")
|
wandb_args.append(f"entity={wandb_entity}")
|
||||||
if wandb_args:
|
if wandb_args:
|
||||||
lm_eval_args.append("--wandb_args")
|
lm_eval_args.append("--wandb_args")
|
||||||
lm_eval_args.extend(",".join(wandb_args))
|
lm_eval_args.append(",".join(wandb_args))
|
||||||
if num_fewshot_val:
|
if num_fewshot_val:
|
||||||
lm_eval_args.append("--num_fewshot")
|
lm_eval_args.append("--num_fewshot")
|
||||||
lm_eval_args.append(str(num_fewshot_val))
|
lm_eval_args.append(str(num_fewshot_val))
|
||||||
|
|||||||
Reference in New Issue
Block a user