This commit is contained in:
Wing Lian
2023-04-17 18:23:55 -04:00
parent 87e073d0de
commit 120e7df7df

View File

@@ -427,9 +427,10 @@ def train(
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
if any(prepared_ds_path.glob("*")):
logging.info("Loading prepared dataset from disk...")
dataset = load_from_disk(cfg.dataset_prepared_path)
dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared dataset loaded from disk...")
else:
logging.info("Loading raw datasets...")
@@ -437,7 +438,7 @@ def train(
for d in cfg.datasets:
ds_from_hub = False
try:
ds = load_dataset(d.path, streaming=True)
load_dataset(d.path, streaming=True)
ds_from_hub = True
except FileNotFoundError:
pass