tweaks to data loading, 8 bit adam, accelerate and deepspeed

This commit is contained in:
Wing Lian
2023-04-22 16:25:23 -04:00
parent 4f2584f2dc
commit 097d367af6
4 changed files with 87 additions and 19 deletions

View File

@@ -2,7 +2,7 @@ import logging
from hashlib import md5
from pathlib import Path
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset
from datasets import load_from_disk, load_dataset, IterableDataset, Dataset, concatenate_datasets
from huggingface_hub import hf_hub_download
from axolotl.datasets import TokenizedPromptDataset, ConstantLengthDataset
@@ -44,10 +44,11 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
)
if any(prepared_ds_path.glob("*")):
logging.info("Loading prepared dataset from disk...")
logging.info(f"Loading prepared dataset from disk ay {prepared_ds_path}...")
dataset = load_from_disk(str(prepared_ds_path))
logging.info("Prepared dataset loaded from disk...")
else:
logging.info(f"Unable to find prepared dataset in {prepared_ds_path}")
logging.info("Loading raw datasets...")
datasets = []
for d in cfg.datasets:
@@ -113,18 +114,26 @@ def load_prepare_datasets(tokenizer, cfg, default_dataset_prepared_path):
datasets.append(ds_wrapper)
else:
logging.error(f"unhandled prompt tokenization strategy: {d.type}")
constant_len_dataset = ConstantLengthDataset(
tokenizer,
datasets,
seq_length=max_packed_sequence_len,
)
logging.info("merging, packing, shuffling, and splitting master dataset")
dataset = Dataset.from_list([_ for _ in constant_len_dataset]).shuffle(seed=42)
logging.info("merging and shuffling master dataset")
dataset = concatenate_datasets(datasets).shuffle(seed=42)
if cfg.local_rank == 0:
logging.info(f"Saving prepared dataset to disk... {prepared_ds_path}")
logging.info(f"Saving merged prepared dataset to disk... {prepared_ds_path}")
dataset.save_to_disk(prepared_ds_path)
if cfg.max_packed_sequence_len is not None:
constant_len_dataset = ConstantLengthDataset(
tokenizer,
[dataset],
seq_length=max_packed_sequence_len,
)
logging.info("packing master dataset")
dataset = Dataset.from_list([_ for _ in constant_len_dataset])
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info(f"Using index #{cfg.dataset_shard_idx} of {cfg.dataset_shard_num} shards")
dataset = dataset.shard(num_shards=cfg.dataset_shard_num, index=cfg.dataset_shard_idx)
dataset = dataset.train_test_split(
test_size=cfg.val_set_size, shuffle=False
)

View File

@@ -101,12 +101,19 @@ def load_model(
)
load_in_8bit = False
elif is_llama_derived_model and "LlamaForCausalLM" in globals():
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
if not cfg.load_in_8bit:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map=cfg.device_map,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=cfg.load_in_8bit,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
)
elif model_type:
model = getattr(transformers, model_type).from_pretrained(
base_model,

View File

@@ -1,5 +1,9 @@
import math
import os
from pathlib import Path
import bitsandbytes as bnb
import torch.cuda
import transformers
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
@@ -12,7 +16,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
warmup_steps = cfg.warmup_steps if cfg.warmup_steps else min(int(0.03 * total_num_steps), 100)
logging_steps = max(min(int(0.005 * total_num_steps), 10), 1)
logging_steps = cfg.logging_steps if cfg.logging_steps else max(min(int(0.005 * total_num_steps), 10), 1)
save_steps = eval_steps = cfg.save_steps if cfg.save_steps else min(int(0.05 * total_num_steps), 200)
training_arguments_kwargs = {}
@@ -26,6 +30,15 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.gradient_checkpointing is not None:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
# deepspeed
if os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" and torch.cuda.device_count() > 1:
if cfg.deepspeed:
training_arguments_kwargs["deepspeed"] = cfg.deepspeed
else:
# make a guess here
# TODO search Path("./") for one
training_arguments_kwargs["deepspeed"] = "./ds_config.json"
training_args = transformers.TrainingArguments(
per_device_train_batch_size=cfg.micro_batch_size,
gradient_accumulation_steps=cfg.gradient_accumulation_steps,
@@ -37,7 +50,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
save_steps=save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
load_best_model_at_end=True if cfg.val_set_size > 0 else False,
load_best_model_at_end=True if cfg.val_set_size > 0 and save_steps % eval_steps == 0 else False,
ddp_find_unused_parameters=False if cfg.ddp else None,
group_by_length=cfg.group_by_length,
report_to="wandb" if cfg.use_wandb else None,
@@ -47,7 +60,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
trainer_kwargs = {}
if cfg.load_in_8bit and not cfg.load_4bit:
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [