Compare commits

...

56 Commits

Author SHA1 Message Date
Wing Lian
64af21bcb2 set env vars trainer needs for FSDP
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-08-11 08:46:26 -04:00
Wing Lian
6b5cf8b5ea optimize length reducer from 9m -> <5sec 2023-08-11 08:30:30 -04:00
Wing Lian
79500f358a need to pass total num tokens to trainer too 2023-08-10 19:08:23 -04:00
Wing Lian
7e977a9b68 optimization if total_num_tokens is already known 2023-08-10 19:02:28 -04:00
Wing Lian
ac4b700daa optimization if total_num_tokens is already known 2023-08-10 19:01:17 -04:00
Wing Lian
2565c2f259 async batching for multipack 2023-08-10 18:28:15 -04:00
Wing Lian
a07f432d9c calculate cum seq lens with pos_ids instead of mask, simplify packing params, fix distributed barrier 2023-08-10 17:16:01 -04:00
Wing Lian
57d9bf711c let's not cleanup the cached datasets 2023-08-08 21:27:55 -04:00
Wing Lian
26983a1974 fix sampler to prevent overfit w new epochs 2023-08-08 15:34:18 -04:00
Wing Lian
1b8747e319 use custom distributed checks 2023-08-08 13:35:04 -04:00
Wing Lian
035b3c760c add numba to requirements. 2023-08-08 10:55:29 -04:00
Wing Lian
17abbd59e1 previous accelerate is still most performant 2023-08-08 09:46:01 -04:00
Wing Lian
6ec76ddb4c fix steps calculation 2023-08-08 05:13:21 -04:00
Wing Lian
21d307b15b fix counts by accounting for num devices 2023-08-08 04:13:10 -04:00
Wing Lian
58e9dee204 fixes and go back to distributed sampler since batch sampler won't work 2023-08-08 03:49:29 -04:00
Wing Lian
4f7c04bae0 more fixes and optimizations 2023-08-08 03:16:00 -04:00
Wing Lian
1162b93b6b filter w multiple cpus 2023-08-08 00:50:56 -04:00
Wing Lian
21f445d763 more packing and dataset optimizations and fixes 2023-08-08 00:45:24 -04:00
Wing Lian
229b9165aa fix test and pylint checks 2023-08-07 09:38:05 -04:00
Wing Lian
394a65f11f add unit tests for cum seq lens, add ability to build cu_seq_lens from positional ids, fix prompt test 2023-08-07 09:38:04 -04:00
Wing Lian
c70dae63cc add chatml 2023-08-07 09:38:04 -04:00
Wing Lian
7712955b35 fix chatml system prompt for openorca, legacy tokenizer opts 2023-08-07 09:38:04 -04:00
Wing Lian
f93f0017cd fix flash-attn, xformers, packing, support chatml 2023-08-07 09:38:04 -04:00
Wing Lian
0b01da0713 properly calculate max len 2023-08-07 09:38:04 -04:00
Wing Lian
b2f7bc7ccd use cumulative seq len with var len flash attn v2 w packing 2023-08-07 09:38:04 -04:00
Wing Lian
b8905e2a91 sample_packing_seq_len_multiplier config 2023-08-07 09:38:04 -04:00
Wing Lian
7e1edc662a make sure the chunk size is an int 2023-08-07 09:38:04 -04:00
Wing Lian
98c9bc69de seq_len_multiple for packing 2023-08-07 09:38:04 -04:00
Wing Lian
8378335dc9 limit packing to sequences of max seq len 2023-08-07 09:38:04 -04:00
Wing Lian
bdd34c7400 weighted CEL fixes 2023-08-07 09:38:04 -04:00
Wing Lian
c6cc54c7d9 weighted CE losses 2023-08-07 09:38:04 -04:00
Wing Lian
83f7362480 don't split batches when packing 2023-08-07 09:38:04 -04:00
Wing Lian
958d423e7c only process eval dataset for packing if not None 2023-08-07 09:38:04 -04:00
Wing Lian
e74eab6e73 add a test for the mask expansion for sequence packing 2023-08-07 09:38:04 -04:00
Wing Lian
487abfc769 pass sample packing efficiency to training args 2023-08-07 09:38:04 -04:00
Wing Lian
2bee646e85 fix step calc for packing 2023-08-07 09:38:04 -04:00
Wing Lian
945f2e5029 better handling so that all devices have the same dataloader len 2023-08-07 09:38:04 -04:00
Wing Lian
daed942fe9 fix rounding of len of batches to int 2023-08-07 09:38:04 -04:00
Wing Lian
df3eb645da better handling of variance in multipack dataloader length and trainer hanging when it runs out of data 2023-08-07 09:38:04 -04:00
Wing Lian
32fed7039d optimized expand mask fn 2023-08-07 09:38:04 -04:00
Wing Lian
7d7b5ebd71 more fixes for 4k and optimizations 2023-08-07 09:38:03 -04:00
Wing Lian
4b7ad9927f validation for sample packing and doc 2023-08-07 09:38:03 -04:00
Wing Lian
fedcf5a089 Update src/axolotl/utils/dataloader.py 2023-08-07 09:38:03 -04:00
Wing Lian
2f2974196d fix for position_ids w packing 2023-08-07 09:38:03 -04:00
Wing Lian
2e295c9f94 use accelerator prepare for dataloader 2023-08-07 09:38:03 -04:00
Wing Lian
4ab9ab79fd use distributed sampler, avoid accelerate prepare 2023-08-07 09:38:03 -04:00
Wing Lian
b02484a83e more fixes for sample packing 2023-08-07 09:38:03 -04:00
Wing Lian
58045f0816 more fixes, position_ids seems broken 2023-08-07 09:38:03 -04:00
Wing Lian
66774011c4 est total tokens, fix field loop 2023-08-07 09:38:03 -04:00
Wing Lian
41d4992029 more fixes for dataloader integration 2023-08-07 09:38:03 -04:00
Wing Lian
762f1b08db add position_ids back 2023-08-07 09:38:03 -04:00
Wing Lian
3aba4c5d7c use multi pack dataloader w random sampler 2023-08-07 09:38:03 -04:00
Wing Lian
ffd96839cf don't move masks to cpu 2023-08-07 09:38:03 -04:00
Wing Lian
ef9bf7ad73 fix expand mask for multiple batch items, make sure we pad position_ids 2023-08-07 09:38:03 -04:00
Wing Lian
4964b0d345 set position ids and use block diagonal attn mask 2023-08-07 09:38:03 -04:00
Wing Lian
36b0e30a9d fix attetion mask with packing 2023-08-07 09:38:03 -04:00
23 changed files with 1208 additions and 69 deletions

