From a4f12415a0c58f943449ad45fa5b80063950c2ae Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Thu, 25 May 2023 23:10:11 -0400 Subject: [PATCH] update readme and add typehints --- README.md | 8 +------- src/axolotl/utils/data.py | 15 ++++++++------- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index af32eb64e..b33235170 100644 --- a/README.md +++ b/README.md @@ -363,13 +363,7 @@ Pass the appropriate flag to the train command: ### Merge LORA to base -Add below flag to train command above (and using LoRA) - -```bash ---merge_lora --lora_model_dir="./completed-model" -``` - -Add below flag to train command above (and using QLoRA) +Add below flag to train command above ```bash --merge_lora --lora_model_dir="./completed-model" --load_in_8bit=False --load_in_4bit=False diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 8e333ca8b..b2045c229 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -1,6 +1,7 @@ import logging from hashlib import md5 from pathlib import Path +from typing import Union from datasets import ( load_from_disk, @@ -80,7 +81,7 @@ def load_tokenized_prepared_datasets( logging.info("Loading raw datasets...") datasets = [] for d in cfg.datasets: - ds = None + ds: Union[Dataset, DatasetDict] = None ds_from_hub = False try: load_dataset(d.path, streaming=True, use_auth_token=True) @@ -90,32 +91,32 @@ def load_tokenized_prepared_datasets( # prefer local dataset, even if hub exists if Path(d.path).exists(): - ds: IterableDataset = load_dataset( + ds: Dataset = load_dataset( "json", data_files=d.path, streaming=False, split=None ) elif ds_from_hub: if d.data_files: - ds = load_dataset( + ds: Dataset = load_dataset( d.path, streaming=False, data_files=d.data_files, use_auth_token=True, ) else: - ds = load_dataset(d.path, streaming=False, use_auth_token=True) + ds: Dataset = load_dataset(d.path, streaming=False, use_auth_token=True) else: fp = hf_hub_download( repo_id=d.path, repo_type="dataset", filename=d.data_files ) - ds = load_dataset("json", data_files=fp, streaming=False, split=None) + ds: Dataset = load_dataset("json", data_files=fp, streaming=False, split=None) if not ds: raise Exception("unhandled dataset load") # support for using a subset of the data if d.shards: if "train" in ds: - ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0) + ds: DatasetDict = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) else: - ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0) + ds: Dataset = ds.shuffle(seed=42).shard(num_shards=d.shards, index=0) d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0]