more packing and dataset optimizations and fixes

This commit is contained in:
Wing Lian
2023-08-08 00:45:24 -04:00
parent 229b9165aa
commit 21f445d763
8 changed files with 123 additions and 38 deletions

View File

@@ -1,7 +1,7 @@
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.39.0
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
accelerate @ git+https://github.com/huggingface/accelerate@b42c65b
addict
fire
PyYAML==6.0

View File

@@ -14,6 +14,7 @@ import torch
import yaml
# add src to the pythonpath so we don't need to pip install this
from accelerate import Accelerator
from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer
@@ -22,7 +23,11 @@ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
setup_trainer,
)
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
@@ -168,6 +173,7 @@ def train(
prepare_ds_only: bool = False,
**kwargs,
):
accelerator = Accelerator()
if Path(config).is_dir():
config = choose_config(config)
@@ -237,6 +243,21 @@ def train(
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if accelerator.is_local_main_process:
# process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
accelerator.wait_for_everyone()
if not accelerator.is_local_main_process:
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
train_dataset.cleanup_cache_files()
eval_dataset.cleanup_cache_files()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")
check_dataset_labels(
@@ -286,7 +307,9 @@ def train(
model.save_pretrained(cfg.output_dir)
return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
@@ -345,7 +368,13 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
model.save_pretrained(cfg.output_dir)
with model.summon_full_params():
model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(model),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)

View File

@@ -128,8 +128,8 @@ def xformers_forward(
query_states,
key_states,
value_states,
attn_bias=attention_mask,
# attn_bias=xformers.ops.LowerTriangularMask(),
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:

View File

@@ -109,13 +109,16 @@ def load_tokenized_prepared_datasets(
local_path = Path(d.path)
if local_path.exists():
if local_path.is_dir():
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
try:
ds = load_from_disk(d.path)
except FileNotFoundError:
ds = load_dataset(
d.path,
name=d.name,
data_files=d.data_files,
streaming=False,
split=None,
)
elif local_path.is_file():
ds = load_dataset(
"json",

View File

@@ -2,7 +2,6 @@
import itertools
import logging
import math
import os
from typing import Any, Callable, List, Union
import numba
@@ -63,6 +62,14 @@ def ffd_with_result(a: np.ndarray, c: int, start_index: int):
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
"""
:param lengths: array of lengths of each sample
:param lengths_cumsum: cumulative sum of consecutive lengths
:param rank: rank for this process
:param c: length of tokens per batch
:param n: number of ranks
:return:
"""
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
@@ -98,7 +105,7 @@ def allocate(
result.append(batch[rank])
# add total seqs for all ranks
result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n
@@ -129,6 +136,7 @@ class MultipackDistributedDataloader:
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
):
# Dataset
self.dataset = dataset
@@ -152,6 +160,7 @@ class MultipackDistributedDataloader:
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
@@ -233,7 +242,7 @@ class MultipackDistributedDataloader:
indices = range(0, len(self.dataset))
lengths = self.lengths[indices]
lengths_sum = np.cumsum(lengths)[-1]
lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1))
lengths_sum_per_device = lengths_sum // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"

View File

@@ -238,7 +238,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -273,7 +272,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -304,7 +302,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -318,7 +315,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)

View File

@@ -4,6 +4,7 @@ import logging
import math
import os
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
@@ -12,10 +13,10 @@ from typing import Optional, Union
import bitsandbytes as bnb
import torch.cuda
import transformers
from datasets import Dataset
from datasets import Dataset, set_caching_enabled
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from torch.utils.data import DataLoader, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
@@ -158,20 +159,22 @@ class AxolotlTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
# def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
# if self.args.world_size > 1 and self.args.sample_packing:
# return DistributedSampler(
# self.train_dataset,
# num_replicas=self.args.world_size,
# rank=self.args.process_index,
# seed=self.args.seed,
# )
# return super()._get_train_sampler()
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
# If set to True, the dataloader prepared is only iterated through on the
# main process and then the batches are split and broadcast to each process
self.accelerator.dispatch_batches = True
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
@@ -181,6 +184,7 @@ class AxolotlTrainer(Trainer):
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_train_dataloader()
@@ -193,6 +197,9 @@ class AxolotlTrainer(Trainer):
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
# If set to True, the datalaoder prepared is only iterated through on the
# main process and then the batches are split and broadcast to each process
self.accelerator.dispatch_batches = True
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
@@ -202,6 +209,7 @@ class AxolotlTrainer(Trainer):
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
)
return super().get_eval_dataloader(eval_dataset)
@@ -253,7 +261,16 @@ def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.sample_packing:
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long).map(
@@ -263,12 +280,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
eval_dataset = eval_dataset.filter(drop_long).map(
add_position_ids, num_proc=os.cpu_count()
)
return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
if not cfg.total_num_tokens:
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
if cfg.sample_packing_eff_est:
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
total_num_steps = (
# match count to len est in dataloader
(
@@ -300,8 +327,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
data_loader_len = len(data_loader)
actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.ceil(
@@ -311,10 +340,18 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
/ cfg.batch_size
)
)
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
)
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
LOG.info(f"total_num_steps: {total_num_steps}")
return total_num_steps
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
warmup_steps = (
cfg.warmup_steps
if cfg.warmup_steps is not None

View File

@@ -110,6 +110,17 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError(
"sample_packing not compatible with sdp_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25