View File

@@ -375,7 +375,10 @@ dataset_shard_idx:
sequence_len: 2048
# max sequence length to concatenate training samples together up to
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
# soon to be DEPRECATED
max_packed_sequence_len: 1024
# use efficient multi-packing with block diagonal attention and per sequence position_ids
sample_packing:
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
adapter: lora

View File

@@ -13,6 +13,8 @@ einops
xformers
optimum
hf_transfer
numba
numpy==1.24.4
# qlora things
bert-score==0.3.13
evaluate==0.4.0

View File

@@ -20,9 +20,14 @@ from transformers import GenerationConfig, TextStreamer
from axolotl.logging_config import configure_logging
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.tokenization import check_dataset_labels
from axolotl.utils.trainer import setup_trainer
from axolotl.utils.trainer import (
calculate_total_num_steps,
process_datasets_for_packing,
setup_trainer,
)
from axolotl.utils.validation import validate_config
from axolotl.utils.wandb import setup_wandb_env_vars
@@ -231,12 +236,25 @@ def train(
cfg.pretraining_dataset,
tokenizer,
max_tokens=cfg.sequence_len,
seed=cfg.seed,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
train_dataset = train_dataset.with_format("torch")
eval_dataset = None
if is_main_process():
# process on rank 0 first so it gets cached so other ranks load from cache
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
if not is_main_process():
train_dataset, eval_dataset = process_datasets_for_packing(
cfg, train_dataset, eval_dataset
)
barrier()
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
if cfg.debug or "debug" in kwargs:
LOG.info("check_dataset_labels...")
check_dataset_labels(
@@ -286,7 +304,9 @@ def train(
model.save_pretrained(cfg.output_dir)
return
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
trainer = setup_trainer(
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
)
model.config.use_cache = False
@@ -345,14 +365,12 @@ def train(
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
if cfg.fsdp:
model.save_pretrained(cfg.output_dir)
trainer.save_model(cfg.output_dir)
elif cfg.local_rank == 0:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir)
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
if __name__ == "__main__":
fire.Fire(train)

View File

@@ -77,14 +77,21 @@ class ConstantLengthDataset(IterableDataset):
self.tokens_dtype = torch.int64
def __iter__(self):
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
for dataset in self.datasets:
idx = 0
iterator = iter(dataset)
more_examples = True
while more_examples:
try:
example = next(iterator)
idx += 1
except StopIteration:
more_examples = False
example = None
@@ -106,6 +113,9 @@ class ConstantLengthDataset(IterableDataset):
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
: self.seq_length
]
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
: self.seq_length
]
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
if labels.size() == input_ids.size() and (
attention_mask.size() == input_ids.size()
@@ -114,6 +124,7 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
"position_ids": position_ids,
}
else:
LOG.warning(
@@ -123,8 +134,10 @@ class ConstantLengthDataset(IterableDataset):
"input_ids": [],
"attention_mask": [],
"labels": [],
"position_ids": [],
}
buffer_len = 0
idx = 1
if example:
# FIXME
@@ -133,11 +146,6 @@ class ConstantLengthDataset(IterableDataset):
input_ids = example["input_ids"]
attention_mask = example["attention_mask"]
labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token:
input_ids.append(self.concat_token_id)
@@ -148,13 +156,17 @@ class ConstantLengthDataset(IterableDataset):
input_ids, dtype=self.tokens_dtype
)
attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.tokens_dtype
[idx * m for m in attention_mask], dtype=torch.int16
)
labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype
)
position_ids = torch.arange(
len(input_ids), dtype=self.tokens_dtype
)
buffer["input_ids"].append(input_ids_with_concat)
buffer["attention_mask"].append(attention_mask_with_concat)
buffer["labels"].append(labels_with_concat)
buffer["position_ids"].append(position_ids)
buffer_len += len(input_ids)

