From 530bf77cf928e1d0021099a7872c638c2b1c019d Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sun, 12 Jan 2025 05:17:03 -0500 Subject: [PATCH] revision support --- src/axolotl/integrations/lm_eval/cli.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index 2b1fd5654..23aa1845f 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -21,6 +21,7 @@ def build_lm_eval_command( wandb_project=None, wandb_entity=None, model=None, + revision=None, ): tasks_by_num_fewshot: dict[str, list] = defaultdict(list) for task in tasks: @@ -38,6 +39,7 @@ def build_lm_eval_command( pretrained += model if model else output_dir fa2 = ",attn_implementation=flash_attention_2" if flash_attention else "" dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16" + revision = f",revision={revision}" if revision else "" output_path = output_dir output_path += "" if output_dir.endswith("/") else "/" output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") @@ -46,7 +48,7 @@ def build_lm_eval_command( "--model", "hf", "--model_args", - f"{pretrained}{fa2}{dtype}", + f"{pretrained}{fa2}{dtype}{revision}", "--tasks", tasks_str, "--batch_size", @@ -97,6 +99,7 @@ def lm_eval(config: str, cloud: Optional[str] = None): wandb_project=cfg.wandb_project, wandb_entity=cfg.wandb_entity, model=cfg.lm_eval_model or cfg.hub_model_id, + revision=cfg.revision, ): subprocess.run( # nosec lm_eval_args,