revision support

This commit is contained in:
Wing Lian
2025-01-12 05:17:03 -05:00
parent bfc91a91ca
commit 530bf77cf9

View File

@@ -21,6 +21,7 @@ def build_lm_eval_command(
wandb_project=None,
wandb_entity=None,
model=None,
revision=None,
):
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
for task in tasks:
@@ -38,6 +39,7 @@ def build_lm_eval_command(
pretrained += model if model else output_dir
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
revision = f",revision={revision}" if revision else ""
output_path = output_dir
output_path += "" if output_dir.endswith("/") else "/"
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
@@ -46,7 +48,7 @@ def build_lm_eval_command(
"--model",
"hf",
"--model_args",
f"{pretrained}{fa2}{dtype}",
f"{pretrained}{fa2}{dtype}{revision}",
"--tasks",
tasks_str,
"--batch_size",
@@ -97,6 +99,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
wandb_project=cfg.wandb_project,
wandb_entity=cfg.wandb_entity,
model=cfg.lm_eval_model or cfg.hub_model_id,
revision=cfg.revision,
):
subprocess.run( # nosec
lm_eval_args,