diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index c36bfcee9..7e545b608 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -37,7 +37,7 @@ from axolotl.prompters import ( def load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + split, tokenizer, cfg, default_dataset_prepared_path ) -> DatasetDict: tokenizer_name = tokenizer.__class__.__name__ ds_hash = str( @@ -49,6 +49,8 @@ def load_tokenized_prepared_datasets( sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) ) + "|" + + split + + "|" + tokenizer_name ).encode("utf-8") ).hexdigest() @@ -66,7 +68,7 @@ def load_tokenized_prepared_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token, ) - dataset = dataset["train"] + dataset = dataset[split] except Exception: # pylint: disable=broad-except # nosec pass @@ -134,8 +136,8 @@ def load_tokenized_prepared_datasets( raise ValueError("unhandled dataset load") # support for using a subset of the data if d.shards: - if "train" in ds: - ds = ds.shuffle(seed=seed)["train"].shard( + if split in ds: + ds = ds.shuffle(seed=seed)[split].shard( num_shards=d.shards, index=0 ) else: @@ -144,8 +146,8 @@ def load_tokenized_prepared_datasets( d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None - if "train" in ds: - ds = ds["train"] + if split in ds: + ds = ds[split] if ds_strategy := load(d.type, tokenizer, cfg): ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) @@ -319,7 +321,6 @@ def load_prepare_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", use_auth_token=use_auth_token, ) - dataset = dataset["train"] except Exception: # pylint: disable=broad-except # nosec pass @@ -339,28 +340,37 @@ def load_prepare_datasets( f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True ) else: - dataset = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + dataset_train = load_tokenized_prepared_datasets( + "train", tokenizer, cfg, default_dataset_prepared_path ) - + dataset_test = load_tokenized_prepared_datasets( + "test", tokenizer, cfg, default_dataset_prepared_path + ) + dataset = DatasetDict({"train": dataset_train, "test": dataset_test}) if cfg.seed: dataset = dataset.shuffle(seed=cfg.seed) - constant_len_dataset = ConstantLengthDataset( + constant_len_dataset_train = ConstantLengthDataset( tokenizer, - [dataset], + [dataset["train"]], + seq_length=max_packed_sequence_len, + ) + constant_len_dataset_test = ConstantLengthDataset( + tokenizer, + [dataset["test"]], seq_length=max_packed_sequence_len, ) logging.info( f"packing master dataset to len: {cfg.max_packed_sequence_len}" ) - dataset = Dataset.from_list(list(constant_len_dataset)) + dataset_train = Dataset.from_list(list(constant_len_dataset_train)) + dataset_test = Dataset.from_list(list(constant_len_dataset_test)) # filter out bad data - dataset = Dataset.from_list( + dataset_train = Dataset.from_list( [ d - for d in dataset + for d in dataset_train if len(d["input_ids"]) < cfg.sequence_len and len(d["input_ids"]) > 0 and len(d["input_ids"]) == len(d["attention_mask"]) @@ -368,6 +378,19 @@ def load_prepare_datasets( ] ) + # filter out bad data + dataset_test = Dataset.from_list( + [ + d + for d in dataset_test + if len(d["input_ids"]) < cfg.sequence_len + and len(d["input_ids"]) > 0 + and len(d["input_ids"]) == len(d["attention_mask"]) + and len(d["input_ids"]) == len(d["labels"]) + ] + ) + dataset = DatasetDict({"train": dataset_train, "test": dataset_test}) + if cfg.local_rank == 0: logging.info( f"Saving packed prepared dataset to disk... {prepared_ds_path}" @@ -382,9 +405,13 @@ def load_prepare_datasets( private=True, ) else: - dataset = load_tokenized_prepared_datasets( - tokenizer, cfg, default_dataset_prepared_path + dataset_train = load_tokenized_prepared_datasets( + "train", tokenizer, cfg, default_dataset_prepared_path ) + dataset_test = load_tokenized_prepared_datasets( + "test", tokenizer, cfg, default_dataset_prepared_path + ) + dataset = DatasetDict({"train": dataset_train, "test": dataset_test}) if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: logging.info( @@ -399,6 +426,9 @@ def load_prepare_datasets( dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) train_dataset = dataset["train"] eval_dataset = dataset["test"] + elif "train" in dataset: + train_dataset = dataset["train"] + eval_dataset = dataset["test"] else: train_dataset = dataset eval_dataset = None