diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 581b48a88..a168c5247 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -31,13 +31,7 @@ from axolotl.prompters import ( ) -def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): - max_packed_sequence_len = ( - cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len - ) - max_packed_sequence_len = min( - max_packed_sequence_len, cfg.sequence_len - ) # make sure we don't accidentally set it larger than sequence_len +def load_tokenized_prepared_datasets(tokenizer, cfg, default_dataset_prepared_path): ds_hash = str( md5( ( @@ -54,7 +48,7 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): ) if any(prepared_ds_path.glob("*")): - logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...") + logging.info(f"Loading prepared dataset from disk at {prepared_ds_path}...") dataset = load_from_disk(str(prepared_ds_path)) logging.info("Prepared dataset loaded from disk...") else: @@ -153,14 +147,78 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): ) dataset.save_to_disk(prepared_ds_path) + return dataset + + +def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path): + max_packed_sequence_len = ( + cfg.max_packed_sequence_len if cfg.max_packed_sequence_len else cfg.sequence_len + ) + max_packed_sequence_len = min( + max_packed_sequence_len, cfg.sequence_len + ) # make sure we don't accidentally set it larger than sequence_len + if cfg.max_packed_sequence_len is not None: - constant_len_dataset = ConstantLengthDataset( - tokenizer, - [dataset], - seq_length=max_packed_sequence_len, + # see if we can go ahead and load the stacked dataset + + ds_hash = str( + md5( + ( + str(cfg.sequence_len) + + "@" + + str(max_packed_sequence_len) + + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + ).encode("utf-8") + ).hexdigest() ) - logging.info(f"packing master dataset to len: {cfg.max_packed_sequence_len}") - dataset = Dataset.from_list([_ for _ in constant_len_dataset]) + prepared_ds_path = ( + Path(cfg.dataset_prepared_path) / ds_hash + if cfg.dataset_prepared_path + else Path(default_dataset_prepared_path) / ds_hash + ) + + if any(prepared_ds_path.glob("*")): + logging.info( + f"Loading prepared packed dataset from disk at {prepared_ds_path}..." + ) + dataset = load_from_disk(str(prepared_ds_path)) + logging.info("Prepared packed dataset loaded from disk...") + else: + dataset = load_tokenized_prepared_datasets( + tokenizer, cfg, default_dataset_prepared_path + ) + + constant_len_dataset = ConstantLengthDataset( + tokenizer, + [dataset], + seq_length=max_packed_sequence_len, + ) + logging.info( + f"packing master dataset to len: {cfg.max_packed_sequence_len}" + ) + dataset = Dataset.from_list([_ for _ in constant_len_dataset]) + + if cfg.local_rank == 0: + logging.info( + f"Saving packed prepared dataset to disk... {prepared_ds_path}" + ) + dataset.save_to_disk(prepared_ds_path) + else: + dataset = load_tokenized_prepared_datasets( + tokenizer, cfg, default_dataset_prepared_path + ) + + # filter out bad data + dataset = Dataset.from_list( + [ + d + for d in dataset + if len(d["input_ids"]) > cfg.sequence_len + and len(d["input_ids"]) > 0 + and len(d["input_ids"]) == len(d["attention_mask"]) + and len(d["input_ids"]) == len(d["labels"]) + ] + ) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: logging.info( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2ca84b795..bec6d8194 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -181,6 +181,8 @@ def load_model( for k, v in cfg.tokens.items(): tokenizer.add_special_tokens({k: v}) + model.resize_token_embeddings(len(tokenizer)) + if cfg.adapter and load_in_8bit and not cfg.load_4bit: logging.info("converting PEFT model w/ prepare_model_for_int8_training") model = prepare_model_for_int8_training(model)