address PR feedback

This commit is contained in:
Wing Lian
2023-06-10 14:21:43 -04:00
parent eea2731a5e
commit 0c6f928601
5 changed files with 9 additions and 8 deletions

View File

@@ -1,4 +1,4 @@
# Python 12B # Pythia 12B
- Single-GPU A100 only (?) - Single-GPU A100 only (?)

View File

@@ -22,7 +22,7 @@ lora_dropout: 0.0
lora_target_modules: lora_target_modules:
lora_target_linear: true lora_target_linear: true
lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific lora_fan_in_fan_out: true # pythia/GPTNeoX lora specific
wandb_project: pythia-12b wandb_project:
wandb_watch: wandb_watch:
wandb_run_id: wandb_run_id:
wandb_log_model: wandb_log_model:
@@ -45,5 +45,5 @@ resume_from_checkpoint:
local_rank: local_rank:
gradient_checkpointing: true gradient_checkpointing: true
fsdp: fsdp:
fsdp_transformer_layer_cls_to_wrap: fsdp_config:
collator_pad_to_longest: true collator_pad_to_longest: true

View File

@@ -208,7 +208,10 @@ def train(
) )
else: else:
train_dataset = load_pretraining_dataset( train_dataset = load_pretraining_dataset(
cfg.pretraining_dataset, tokenizer, max_tokens=cfg.sequence_len cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
) )
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230 # https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")

View File

@@ -505,10 +505,10 @@ def encode_pretraining(tokenizer, max_tokens, examples):
return ret return ret
def load_pretraining_dataset(path, tokenizer, max_tokens=2048): def load_pretraining_dataset(path, tokenizer, max_tokens=2048, seed=42):
encode = functools.partial(encode_pretraining, tokenizer, max_tokens) encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = load_dataset(path, streaming=True, split="train") dataset = load_dataset(path, streaming=True, split="train")
dataset = dataset.shuffle(seed=42, buffer_size=10_000) dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
# TODO dynamically figure out which columns/features to remove # TODO dynamically figure out which columns/features to remove
dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"]) dataset = dataset.map(encode, batched=True, remove_columns=["text", "meta"])
return dataset return dataset

View File

@@ -1,7 +1,6 @@
"""Module containing the Trainer class and related functions""" """Module containing the Trainer class and related functions"""
import importlib import importlib
import logging
import math import math
import os import os
import sys import sys
@@ -232,7 +231,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
callbacks.append(SavePeftModelCallback) callbacks.append(SavePeftModelCallback)
if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True: if hasattr(model, "use_bettertransformer") and model.use_bettertransformer is True:
logging.info("Setting up SaveBetterTransformerModelCallback.")
callbacks.append(SaveBetterTransformerModelCallback) callbacks.append(SaveBetterTransformerModelCallback)
data_collator_kwargs = { data_collator_kwargs = {