more packing and dataset optimizations and fixes

This commit is contained in:
Wing Lian
2023-08-08 00:45:24 -04:00
parent 229b9165aa
commit 21f445d763
8 changed files with 123 additions and 38 deletions

View File

@@ -1,7 +1,7 @@
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.39.0 bitsandbytes>=0.39.0
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b accelerate @ git+https://github.com/huggingface/accelerate@b42c65b
addict addict
fire fire
PyYAML==6.0 PyYAML==6.0

View File

@@ -14,6 +14,7 @@ import torch
import yaml import yaml
# add src to the pythonpath so we don't need to pip install this # add src to the pythonpath so we don't need to pip install this
from accelerate import Accelerator
from optimum.bettertransformer import BetterTransformer from optimum.bettertransformer import BetterTransformer
from transformers import GenerationConfig, TextStreamer from transformers import GenerationConfig, TextStreamer
@@ -22,7 +23,11 @@ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
setup_trainer,
)
from axolotl.utils.validation import validate_config from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars from axolotl.utils.wandb import setup_wandb_env_vars
@@ -168,6 +173,7 @@ def train(
prepare_ds_only: bool = False, prepare_ds_only: bool = False,
**kwargs, **kwargs,
): ):
accelerator = Accelerator()
if Path(config).is_dir(): if Path(config).is_dir():
config = choose_config(config) config = choose_config(config)
@@ -237,6 +243,21 @@ def train(
train_dataset = train_dataset.with_format("torch") train_dataset = train_dataset.with_format("torch")
eval_dataset = None eval_dataset = None
if accelerator.is_local_main_process:
# process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
accelerator.wait_for_everyone()
if not accelerator.is_local_main_process:
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
train_dataset.cleanup_cache_files()
eval_dataset.cleanup_cache_files()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs: if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...") LOG.info("check_dataset_labels...")
check_dataset_labels( check_dataset_labels(
@@ -286,7 +307,9 @@ def train(
model.save_pretrained(cfg.output_dir) model.save_pretrained(cfg.output_dir)
return return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer) trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False model.config.use_cache = False
@@ -345,7 +368,13 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading # TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file # only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp: if cfg.fsdp:
model.save_pretrained(cfg.output_dir) with model.summon_full_params():
model.save_pretrained(
cfg.output_dir,
is_main_process=trainer.accelerator.is_main_process,
save_function=trainer.accelerator.save,
state_dict=trainer.accelerator.get_state_dict(model),
)
elif cfg.local_rank == 0: elif cfg.local_rank == 0:
if cfg.flash_optimum: if cfg.flash_optimum:
model = BetterTransformer.reverse(model) model = BetterTransformer.reverse(model)

View File

@@ -128,8 +128,8 @@ def xformers_forward(
query_states, query_states,
key_states, key_states,
value_states, value_states,
attn_bias=attention_mask, # attn_bias=attention_mask,
# attn_bias=xformers.ops.LowerTriangularMask(), attn_bias=xformers.ops.LowerTriangularMask(),
) )
attn_weights = None attn_weights = None
else: else:

View File

@@ -109,13 +109,16 @@ def load_tokenized_prepared_datasets(
local_path = Path(d.path) local_path = Path(d.path)
if local_path.exists(): if local_path.exists():
if local_path.is_dir(): if local_path.is_dir():
ds = load_dataset( try:
d.path, ds = load_from_disk(d.path)
name=d.name, except FileNotFoundError:
data_files=d.data_files, ds = load_dataset(
streaming=False, d.path,
split=None, name=d.name,
) data_files=d.data_files,
streaming=False,
split=None,
)
elif local_path.is_file(): elif local_path.is_file():
ds = load_dataset( ds = load_dataset(
"json", "json",

View File

@@ -2,7 +2,6 @@
import itertools import itertools
import logging import logging
import math import math
import os
from typing import Any, Callable, List, Union from typing import Any, Callable, List, Union
import numba import numba
@@ -63,6 +62,14 @@ def ffd_with_result(a: np.ndarray, c: int, start_index: int):
def allocate( def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
): ):
"""
:param lengths: array of lengths of each sample
:param lengths_cumsum: cumulative sum of consecutive lengths
:param rank: rank for this process
:param c: length of tokens per batch
:param n: number of ranks
:return:
"""
# Dynamic batch allocator, similar to Multifit # Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm # https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len) # ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
@@ -98,7 +105,7 @@ def allocate(
result.append(batch[rank]) result.append(batch[rank])
# add total seqs for all ranks # add total seqs for all ranks
result_totseqs.append(tot_seqs) result_totseqs.append(tot_seqs)
# yield batch[rank], tot_seqs, s, len(result) * c * n
return result, result_totseqs, s, len(result) * c * n return result, result_totseqs, s, len(result) * c * n
@@ -129,6 +136,7 @@ class MultipackDistributedDataloader:
sampler: Union[Sampler, DistributedSampler] = None, sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0, packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1, sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
): ):
# Dataset # Dataset
self.dataset = dataset self.dataset = dataset
@@ -152,6 +160,7 @@ class MultipackDistributedDataloader:
self.eff_total_used = 0 self.eff_total_used = 0
self.eff_total_slots = 0 self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0 self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
def generate_batches(self, set_stats=False): def generate_batches(self, set_stats=False):
LOG.info("generating packed batches") LOG.info("generating packed batches")
@@ -233,7 +242,7 @@ class MultipackDistributedDataloader:
indices = range(0, len(self.dataset)) indices = range(0, len(self.dataset))
lengths = self.lengths[indices] lengths = self.lengths[indices]
lengths_sum = np.cumsum(lengths)[-1] lengths_sum = np.cumsum(lengths)[-1]
lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1)) lengths_sum_per_device = lengths_sum // self.device_count
LOG.info( LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} " f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}" f"total_num_tokens per device: {lengths_sum_per_device}"

View File

@@ -238,7 +238,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
**model_kwargs, **model_kwargs,
) )
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention: # elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -273,7 +272,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
@@ -304,7 +302,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )
@@ -318,7 +315,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None, load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False, trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs, **model_kwargs,
) )

