diff --git a/src/axolotl/utils/data.py b/src/axolotl/utils/data.py index 0375cf9db..12b4f74a0 100644 --- a/src/axolotl/utils/data.py +++ b/src/axolotl/utils/data.py @@ -48,7 +48,7 @@ def load_tokenized_prepared_datasets( ( str(cfg.sequence_len) + "@" - + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])) + "|" + tokenizer_name ).encode("utf-8") @@ -112,13 +112,22 @@ def load_tokenized_prepared_datasets( raise Exception("unhandled dataset load") # support for using a subset of the data if d.shards: +<<<<<<< Updated upstream ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0) +======= + if "train" in ds: + ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0) + else: + ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0) +>>>>>>> Stashed changes d_type = d.type d_type_split = d_type.split(":") d_base_type = d_type_split[0] d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None + if "train" in ds: + ds = ds["train"] if ds_strategy := load(d.type, tokenizer, cfg): - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "alpaca": ds_strategy = AlpacaPromptTokenizingStrategy( @@ -127,7 +136,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "explainchoice": ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( @@ -136,7 +145,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "concisechoice": ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy( @@ -145,7 +154,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "summarizetldr": ds_strategy = SummarizeTLDRPromptTokenizingStrategy( @@ -154,7 +163,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "jeopardy": ds_strategy = JeopardyPromptTokenizingStrategy( @@ -163,7 +172,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "oasst": ds_strategy = OpenAssistantPromptTokenizingStrategy( @@ -172,7 +181,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "gpteacher": ds_strategy = GPTeacherPromptTokenizingStrategy( @@ -181,7 +190,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "reflection": ds_strategy = AlpacaReflectionPTStrategy( @@ -190,7 +199,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "sharegpt": ds_strategy = ShareGPTPromptTokenizingStrategy( @@ -199,7 +208,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) elif d_base_type == "completion": ds_strategy = CompletionPromptTokenizingStrategy( @@ -208,7 +217,7 @@ def load_tokenized_prepared_datasets( cfg.train_on_inputs, cfg.sequence_len, ) - ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"]) + ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) datasets.append(ds_wrapper) else: logging.error(f"unhandled prompt tokenization strategy: {d.type}") @@ -255,7 +264,7 @@ def load_prepare_datasets( + "@" + str(max_packed_sequence_len) + seed - + "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets])) + + "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])) + "|" + tokenizer_name ).encode("utf-8")