another fix for shard and train split
This commit is contained in:
@@ -48,7 +48,7 @@ def load_tokenized_prepared_datasets(
|
||||
(
|
||||
str(cfg.sequence_len)
|
||||
+ "@"
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
).encode("utf-8")
|
||||
@@ -112,13 +112,22 @@ def load_tokenized_prepared_datasets(
|
||||
raise Exception("unhandled dataset load")
|
||||
# support for using a subset of the data
|
||||
if d.shards:
|
||||
<<<<<<< Updated upstream
|
||||
ds = ds.shuffle(seed=42)["train"].shard(num_shards=d.shards, index=0)
|
||||
=======
|
||||
if "train" in ds:
|
||||
ds = ds.shuffle(seed=42)["train"].shard(num_shards=cfg.shards, index=0)
|
||||
else:
|
||||
ds = ds.shuffle(seed=42).shard(num_shards=cfg.shards, index=0)
|
||||
>>>>>>> Stashed changes
|
||||
d_type = d.type
|
||||
d_type_split = d_type.split(":")
|
||||
d_base_type = d_type_split[0]
|
||||
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
|
||||
if "train" in ds:
|
||||
ds = ds["train"]
|
||||
if ds_strategy := load(d.type, tokenizer, cfg):
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "alpaca":
|
||||
ds_strategy = AlpacaPromptTokenizingStrategy(
|
||||
@@ -127,7 +136,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "explainchoice":
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
@@ -136,7 +145,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "concisechoice":
|
||||
ds_strategy = AlpacaMultipleChoicePromptTokenizingStrategy(
|
||||
@@ -145,7 +154,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "summarizetldr":
|
||||
ds_strategy = SummarizeTLDRPromptTokenizingStrategy(
|
||||
@@ -154,7 +163,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "jeopardy":
|
||||
ds_strategy = JeopardyPromptTokenizingStrategy(
|
||||
@@ -163,7 +172,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "oasst":
|
||||
ds_strategy = OpenAssistantPromptTokenizingStrategy(
|
||||
@@ -172,7 +181,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "gpteacher":
|
||||
ds_strategy = GPTeacherPromptTokenizingStrategy(
|
||||
@@ -181,7 +190,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "reflection":
|
||||
ds_strategy = AlpacaReflectionPTStrategy(
|
||||
@@ -190,7 +199,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "sharegpt":
|
||||
ds_strategy = ShareGPTPromptTokenizingStrategy(
|
||||
@@ -199,7 +208,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
elif d_base_type == "completion":
|
||||
ds_strategy = CompletionPromptTokenizingStrategy(
|
||||
@@ -208,7 +217,7 @@ def load_tokenized_prepared_datasets(
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds["train"])
|
||||
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
|
||||
datasets.append(ds_wrapper)
|
||||
else:
|
||||
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
|
||||
@@ -255,7 +264,7 @@ def load_prepare_datasets(
|
||||
+ "@"
|
||||
+ str(max_packed_sequence_len)
|
||||
+ seed
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}" for d in cfg.datasets]))
|
||||
+ "|".join(sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]))
|
||||
+ "|"
|
||||
+ tokenizer_name
|
||||
).encode("utf-8")
|
||||
|
||||
Reference in New Issue
Block a user