View File

@@ -7,10 +7,18 @@ from typing import Optional, Tuple
import torch
import transformers
from einops import rearrange
from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
try:
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
except ImportError:
from flash_attn.flash_attn_interface import (
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
)
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
def forward(
self,
@@ -84,35 +92,15 @@ def forward(
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
qkv = rearrange(qkv, "b s ... -> (b s) ...")
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
cu_q_lens = cu_q_lens.squeeze()
# pylint: disable=invalid-name
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(
x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads,
)
output_unpad = flash_attn_varlen_qkvpacked_func(
x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True,
)
output = rearrange(
pad_input(
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
indices,
bsz,
q_len,
),
"b s (h d) -> b s h d",
h=nheads,
output = flash_attn_varlen_qkvpacked_func(
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
return (
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
None,

View File

@@ -128,6 +128,7 @@ def xformers_forward(
query_states,
key_states,
value_states,
# attn_bias=attention_mask,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None

View File

@@ -0,0 +1,52 @@
"""
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
"""
from typing import Optional
import torch
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(dtype).min
)
def hijack_expand_mask():
import transformers
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
_expand_mask
)

View File

@@ -0,0 +1,103 @@
"""
Shared utils for the monkeypatches
"""
import torch
def get_cu_seqlens(attn_mask):
"""generate a cumulative sequence length mask for flash attention using attn mask"""
if len(attn_mask.shape) == 1:
attn_mask = attn_mask.unsqueeze(0)
device = attn_mask.device
results = []
max_seq_lens = []
for row in attn_mask:
# Exclude zeros to avoid adding their positions to the mask
t_non_zeros = row[row != 0]
# Find where the sequence number changes (including the first position)
seq_change = torch.cat(
[
torch.tensor([1], dtype=torch.int32, device=device),
t_non_zeros[1:] != t_non_zeros[:-1],
]
)
# Get the indices where the sequence changes
change_indices = torch.cat(
[
(seq_change == 1).nonzero(as_tuple=True)[0],
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = change_indices[1:] - change_indices[:-1]
# Calculate the length of the final sequence or padding
final_seq_length = len(row) - change_indices[-1]
# Append the length of the final sequence or padding to seq_lengths
if final_seq_length.item():
seq_lengths = torch.cat(
[
seq_lengths,
torch.tensor(
[final_seq_length.item()], dtype=torch.int32, device=device
),
]
)
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
position_ids = position_ids.unsqueeze(0)
device = position_ids.device
results = []
max_seq_lens = []
for row in position_ids:
# Count the number of consecutive zeros from the right side
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
# Adjust the row to exclude padding
adjusted_row = row[:-padding_length] if padding_length else row.clone()
# Find where the position resets to 0 (indicating a new sequence)
seq_starts = torch.cat(
[
torch.tensor([True], dtype=torch.bool, device=device),
adjusted_row[1:] == 0,
]
)
# Get the indices where the sequence starts
start_indices = torch.cat(
[
(seq_starts).nonzero(as_tuple=True)[0],
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
]
)
# Calculate the sequence lengths
seq_lengths = start_indices[1:] - start_indices[:-1]
# Calculate the cumulative sequence lengths
cu_seqlens = torch.cat(
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
)
# Append the padding length to the cumulative sequence lengths
if padding_length:
cu_seqlens = torch.cat(
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
)
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)

View File

@@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter):
) -> Generator[str, None, None]:
# returns the full prompt from instruction and optional input
# if a label (=response, =output) is provided, it's also appended.
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
formatted_sys_prompt = (
self.system_format.format(system=system)
if system and self.system_format
else ""
)
if input:
res = formatted_sys_prompt + self.turn_format.format(
instruction=instruction, input=input
@@ -86,12 +90,20 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
"""
def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.turn_format = "User: {instruction}\n{input}\nAssistant:"
self.turn_no_input_format = "User: {instruction}\nAssistant:"
self.system_format = "System: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
@@ -137,3 +149,12 @@ def load_open_orca(tokenizer, cfg):
cfg.train_on_inputs,
cfg.sequence_len,
)
def load_open_orca_chatml(tokenizer, cfg):
return OpenOrcaPromptTokenizingStrategy(
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)

View File

@@ -16,6 +16,7 @@ class PromptStyle(Enum):
INSTRUCT = "instruct"
CHAT = "chat"
CHATML = "chatml"
class AlpacaPrompter:
@@ -25,6 +26,7 @@ class AlpacaPrompter:
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
system_format: str
turn_format: str
turn_no_input_format: str
prompt_style: Optional[PromptStyle] = None
@@ -34,14 +36,23 @@ class AlpacaPrompter:
self.match_prompt_style()
def match_prompt_style(self):
# pylint: disable=duplicate-code
if self.prompt_style == PromptStyle.INSTRUCT.value:
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
self.turn_no_input_format = (
"### Instruction:\n{instruction}\n\n### Response:\n"
)
self.system_format = "### System:\n{system}\n\n"
if self.prompt_style == PromptStyle.CHAT.value:
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
self.system_format = "SYSTEM: {system}\n"
if self.prompt_style == PromptStyle.CHATML.value:
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
self.turn_no_input_format = (
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
)
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
def build_prompt(
self,

View File

@@ -0,0 +1,121 @@
"""
DataCollator for axolotl to pad labels and position_ids for packed sequences
"""
from dataclasses import dataclass
from typing import Any, Optional, Union
import numpy as np
from transformers import PreTrainedTokenizerBase
from transformers.utils import PaddingStrategy
@dataclass
class DataCollatorForSeq2Seq:
"""
Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
Args:
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
The tokenizer used for encoding the data.
model ([`PreTrainedModel`]):
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
prepare the *decoder_input_ids*
This is useful when using *label_smoothing* to avoid calculating loss twice.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
among:
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
sequence is provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
pad_to_multiple_of (`int`, *optional*):
If set will pad the sequence to a multiple of the provided value.
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
7.5 (Volta).
label_pad_token_id (`int`, *optional*, defaults to -100):
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
return_tensors (`str`):
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
"""
tokenizer: PreTrainedTokenizerBase
model: Optional[Any] = None
padding: Union[bool, str, PaddingStrategy] = True
max_length: Optional[int] = None
pad_to_multiple_of: Optional[int] = None
label_pad_token_id: int = -100
position_pad_token_id: int = 0
return_tensors: str = "pt"
def __call__(self, features, return_tensors=None):
labels = None
if return_tensors is None:
return_tensors = self.return_tensors
for feature_name, pad_token_id in [
("labels", self.label_pad_token_id),
("position_ids", self.position_pad_token_id),
]:
feat = (
[feature[feature_name] for feature in features]
if feature_name in features[0].keys()
else None
)
labels = feat if feat and feature_name == "labels" else labels
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
# same length to return tensors.
if feat is not None:
max_feature_length = max(len(l) for l in feat) # noqa: E741
if self.pad_to_multiple_of is not None:
max_feature_length = (
(max_feature_length + self.pad_to_multiple_of - 1)
// self.pad_to_multiple_of
* self.pad_to_multiple_of
)
padding_side = self.tokenizer.padding_side
for feature in features:
remainder = [pad_token_id] * (
max_feature_length - len(feature[feature_name])
)
if isinstance(feature[feature_name], list):
feature[feature_name] = (
feature[feature_name] + remainder
if padding_side == "right"
else remainder + feature[feature_name]
)
elif padding_side == "right":
feature[feature_name] = np.concatenate(
[feature[feature_name], remainder]
).astype(np.int64)
else:
feature[feature_name] = np.concatenate(
[remainder, feature[feature_name]]
).astype(np.int64)
features = self.tokenizer.pad(
features,
padding=self.padding,
max_length=self.max_length,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=return_tensors,
)
# prepare decoder_input_ids
if (
labels is not None
and self.model is not None
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
):
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
labels=features["labels"]
)
features["decoder_input_ids"] = decoder_input_ids
return features

View File

@@ -1,5 +1,6 @@
"""Module containing data utilities"""
import functools
import hashlib
import itertools
import logging
from hashlib import md5
@@ -35,6 +36,7 @@ from axolotl.prompters import (
ShareGPTPrompter,
SummarizeTLDRPrompter,
)
from axolotl.utils.distributed import barrier, is_main_process
LOG = logging.getLogger("axolotl")
@@ -109,6 +111,7 @@ def load_tokenized_prepared_datasets(
local_path = Path(d.path)
if local_path.exists():
if local_path.is_dir():
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
ds = load_dataset(
d.path,
name=d.name,
@@ -374,6 +377,7 @@ def load_prepare_datasets(
dataset = Dataset.from_list(list(constant_len_dataset))
# filter out bad data
# TODO convert to dataset.filter(...)
dataset = Dataset.from_list(
[
d
@@ -413,7 +417,51 @@ def load_prepare_datasets(
)
if cfg.val_set_size:
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
to_hash_train = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "train"
+ "|"
+ str(cfg.seed or 42)
)
to_hash_test = (
dataset._fingerprint # pylint: disable=protected-access
+ "|"
+ str(cfg.val_set_size)
+ "|"
+ "test"
+ "|"
+ str(cfg.seed or 42)
)
train_fingerprint = hashlib.md5(
to_hash_train.encode(), usedforsecurity=False
).hexdigest()
test_fingerprint = hashlib.md5(
to_hash_test.encode(), usedforsecurity=False
).hexdigest()
if is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
if not is_main_process():
dataset = dataset.train_test_split(
test_size=cfg.val_set_size,
shuffle=False,
seed=cfg.seed or 42,
train_new_fingerprint=train_fingerprint,
test_new_fingerprint=test_fingerprint,
)
barrier()
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else:

View File

@@ -0,0 +1,310 @@
# pylint: skip-file
import hashlib
import itertools
import logging
import math
import queue
import threading
from typing import Any, Callable, List, Optional, Union
import numba
import numpy as np
from torch.utils.data import DistributedSampler, Sampler
LOG = logging.getLogger("axolotl.utils.dataloader")
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[Any] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result, len(a)
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
"""
:param lengths: array of lengths of each sample
:param lengths_cumsum: cumulative sum of consecutive lengths
:param rank: rank for this process
:param c: length of tokens per batch
:param n: number of ranks
:return:
"""
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [left, right)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
mid = (left + right) // 2
if ffd_check(lengths[start_index : start_index + mid], c, n):
left = mid
else:
right = mid
# use length left
batch, tot_seqs = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
yield batch[rank], tot_seqs, s, len(result) * c * n
def chunk(iterable, n):
"""
Chunk data into tuples of length n
"""
# batched('ABCDEFG', 3) --> ABC DEF G
if n < 1:
raise ValueError("n must be at least one")
it = iter(iterable)
while batch := tuple(itertools.islice(it, n)):
yield batch
def hash_indices(lst: List[int]) -> str:
# Convert the list of integers to a string representation
concatenated = ",".join(map(str, lst))
# Generate the hash
sha256 = hashlib.sha256()
sha256.update(concatenated.encode())
return sha256.hexdigest()
class MultipackDistributedDataloader:
"""Unpadded data loading using Multipack.
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
"""
def __init__(
self,
dataset: Any,
collate_fn: Callable,
seq_max_length: int = 2048,
batch_size: int = 1,
sampler: Union[Sampler, DistributedSampler] = None,
packing_efficiency_estimate: float = 1.0,
sample_packing_seq_len_multiplier: int = 1,
device_count: int = 1,
total_num_tokens: Optional[int] = None,
):
# Dataset
self.dataset = dataset
lengths_series = (
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
)
self.lengths: np.ndarray = lengths_series.values
assert isinstance(self.lengths, np.ndarray)
assert batch_size % sample_packing_seq_len_multiplier == 0
assert batch_size >= sample_packing_seq_len_multiplier
self.sampler = sampler
self.batch_size = batch_size
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
self.seq_max_length = seq_max_length
self.batch_max_length = batch_size * seq_max_length
self.collate_fn = collate_fn
self.num_replicas = 1
self.rank = 0
# statistics
self.total_num_tokens = total_num_tokens
self.eff_total_used = 0
self.eff_total_slots = 0
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
self.device_count = device_count
# for non-blocking batch creation
self.batch_queue: queue.Queue = queue.Queue(
maxsize=10
) # Adjust maxsize as needed
def generate_batches(self, set_stats=False):
LOG.info("generating packed batches")
if self.sampler:
indices = [idx for idx in self.sampler]
else:
indices = range(0, len(self.dataset))
LOG.info(hash_indices(indices))
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
alloc_iter = iter(
allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
# c=self.batch_max_length,
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
n=self.num_replicas,
)
)
for batch, tot_seqs, total_used, total_slots in alloc_iter:
self.batch_queue.put([indices[b_idx] for b_idx in batch])
# statistics
if set_stats:
self.eff_total_used = total_used
self.eff_total_slots = total_slots
self.batch_queue.put(None) # Signal the end of batch generation
def _generate_batches_thread(self):
try:
self.generate_batches(set_stats=True)
except Exception as e:
LOG.error(f"Error in batch generation thread: {e}")
self.batch_queue.put(
None
) # Signal the end of batch generation in case of error
def __iter__(self):
if hasattr(self.sampler, "set_epoch"):
new_epoch = self.sampler.epoch + 1
self.sampler.set_epoch(new_epoch)
LOG.info(f"calling sampler.set_epoch({new_epoch})")
# Start the batch generation in a separate thread
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
batch_gen_thread.start()
features = self.dataset.features.keys()
len_remaining = self._len_est()
while True:
batch = self.batch_queue.get()
if batch is None: # Sentinel value received, stop iteration
break
chunked_data = []
attn_mask_cum_idx = 0
concatenated = {}
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
for feature in features:
if feature == "attention_mask":
arrays = [
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
for idx, item in enumerate(batched_data)
if feature in item
]
attn_mask_cum_idx += len(batched_data)
concatenated[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature])
for item in batched_data
if feature in item
]
concatenated[feature] = np.concatenate(arrays)
chunked_data.append(concatenated)
yield self.collate_fn(chunked_data)
len_remaining -= 1
if not len_remaining:
break
# Wait for the batch generation thread to finish
batch_gen_thread.join(timeout=5)
LOG.info(f"actual packing efficiency: {self.efficiency()}")
def _len_est(self):
if not self.total_num_tokens:
self.total_num_tokens = np.sum(self.lengths)
lengths_sum_per_device = self.total_num_tokens // self.device_count
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"total_num_tokens per device: {lengths_sum_per_device}"
)
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
return (
math.floor(
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// self.seq_max_length
// self.batch_size
)
- 1
)
def __len__(self):
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
# the same share of total tokens
# if not self.eff_total_used:
# batches, _ = self.generate_batches(set_stats=True)
# LOG.info(
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
# f"actual packing efficiency: {self.efficiency()}"
# )
return max(1, self._len_est())
def len_w_stats(self):
if not self.eff_total_used:
batches, _ = self.generate_batches(set_stats=True)
LOG.info(
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
f"actual packing efficiency: {self.efficiency()}"
)
return max(1, self._len_est())
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -0,0 +1,41 @@
"""
utility helpers for distributed checks
"""
import torch.distributed as dist
from accelerate import Accelerator
accelerate = None # pylint: disable=invalid-name
def load_accelerate():
global accelerate # pylint: disable=global-statement
accelerate = Accelerator()
def is_distributed():
"""
Check if distributed training is initialized.
"""
global accelerate # pylint: disable=global-statement
if not accelerate:
accelerate = Accelerator()
return dist.is_available() and dist.is_initialized()
def barrier():
"""
Acts as a barrier to wait for all processes. This ensures that all processes
reach the barrier before proceeding further.
"""
if is_distributed():
dist.barrier()
def is_main_process():
"""
Check if the current process is the main process.
If not in distributed mode, always return True.
"""
if not is_distributed():
return True
return dist.get_rank() == 0

View File

@@ -36,20 +36,26 @@ def load_tokenizer(
tokenizer_type,
cfg,
):
tokenizer_kwargs = {}
use_fast = True # this is the default
if cfg.tokenizer_use_fast is not None:
use_fast = cfg.tokenizer_use_fast
if cfg.tokenizer_legacy is not None:
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
if tokenizer_type:
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
else:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_config,
trust_remote_code=cfg.trust_remote_code or False,
use_fast=use_fast,
**tokenizer_kwargs,
)
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
@@ -86,8 +92,10 @@ def load_model(
# TODO refactor as a kwarg
load_in_8bit = cfg.load_in_8bit
cfg.is_llama_derived_model = "llama" in base_model or (
cfg.model_type and "llama" in cfg.model_type.lower()
cfg.is_llama_derived_model = (
"llama" in base_model
or (cfg.model_type and "llama" in cfg.model_type.lower())
or cfg.is_llama_derived_model is True
)
if cfg.is_llama_derived_model and cfg.flash_attention:
@@ -132,6 +140,14 @@ def load_model(
LOG.info("patching with xpos rope")
replace_llama_rope_with_xpos_rope()
if cfg.is_llama_derived_model and (
cfg.max_packed_sequence_len or cfg.sample_packing
):
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
LOG.info("patching _expand_mask")
hijack_expand_mask()
if cfg.bf16 or cfg.bfloat16:
torch_dtype = torch.bfloat16
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
@@ -222,7 +238,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
**model_kwargs,
)
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
@@ -257,7 +272,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -288,7 +302,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -302,7 +315,6 @@ def load_model(
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
torch_dtype=torch_dtype,
device_map=cfg.device_map,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)

View File

@@ -1,19 +1,22 @@
"""Module containing the Trainer class and related functions"""
import importlib
import logging
import math
import os
import sys
from contextlib import contextmanager
from dataclasses import dataclass, field
from functools import partial
from pathlib import Path
from typing import Optional
from typing import Optional, Union
import bitsandbytes as bnb
import torch.cuda
import transformers
from datasets import Dataset, set_caching_enabled
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names
@@ -21,6 +24,8 @@ from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback,
SavePeftModelCallback,
)
from axolotl.utils.collators import DataCollatorForSeq2Seq
from axolotl.utils.dataloader import MultipackDistributedDataloader
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
@@ -29,6 +34,68 @@ from axolotl.utils.schedulers import (
LOG = logging.getLogger("axolotl")
@torch.jit.script
def weighted_cross_entropy(
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
):
# Flatten the logits, labels, and weights tensors
logits = logits.view(
-1, logits.size(-1)
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
# Compute the unweighted cross entropy loss
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
# Apply the weights to the losses and compute their sum
return (weights * losses).sum()
@torch.jit.script
def create_weighted_mask(labels: torch.Tensor):
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
if len(labels.shape) == 1:
labels = labels.unsqueeze(0)
weights = torch.zeros_like(labels).float()
for i in range(labels.shape[0]):
mask = labels[i] != -100
# Create a tensor to track group ids
group_ids = torch.zeros_like(labels[i]).int()
curr_group_id = 0
for j in range(1, len(labels[i])):
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
curr_group_id += 1 # start new group
group_ids[j] = (
curr_group_id if mask[j] else 0
) # assign group id if unmasked label
# Count only unmasked labels in each group
group_counts = torch.bincount(group_ids[mask])
mask_weights = torch.zeros_like(labels[i]).float()
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
weights[i] = mask_weights
return weights.squeeze() # squeeze the output to match the input dimension
def trainer_weighted_loss(model_output, labels, shift_labels=True):
logits = (
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
)
if shift_labels:
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
weights = create_weighted_mask(labels)
return weighted_cross_entropy(logits, labels, weights)
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
@@ -39,6 +106,26 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
sample_packing_efficiency: float = field(
default=1.0,
metadata={"help": "Sample packing efficiency for calculating batch length."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
sample_packing_seq_len_multiplier: int = field(
default=1,
metadata={"help": "the multiplier for the max len for packed sequences"},
)
train_data_total_num_tokens: Optional[int] = field(
default=None,
metadata={"help": "the total number of tokens in the train dataset"},
)
class AxolotlTrainer(Trainer):
@@ -76,6 +163,66 @@ class AxolotlTrainer(Trainer):
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size > 1 and self.args.sample_packing:
return DistributedSampler(
self.train_dataset,
num_replicas=self.args.world_size,
rank=self.args.process_index,
seed=self.args.seed,
)
return super()._get_train_sampler()
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
train_sampler = self._get_train_sampler()
return self.accelerator.prepare(
MultipackDistributedDataloader(
self.train_dataset,
batch_size=self._train_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=train_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=self.args.train_data_total_num_tokens,
)
)
return super().get_train_dataloader()
def get_eval_dataloader(
self, eval_dataset: Optional[Dataset] = None
) -> Union[DataLoader, MultipackDistributedDataloader]:
if self.args.sample_packing:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
return self.accelerator.prepare(
MultipackDistributedDataloader(
eval_dataset,
batch_size=self.args.eval_batch_size,
seq_max_length=self.args.max_seq_length,
collate_fn=self.data_collator,
sampler=eval_sampler,
packing_efficiency_estimate=self.args.sample_packing_efficiency,
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
total_num_tokens=None,
)
)
return super().get_eval_dataloader(eval_dataset)
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
# labels = inputs.pop("labels")
# outputs = model(**inputs)
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
# return (loss, outputs) if return_outputs else loss
return super().compute_loss(model, inputs, return_outputs=return_outputs)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
"""
@@ -106,10 +253,117 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
return self.lr_scheduler
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
def add_position_ids(sample):
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
return sample
def drop_long_seq(sample, sequence_len=2048):
return len(sample["input_ids"]) <= sequence_len
@contextmanager
def disable_datasets_caching():
try:
set_caching_enabled(False)
yield
finally:
set_caching_enabled(True)
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
if cfg.sample_packing:
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
add_position_ids, num_proc=os.cpu_count()
)
if eval_dataset:
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
add_position_ids, num_proc=os.cpu_count()
)
return train_dataset, eval_dataset
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
if cfg.sample_packing:
# we have to drop anything longer then sequence len otherwise
# flash attention with position ids fails
total_num_tokens = (
cfg.total_num_tokens
if cfg.total_num_tokens
else sum(len(s["input_ids"]) for s in train_dataset)
)
if not cfg.total_num_tokens:
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
if cfg.sample_packing_eff_est:
total_num_steps = (
# match count to len est in dataloader
(
math.floor(
0.99
* total_num_tokens
/ cfg.sample_packing_eff_est
/ 2048
// cfg.batch_size
// int(os.environ.get("WORLD_SIZE", 1))
)
- 1
)
* cfg.num_epochs
)
LOG.info(
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
)
else:
sampler = RandomSampler(train_dataset)
data_loader = MultipackDistributedDataloader(
train_dataset,
batch_size=cfg.micro_batch_size,
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
collate_fn=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding="longest",
),
sampler=sampler,
packing_efficiency_estimate=cfg.sample_packing_eff_est,
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
device_count=int(os.environ.get("WORLD_SIZE", 1)),
)
data_loader_len = data_loader.len_w_stats()
actual_eff = data_loader.efficiency()
LOG.info(f"data_loader_len: {data_loader_len}")
total_num_steps = int(
math.floor(
data_loader_len
* cfg.micro_batch_size
* cfg.num_epochs
// cfg.batch_size
)
)
LOG.info(
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
)
else:
total_num_steps = int(
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
)
LOG.info(f"total_num_steps: {total_num_steps}")
return total_num_steps
def setup_fsdp_envs(cfg):
os.environ["ACCELERATE_USE_FSDP"] = "true"
if cfg.fsdp_config.fsdp_sync_module_states:
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
if cfg.fsdp_config.fsdp_state_dict_type:
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
if cfg.fsdp:
setup_fsdp_envs(cfg)
warmup_steps = (
cfg.warmup_steps
if cfg.warmup_steps is not None
@@ -189,7 +443,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.save_safetensors:
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
if cfg.sample_packing_eff_est:
training_arguments_kwargs[
"sample_packing_efficiency"
] = cfg.sample_packing_eff_est
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
max_seq_length=cfg.sequence_len,
per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None
@@ -203,7 +464,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
save_steps=cfg.save_steps,
output_dir=cfg.output_dir,
save_total_limit=3,
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
load_best_model_at_end=(
cfg.load_best_model_at_end is not False
and cfg.val_set_size > 0
@@ -221,6 +482,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
else "cosine",
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
train_data_total_num_tokens=cfg.total_num_tokens,
**training_arguments_kwargs,
)
@@ -314,11 +578,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.collator_pad_to_longest:
data_collator_kwargs["padding"] = "longest"
else:
data_collator_kwargs["pad_to_multiple_of"] = 8
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
data_collator_kwargs["pad_to_multiple_of"] = 64
if cfg.is_llama_derived_model and cfg.landmark_attention:
from functools import partial
from axolotl.monkeypatch.llama_landmark_attn import (
add_mem_tokens,
get_mem_id,
@@ -346,7 +610,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
train_dataset=train_dataset,
eval_dataset=eval_dataset,
args=training_args,
data_collator=transformers.DataCollatorForSeq2Seq(
data_collator=DataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
**data_collator_kwargs,

View File

@@ -8,6 +8,19 @@ LOG = logging.getLogger("axolotl")
def validate_config(cfg):
if cfg.max_packed_sequence_len and cfg.sample_packing:
raise ValueError(
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
)
if cfg.max_packed_sequence_len:
LOG.warning(
str(
PendingDeprecationWarning(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
)
)
)
if cfg.gradient_accumulation_steps and cfg.batch_size:
raise ValueError(
"please set only one of gradient_accumulation_steps or batch_size"
@@ -97,6 +110,17 @@ def validate_config(cfg):
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError(
"sample_packing not compatible with sdp_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
# TODO
# MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -0,0 +1,30 @@
"""
Unit tests for the monkeypatch utils
"""
import unittest
import torch
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
class TestMonkeyPatchUtils(unittest.TestCase):
"""
Unit test class for monkeypatch utils
"""
def test_get_cu_seqlens_1d(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
def test_get_cu_seqlens_from_pos_ids_1d(self):
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
self.assertTrue(
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
if __name__ == "__main__":
unittest.main()

44
tests/test_expand_mask.py Normal file
View File

@@ -0,0 +1,44 @@
"""
Unit tests for the monkey patch for expand mask to handle packed sequences
"""
import unittest
import torch
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
class TestExpandMask(unittest.TestCase):
"""
Test class for attention mask expansion for packed sequences
"""
def test_output(self):
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
dtype = torch.float32
expected_output = torch.tensor(
[
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
]
],
[
[
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
]
],
]
)
# Check that the output matches the expected output
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
if __name__ == "__main__":
unittest.main()

View File

@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
}
)
def test_resets_attention(self):
def test_increments_attention(self):
prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy(
prompter,
@@ -55,10 +55,14 @@ class TestPacking(unittest.TestCase):
# first example doesn't have mask reset
assert example["input_ids"][0] == self.tokenizer.bos_token_id
assert example["attention_mask"][0] == 1
assert example["position_ids"][0] == 0
assert example["position_ids"][1] == 1
# but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 0
assert example["attention_mask"][next_bos_index] == 2
assert example["position_ids"][next_bos_index] == 0
assert example["position_ids"][next_bos_index + 1] == 1
if __name__ == "__main__":

View File

@@ -134,9 +134,15 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
"output": "Hi! How can I help?",
}
example = strat.tokenize_prompt(sample)
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
assert example["input_ids"][9] == 11889 # USER
assert example["input_ids"][0:5] == [
1,
28962,
1254,
12665,
29901,
] # "<s>SYSTEM:"
assert example["input_ids"][5:7] == [671, 20118] # " use cot"
assert example["input_ids"][8] == 11889 # USER
class Llama2ChatTokenizationTest(unittest.TestCase):

View File

@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
)
)
assert "use cot" in res
assert res.startswith("### System:")
assert res.startswith("SYSTEM:")
assert "### Instruction:" not in res
assert "### Input:" not in res
assert "alpacas" in res

View File

@@ -313,3 +313,27 @@ class ValidationTest(unittest.TestCase):
)
validate_config(cfg)
def test_packing(self):
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
}
)
with self._caplog.at_level(logging.WARNING):
validate_config(cfg)
assert any(
"max_packed_sequence_len will be deprecated in favor of sample_packing"
in record.message
for record in self._caplog.records
)
cfg = DictDefault(
{
"max_packed_sequence_len": 2048,
"sample_packing": True,
}
)
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
with pytest.raises(ValueError, match=regex_exp):
validate_config(cfg)