diff --git a/src/axolotl/integrations/lm_eval/cli.py b/src/axolotl/integrations/lm_eval/cli.py index 2b1fd5654..23aa1845f 100644 --- a/src/axolotl/integrations/lm_eval/cli.py +++ b/src/axolotl/integrations/lm_eval/cli.py @@ -21,6 +21,7 @@ def build_lm_eval_command( wandb_project=None, wandb_entity=None, model=None, + revision=None, ): tasks_by_num_fewshot: dict[str, list] = defaultdict(list) for task in tasks: @@ -38,6 +39,7 @@ def build_lm_eval_command( pretrained += model if model else output_dir fa2 = ",attn_implementation=flash_attention_2" if flash_attention else "" dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16" + revision = f",revision={revision}" if revision else "" output_path = output_dir output_path += "" if output_dir.endswith("/") else "/" output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S") @@ -46,7 +48,7 @@ def build_lm_eval_command( "--model", "hf", "--model_args", - f"{pretrained}{fa2}{dtype}", + f"{pretrained}{fa2}{dtype}{revision}", "--tasks", tasks_str, "--batch_size", @@ -97,6 +99,7 @@ def lm_eval(config: str, cloud: Optional[str] = None): wandb_project=cfg.wandb_project, wandb_entity=cfg.wandb_entity, model=cfg.lm_eval_model or cfg.hub_model_id, + revision=cfg.revision, ): subprocess.run( # nosec lm_eval_args,