move list not in list logic to fn
This commit is contained in:
@@ -5,7 +5,7 @@ import random
|
|||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, List, Dict, Any, Union
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
@@ -117,6 +117,10 @@ def choose_config(path: Path):
|
|||||||
return chosen_file
|
return chosen_file
|
||||||
|
|
||||||
|
|
||||||
|
def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> bool:
|
||||||
|
return not any(el in list2 for el in list1)
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
config: Path = Path("configs/"),
|
config: Path = Path("configs/"),
|
||||||
prepare_ds_only: bool = False,
|
prepare_ds_only: bool = False,
|
||||||
@@ -169,7 +173,7 @@ def train(
|
|||||||
cfg
|
cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
if "inference" not in kwargs and "shard" not in kwargs: # don't need to load dataset for these
|
if check_not_in(["inference", "shard", "merge_lora"], kwargs): # don't need to load dataset for these
|
||||||
train_dataset, eval_dataset = load_prepare_datasets(
|
train_dataset, eval_dataset = load_prepare_datasets(
|
||||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user