optimize dataloading to use cache, fix model token embedding sizes
This commit is contained in:
@@ -31,13 +31,7 @@ from axolotl.prompters import (
|
||||
)
|
||||
|
||||
|
||||
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
max_packed_sequence_len = min(
|
||||
max_packed_sequence_len, cfg.sequence_len
|
||||
) # make sure we don't accidentally set it larger than sequence_len
|
||||
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
ds_hash = str(
|
||||
md5(
|
||||
(
|
||||
@@ -54,7 +48,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
)
|
||||
|
||||
if any(prepared_ds_path.glob("*")):
|
||||
logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...")
|
||||
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
logging.info("Prepared dataset loaded from disk...")
|
||||
else:
|
||||
@@ -153,14 +147,78 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
)
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
|
||||
max_packed_sequence_len = (
|
||||
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
|
||||
)
|
||||
max_packed_sequence_len = min(
|
||||
max_packed_sequence_len, cfg.sequence_len
|
||||
) # make sure we don't accidentally set it larger than sequence_len
|
||||
|
||||
if cfg.max_packed_sequence_len is not None:
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
[dataset],
|
||||
seq_length=max_packed_sequence_len,
|
||||
# see if we can go ahead and load the stacked dataset
|
||||
|
||||
ds_hash = str(
|
||||
md5(
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ str(max_packed_sequence_len)
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
||||
).encode("utf-8")
|
||||
).hexdigest()
|
||||
)
|
||||
logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
|
||||
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
||||
prepared_ds_path = (
|
||||
Path(cfg.dataset_prepared_path) / ds_hash
|
||||
if cfg.dataset_prepared_path
|
||||
else Path(default_dataset_prepared_path) / ds_hash
|
||||
)
|
||||
|
||||
if any(prepared_ds_path.glob("*")):
|
||||
logging.info(
|
||||
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
|
||||
)
|
||||
dataset = load_from_disk(str(prepared_ds_path))
|
||||
logging.info("Prepared packed dataset loaded from disk...")
|
||||
else:
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
constant_len_dataset = ConstantLengthDataset(
|
||||
tokenizer,
|
||||
[dataset],
|
||||
seq_length=max_packed_sequence_len,
|
||||
)
|
||||
logging.info(
|
||||
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
|
||||
)
|
||||
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
|
||||
|
||||
if cfg.local_rank == 0:
|
||||
logging.info(
|
||||
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
|
||||
)
|
||||
dataset.save_to_disk(prepared_ds_path)
|
||||
else:
|
||||
dataset = load_tokenized_prepared_datasets(
|
||||
tokenizer, cfg, default_dataset_prepared_path
|
||||
)
|
||||
|
||||
# filter out bad data
|
||||
dataset = Dataset.from_list(
|
||||
[
|
||||
d
|
||||
for d in dataset
|
||||
if len(d["input_ids"]) > cfg.sequence_len
|
||||
and len(d["input_ids"]) > 0
|
||||
and len(d["input_ids"]) == len(d["attention_mask"])
|
||||
and len(d["input_ids"]) == len(d["labels"])
|
||||
]
|
||||
)
|
||||
|
||||
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
|
||||
logging.info(
|
||||
|
||||
@@ -181,6 +181,8 @@ def load_model(
|
||||
for k, v in cfg.tokens.items():
|
||||
tokenizer.add_special_tokens({k: v})
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
|
||||
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
|
||||
model = prepare_model_for_int8_training(model)
|
||||
|
||||
Reference in New Issue
Block a user