bugfixes
This commit is contained in:
@@ -427,9 +427,10 @@ def train(
|
||||
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
|
||||
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
|
||||
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
||||
|
||||
if any(prepared_ds_path.glob("*")):
|
||||
logging.info("Loading prepared dataset from disk...")
|
||||
dataset = load_from_disk(cfg.dataset_prepared_path)
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
logging.info("Prepared dataset loaded from disk...")
|
||||
else:
|
||||
logging.info("Loading raw datasets...")
|
||||
@@ -437,7 +438,7 @@ def train(
|
||||
for d in cfg.datasets:
|
||||
ds_from_hub = False
|
||||
try:
|
||||
ds = load_dataset(d.path, streaming=True)
|
||||
load_dataset(d.path, streaming=True)
|
||||
ds_from_hub = True
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user