revision support
This commit is contained in:
@@ -21,6 +21,7 @@ def build_lm_eval_command(
|
|||||||
wandb_project=None,
|
wandb_project=None,
|
||||||
wandb_entity=None,
|
wandb_entity=None,
|
||||||
model=None,
|
model=None,
|
||||||
|
revision=None,
|
||||||
):
|
):
|
||||||
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
tasks_by_num_fewshot: dict[str, list] = defaultdict(list)
|
||||||
for task in tasks:
|
for task in tasks:
|
||||||
@@ -38,6 +39,7 @@ def build_lm_eval_command(
|
|||||||
pretrained += model if model else output_dir
|
pretrained += model if model else output_dir
|
||||||
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
fa2 = ",attn_implementation=flash_attention_2" if flash_attention else ""
|
||||||
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
dtype = ",dtype=bfloat16" if bfloat16 else ",dtype=float16"
|
||||||
|
revision = f",revision={revision}" if revision else ""
|
||||||
output_path = output_dir
|
output_path = output_dir
|
||||||
output_path += "" if output_dir.endswith("/") else "/"
|
output_path += "" if output_dir.endswith("/") else "/"
|
||||||
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
output_path += "lm_eval_results/" + datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||||
@@ -46,7 +48,7 @@ def build_lm_eval_command(
|
|||||||
"--model",
|
"--model",
|
||||||
"hf",
|
"hf",
|
||||||
"--model_args",
|
"--model_args",
|
||||||
f"{pretrained}{fa2}{dtype}",
|
f"{pretrained}{fa2}{dtype}{revision}",
|
||||||
"--tasks",
|
"--tasks",
|
||||||
tasks_str,
|
tasks_str,
|
||||||
"--batch_size",
|
"--batch_size",
|
||||||
@@ -97,6 +99,7 @@ def lm_eval(config: str, cloud: Optional[str] = None):
|
|||||||
wandb_project=cfg.wandb_project,
|
wandb_project=cfg.wandb_project,
|
||||||
wandb_entity=cfg.wandb_entity,
|
wandb_entity=cfg.wandb_entity,
|
||||||
model=cfg.lm_eval_model or cfg.hub_model_id,
|
model=cfg.lm_eval_model or cfg.hub_model_id,
|
||||||
|
revision=cfg.revision,
|
||||||
):
|
):
|
||||||
subprocess.run( # nosec
|
subprocess.run( # nosec
|
||||||
lm_eval_args,
|
lm_eval_args,
|
||||||
|
|||||||
Reference in New Issue
Block a user