revision support
This commit is contained in:
@@ -21,6 +21,7 @@ def build_lm_eval_command(
|
||||
wandb_project=None,
|
||||
wandb_entity=None,
|
||||
model=None,
|
||||
revision=None,
|
||||
):
|
||||
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||
for task in tasks:
|
||||
@@ -38,6 +39,7 @@ def build_lm_eval_command(
|
||||
pretrained += model if model else output_dir
|
||||
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||
revision = f",revision={revision}" if revision else ""
|
||||
output_path = output_dir
|
||||
output_path += "" if output_dir.endswith("/") else "/"
|
||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
@@ -46,7 +48,7 @@ def build_lm_eval_command(
|
||||
"--model",
|
||||
"hf",
|
||||
"--model_args",
|
||||
f"{pretrained}{fa2}{dtype}",
|
||||
f"{pretrained}{fa2}{dtype}{revision}",
|
||||
"--tasks",
|
||||
tasks_str,
|
||||
"--batch_size",
|
||||
@@ -97,6 +99,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
|
||||
wandb_project=cfg.wandb_project,
|
||||
wandb_entity=cfg.wandb_entity,
|
||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||
revision=cfg.revision,
|
||||
):
|
||||
subprocess.run( # nosec
|
||||
lm_eval_args,
|
||||
|
||||
Reference in New Issue
Block a user