another fix for shard and train split

This commit is contained in:
Wing Lian
2023-05-25 17:23:57 -04:00
parent be3d3963cd
commit 2e56203b50

View File

@@ -48,7 +48,7 @@ def load_tokenized_prepared_datasets(
(
str(cfg.sequence_len)
+ "@"
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ "|"
+ tokenizer_name
).encode("utf-8")
@@ -112,13 +112,22 @@ def load_tokenized_prepared_datasets(
raise Exception("unhandled dataset load")
# support for using a subset of the data
if d.shards:
<<<<<<< Updated upstream
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
=======
if "train" in ds:
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
else:
ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
>>>>>>> Stashed changes
d_type = d.type
d_type_split = d_type.split(":")
d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds:
ds = ds["train"]
if ds_strategy := load(d.type, tokenizer, cfg):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "alpaca":
ds_strategy = AlpacaPromptTokenizingStrategy(
@@ -127,7 +136,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "explainchoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
@@ -136,7 +145,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "concisechoice":
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
@@ -145,7 +154,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "summarizetldr":
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
@@ -154,7 +163,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "jeopardy":
ds_strategy = JeopardyPromptTokenizingStrategy(
@@ -163,7 +172,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "oasst":
ds_strategy = OpenAssistantPromptTokenizingStrategy(
@@ -172,7 +181,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "gpteacher":
ds_strategy = GPTeacherPromptTokenizingStrategy(
@@ -181,7 +190,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "reflection":
ds_strategy = AlpacaReflectionPTStrategy(
@@ -190,7 +199,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "sharegpt":
ds_strategy = ShareGPTPromptTokenizingStrategy(
@@ -199,7 +208,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
elif d_base_type == "completion":
ds_strategy = CompletionPromptTokenizingStrategy(
@@ -208,7 +217,7 @@ def load_tokenized_prepared_datasets(
cfg.train_on_inputs,
cfg.sequence_len,
)
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper)
else:
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
@@ -255,7 +264,7 @@ def load_prepare_datasets(
+ "@"
+ str(max_packed_sequence_len)
+ seed
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
+ "|"
+ tokenizer_name
).encode("utf-8")