bugfixes
This commit is contained in:
@@ -427,9 +427,10 @@ def train(
|
|||||||
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
|
max_packed_sequence_len = min(max_packed_sequence_len, cfg.sequence_len) # make sure we don't accidentally set it larger than sequence_len
|
||||||
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
|
ds_hash = str(md5((str(max_packed_sequence_len) + "@" + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))).encode('utf-8')).hexdigest())
|
||||||
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
prepared_ds_path = Path(cfg.dataset_prepared_path) / ds_hash if cfg.dataset_prepared_path else Path(DEFAULT_DATASET_PREPARED_PATH) / ds_hash
|
||||||
|
|
||||||
if any(prepared_ds_path.glob("*")):
|
if any(prepared_ds_path.glob("*")):
|
||||||
logging.info("Loading prepared dataset from disk...")
|
logging.info("Loading prepared dataset from disk...")
|
||||||
dataset = load_from_disk(cfg.dataset_prepared_path)
|
dataset = load_from_disk(str(prepared_ds_path))
|
||||||
logging.info("Prepared dataset loaded from disk...")
|
logging.info("Prepared dataset loaded from disk...")
|
||||||
else:
|
else:
|
||||||
logging.info("Loading raw datasets...")
|
logging.info("Loading raw datasets...")
|
||||||
@@ -437,7 +438,7 @@ def train(
|
|||||||
for d in cfg.datasets:
|
for d in cfg.datasets:
|
||||||
ds_from_hub = False
|
ds_from_hub = False
|
||||||
try:
|
try:
|
||||||
ds = load_dataset(d.path, streaming=True)
|
load_dataset(d.path, streaming=True)
|
||||||
ds_from_hub = True
|
ds_from_hub = True
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
pass
|
pass
|
||||||
|
|||||||
Reference in New Issue
Block a user