more packing and dataset optimizations and fixes
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.39.0
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
||||
accelerate @ git+https://github.com/huggingface/accelerate@b42c65b
|
||||
addict
|
||||
fire
|
||||
PyYAML==6.0
|
||||
|
||||
@@ -14,6 +14,7 @@ import torch
|
||||
import yaml
|
||||
|
||||
# add src to the pythonpath so we don't need to pip install this
|
||||
from accelerate import Accelerator
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from transformers import GenerationConfig, TextStreamer
|
||||
|
||||
@@ -22,7 +23,11 @@ from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.tokenization import check_dataset_labels
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
from axolotl.utils.trainer import (
|
||||
calculate_total_num_steps,
|
||||
process_datasets_for_packing,
|
||||
setup_trainer,
|
||||
)
|
||||
from axolotl.utils.validation import validate_config
|
||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||
|
||||
@@ -168,6 +173,7 @@ def train(
|
||||
prepare_ds_only: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
accelerator = Accelerator()
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(config)
|
||||
|
||||
@@ -237,6 +243,21 @@ def train(
|
||||
train_dataset = train_dataset.with_format("torch")
|
||||
eval_dataset = None
|
||||
|
||||
if accelerator.is_local_main_process:
|
||||
# process on rank 0 first so it gets cached so other ranks load from cache
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
accelerator.wait_for_everyone()
|
||||
if not accelerator.is_local_main_process:
|
||||
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||
cfg, train_dataset, eval_dataset
|
||||
)
|
||||
|
||||
train_dataset.cleanup_cache_files()
|
||||
eval_dataset.cleanup_cache_files()
|
||||
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||
|
||||
if cfg.debug or "debug" in kwargs:
|
||||
LOG.info("check_dataset_labels...")
|
||||
check_dataset_labels(
|
||||
@@ -286,7 +307,9 @@ def train(
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
return
|
||||
|
||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
||||
trainer = setup_trainer(
|
||||
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||
)
|
||||
|
||||
model.config.use_cache = False
|
||||
|
||||
@@ -345,7 +368,13 @@ def train(
|
||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||
if cfg.fsdp:
|
||||
model.save_pretrained(cfg.output_dir)
|
||||
with model.summon_full_params():
|
||||
model.save_pretrained(
|
||||
cfg.output_dir,
|
||||
is_main_process=trainer.accelerator.is_main_process,
|
||||
save_function=trainer.accelerator.save,
|
||||
state_dict=trainer.accelerator.get_state_dict(model),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
@@ -128,8 +128,8 @@ def xformers_forward(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_bias=attention_mask,
|
||||
# attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
# attn_bias=attention_mask,
|
||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||
)
|
||||
attn_weights = None
|
||||
else:
|
||||
|
||||
@@ -109,13 +109,16 @@ def load_tokenized_prepared_datasets(
|
||||
local_path = Path(d.path)
|
||||
if local_path.exists():
|
||||
if local_path.is_dir():
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
try:
|
||||
ds = load_from_disk(d.path)
|
||||
except FileNotFoundError:
|
||||
ds = load_dataset(
|
||||
d.path,
|
||||
name=d.name,
|
||||
data_files=d.data_files,
|
||||
streaming=False,
|
||||
split=None,
|
||||
)
|
||||
elif local_path.is_file():
|
||||
ds = load_dataset(
|
||||
"json",
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
from typing import Any, Callable, List, Union
|
||||
|
||||
import numba
|
||||
@@ -63,6 +62,14 @@ def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
"""
|
||||
:param lengths: array of lengths of each sample
|
||||
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||
:param rank: rank for this process
|
||||
:param c: length of tokens per batch
|
||||
:param n: number of ranks
|
||||
:return:
|
||||
"""
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
@@ -98,7 +105,7 @@ def allocate(
|
||||
result.append(batch[rank])
|
||||
# add total seqs for all ranks
|
||||
result_totseqs.append(tot_seqs)
|
||||
|
||||
# yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||
return result, result_totseqs, s, len(result) * c * n
|
||||
|
||||
|
||||
@@ -129,6 +136,7 @@ class MultipackDistributedDataloader:
|
||||
sampler: Union[Sampler, DistributedSampler] = None,
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
sample_packing_seq_len_multiplier: int = 1,
|
||||
device_count: int = 1,
|
||||
):
|
||||
# Dataset
|
||||
self.dataset = dataset
|
||||
@@ -152,6 +160,7 @@ class MultipackDistributedDataloader:
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
self.device_count = device_count
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
LOG.info("generating packed batches")
|
||||
@@ -233,7 +242,7 @@ class MultipackDistributedDataloader:
|
||||
indices = range(0, len(self.dataset))
|
||||
lengths = self.lengths[indices]
|
||||
lengths_sum = np.cumsum(lengths)[-1]
|
||||
lengths_sum_per_device = lengths_sum // int(os.environ.get("WORLD_SIZE", 1))
|
||||
lengths_sum_per_device = lengths_sum // self.device_count
|
||||
LOG.info(
|
||||
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||
|
||||
@@ -238,7 +238,6 @@ def load_model(
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
|
||||
**model_kwargs,
|
||||
)
|
||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||
@@ -273,7 +272,6 @@ def load_model(
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -304,7 +302,6 @@ def load_model(
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -318,7 +315,6 @@ def load_model(
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map=cfg.device_map,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
@@ -12,10 +13,10 @@ from typing import Optional, Union
|
||||
import bitsandbytes as bnb
|
||||
import torch.cuda
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from datasets import Dataset, set_caching_enabled
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
@@ -158,20 +159,22 @@ class AxolotlTrainer(Trainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size > 1 and self.args.sample_packing:
|
||||
return DistributedSampler(
|
||||
self.train_dataset,
|
||||
num_replicas=self.args.world_size,
|
||||
rank=self.args.process_index,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
# def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
# if self.args.world_size > 1 and self.args.sample_packing:
|
||||
# return DistributedSampler(
|
||||
# self.train_dataset,
|
||||
# num_replicas=self.args.world_size,
|
||||
# rank=self.args.process_index,
|
||||
# seed=self.args.seed,
|
||||
# )
|
||||
# return super()._get_train_sampler()
|
||||
|
||||
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||
if self.args.sample_packing:
|
||||
train_sampler = self._get_train_sampler()
|
||||
|
||||
# If set to True, the dataloader prepared is only iterated through on the
|
||||
# main process and then the batches are split and broadcast to each process
|
||||
self.accelerator.dispatch_batches = True
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
self.train_dataset,
|
||||
@@ -181,6 +184,7 @@ class AxolotlTrainer(Trainer):
|
||||
sampler=train_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
@@ -193,6 +197,9 @@ class AxolotlTrainer(Trainer):
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
# If set to True, the datalaoder prepared is only iterated through on the
|
||||
# main process and then the batches are split and broadcast to each process
|
||||
self.accelerator.dispatch_batches = True
|
||||
return self.accelerator.prepare(
|
||||
MultipackDistributedDataloader(
|
||||
eval_dataset,
|
||||
@@ -202,6 +209,7 @@ class AxolotlTrainer(Trainer):
|
||||
sampler=eval_sampler,
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||
# device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
)
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
@@ -253,7 +261,16 @@ def drop_long_seq(sample, sequence_len=2048):
|
||||
return len(sample["input_ids"]) <= sequence_len
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
@contextmanager
|
||||
def disable_datasets_caching():
|
||||
try:
|
||||
set_caching_enabled(False)
|
||||
yield
|
||||
finally:
|
||||
set_caching_enabled(True)
|
||||
|
||||
|
||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||
if cfg.sample_packing:
|
||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||
train_dataset = train_dataset.filter(drop_long).map(
|
||||
@@ -263,12 +280,22 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
eval_dataset = eval_dataset.filter(drop_long).map(
|
||||
add_position_ids, num_proc=os.cpu_count()
|
||||
)
|
||||
return train_dataset, eval_dataset
|
||||
|
||||
|
||||
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||
if cfg.sample_packing:
|
||||
# we have to drop anything longer then sequence len otherwise
|
||||
# flash attention with position ids fails
|
||||
total_num_tokens = (
|
||||
cfg.total_num_tokens
|
||||
if cfg.total_num_tokens
|
||||
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||
)
|
||||
if not cfg.total_num_tokens:
|
||||
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
||||
|
||||
if cfg.sample_packing_eff_est:
|
||||
total_num_tokens = (
|
||||
cfg.total_num_tokens
|
||||
if cfg.total_num_tokens
|
||||
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||
)
|
||||
total_num_steps = (
|
||||
# match count to len est in dataloader
|
||||
(
|
||||
@@ -300,8 +327,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
sampler=sampler,
|
||||
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||
sample_packing_seq_len_multiplier=cfg.sample_packing_seq_len_multiplier,
|
||||
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||
)
|
||||
data_loader_len = len(data_loader)
|
||||
actual_eff = data_loader.efficiency()
|
||||
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||
total_num_steps = int(
|
||||
math.ceil(
|
||||
@@ -311,10 +340,18 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
/ cfg.batch_size
|
||||
)
|
||||
)
|
||||
LOG.info(
|
||||
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
||||
)
|
||||
else:
|
||||
total_num_steps = int(
|
||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||
)
|
||||
LOG.info(f"total_num_steps: {total_num_steps}")
|
||||
return total_num_steps
|
||||
|
||||
|
||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||
warmup_steps = (
|
||||
cfg.warmup_steps
|
||||
if cfg.warmup_steps is not None
|
||||
|
||||
@@ -110,6 +110,17 @@ def validate_config(cfg):
|
||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention:
|
||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
# TODO
|
||||
# MPT 7b
|
||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||
|
||||
Reference in New Issue
Block a user