diff --git a/scripts/finetune.py b/scripts/finetune.py index 3d72fb1d9..8d7a18a4a 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -5,7 +5,7 @@ import random import signal import sys from pathlib import Path -from typing import Optional +from typing import Optional, List, Dict, Any, Union import fire import torch @@ -117,6 +117,10 @@ def choose_config(path: Path): return chosen_file +def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool: + return not any(el in list2 for el in list1) + + def train( config: Path = Path("configs/"), prepare_ds_only: bool = False, @@ -169,7 +173,7 @@ def train( cfg ) - if "inference" not in kwargs and "shard" not in kwargs: # don't need to load dataset for these + if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these train_dataset, eval_dataset = load_prepare_datasets( tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH )