optimize dataloading to use cache, fix model token embedding sizes

This commit is contained in:
Wing Lian
2023-05-12 13:53:27 -04:00
parent f6d1fa4a85
commit aa3c3f97ae
2 changed files with 74 additions and 14 deletions

View File

@@ -31,13 +31,7 @@ from axolotl.prompters import (
) )
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path):
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
max_packed_sequence_len = min(
max_packed_sequence_len, cfg.sequence_len
) # make sure we don't accidentally set it larger than sequence_len
ds_hash = str( ds_hash = str(
md5( md5(
( (
@@ -54,7 +48,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
) )
if any(prepared_ds_path.glob("*")): if any(prepared_ds_path.glob("*")):
logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...") logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path)) dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared dataset loaded from disk...") logging.info("Prepared dataset loaded from disk...")
else: else:
@@ -153,14 +147,78 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
) )
dataset.save_to_disk(prepared_ds_path) dataset.save_to_disk(prepared_ds_path)
return dataset
def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
max_packed_sequence_len = (
cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len
)
max_packed_sequence_len = min(
max_packed_sequence_len, cfg.sequence_len
) # make sure we don't accidentally set it larger than sequence_len
if cfg.max_packed_sequence_len is not None: if cfg.max_packed_sequence_len is not None:
constant_len_dataset = ConstantLengthDataset( # see if we can go ahead and load the stacked dataset
tokenizer,
[dataset], ds_hash = str(
seq_length=max_packed_sequence_len, md5(
(
str(cfg.sequence_len)
+ "@"
+ str(max_packed_sequence_len)
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
).encode("utf-8")
).hexdigest()
) )
logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}") prepared_ds_path = (
dataset = Dataset.from_list([_ for _ in constant_len_dataset]) Path(cfg.dataset_prepared_path) / ds_hash
if cfg.dataset_prepared_path
else Path(default_dataset_prepared_path) / ds_hash
)
if any(prepared_ds_path.glob("*")):
logging.info(
f"Loading prepared packed dataset from disk at {prepared_ds_path}..."
)
dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared packed dataset loaded from disk...")
else:
dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)
constant_len_dataset = ConstantLengthDataset(
tokenizer,
[dataset],
seq_length=max_packed_sequence_len,
)
logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}"
)
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
if cfg.local_rank == 0:
logging.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}"
)
dataset.save_to_disk(prepared_ds_path)
else:
dataset = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path
)
# filter out bad data
dataset = Dataset.from_list(
[
d
for d in dataset
if len(d["input_ids"]) > cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info( logging.info(

View File

@@ -181,6 +181,8 @@ def load_model(
for k, v in cfg.tokens.items(): for k, v in cfg.tokens.items():
tokenizer.add_special_tokens({k: v}) tokenizer.add_special_tokens({k: v})
model.resize_token_embeddings(len(tokenizer))
if cfg.adapter and load_in_8bit and not cfg.load_4bit: if cfg.adapter and load_in_8bit and not cfg.load_4bit:
logging.info("converting PEFT model w/ prepare_model_for_int8_training") logging.info("converting PEFT model w/ prepare_model_for_int8_training")
model = prepare_model_for_int8_training(model) model = prepare_model_for_int8_training(model)