From d33a975747ad6578c2225b464691d06b810e8dc3 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 15 Apr 2023 14:24:37 -0400 Subject: [PATCH] configure log level, add llama 7b config --- scripts/finetune.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/scripts/finetune.py b/scripts/finetune.py index b353be67d..77f1b3c78 100644 --- a/scripts/finetune.py +++ b/scripts/finetune.py @@ -39,6 +39,7 @@ from axolotl.prompt_tokenizers import ( from axolotl.prompters import AlpacaPrompter, GPTeacherPrompter, ShareGPTPrompter logger = logging.getLogger(__name__) +logger.setLevel(os.getenv("LOG_LEVEL", "INFO")) DEFAULT_DATASET_PREPARED_PATH = "data/last_run" @@ -353,9 +354,15 @@ def train( else: datasets = [] for d in cfg.datasets: - ds: IterableDataset = load_dataset( - "json", data_files=d.path, streaming=True, split=None - ) + if Path(d.path).exists(): + ds: IterableDataset = load_dataset( + "json", data_files=d.path, streaming=True, split=None + ) + # elif d.name and d.path: + # # TODO load from huggingface hub, but it only seems to support arrow or parquet atm + # ds = load_dataset(d.path, split=None, data_files=d.name) + else: + raise Exception("unhandled dataset load") if d.type == "alpaca": ds_strategy = AlpacaPromptTokenizingStrategy(