move list not in list logic to fn

This commit is contained in:
Wing Lian
2023-05-27 16:42:05 -04:00
parent ca1bb92337
commit cc67862dd3

View File

@@ -5,7 +5,7 @@ import random
import signal import signal
import sys import sys
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, List, Dict, Any, Union
import fire import fire
import torch import torch
@@ -117,6 +117,10 @@ def choose_config(path: Path):
return chosen_file return chosen_file
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
return not any(el in list2 for el in list1)
def train( def train(
config: Path = Path("configs/"), config: Path = Path("configs/"),
prepare_ds_only: bool = False, prepare_ds_only: bool = False,
@@ -169,7 +173,7 @@ def train(
cfg cfg
) )
if "inference" not in kwargs and "shard" not in kwargs: # don't need to load dataset for these if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
train_dataset, eval_dataset = load_prepare_datasets( train_dataset, eval_dataset = load_prepare_datasets(
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
) )