diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index 23aa1845f..7e2b4a1e6 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -22,6 +22,8 @@ def build_lm_eval_command( wandb_entity=None, model=None, revision=None, + apply_chat_template=None, + fewshot_as_multiturn=None, ): tasks_by_num_fewshot: dict[str, list] = defaultdict(list) for task in tasks: @@ -55,7 +57,6 @@ def build_lm_eval_command( str(batch_size), "--output_path", output_path, - "--apply_chat_template", ] wandb_args = [] if wandb_project: @@ -65,10 +66,13 @@ def build_lm_eval_command( if wandb_args: lm_eval_args.append("--wandb_args") lm_eval_args.append(",".join(wandb_args)) + if apply_chat_template: + lm_eval_args.append("--apply_chat_template") if num_fewshot_val: lm_eval_args.append("--num_fewshot") lm_eval_args.append(str(num_fewshot_val)) - # lm_eval_args.append("--fewshot_as_multiturn") + if apply_chat_template and fewshot_as_multiturn: + lm_eval_args.append("--fewshot_as_multiturn") yield lm_eval_args @@ -100,6 +104,8 @@ def lm_eval(config: str, cloud: Optional[str] = None): wandb_entity=cfg.wandb_entity, model=cfg.lm_eval_model or cfg.hub_model_id, revision=cfg.revision, + apply_chat_template=cfg.apply_chat_template, + fewshot_as_multiturn=cfg.fewshot_as_multiturn, ): subprocess.run( # nosec lm_eval_args,