Compare commits
1 Commits
tensor-par
...
multipack
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
81d60e96f0 |
@@ -1,7 +1,6 @@
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git
|
||||
bitsandbytes>=0.39.0
|
||||
accelerate
|
||||
addict
|
||||
fire
|
||||
PyYAML==6.0
|
||||
@@ -18,3 +17,4 @@ evaluate==0.4.0
|
||||
rouge-score==0.1.2
|
||||
scipy
|
||||
scikit-learn==1.2.2
|
||||
numba
|
||||
|
||||
173
src/axolotl/utils/sampler.py
Normal file
173
src/axolotl/utils/sampler.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# pylint: skip-file
|
||||
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import numba
|
||||
import numpy as np
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Sampler
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||
# First-fit-decreasing bin packing
|
||||
# Check if a[] could fit in n bins with capacity c
|
||||
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||
|
||||
a = np.sort(a)[::-1]
|
||||
bins = np.full((n,), c, dtype=a.dtype)
|
||||
for size in a:
|
||||
not_found = True
|
||||
for idx in range(n):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
not_found = False
|
||||
break
|
||||
|
||||
if not_found:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@numba.njit
|
||||
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||
# First-fit-decreasing bin packing (with result return)
|
||||
|
||||
indices = np.argsort(a)[::-1]
|
||||
a = a[indices]
|
||||
|
||||
bins: List[int] = []
|
||||
bins_result: List[Any] = []
|
||||
for a_id, size in enumerate(a):
|
||||
add_new = True
|
||||
for idx in range(len(bins)):
|
||||
if bins[idx] >= size:
|
||||
bins[idx] -= size
|
||||
bins_result[idx].append(indices[a_id] + start_index)
|
||||
add_new = False
|
||||
break
|
||||
|
||||
if add_new:
|
||||
bins.append(c - size)
|
||||
bins_result.append([indices[a_id] + start_index])
|
||||
|
||||
return bins_result
|
||||
|
||||
|
||||
@numba.njit
|
||||
def allocate(
|
||||
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||
):
|
||||
# Dynamic batch allocator, similar to Multifit
|
||||
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||
|
||||
s = 0
|
||||
start_index = 0
|
||||
result = []
|
||||
|
||||
while True:
|
||||
# binary search [l, r)
|
||||
left = 1
|
||||
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||
|
||||
while right - left > 1:
|
||||
m = (left + right) // 2
|
||||
if ffd_check(lengths[start_index : start_index + m], c, n):
|
||||
left = m
|
||||
else:
|
||||
right = m
|
||||
|
||||
# use length l
|
||||
batch = ffd_with_result(
|
||||
lengths[start_index : start_index + left], c, start_index
|
||||
)
|
||||
assert len(batch) <= n
|
||||
if len(batch) < n:
|
||||
break
|
||||
|
||||
start_index += left
|
||||
s = lengths_cumsum[start_index - 1]
|
||||
|
||||
# add local rank
|
||||
result.append(batch[rank])
|
||||
|
||||
return result, s, len(result) * c * n
|
||||
|
||||
|
||||
class MultipackDistributedBatchSampler(Sampler):
|
||||
"""Unpadded length sampling using Multipack.
|
||||
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
batch_max_length: int,
|
||||
lengths: List[int],
|
||||
num_replicas: Optional[int] = None,
|
||||
rank: Optional[int] = None,
|
||||
seed: int = 0,
|
||||
):
|
||||
# Get rank
|
||||
if num_replicas is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
num_replicas = dist.get_world_size()
|
||||
if rank is None:
|
||||
if not dist.is_available():
|
||||
raise RuntimeError("Requires distributed package to be available")
|
||||
rank = dist.get_rank()
|
||||
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
self.seed = seed
|
||||
|
||||
self.batch_max_length = batch_max_length
|
||||
self.lengths = lengths
|
||||
assert isinstance(self.lengths, np.ndarray)
|
||||
|
||||
self.epoch = 0
|
||||
|
||||
# statistics
|
||||
self.eff_total_used = 0
|
||||
self.eff_total_slots = 0
|
||||
|
||||
def set_epoch(self, epoch: int):
|
||||
self.epoch = epoch
|
||||
|
||||
def generate_batches(self, set_stats=False):
|
||||
indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(
|
||||
len(self.lengths)
|
||||
)
|
||||
|
||||
lengths = self.lengths[indices]
|
||||
lengths_cumsum = np.cumsum(lengths)
|
||||
|
||||
batches, total_used, total_slots = allocate(
|
||||
lengths=lengths,
|
||||
lengths_cumsum=lengths_cumsum,
|
||||
rank=self.rank,
|
||||
c=self.batch_max_length,
|
||||
n=self.num_replicas,
|
||||
)
|
||||
|
||||
batches = [indices[batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
self.eff_total_used += total_used
|
||||
self.eff_total_slots += total_slots
|
||||
|
||||
return batches
|
||||
|
||||
def __iter__(self):
|
||||
batches = self.generate_batches(set_stats=True)
|
||||
return iter(batches)
|
||||
|
||||
def num_batches(self):
|
||||
batches = self.generate_batches()
|
||||
return len(batches)
|
||||
|
||||
def efficiency(self):
|
||||
return self.eff_total_used / self.eff_total_slots
|
||||
@@ -5,15 +5,17 @@ import logging
|
||||
import math
|
||||
import os
|
||||
import sys
|
||||
from dataclasses import field
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import numpy as np
|
||||
import torch.cuda
|
||||
import transformers
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
@@ -21,12 +23,91 @@ from axolotl.utils.callbacks import (
|
||||
SaveBetterTransformerModelCallback,
|
||||
SavePeftModelCallback,
|
||||
)
|
||||
from axolotl.utils.sampler import MultipackDistributedBatchSampler
|
||||
from axolotl.utils.schedulers import (
|
||||
InterpolatingLogScheduler,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
)
|
||||
|
||||
IGNORE_LABEL_ID = -100
|
||||
|
||||
|
||||
def _find_multiple(val1, val2):
|
||||
return (-(val1 // -val2)) * val2
|
||||
|
||||
|
||||
def batch_to_tensor(batch, pad_id=0, dtype=torch.long, loss_dtype=torch.bfloat16):
|
||||
# Pad an unused item to reach multiple of 64, for faster GEMM
|
||||
pad_cur_len = sum(list(batch["length"]))
|
||||
pad_len = _find_multiple(pad_cur_len, 64) - pad_cur_len
|
||||
|
||||
if pad_len > 0:
|
||||
assert pad_len < 64
|
||||
|
||||
batch["input_ids"].append([pad_id] * pad_len)
|
||||
batch["labels"].append([pad_id] * pad_len)
|
||||
batch["attention_mask"].append([0] * pad_len)
|
||||
batch["length"].append(pad_len)
|
||||
|
||||
# seqlen
|
||||
batch_lengths = torch.tensor(list(batch["length"]), dtype=torch.int32, device="cpu")
|
||||
|
||||
max_seqlen = torch.max(batch_lengths)
|
||||
cu_seqlens = torch.nn.functional.pad(
|
||||
batch_lengths.cumsum(-1, dtype=torch.int32), (1, 0)
|
||||
)
|
||||
|
||||
# nz elements
|
||||
nz_num = cu_seqlens[-1]
|
||||
nz_input_ids = torch.zeros((nz_num,), dtype=dtype, pin_memory=True, device="cpu")
|
||||
nz_position_ids = torch.zeros((nz_num,), dtype=dtype, pin_memory=True, device="cpu")
|
||||
nz_shifted_label_ids = torch.zeros(
|
||||
(nz_num,), dtype=dtype, pin_memory=True, device="cpu"
|
||||
)
|
||||
nz_shifted_loss_weights = torch.zeros(
|
||||
(nz_num,), dtype=loss_dtype, pin_memory=True, device="cpu"
|
||||
)
|
||||
|
||||
index = 0
|
||||
for token_list, length, labels_list in zip(
|
||||
batch["input_ids"], batch["length"], batch["labels"]
|
||||
):
|
||||
tokens = torch.tensor(token_list, dtype=dtype, device="cpu")
|
||||
position_ids = torch.arange(length, dtype=dtype, device="cpu")
|
||||
|
||||
# Input IDs & shifted labels
|
||||
# shifted_label_ids = torch.where(masks, tokens, IGNORE_LABEL_ID)
|
||||
shifted_label_ids = labels_list
|
||||
shifted_label_ids = torch.nn.functional.pad(
|
||||
shifted_label_ids[1:], (0, 1), "constant", IGNORE_LABEL_ID
|
||||
)
|
||||
|
||||
nz_input_ids[index : index + length] = tokens
|
||||
nz_position_ids[index : index + length] = position_ids
|
||||
nz_shifted_label_ids[index : index + length] = shifted_label_ids
|
||||
|
||||
# Loss weights
|
||||
mask_count = sum(1 for label in labels_list[1:] if label != IGNORE_LABEL_ID)
|
||||
loss_weight = (
|
||||
1 / mask_count if mask_count > 0 else 0
|
||||
) # Avoid division by zero for paddings
|
||||
|
||||
nz_shifted_loss_weights[index : index + length] = loss_weight
|
||||
|
||||
index += length
|
||||
|
||||
# inputs
|
||||
return {
|
||||
"max_seqlen": max_seqlen,
|
||||
"cu_seqlens": cu_seqlens,
|
||||
"nz_input_ids": nz_input_ids,
|
||||
"nz_position_ids": nz_position_ids,
|
||||
"nz_shifted_label_ids": nz_shifted_label_ids,
|
||||
"nz_shifted_loss_weights": nz_shifted_loss_weights,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class AxolotlTrainingArguments(TrainingArguments):
|
||||
"""
|
||||
Extend the base TrainingArguments for axolotl helpers
|
||||
@@ -36,6 +117,14 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||
)
|
||||
sample_packing: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
max_seq_length: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "The maximum sequence length the model can handle"},
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
@@ -73,6 +162,26 @@ class AxolotlTrainer(Trainer):
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
lengths = np.array([len(sample["input_ids"]) for sample in self.train_dataset])
|
||||
return MultipackDistributedBatchSampler(
|
||||
batch_max_length=self.args.per_device_train_batch_size
|
||||
* self.args.max_seq_length,
|
||||
lengths=lengths,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
lengths = np.array([len(sample["input_ids"]) for sample in eval_dataset])
|
||||
return MultipackDistributedBatchSampler(
|
||||
batch_max_length=self.args.per_device_eval_batch_size
|
||||
* self.args.max_seq_length,
|
||||
lengths=lengths,
|
||||
seed=self.args.seed,
|
||||
)
|
||||
|
||||
|
||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||
"""
|
||||
@@ -186,7 +295,8 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
||||
if cfg.save_safetensors:
|
||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||
|
||||
training_args = AxolotlTrainingArguments(
|
||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||
max_steps=total_num_steps * cfg.num_epochs,
|
||||
per_device_train_batch_size=cfg.micro_batch_size,
|
||||
per_device_eval_batch_size=cfg.eval_batch_size
|
||||
if cfg.eval_batch_size is not None
|
||||
|
||||
Reference in New Issue
Block a user