View File

@@ -4,6 +4,7 @@ import logging
import math import math
import os import os
import sys import sys
from contextlib import contextmanager
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
@@ -12,10 +13,10 @@ from typing import Optional, Union
import bitsandbytes as bnb import bitsandbytes as bnb
import torch.cuda import torch.cuda
import transformers import transformers
from datasets import Dataset from datasets import Dataset, set_caching_enabled
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
@@ -158,20 +159,22 @@ class AxolotlTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]: # def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing: # if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler( # return DistributedSampler(
self.train_dataset, # self.train_dataset,
num_replicas=self.args.world_size, # num_replicas=self.args.world_size,
rank=self.args.process_index, # rank=self.args.process_index,
seed=self.args.seed, # seed=self.args.seed,
) # )
return super()._get_train_sampler() # return super()._get_train_sampler()
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]: def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing: if self.args.sample_packing:
train_sampler = self._get_train_sampler() train_sampler = self._get_train_sampler()
# If set to True, the dataloader prepared is only iterated through on the
# main process and then the batches are split and broadcast to each process
self.accelerator.dispatch_batches = True
return self.accelerator.prepare( return self.accelerator.prepare(
MultipackDistributedDataloader( MultipackDistributedDataloader(
self.train_dataset, self.train_dataset,
@@ -181,6 +184,7 @@ class AxolotlTrainer(Trainer):
sampler=train_sampler, sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
) )
) )
return super().get_train_dataloader() return super().get_train_dataloader()
@@ -193,6 +197,9 @@ class AxolotlTrainer(Trainer):
eval_dataset if eval_dataset is not None else self.eval_dataset eval_dataset if eval_dataset is not None else self.eval_dataset
) )
eval_sampler = self._get_eval_sampler(eval_dataset) eval_sampler = self._get_eval_sampler(eval_dataset)
# If set to True, the datalaoder prepared is only iterated through on the
# main process and then the batches are split and broadcast to each process
self.accelerator.dispatch_batches = True
return self.accelerator.prepare( return self.accelerator.prepare(
MultipackDistributedDataloader( MultipackDistributedDataloader(
eval_dataset, eval_dataset,
@@ -202,6 +209,7 @@ class AxolotlTrainer(Trainer):
sampler=eval_sampler, sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency, packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
) )
) )
return super().get_eval_dataloader(eval_dataset) return super().get_eval_dataloader(eval_dataset)
@@ -253,7 +261,16 @@ def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len return len(sample["input_ids"]) <= sequence_len
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer): @contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.sample_packing: if cfg.sample_packing:
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len) drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long).map( train_dataset = train_dataset.filter(drop_long).map(
@@ -263,12 +280,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
eval_dataset = eval_dataset.filter(drop_long).map( eval_dataset = eval_dataset.filter(drop_long).map(
add_position_ids, num_proc=os.cpu_count() add_position_ids, num_proc=os.cpu_count()
) )
return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
if not cfg.total_num_tokens:
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
if cfg.sample_packing_eff_est: if cfg.sample_packing_eff_est:
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
total_num_steps = ( total_num_steps = (
# match count to len est in dataloader # match count to len est in dataloader
( (
@@ -300,8 +327,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
sampler=sampler, sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est, packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier, sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
) )
data_loader_len = len(data_loader) data_loader_len = len(data_loader)
actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}") LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int( total_num_steps = int(
math.ceil( math.ceil(
@@ -311,10 +340,18 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
/ cfg.batch_size / cfg.batch_size
) )
) )
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
)
else: else:
total_num_steps = int( total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size) math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
) )
LOG.info(f"total_num_steps: {total_num_steps}")
return total_num_steps
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
warmup_steps = ( warmup_steps = (
cfg.warmup_steps cfg.warmup_steps
if cfg.warmup_steps is not None if cfg.warmup_steps is not None

View File

@@ -110,6 +110,17 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead." "push_to_hub_model_id is deprecated. Please use hub_model_id instead."
) )
if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError(
"sample_packing not compatible with sdp_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25