optimize dataloading to use cache, fix model token embedding sizes

This commit is contained in:
Wing Lian
2023-05-12 13:53:27 -04:00
parent f6d1fa4a85
commit aa3c3f97ae
2 changed files with 74 additions and 14 deletions

View File

@@ -31,13 +31,7 @@ from axolotl.prompters import (
)
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
max_packed_sequence_len = min(
max_packed_sequence_len, cfg.sequence_len
) # make sure we don't accidentally set it larger than sequence_len
def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
ds_hash = str(
md5(
(
@@ -54,7 +48,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
)
if any(prepared_ds_path.glob("*")):
logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...")
logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared dataset loaded from disk...")
else:
@@ -153,14 +147,78 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
)
dataset.save_to_disk(prepared_ds_path)
return dataset
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
max_packed_sequence_len = min(
max_packed_sequence_len, cfg.sequence_len
) # make sure we don't accidentally set it larger than sequence_len
if cfg.max_packed_sequence_len is not None:
constant_len_dataset = ConstantLengthDataset(
tokenizer,
[dataset],
seq_length=max_packed_sequence_len,
# see if we can go ahead and load the stacked dataset
ds_hash = str(
md5(
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
).encode("utf-8")
).hexdigest()
)
logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}")
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
prepared_ds_path = (
Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(default_dataset_prepared_path) / ds_hash
)
if any(prepared_ds_path.glob("*")):
logging.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)
dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared packed dataset loaded from disk...")
else:
dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)
constant_len_dataset = ConstantLengthDataset(
tokenizer,
[dataset],
seq_length=max_packed_sequence_len,
)
logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
)
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
if cfg.local_rank == 0:
logging.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path)
else:
dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)
# filter out bad data
dataset = Dataset.from_list(
[
d
for d in dataset
if len(d["input_ids"]) > cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info(

View File

@@ -181,6 +181,8 @@ def load_model(
for k, v in cfg.tokens.items():
tokenizer.add_special_tokens({k: v})
model.resize_token_embeddings(len(tokenizer))
if cfg.adapter and load_in_8bit and not cfg.load_4bit:
logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model)