address PR feedback

This commit is contained in:
Wing Lian
2023-06-10 14:21:43 -04:00
parent eea2731a5e
commit 0c6f928601
5 changed files with 9 additions and 8 deletions

View File

@@ -208,7 +208,10 @@ def train(
)
else:
train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")