Compare commits
56 Commits
quantize-p
...
packing-at
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64af21bcb2 | ||
|
|
6b5cf8b5ea | ||
|
|
79500f358a | ||
|
|
7e977a9b68 | ||
|
|
ac4b700daa | ||
|
|
2565c2f259 | ||
|
|
a07f432d9c | ||
|
|
57d9bf711c | ||
|
|
26983a1974 | ||
|
|
1b8747e319 | ||
|
|
035b3c760c | ||
|
|
17abbd59e1 | ||
|
|
6ec76ddb4c | ||
|
|
21d307b15b | ||
|
|
58e9dee204 | ||
|
|
4f7c04bae0 | ||
|
|
1162b93b6b | ||
|
|
21f445d763 | ||
|
|
229b9165aa | ||
|
|
394a65f11f | ||
|
|
c70dae63cc | ||
|
|
7712955b35 | ||
|
|
f93f0017cd | ||
|
|
0b01da0713 | ||
|
|
b2f7bc7ccd | ||
|
|
b8905e2a91 | ||
|
|
7e1edc662a | ||
|
|
98c9bc69de | ||
|
|
8378335dc9 | ||
|
|
bdd34c7400 | ||
|
|
c6cc54c7d9 | ||
|
|
83f7362480 | ||
|
|
958d423e7c | ||
|
|
e74eab6e73 | ||
|
|
487abfc769 | ||
|
|
2bee646e85 | ||
|
|
945f2e5029 | ||
|
|
daed942fe9 | ||
|
|
df3eb645da | ||
|
|
32fed7039d | ||
|
|
7d7b5ebd71 | ||
|
|
4b7ad9927f | ||
|
|
fedcf5a089 | ||
|
|
2f2974196d | ||
|
|
2e295c9f94 | ||
|
|
4ab9ab79fd | ||
|
|
b02484a83e | ||
|
|
58045f0816 | ||
|
|
66774011c4 | ||
|
|
41d4992029 | ||
|
|
762f1b08db | ||
|
|
3aba4c5d7c | ||
|
|
ffd96839cf | ||
|
|
ef9bf7ad73 | ||
|
|
4964b0d345 | ||
|
|
36b0e30a9d |
@@ -375,7 +375,10 @@ dataset_shard_idx:
|
|||||||
sequence_len: 2048
|
sequence_len: 2048
|
||||||
# max sequence length to concatenate training samples together up to
|
# max sequence length to concatenate training samples together up to
|
||||||
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
# inspired by StackLLaMA. see https://huggingface.co/blog/stackllama#supervised-fine-tuning
|
||||||
|
# soon to be DEPRECATED
|
||||||
max_packed_sequence_len: 1024
|
max_packed_sequence_len: 1024
|
||||||
|
# use efficient multi-packing with block diagonal attention and per sequence position_ids
|
||||||
|
sample_packing:
|
||||||
|
|
||||||
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
# if you want to use 'lora' or 'qlora' or leave blank to train all parameters in original model
|
||||||
adapter: lora
|
adapter: lora
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ einops
|
|||||||
xformers
|
xformers
|
||||||
optimum
|
optimum
|
||||||
hf_transfer
|
hf_transfer
|
||||||
|
numba
|
||||||
|
numpy==1.24.4
|
||||||
# qlora things
|
# qlora things
|
||||||
bert-score==0.3.13
|
bert-score==0.3.13
|
||||||
evaluate==0.4.0
|
evaluate==0.4.0
|
||||||
|
|||||||
@@ -20,9 +20,14 @@ from transformers import GenerationConfig, TextStreamer
|
|||||||
from axolotl.logging_config import configure_logging
|
from axolotl.logging_config import configure_logging
|
||||||
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
from axolotl.utils.data import load_prepare_datasets, load_pretraining_dataset
|
||||||
from axolotl.utils.dict import DictDefault
|
from axolotl.utils.dict import DictDefault
|
||||||
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
from axolotl.utils.models import load_model, load_tokenizer
|
from axolotl.utils.models import load_model, load_tokenizer
|
||||||
from axolotl.utils.tokenization import check_dataset_labels
|
from axolotl.utils.tokenization import check_dataset_labels
|
||||||
from axolotl.utils.trainer import setup_trainer
|
from axolotl.utils.trainer import (
|
||||||
|
calculate_total_num_steps,
|
||||||
|
process_datasets_for_packing,
|
||||||
|
setup_trainer,
|
||||||
|
)
|
||||||
from axolotl.utils.validation import validate_config
|
from axolotl.utils.validation import validate_config
|
||||||
from axolotl.utils.wandb import setup_wandb_env_vars
|
from axolotl.utils.wandb import setup_wandb_env_vars
|
||||||
|
|
||||||
@@ -231,12 +236,25 @@ def train(
|
|||||||
cfg.pretraining_dataset,
|
cfg.pretraining_dataset,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
max_tokens=cfg.sequence_len,
|
max_tokens=cfg.sequence_len,
|
||||||
seed=cfg.seed,
|
seed=cfg.seed or 42,
|
||||||
)
|
)
|
||||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||||
train_dataset = train_dataset.with_format("torch")
|
train_dataset = train_dataset.with_format("torch")
|
||||||
eval_dataset = None
|
eval_dataset = None
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
# process on rank 0 first so it gets cached so other ranks load from cache
|
||||||
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
|
cfg, train_dataset, eval_dataset
|
||||||
|
)
|
||||||
|
barrier()
|
||||||
|
if not is_main_process():
|
||||||
|
train_dataset, eval_dataset = process_datasets_for_packing(
|
||||||
|
cfg, train_dataset, eval_dataset
|
||||||
|
)
|
||||||
|
barrier()
|
||||||
|
total_num_steps = calculate_total_num_steps(cfg, train_dataset, tokenizer)
|
||||||
|
|
||||||
if cfg.debug or "debug" in kwargs:
|
if cfg.debug or "debug" in kwargs:
|
||||||
LOG.info("check_dataset_labels...")
|
LOG.info("check_dataset_labels...")
|
||||||
check_dataset_labels(
|
check_dataset_labels(
|
||||||
@@ -286,7 +304,9 @@ def train(
|
|||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
return
|
return
|
||||||
|
|
||||||
trainer = setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer)
|
trainer = setup_trainer(
|
||||||
|
cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
model.config.use_cache = False
|
model.config.use_cache = False
|
||||||
|
|
||||||
@@ -345,14 +365,12 @@ def train(
|
|||||||
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
# TODO do we need this fix? https://huggingface.co/docs/accelerate/usage_guides/fsdp#saving-and-loading
|
||||||
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
# only save on rank 0, otherwise it corrupts output on multi-GPU when multiple processes attempt to write the same file
|
||||||
if cfg.fsdp:
|
if cfg.fsdp:
|
||||||
model.save_pretrained(cfg.output_dir)
|
trainer.save_model(cfg.output_dir)
|
||||||
elif cfg.local_rank == 0:
|
elif cfg.local_rank == 0:
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
model.save_pretrained(cfg.output_dir)
|
model.save_pretrained(cfg.output_dir)
|
||||||
|
|
||||||
# trainer.save_model(cfg.output_dir) # TODO this may be needed for deepspeed to work? need to review another time
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(train)
|
fire.Fire(train)
|
||||||
|
|||||||
@@ -77,14 +77,21 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
self.tokens_dtype = torch.int64
|
self.tokens_dtype = torch.int64
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
buffer = {"input_ids": [], "attention_mask": [], "labels": []}
|
buffer = {
|
||||||
|
"input_ids": [],
|
||||||
|
"attention_mask": [],
|
||||||
|
"labels": [],
|
||||||
|
"position_ids": [],
|
||||||
|
}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
for dataset in self.datasets:
|
for dataset in self.datasets:
|
||||||
|
idx = 0
|
||||||
iterator = iter(dataset)
|
iterator = iter(dataset)
|
||||||
more_examples = True
|
more_examples = True
|
||||||
while more_examples:
|
while more_examples:
|
||||||
try:
|
try:
|
||||||
example = next(iterator)
|
example = next(iterator)
|
||||||
|
idx += 1
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
more_examples = False
|
more_examples = False
|
||||||
example = None
|
example = None
|
||||||
@@ -106,6 +113,9 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
attention_mask = torch.cat(buffer["attention_mask"], dim=-1)[
|
||||||
: self.seq_length
|
: self.seq_length
|
||||||
]
|
]
|
||||||
|
position_ids = torch.cat(buffer["position_ids"], dim=-1)[
|
||||||
|
: self.seq_length
|
||||||
|
]
|
||||||
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
labels = torch.cat(buffer["labels"], dim=-1)[: self.seq_length]
|
||||||
if labels.size() == input_ids.size() and (
|
if labels.size() == input_ids.size() and (
|
||||||
attention_mask.size() == input_ids.size()
|
attention_mask.size() == input_ids.size()
|
||||||
@@ -114,6 +124,7 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"input_ids": input_ids,
|
"input_ids": input_ids,
|
||||||
"labels": labels,
|
"labels": labels,
|
||||||
"attention_mask": attention_mask,
|
"attention_mask": attention_mask,
|
||||||
|
"position_ids": position_ids,
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
LOG.warning(
|
LOG.warning(
|
||||||
@@ -123,8 +134,10 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
"input_ids": [],
|
"input_ids": [],
|
||||||
"attention_mask": [],
|
"attention_mask": [],
|
||||||
"labels": [],
|
"labels": [],
|
||||||
|
"position_ids": [],
|
||||||
}
|
}
|
||||||
buffer_len = 0
|
buffer_len = 0
|
||||||
|
idx = 1
|
||||||
|
|
||||||
if example:
|
if example:
|
||||||
# FIXME
|
# FIXME
|
||||||
@@ -133,11 +146,6 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
input_ids = example["input_ids"]
|
input_ids = example["input_ids"]
|
||||||
attention_mask = example["attention_mask"]
|
attention_mask = example["attention_mask"]
|
||||||
labels = example["labels"]
|
labels = example["labels"]
|
||||||
if (
|
|
||||||
buffer["input_ids"]
|
|
||||||
and input_ids[0] == self.tokenizer.bos_token_id
|
|
||||||
):
|
|
||||||
attention_mask[0] = 0
|
|
||||||
|
|
||||||
if add_concat_token:
|
if add_concat_token:
|
||||||
input_ids.append(self.concat_token_id)
|
input_ids.append(self.concat_token_id)
|
||||||
@@ -148,13 +156,17 @@ class ConstantLengthDataset(IterableDataset):
|
|||||||
input_ids, dtype=self.tokens_dtype
|
input_ids, dtype=self.tokens_dtype
|
||||||
)
|
)
|
||||||
attention_mask_with_concat = torch.tensor(
|
attention_mask_with_concat = torch.tensor(
|
||||||
attention_mask, dtype=self.tokens_dtype
|
[idx * m for m in attention_mask], dtype=torch.int16
|
||||||
)
|
)
|
||||||
labels_with_concat = torch.tensor(
|
labels_with_concat = torch.tensor(
|
||||||
labels, dtype=self.tokens_dtype
|
labels, dtype=self.tokens_dtype
|
||||||
)
|
)
|
||||||
|
position_ids = torch.arange(
|
||||||
|
len(input_ids), dtype=self.tokens_dtype
|
||||||
|
)
|
||||||
|
|
||||||
buffer["input_ids"].append(input_ids_with_concat)
|
buffer["input_ids"].append(input_ids_with_concat)
|
||||||
buffer["attention_mask"].append(attention_mask_with_concat)
|
buffer["attention_mask"].append(attention_mask_with_concat)
|
||||||
buffer["labels"].append(labels_with_concat)
|
buffer["labels"].append(labels_with_concat)
|
||||||
|
buffer["position_ids"].append(position_ids)
|
||||||
buffer_len += len(input_ids)
|
buffer_len += len(input_ids)
|
||||||
|
|||||||
@@ -7,10 +7,18 @@ from typing import Optional, Tuple
|
|||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from flash_attn.bert_padding import pad_input, unpad_input
|
|
||||||
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
try:
|
||||||
|
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
|
||||||
|
except ImportError:
|
||||||
|
from flash_attn.flash_attn_interface import (
|
||||||
|
flash_attn_unpadded_qkvpacked_func as flash_attn_varlen_qkvpacked_func,
|
||||||
|
)
|
||||||
|
|
||||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -84,35 +92,15 @@ def forward(
|
|||||||
)
|
)
|
||||||
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
else:
|
else:
|
||||||
nheads = qkv.shape[-2]
|
qkv = rearrange(qkv, "b s ... -> (b s) ...")
|
||||||
|
cu_q_lens, max_s = get_cu_seqlens_from_pos_ids(position_ids)
|
||||||
|
cu_q_lens = cu_q_lens.squeeze()
|
||||||
|
|
||||||
# pylint: disable=invalid-name
|
output = flash_attn_varlen_qkvpacked_func(
|
||||||
x = rearrange(qkv, "b s three h d -> b s (three h d)")
|
qkv, cu_q_lens, max_s, 0.0, softmax_scale=None, causal=True
|
||||||
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
|
|
||||||
x_unpad = rearrange(
|
|
||||||
x_unpad,
|
|
||||||
"nnz (three h d) -> nnz three h d",
|
|
||||||
three=3,
|
|
||||||
h=nheads,
|
|
||||||
)
|
|
||||||
output_unpad = flash_attn_varlen_qkvpacked_func(
|
|
||||||
x_unpad,
|
|
||||||
cu_q_lens,
|
|
||||||
max_s,
|
|
||||||
0.0,
|
|
||||||
softmax_scale=None,
|
|
||||||
causal=True,
|
|
||||||
)
|
|
||||||
output = rearrange(
|
|
||||||
pad_input(
|
|
||||||
rearrange(output_unpad, "nnz h d -> nnz (h d)"),
|
|
||||||
indices,
|
|
||||||
bsz,
|
|
||||||
q_len,
|
|
||||||
),
|
|
||||||
"b s (h d) -> b s h d",
|
|
||||||
h=nheads,
|
|
||||||
)
|
)
|
||||||
|
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
self.o_proj(rearrange(output, "b s h d -> b s (h d)")),
|
||||||
None,
|
None,
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ def xformers_forward(
|
|||||||
query_states,
|
query_states,
|
||||||
key_states,
|
key_states,
|
||||||
value_states,
|
value_states,
|
||||||
|
# attn_bias=attention_mask,
|
||||||
attn_bias=xformers.ops.LowerTriangularMask(),
|
attn_bias=xformers.ops.LowerTriangularMask(),
|
||||||
)
|
)
|
||||||
attn_weights = None
|
attn_weights = None
|
||||||
|
|||||||
52
src/axolotl/monkeypatch/llama_expand_mask.py
Normal file
52
src/axolotl/monkeypatch/llama_expand_mask.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
expands the binary attention mask per 3.2.2 of https://arxiv.org/pdf/2107.02027.pdf
|
||||||
|
"""
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||||
|
"""
|
||||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||||
|
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||||
|
when they attend to each other within that sequence.
|
||||||
|
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||||
|
"""
|
||||||
|
bsz, src_len = mask.size()
|
||||||
|
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||||
|
|
||||||
|
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||||
|
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||||
|
|
||||||
|
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||||
|
binary_mask = torch.where(
|
||||||
|
mask != 0,
|
||||||
|
torch.tensor(1).to(dtype),
|
||||||
|
torch.tensor(0).to(dtype),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create a block-diagonal mask.
|
||||||
|
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||||
|
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||||
|
|
||||||
|
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||||
|
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||||
|
mask.device
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||||
|
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||||
|
inverted_mask = 1.0 - masked_zero_one_mask
|
||||||
|
|
||||||
|
return inverted_mask.masked_fill(
|
||||||
|
inverted_mask.to(torch.bool), torch.finfo(dtype).min
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def hijack_expand_mask():
|
||||||
|
import transformers
|
||||||
|
|
||||||
|
transformers.models.llama.modeling_llama._expand_mask = ( # pylint: disable=protected-access
|
||||||
|
_expand_mask
|
||||||
|
)
|
||||||
103
src/axolotl/monkeypatch/utils.py
Normal file
103
src/axolotl/monkeypatch/utils.py
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
"""
|
||||||
|
Shared utils for the monkeypatches
|
||||||
|
"""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_seqlens(attn_mask):
|
||||||
|
"""generate a cumulative sequence length mask for flash attention using attn mask"""
|
||||||
|
if len(attn_mask.shape) == 1:
|
||||||
|
attn_mask = attn_mask.unsqueeze(0)
|
||||||
|
|
||||||
|
device = attn_mask.device
|
||||||
|
results = []
|
||||||
|
max_seq_lens = []
|
||||||
|
|
||||||
|
for row in attn_mask:
|
||||||
|
# Exclude zeros to avoid adding their positions to the mask
|
||||||
|
t_non_zeros = row[row != 0]
|
||||||
|
# Find where the sequence number changes (including the first position)
|
||||||
|
seq_change = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([1], dtype=torch.int32, device=device),
|
||||||
|
t_non_zeros[1:] != t_non_zeros[:-1],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Get the indices where the sequence changes
|
||||||
|
change_indices = torch.cat(
|
||||||
|
[
|
||||||
|
(seq_change == 1).nonzero(as_tuple=True)[0],
|
||||||
|
torch.tensor([len(t_non_zeros)], dtype=torch.int32, device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the sequence lengths
|
||||||
|
seq_lengths = change_indices[1:] - change_indices[:-1]
|
||||||
|
# Calculate the length of the final sequence or padding
|
||||||
|
final_seq_length = len(row) - change_indices[-1]
|
||||||
|
# Append the length of the final sequence or padding to seq_lengths
|
||||||
|
if final_seq_length.item():
|
||||||
|
seq_lengths = torch.cat(
|
||||||
|
[
|
||||||
|
seq_lengths,
|
||||||
|
torch.tensor(
|
||||||
|
[final_seq_length.item()], dtype=torch.int32, device=device
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the cumulative sequence lengths
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||||
|
)
|
||||||
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
|
results.append(cu_seqlens)
|
||||||
|
max_seq_lens.append(max_seq_len)
|
||||||
|
|
||||||
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
|
|
||||||
|
|
||||||
|
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||||
|
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||||
|
if len(position_ids.shape) == 1:
|
||||||
|
position_ids = position_ids.unsqueeze(0)
|
||||||
|
|
||||||
|
device = position_ids.device
|
||||||
|
results = []
|
||||||
|
max_seq_lens = []
|
||||||
|
|
||||||
|
for row in position_ids:
|
||||||
|
# Count the number of consecutive zeros from the right side
|
||||||
|
padding_length = (row == 0).int().flip(dims=[0]).cumprod(dim=0).sum().item()
|
||||||
|
|
||||||
|
# Adjust the row to exclude padding
|
||||||
|
adjusted_row = row[:-padding_length] if padding_length else row.clone()
|
||||||
|
|
||||||
|
# Find where the position resets to 0 (indicating a new sequence)
|
||||||
|
seq_starts = torch.cat(
|
||||||
|
[
|
||||||
|
torch.tensor([True], dtype=torch.bool, device=device),
|
||||||
|
adjusted_row[1:] == 0,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Get the indices where the sequence starts
|
||||||
|
start_indices = torch.cat(
|
||||||
|
[
|
||||||
|
(seq_starts).nonzero(as_tuple=True)[0],
|
||||||
|
torch.tensor([len(adjusted_row)], dtype=torch.int32, device=device),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Calculate the sequence lengths
|
||||||
|
seq_lengths = start_indices[1:] - start_indices[:-1]
|
||||||
|
# Calculate the cumulative sequence lengths
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[torch.tensor([0], dtype=torch.int32, device=device), seq_lengths.cumsum(0)]
|
||||||
|
)
|
||||||
|
# Append the padding length to the cumulative sequence lengths
|
||||||
|
if padding_length:
|
||||||
|
cu_seqlens = torch.cat(
|
||||||
|
[cu_seqlens, torch.tensor([len(row)], dtype=torch.int32, device=device)]
|
||||||
|
)
|
||||||
|
max_seq_len = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
|
||||||
|
results.append(cu_seqlens)
|
||||||
|
max_seq_lens.append(max_seq_len)
|
||||||
|
|
||||||
|
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||||
@@ -66,7 +66,11 @@ class SystemDataPrompter(AlpacaPrompter):
|
|||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
# returns the full prompt from instruction and optional input
|
# returns the full prompt from instruction and optional input
|
||||||
# if a label (=response, =output) is provided, it's also appended.
|
# if a label (=response, =output) is provided, it's also appended.
|
||||||
formatted_sys_prompt = f"### System:\n{system}\n\n" if system else ""
|
formatted_sys_prompt = (
|
||||||
|
self.system_format.format(system=system)
|
||||||
|
if system and self.system_format
|
||||||
|
else ""
|
||||||
|
)
|
||||||
if input:
|
if input:
|
||||||
res = formatted_sys_prompt + self.turn_format.format(
|
res = formatted_sys_prompt + self.turn_format.format(
|
||||||
instruction=instruction, input=input
|
instruction=instruction, input=input
|
||||||
@@ -86,12 +90,20 @@ class OpenOrcaSystemDataPrompter(SystemDataPrompter):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def match_prompt_style(self):
|
def match_prompt_style(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||||
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
self.turn_format = "### User:\n{instruction}\n\n### Additional Context:\n{input}\n\n### Assistant:\n"
|
||||||
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
self.turn_no_input_format = "### User:\n{instruction}\n\n### Assistant:\n"
|
||||||
if self.prompt_style == PromptStyle.CHAT.value:
|
if self.prompt_style == PromptStyle.CHAT.value:
|
||||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
self.turn_format = "User: {instruction}\n{input}\nAssistant:"
|
||||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
self.turn_no_input_format = "User: {instruction}\nAssistant:"
|
||||||
|
self.system_format = "System: {system}\n"
|
||||||
|
if self.prompt_style == PromptStyle.CHATML.value:
|
||||||
|
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.turn_no_input_format = (
|
||||||
|
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||||
|
|
||||||
|
|
||||||
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
class OpenOrcaPromptTokenizingStrategy(InstructionWSystemPromptTokenizingStrategy):
|
||||||
@@ -137,3 +149,12 @@ def load_open_orca(tokenizer, cfg):
|
|||||||
cfg.train_on_inputs,
|
cfg.train_on_inputs,
|
||||||
cfg.sequence_len,
|
cfg.sequence_len,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_open_orca_chatml(tokenizer, cfg):
|
||||||
|
return OpenOrcaPromptTokenizingStrategy(
|
||||||
|
OpenOrcaSystemDataPrompter(PromptStyle.CHATML.value),
|
||||||
|
tokenizer,
|
||||||
|
cfg.train_on_inputs,
|
||||||
|
cfg.sequence_len,
|
||||||
|
)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ class PromptStyle(Enum):
|
|||||||
|
|
||||||
INSTRUCT = "instruct"
|
INSTRUCT = "instruct"
|
||||||
CHAT = "chat"
|
CHAT = "chat"
|
||||||
|
CHATML = "chatml"
|
||||||
|
|
||||||
|
|
||||||
class AlpacaPrompter:
|
class AlpacaPrompter:
|
||||||
@@ -25,6 +26,7 @@ class AlpacaPrompter:
|
|||||||
|
|
||||||
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
system_prompt = "Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n\n"
|
||||||
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
system_no_input_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n"
|
||||||
|
system_format: str
|
||||||
turn_format: str
|
turn_format: str
|
||||||
turn_no_input_format: str
|
turn_no_input_format: str
|
||||||
prompt_style: Optional[PromptStyle] = None
|
prompt_style: Optional[PromptStyle] = None
|
||||||
@@ -34,14 +36,23 @@ class AlpacaPrompter:
|
|||||||
self.match_prompt_style()
|
self.match_prompt_style()
|
||||||
|
|
||||||
def match_prompt_style(self):
|
def match_prompt_style(self):
|
||||||
|
# pylint: disable=duplicate-code
|
||||||
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
if self.prompt_style == PromptStyle.INSTRUCT.value:
|
||||||
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
self.turn_format = "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:\n"
|
||||||
self.turn_no_input_format = (
|
self.turn_no_input_format = (
|
||||||
"### Instruction:\n{instruction}\n\n### Response:\n"
|
"### Instruction:\n{instruction}\n\n### Response:\n"
|
||||||
)
|
)
|
||||||
|
self.system_format = "### System:\n{system}\n\n"
|
||||||
if self.prompt_style == PromptStyle.CHAT.value:
|
if self.prompt_style == PromptStyle.CHAT.value:
|
||||||
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
self.turn_format = "USER: {instruction}\n{input}\nASSISTANT:"
|
||||||
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
self.turn_no_input_format = "USER: {instruction}\nASSISTANT:"
|
||||||
|
self.system_format = "SYSTEM: {system}\n"
|
||||||
|
if self.prompt_style == PromptStyle.CHATML.value:
|
||||||
|
self.turn_format = "<|im_start|>user\n{instruction}\n{input}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
self.turn_no_input_format = (
|
||||||
|
"<|im_start|>user\n{instruction}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
)
|
||||||
|
self.system_format = "<|im_start|>system\n{system}<|im_end|>\n"
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
self,
|
self,
|
||||||
|
|||||||
121
src/axolotl/utils/collators.py
Normal file
121
src/axolotl/utils/collators.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""
|
||||||
|
DataCollator for axolotl to pad labels and position_ids for packed sequences
|
||||||
|
"""
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from transformers import PreTrainedTokenizerBase
|
||||||
|
from transformers.utils import PaddingStrategy
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DataCollatorForSeq2Seq:
|
||||||
|
"""
|
||||||
|
Data collator that will dynamically pad the inputs received, as well as the labels and position_ids
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tokenizer ([`PreTrainedTokenizer`] or [`PreTrainedTokenizerFast`]):
|
||||||
|
The tokenizer used for encoding the data.
|
||||||
|
model ([`PreTrainedModel`]):
|
||||||
|
The model that is being trained. If set and has the *prepare_decoder_input_ids_from_labels*, use it to
|
||||||
|
prepare the *decoder_input_ids*
|
||||||
|
|
||||||
|
This is useful when using *label_smoothing* to avoid calculating loss twice.
|
||||||
|
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `True`):
|
||||||
|
Select a strategy to pad the returned sequences (according to the model's padding side and padding index)
|
||||||
|
among:
|
||||||
|
|
||||||
|
- `True` or `'longest'` (default): Pad to the longest sequence in the batch (or no padding if only a single
|
||||||
|
sequence is provided).
|
||||||
|
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
|
||||||
|
acceptable input length for the model if that argument is not provided.
|
||||||
|
- `False` or `'do_not_pad'`: No padding (i.e., can output a batch with sequences of different lengths).
|
||||||
|
max_length (`int`, *optional*):
|
||||||
|
Maximum length of the returned list and optionally padding length (see above).
|
||||||
|
pad_to_multiple_of (`int`, *optional*):
|
||||||
|
If set will pad the sequence to a multiple of the provided value.
|
||||||
|
|
||||||
|
This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability >=
|
||||||
|
7.5 (Volta).
|
||||||
|
label_pad_token_id (`int`, *optional*, defaults to -100):
|
||||||
|
The id to use when padding the labels (-100 will be automatically ignored by PyTorch loss functions).
|
||||||
|
return_tensors (`str`):
|
||||||
|
The type of Tensor to return. Allowable values are "np", "pt" and "tf".
|
||||||
|
"""
|
||||||
|
|
||||||
|
tokenizer: PreTrainedTokenizerBase
|
||||||
|
model: Optional[Any] = None
|
||||||
|
padding: Union[bool, str, PaddingStrategy] = True
|
||||||
|
max_length: Optional[int] = None
|
||||||
|
pad_to_multiple_of: Optional[int] = None
|
||||||
|
label_pad_token_id: int = -100
|
||||||
|
position_pad_token_id: int = 0
|
||||||
|
return_tensors: str = "pt"
|
||||||
|
|
||||||
|
def __call__(self, features, return_tensors=None):
|
||||||
|
labels = None
|
||||||
|
if return_tensors is None:
|
||||||
|
return_tensors = self.return_tensors
|
||||||
|
|
||||||
|
for feature_name, pad_token_id in [
|
||||||
|
("labels", self.label_pad_token_id),
|
||||||
|
("position_ids", self.position_pad_token_id),
|
||||||
|
]:
|
||||||
|
feat = (
|
||||||
|
[feature[feature_name] for feature in features]
|
||||||
|
if feature_name in features[0].keys()
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
labels = feat if feat and feature_name == "labels" else labels
|
||||||
|
# We have to pad the labels before calling `tokenizer.pad` as this method won't pad them and needs them of the
|
||||||
|
# same length to return tensors.
|
||||||
|
if feat is not None:
|
||||||
|
max_feature_length = max(len(l) for l in feat) # noqa: E741
|
||||||
|
if self.pad_to_multiple_of is not None:
|
||||||
|
max_feature_length = (
|
||||||
|
(max_feature_length + self.pad_to_multiple_of - 1)
|
||||||
|
// self.pad_to_multiple_of
|
||||||
|
* self.pad_to_multiple_of
|
||||||
|
)
|
||||||
|
|
||||||
|
padding_side = self.tokenizer.padding_side
|
||||||
|
for feature in features:
|
||||||
|
remainder = [pad_token_id] * (
|
||||||
|
max_feature_length - len(feature[feature_name])
|
||||||
|
)
|
||||||
|
if isinstance(feature[feature_name], list):
|
||||||
|
feature[feature_name] = (
|
||||||
|
feature[feature_name] + remainder
|
||||||
|
if padding_side == "right"
|
||||||
|
else remainder + feature[feature_name]
|
||||||
|
)
|
||||||
|
elif padding_side == "right":
|
||||||
|
feature[feature_name] = np.concatenate(
|
||||||
|
[feature[feature_name], remainder]
|
||||||
|
).astype(np.int64)
|
||||||
|
else:
|
||||||
|
feature[feature_name] = np.concatenate(
|
||||||
|
[remainder, feature[feature_name]]
|
||||||
|
).astype(np.int64)
|
||||||
|
|
||||||
|
features = self.tokenizer.pad(
|
||||||
|
features,
|
||||||
|
padding=self.padding,
|
||||||
|
max_length=self.max_length,
|
||||||
|
pad_to_multiple_of=self.pad_to_multiple_of,
|
||||||
|
return_tensors=return_tensors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# prepare decoder_input_ids
|
||||||
|
if (
|
||||||
|
labels is not None
|
||||||
|
and self.model is not None
|
||||||
|
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
||||||
|
):
|
||||||
|
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(
|
||||||
|
labels=features["labels"]
|
||||||
|
)
|
||||||
|
features["decoder_input_ids"] = decoder_input_ids
|
||||||
|
|
||||||
|
return features
|
||||||
@@ -1,5 +1,6 @@
|
|||||||
"""Module containing data utilities"""
|
"""Module containing data utilities"""
|
||||||
import functools
|
import functools
|
||||||
|
import hashlib
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
@@ -35,6 +36,7 @@ from axolotl.prompters import (
|
|||||||
ShareGPTPrompter,
|
ShareGPTPrompter,
|
||||||
SummarizeTLDRPrompter,
|
SummarizeTLDRPrompter,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.distributed import barrier, is_main_process
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
@@ -109,6 +111,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
local_path = Path(d.path)
|
local_path = Path(d.path)
|
||||||
if local_path.exists():
|
if local_path.exists():
|
||||||
if local_path.is_dir():
|
if local_path.is_dir():
|
||||||
|
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
||||||
ds = load_dataset(
|
ds = load_dataset(
|
||||||
d.path,
|
d.path,
|
||||||
name=d.name,
|
name=d.name,
|
||||||
@@ -374,6 +377,7 @@ def load_prepare_datasets(
|
|||||||
dataset = Dataset.from_list(list(constant_len_dataset))
|
dataset = Dataset.from_list(list(constant_len_dataset))
|
||||||
|
|
||||||
# filter out bad data
|
# filter out bad data
|
||||||
|
# TODO convert to dataset.filter(...)
|
||||||
dataset = Dataset.from_list(
|
dataset = Dataset.from_list(
|
||||||
[
|
[
|
||||||
d
|
d
|
||||||
@@ -413,7 +417,51 @@ def load_prepare_datasets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if cfg.val_set_size:
|
if cfg.val_set_size:
|
||||||
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
|
# ensure we end up with the same fingerprint by doing rank0 first and being able to cache
|
||||||
|
to_hash_train = (
|
||||||
|
dataset._fingerprint # pylint: disable=protected-access
|
||||||
|
+ "|"
|
||||||
|
+ str(cfg.val_set_size)
|
||||||
|
+ "|"
|
||||||
|
+ "train"
|
||||||
|
+ "|"
|
||||||
|
+ str(cfg.seed or 42)
|
||||||
|
)
|
||||||
|
to_hash_test = (
|
||||||
|
dataset._fingerprint # pylint: disable=protected-access
|
||||||
|
+ "|"
|
||||||
|
+ str(cfg.val_set_size)
|
||||||
|
+ "|"
|
||||||
|
+ "test"
|
||||||
|
+ "|"
|
||||||
|
+ str(cfg.seed or 42)
|
||||||
|
)
|
||||||
|
train_fingerprint = hashlib.md5(
|
||||||
|
to_hash_train.encode(), usedforsecurity=False
|
||||||
|
).hexdigest()
|
||||||
|
test_fingerprint = hashlib.md5(
|
||||||
|
to_hash_test.encode(), usedforsecurity=False
|
||||||
|
).hexdigest()
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
dataset = dataset.train_test_split(
|
||||||
|
test_size=cfg.val_set_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=cfg.seed or 42,
|
||||||
|
train_new_fingerprint=train_fingerprint,
|
||||||
|
test_new_fingerprint=test_fingerprint,
|
||||||
|
)
|
||||||
|
barrier()
|
||||||
|
if not is_main_process():
|
||||||
|
dataset = dataset.train_test_split(
|
||||||
|
test_size=cfg.val_set_size,
|
||||||
|
shuffle=False,
|
||||||
|
seed=cfg.seed or 42,
|
||||||
|
train_new_fingerprint=train_fingerprint,
|
||||||
|
test_new_fingerprint=test_fingerprint,
|
||||||
|
)
|
||||||
|
barrier()
|
||||||
|
|
||||||
train_dataset = dataset["train"]
|
train_dataset = dataset["train"]
|
||||||
eval_dataset = dataset["test"]
|
eval_dataset = dataset["test"]
|
||||||
else:
|
else:
|
||||||
|
|||||||
310
src/axolotl/utils/dataloader.py
Normal file
310
src/axolotl/utils/dataloader.py
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
# pylint: skip-file
|
||||||
|
import hashlib
|
||||||
|
import itertools
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
|
import numba
|
||||||
|
import numpy as np
|
||||||
|
from torch.utils.data import DistributedSampler, Sampler
|
||||||
|
|
||||||
|
LOG = logging.getLogger("axolotl.utils.dataloader")
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_check(a: np.ndarray, c: int, n: int):
|
||||||
|
# First-fit-decreasing bin packing
|
||||||
|
# Check if a[] could fit in n bins with capacity c
|
||||||
|
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
|
||||||
|
|
||||||
|
a = np.sort(a)[::-1]
|
||||||
|
bins = np.full((n,), c, dtype=a.dtype)
|
||||||
|
for size in a:
|
||||||
|
not_found = True
|
||||||
|
for idx in range(n):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
not_found = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if not_found:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
|
||||||
|
# First-fit-decreasing bin packing (with result return)
|
||||||
|
|
||||||
|
indices = np.argsort(a)[::-1]
|
||||||
|
a = a[indices]
|
||||||
|
|
||||||
|
bins: List[Any] = []
|
||||||
|
bins_result: List[Any] = []
|
||||||
|
for a_id, size in enumerate(a):
|
||||||
|
add_new = True
|
||||||
|
for idx in range(len(bins)):
|
||||||
|
if bins[idx] >= size:
|
||||||
|
bins[idx] -= size
|
||||||
|
bins_result[idx].append(indices[a_id] + start_index)
|
||||||
|
add_new = False
|
||||||
|
break
|
||||||
|
|
||||||
|
if add_new:
|
||||||
|
bins.append(c - size)
|
||||||
|
bins_result.append([indices[a_id] + start_index])
|
||||||
|
|
||||||
|
return bins_result, len(a)
|
||||||
|
|
||||||
|
|
||||||
|
@numba.njit
|
||||||
|
def allocate(
|
||||||
|
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
:param lengths: array of lengths of each sample
|
||||||
|
:param lengths_cumsum: cumulative sum of consecutive lengths
|
||||||
|
:param rank: rank for this process
|
||||||
|
:param c: length of tokens per batch
|
||||||
|
:param n: number of ranks
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
# Dynamic batch allocator, similar to Multifit
|
||||||
|
# https://en.wikipedia.org/wiki/Multifit_algorithm
|
||||||
|
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
|
||||||
|
|
||||||
|
s = 0
|
||||||
|
start_index = 0
|
||||||
|
result = []
|
||||||
|
|
||||||
|
while True:
|
||||||
|
# binary search [left, right)
|
||||||
|
left = 1
|
||||||
|
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
|
||||||
|
|
||||||
|
while right - left > 1:
|
||||||
|
mid = (left + right) // 2
|
||||||
|
if ffd_check(lengths[start_index : start_index + mid], c, n):
|
||||||
|
left = mid
|
||||||
|
else:
|
||||||
|
right = mid
|
||||||
|
|
||||||
|
# use length left
|
||||||
|
batch, tot_seqs = ffd_with_result(
|
||||||
|
lengths[start_index : start_index + left], c, start_index
|
||||||
|
)
|
||||||
|
if len(batch) < n:
|
||||||
|
break
|
||||||
|
|
||||||
|
start_index += left
|
||||||
|
s = lengths_cumsum[start_index - 1]
|
||||||
|
|
||||||
|
# add local rank
|
||||||
|
result.append(batch[rank])
|
||||||
|
|
||||||
|
yield batch[rank], tot_seqs, s, len(result) * c * n
|
||||||
|
|
||||||
|
|
||||||
|
def chunk(iterable, n):
|
||||||
|
"""
|
||||||
|
Chunk data into tuples of length n
|
||||||
|
"""
|
||||||
|
# batched('ABCDEFG', 3) --> ABC DEF G
|
||||||
|
if n < 1:
|
||||||
|
raise ValueError("n must be at least one")
|
||||||
|
it = iter(iterable)
|
||||||
|
while batch := tuple(itertools.islice(it, n)):
|
||||||
|
yield batch
|
||||||
|
|
||||||
|
|
||||||
|
def hash_indices(lst: List[int]) -> str:
|
||||||
|
# Convert the list of integers to a string representation
|
||||||
|
concatenated = ",".join(map(str, lst))
|
||||||
|
|
||||||
|
# Generate the hash
|
||||||
|
sha256 = hashlib.sha256()
|
||||||
|
sha256.update(concatenated.encode())
|
||||||
|
|
||||||
|
return sha256.hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
class MultipackDistributedDataloader:
|
||||||
|
"""Unpadded data loading using Multipack.
|
||||||
|
Adapted from https://github.com/imoneoi/openchat/blob/v3_fix_mle_loss/ochat/training_deepspeed/multipack_dataloader.py
|
||||||
|
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
dataset: Any,
|
||||||
|
collate_fn: Callable,
|
||||||
|
seq_max_length: int = 2048,
|
||||||
|
batch_size: int = 1,
|
||||||
|
sampler: Union[Sampler, DistributedSampler] = None,
|
||||||
|
packing_efficiency_estimate: float = 1.0,
|
||||||
|
sample_packing_seq_len_multiplier: int = 1,
|
||||||
|
device_count: int = 1,
|
||||||
|
total_num_tokens: Optional[int] = None,
|
||||||
|
):
|
||||||
|
# Dataset
|
||||||
|
self.dataset = dataset
|
||||||
|
lengths_series = (
|
||||||
|
dataset.data.column("position_ids").to_pandas().apply(lambda x: x[-1] + 1)
|
||||||
|
)
|
||||||
|
self.lengths: np.ndarray = lengths_series.values
|
||||||
|
assert isinstance(self.lengths, np.ndarray)
|
||||||
|
assert batch_size % sample_packing_seq_len_multiplier == 0
|
||||||
|
assert batch_size >= sample_packing_seq_len_multiplier
|
||||||
|
self.sampler = sampler
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.sample_packing_seq_len_multiplier = sample_packing_seq_len_multiplier
|
||||||
|
self.seq_max_length = seq_max_length
|
||||||
|
self.batch_max_length = batch_size * seq_max_length
|
||||||
|
self.collate_fn = collate_fn
|
||||||
|
|
||||||
|
self.num_replicas = 1
|
||||||
|
self.rank = 0
|
||||||
|
|
||||||
|
# statistics
|
||||||
|
self.total_num_tokens = total_num_tokens
|
||||||
|
self.eff_total_used = 0
|
||||||
|
self.eff_total_slots = 0
|
||||||
|
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||||
|
self.device_count = device_count
|
||||||
|
|
||||||
|
# for non-blocking batch creation
|
||||||
|
self.batch_queue: queue.Queue = queue.Queue(
|
||||||
|
maxsize=10
|
||||||
|
) # Adjust maxsize as needed
|
||||||
|
|
||||||
|
def generate_batches(self, set_stats=False):
|
||||||
|
LOG.info("generating packed batches")
|
||||||
|
if self.sampler:
|
||||||
|
indices = [idx for idx in self.sampler]
|
||||||
|
else:
|
||||||
|
indices = range(0, len(self.dataset))
|
||||||
|
|
||||||
|
LOG.info(hash_indices(indices))
|
||||||
|
lengths = self.lengths[indices]
|
||||||
|
lengths_cumsum = np.cumsum(lengths)
|
||||||
|
|
||||||
|
alloc_iter = iter(
|
||||||
|
allocate(
|
||||||
|
lengths=lengths,
|
||||||
|
lengths_cumsum=lengths_cumsum,
|
||||||
|
rank=self.rank,
|
||||||
|
# c=self.batch_max_length,
|
||||||
|
c=self.seq_max_length * self.sample_packing_seq_len_multiplier,
|
||||||
|
n=self.num_replicas,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for batch, tot_seqs, total_used, total_slots in alloc_iter:
|
||||||
|
self.batch_queue.put([indices[b_idx] for b_idx in batch])
|
||||||
|
# statistics
|
||||||
|
if set_stats:
|
||||||
|
self.eff_total_used = total_used
|
||||||
|
self.eff_total_slots = total_slots
|
||||||
|
self.batch_queue.put(None) # Signal the end of batch generation
|
||||||
|
|
||||||
|
def _generate_batches_thread(self):
|
||||||
|
try:
|
||||||
|
self.generate_batches(set_stats=True)
|
||||||
|
except Exception as e:
|
||||||
|
LOG.error(f"Error in batch generation thread: {e}")
|
||||||
|
self.batch_queue.put(
|
||||||
|
None
|
||||||
|
) # Signal the end of batch generation in case of error
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
if hasattr(self.sampler, "set_epoch"):
|
||||||
|
new_epoch = self.sampler.epoch + 1
|
||||||
|
self.sampler.set_epoch(new_epoch)
|
||||||
|
LOG.info(f"calling sampler.set_epoch({new_epoch})")
|
||||||
|
# Start the batch generation in a separate thread
|
||||||
|
batch_gen_thread = threading.Thread(target=self._generate_batches_thread)
|
||||||
|
batch_gen_thread.start()
|
||||||
|
|
||||||
|
features = self.dataset.features.keys()
|
||||||
|
len_remaining = self._len_est()
|
||||||
|
while True:
|
||||||
|
batch = self.batch_queue.get()
|
||||||
|
if batch is None: # Sentinel value received, stop iteration
|
||||||
|
break
|
||||||
|
chunked_data = []
|
||||||
|
attn_mask_cum_idx = 0
|
||||||
|
concatenated = {}
|
||||||
|
batched_data = [self.dataset[batch_idx] for batch_idx in batch]
|
||||||
|
for feature in features:
|
||||||
|
if feature == "attention_mask":
|
||||||
|
arrays = [
|
||||||
|
(attn_mask_cum_idx + idx + 1) * np.array(item[feature])
|
||||||
|
for idx, item in enumerate(batched_data)
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
attn_mask_cum_idx += len(batched_data)
|
||||||
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
|
else:
|
||||||
|
arrays = [
|
||||||
|
np.array(item[feature])
|
||||||
|
for item in batched_data
|
||||||
|
if feature in item
|
||||||
|
]
|
||||||
|
concatenated[feature] = np.concatenate(arrays)
|
||||||
|
chunked_data.append(concatenated)
|
||||||
|
|
||||||
|
yield self.collate_fn(chunked_data)
|
||||||
|
len_remaining -= 1
|
||||||
|
if not len_remaining:
|
||||||
|
break
|
||||||
|
# Wait for the batch generation thread to finish
|
||||||
|
batch_gen_thread.join(timeout=5)
|
||||||
|
LOG.info(f"actual packing efficiency: {self.efficiency()}")
|
||||||
|
|
||||||
|
def _len_est(self):
|
||||||
|
if not self.total_num_tokens:
|
||||||
|
self.total_num_tokens = np.sum(self.lengths)
|
||||||
|
lengths_sum_per_device = self.total_num_tokens // self.device_count
|
||||||
|
LOG.info(
|
||||||
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
f"total_num_tokens per device: {lengths_sum_per_device}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# shave off 1% + 1 for dealing with variance in packing from random sampler to sampler
|
||||||
|
return (
|
||||||
|
math.floor(
|
||||||
|
0.99
|
||||||
|
* lengths_sum_per_device
|
||||||
|
/ self.packing_efficiency_estimate
|
||||||
|
// self.seq_max_length
|
||||||
|
// self.batch_size
|
||||||
|
)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
# this doesn't return the actual length b/c with distributed samplers, not all dataloaders get
|
||||||
|
# the same share of total tokens
|
||||||
|
# if not self.eff_total_used:
|
||||||
|
# batches, _ = self.generate_batches(set_stats=True)
|
||||||
|
# LOG.info(
|
||||||
|
# f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
# f"actual packing efficiency: {self.efficiency()}"
|
||||||
|
# )
|
||||||
|
return max(1, self._len_est())
|
||||||
|
|
||||||
|
def len_w_stats(self):
|
||||||
|
if not self.eff_total_used:
|
||||||
|
batches, _ = self.generate_batches(set_stats=True)
|
||||||
|
LOG.info(
|
||||||
|
f"packing_efficiency_estimate: {self.packing_efficiency_estimate} "
|
||||||
|
f"actual packing efficiency: {self.efficiency()}"
|
||||||
|
)
|
||||||
|
return max(1, self._len_est())
|
||||||
|
|
||||||
|
def efficiency(self):
|
||||||
|
return self.eff_total_used / self.eff_total_slots
|
||||||
41
src/axolotl/utils/distributed.py
Normal file
41
src/axolotl/utils/distributed.py
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
"""
|
||||||
|
utility helpers for distributed checks
|
||||||
|
"""
|
||||||
|
import torch.distributed as dist
|
||||||
|
from accelerate import Accelerator
|
||||||
|
|
||||||
|
accelerate = None # pylint: disable=invalid-name
|
||||||
|
|
||||||
|
|
||||||
|
def load_accelerate():
|
||||||
|
global accelerate # pylint: disable=global-statement
|
||||||
|
accelerate = Accelerator()
|
||||||
|
|
||||||
|
|
||||||
|
def is_distributed():
|
||||||
|
"""
|
||||||
|
Check if distributed training is initialized.
|
||||||
|
"""
|
||||||
|
global accelerate # pylint: disable=global-statement
|
||||||
|
if not accelerate:
|
||||||
|
accelerate = Accelerator()
|
||||||
|
return dist.is_available() and dist.is_initialized()
|
||||||
|
|
||||||
|
|
||||||
|
def barrier():
|
||||||
|
"""
|
||||||
|
Acts as a barrier to wait for all processes. This ensures that all processes
|
||||||
|
reach the barrier before proceeding further.
|
||||||
|
"""
|
||||||
|
if is_distributed():
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
|
||||||
|
def is_main_process():
|
||||||
|
"""
|
||||||
|
Check if the current process is the main process.
|
||||||
|
If not in distributed mode, always return True.
|
||||||
|
"""
|
||||||
|
if not is_distributed():
|
||||||
|
return True
|
||||||
|
return dist.get_rank() == 0
|
||||||
@@ -36,20 +36,26 @@ def load_tokenizer(
|
|||||||
tokenizer_type,
|
tokenizer_type,
|
||||||
cfg,
|
cfg,
|
||||||
):
|
):
|
||||||
|
tokenizer_kwargs = {}
|
||||||
use_fast = True # this is the default
|
use_fast = True # this is the default
|
||||||
if cfg.tokenizer_use_fast is not None:
|
if cfg.tokenizer_use_fast is not None:
|
||||||
use_fast = cfg.tokenizer_use_fast
|
use_fast = cfg.tokenizer_use_fast
|
||||||
|
if cfg.tokenizer_legacy is not None:
|
||||||
|
# True is the default w/ https://github.com/huggingface/transformers/pull/25224
|
||||||
|
tokenizer_kwargs["legacy"] = cfg.tokenizer_legacy
|
||||||
if tokenizer_type:
|
if tokenizer_type:
|
||||||
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
tokenizer = getattr(transformers, tokenizer_type).from_pretrained(
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
use_fast=use_fast,
|
use_fast=use_fast,
|
||||||
|
**tokenizer_kwargs,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
tokenizer = AutoTokenizer.from_pretrained(
|
tokenizer = AutoTokenizer.from_pretrained(
|
||||||
tokenizer_config,
|
tokenizer_config,
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
use_fast=use_fast,
|
use_fast=use_fast,
|
||||||
|
**tokenizer_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||||
@@ -86,8 +92,10 @@ def load_model(
|
|||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
load_in_8bit = cfg.load_in_8bit
|
load_in_8bit = cfg.load_in_8bit
|
||||||
cfg.is_llama_derived_model = "llama" in base_model or (
|
cfg.is_llama_derived_model = (
|
||||||
cfg.model_type and "llama" in cfg.model_type.lower()
|
"llama" in base_model
|
||||||
|
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||||
|
or cfg.is_llama_derived_model is True
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.flash_attention:
|
if cfg.is_llama_derived_model and cfg.flash_attention:
|
||||||
@@ -132,6 +140,14 @@ def load_model(
|
|||||||
LOG.info("patching with xpos rope")
|
LOG.info("patching with xpos rope")
|
||||||
replace_llama_rope_with_xpos_rope()
|
replace_llama_rope_with_xpos_rope()
|
||||||
|
|
||||||
|
if cfg.is_llama_derived_model and (
|
||||||
|
cfg.max_packed_sequence_len or cfg.sample_packing
|
||||||
|
):
|
||||||
|
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||||
|
|
||||||
|
LOG.info("patching _expand_mask")
|
||||||
|
hijack_expand_mask()
|
||||||
|
|
||||||
if cfg.bf16 or cfg.bfloat16:
|
if cfg.bf16 or cfg.bfloat16:
|
||||||
torch_dtype = torch.bfloat16
|
torch_dtype = torch.bfloat16
|
||||||
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
elif cfg.load_in_8bit or cfg.fp16 or cfg.float16:
|
||||||
@@ -222,7 +238,6 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map="auto" if cfg.world_size == 1 else cfg.device_map,
|
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
# elif model_type == "GPTNeoXForCausalLM" and cfg.flash_attention:
|
||||||
@@ -257,7 +272,6 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -288,7 +302,6 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
@@ -302,7 +315,6 @@ def load_model(
|
|||||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||||
torch_dtype=torch_dtype,
|
torch_dtype=torch_dtype,
|
||||||
device_map=cfg.device_map,
|
|
||||||
trust_remote_code=cfg.trust_remote_code or False,
|
trust_remote_code=cfg.trust_remote_code or False,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
"""Module containing the Trainer class and related functions"""
|
"""Module containing the Trainer class and related functions"""
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional, Union
|
||||||
|
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
import torch.cuda
|
import torch.cuda
|
||||||
import transformers
|
import transformers
|
||||||
|
from datasets import Dataset, set_caching_enabled
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.optim.lr_scheduler import OneCycleLR
|
from torch.optim.lr_scheduler import OneCycleLR
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler, RandomSampler
|
||||||
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
|
||||||
@@ -21,6 +24,8 @@ from axolotl.utils.callbacks import (
|
|||||||
SaveBetterTransformerModelCallback,
|
SaveBetterTransformerModelCallback,
|
||||||
SavePeftModelCallback,
|
SavePeftModelCallback,
|
||||||
)
|
)
|
||||||
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
from axolotl.utils.schedulers import (
|
from axolotl.utils.schedulers import (
|
||||||
InterpolatingLogScheduler,
|
InterpolatingLogScheduler,
|
||||||
get_cosine_schedule_with_quadratic_warmup,
|
get_cosine_schedule_with_quadratic_warmup,
|
||||||
@@ -29,6 +34,68 @@ from axolotl.utils.schedulers import (
|
|||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def weighted_cross_entropy(
|
||||||
|
logits: torch.Tensor, labels: torch.Tensor, weights: torch.Tensor
|
||||||
|
):
|
||||||
|
# Flatten the logits, labels, and weights tensors
|
||||||
|
logits = logits.view(
|
||||||
|
-1, logits.size(-1)
|
||||||
|
) # logits becomes of shape [batch_size*sequence_length, vocab_size]
|
||||||
|
labels = labels.view(-1) # labels becomes of shape [batch_size*sequence_length]
|
||||||
|
weights = weights.view(-1) # weights becomes of shape [batch_size*sequence_length]
|
||||||
|
|
||||||
|
# Compute the unweighted cross entropy loss
|
||||||
|
losses = torch.nn.functional.cross_entropy(logits, labels, reduction="none")
|
||||||
|
|
||||||
|
# Apply the weights to the losses and compute their sum
|
||||||
|
return (weights * losses).sum()
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def create_weighted_mask(labels: torch.Tensor):
|
||||||
|
# Check if the tensor is 2D. If not, unsqueeze it to make it 2D
|
||||||
|
if len(labels.shape) == 1:
|
||||||
|
labels = labels.unsqueeze(0)
|
||||||
|
|
||||||
|
weights = torch.zeros_like(labels).float()
|
||||||
|
for i in range(labels.shape[0]):
|
||||||
|
mask = labels[i] != -100
|
||||||
|
|
||||||
|
# Create a tensor to track group ids
|
||||||
|
group_ids = torch.zeros_like(labels[i]).int()
|
||||||
|
curr_group_id = 0
|
||||||
|
|
||||||
|
for j in range(1, len(labels[i])):
|
||||||
|
if mask[j] and not mask[j - 1]: # switch from masked to unmasked label
|
||||||
|
curr_group_id += 1 # start new group
|
||||||
|
group_ids[j] = (
|
||||||
|
curr_group_id if mask[j] else 0
|
||||||
|
) # assign group id if unmasked label
|
||||||
|
|
||||||
|
# Count only unmasked labels in each group
|
||||||
|
group_counts = torch.bincount(group_ids[mask])
|
||||||
|
|
||||||
|
mask_weights = torch.zeros_like(labels[i]).float()
|
||||||
|
mask_weights[mask] = 1.0 / group_counts[group_ids[mask]]
|
||||||
|
|
||||||
|
weights[i] = mask_weights
|
||||||
|
|
||||||
|
return weights.squeeze() # squeeze the output to match the input dimension
|
||||||
|
|
||||||
|
|
||||||
|
def trainer_weighted_loss(model_output, labels, shift_labels=True):
|
||||||
|
logits = (
|
||||||
|
model_output["logits"] if isinstance(model_output, dict) else model_output[0]
|
||||||
|
)
|
||||||
|
if shift_labels:
|
||||||
|
logits = logits[..., :-1, :].contiguous()
|
||||||
|
labels = labels[..., 1:].contiguous()
|
||||||
|
|
||||||
|
weights = create_weighted_mask(labels)
|
||||||
|
return weighted_cross_entropy(logits, labels, weights)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class AxolotlTrainingArguments(TrainingArguments):
|
class AxolotlTrainingArguments(TrainingArguments):
|
||||||
"""
|
"""
|
||||||
@@ -39,6 +106,26 @@ class AxolotlTrainingArguments(TrainingArguments):
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
metadata={"help": "Use quadratic warmup for cosine scheduling."},
|
||||||
)
|
)
|
||||||
|
sample_packing: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use sample packing for efficient training."},
|
||||||
|
)
|
||||||
|
sample_packing_efficiency: float = field(
|
||||||
|
default=1.0,
|
||||||
|
metadata={"help": "Sample packing efficiency for calculating batch length."},
|
||||||
|
)
|
||||||
|
max_seq_length: int = field(
|
||||||
|
default=2048,
|
||||||
|
metadata={"help": "The maximum sequence length the model can handle"},
|
||||||
|
)
|
||||||
|
sample_packing_seq_len_multiplier: int = field(
|
||||||
|
default=1,
|
||||||
|
metadata={"help": "the multiplier for the max len for packed sequences"},
|
||||||
|
)
|
||||||
|
train_data_total_num_tokens: Optional[int] = field(
|
||||||
|
default=None,
|
||||||
|
metadata={"help": "the total number of tokens in the train dataset"},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(Trainer):
|
class AxolotlTrainer(Trainer):
|
||||||
@@ -76,6 +163,66 @@ class AxolotlTrainer(Trainer):
|
|||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||||
|
if self.args.world_size > 1 and self.args.sample_packing:
|
||||||
|
return DistributedSampler(
|
||||||
|
self.train_dataset,
|
||||||
|
num_replicas=self.args.world_size,
|
||||||
|
rank=self.args.process_index,
|
||||||
|
seed=self.args.seed,
|
||||||
|
)
|
||||||
|
return super()._get_train_sampler()
|
||||||
|
|
||||||
|
def get_train_dataloader(self) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
|
if self.args.sample_packing:
|
||||||
|
train_sampler = self._get_train_sampler()
|
||||||
|
return self.accelerator.prepare(
|
||||||
|
MultipackDistributedDataloader(
|
||||||
|
self.train_dataset,
|
||||||
|
batch_size=self._train_batch_size,
|
||||||
|
seq_max_length=self.args.max_seq_length,
|
||||||
|
collate_fn=self.data_collator,
|
||||||
|
sampler=train_sampler,
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
|
sample_packing_seq_len_multiplier=self.args.sample_packing_seq_len_multiplier,
|
||||||
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
total_num_tokens=self.args.train_data_total_num_tokens,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return super().get_train_dataloader()
|
||||||
|
|
||||||
|
def get_eval_dataloader(
|
||||||
|
self, eval_dataset: Optional[Dataset] = None
|
||||||
|
) -> Union[DataLoader, MultipackDistributedDataloader]:
|
||||||
|
if self.args.sample_packing:
|
||||||
|
eval_dataset = (
|
||||||
|
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||||
|
)
|
||||||
|
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||||
|
return self.accelerator.prepare(
|
||||||
|
MultipackDistributedDataloader(
|
||||||
|
eval_dataset,
|
||||||
|
batch_size=self.args.eval_batch_size,
|
||||||
|
seq_max_length=self.args.max_seq_length,
|
||||||
|
collate_fn=self.data_collator,
|
||||||
|
sampler=eval_sampler,
|
||||||
|
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||||
|
sample_packing_seq_len_multiplier=self.args.eval_batch_size,
|
||||||
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
total_num_tokens=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return super().get_eval_dataloader(eval_dataset)
|
||||||
|
|
||||||
|
def compute_loss(self, model, inputs, return_outputs=False):
|
||||||
|
# use one's weighted cross entropy loss calc
|
||||||
|
# if self.args.sample_packing:
|
||||||
|
# labels = inputs.pop("labels")
|
||||||
|
# outputs = model(**inputs)
|
||||||
|
# loss = trainer_weighted_loss(outputs, labels, shift_labels=True)
|
||||||
|
# return (loss, outputs) if return_outputs else loss
|
||||||
|
return super().compute_loss(model, inputs, return_outputs=return_outputs)
|
||||||
|
|
||||||
|
|
||||||
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
||||||
"""
|
"""
|
||||||
@@ -106,10 +253,117 @@ class OneCycleLRSchedulerTrainer(AxolotlTrainer):
|
|||||||
return self.lr_scheduler
|
return self.lr_scheduler
|
||||||
|
|
||||||
|
|
||||||
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
def add_position_ids(sample):
|
||||||
total_num_steps = int(
|
sample["position_ids"] = torch.arange(len(sample["input_ids"]))
|
||||||
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
return sample
|
||||||
)
|
|
||||||
|
|
||||||
|
def drop_long_seq(sample, sequence_len=2048):
|
||||||
|
return len(sample["input_ids"]) <= sequence_len
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def disable_datasets_caching():
|
||||||
|
try:
|
||||||
|
set_caching_enabled(False)
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
set_caching_enabled(True)
|
||||||
|
|
||||||
|
|
||||||
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
|
if cfg.sample_packing:
|
||||||
|
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||||
|
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
||||||
|
add_position_ids, num_proc=os.cpu_count()
|
||||||
|
)
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count()).map(
|
||||||
|
add_position_ids, num_proc=os.cpu_count()
|
||||||
|
)
|
||||||
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_total_num_steps(cfg, train_dataset, tokenizer):
|
||||||
|
if cfg.sample_packing:
|
||||||
|
# we have to drop anything longer then sequence len otherwise
|
||||||
|
# flash attention with position ids fails
|
||||||
|
total_num_tokens = (
|
||||||
|
cfg.total_num_tokens
|
||||||
|
if cfg.total_num_tokens
|
||||||
|
else sum(len(s["input_ids"]) for s in train_dataset)
|
||||||
|
)
|
||||||
|
if not cfg.total_num_tokens:
|
||||||
|
LOG.info(f"📝 UPDATE CONFIG WITH: `total_num_tokens: {total_num_tokens}`")
|
||||||
|
|
||||||
|
if cfg.sample_packing_eff_est:
|
||||||
|
total_num_steps = (
|
||||||
|
# match count to len est in dataloader
|
||||||
|
(
|
||||||
|
math.floor(
|
||||||
|
0.99
|
||||||
|
* total_num_tokens
|
||||||
|
/ cfg.sample_packing_eff_est
|
||||||
|
/ 2048
|
||||||
|
// cfg.batch_size
|
||||||
|
// int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
)
|
||||||
|
- 1
|
||||||
|
)
|
||||||
|
* cfg.num_epochs
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"total_num_tokens: {total_num_tokens}, total_num_steps: {total_num_steps}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sampler = RandomSampler(train_dataset)
|
||||||
|
data_loader = MultipackDistributedDataloader(
|
||||||
|
train_dataset,
|
||||||
|
batch_size=cfg.micro_batch_size,
|
||||||
|
seq_max_length=cfg.max_packed_sequence_len or cfg.sequence_len,
|
||||||
|
collate_fn=DataCollatorForSeq2Seq(
|
||||||
|
tokenizer,
|
||||||
|
return_tensors="pt",
|
||||||
|
padding="longest",
|
||||||
|
),
|
||||||
|
sampler=sampler,
|
||||||
|
packing_efficiency_estimate=cfg.sample_packing_eff_est,
|
||||||
|
sample_packing_seq_len_multiplier=cfg.micro_batch_size,
|
||||||
|
device_count=int(os.environ.get("WORLD_SIZE", 1)),
|
||||||
|
)
|
||||||
|
data_loader_len = data_loader.len_w_stats()
|
||||||
|
actual_eff = data_loader.efficiency()
|
||||||
|
LOG.info(f"data_loader_len: {data_loader_len}")
|
||||||
|
total_num_steps = int(
|
||||||
|
math.floor(
|
||||||
|
data_loader_len
|
||||||
|
* cfg.micro_batch_size
|
||||||
|
* cfg.num_epochs
|
||||||
|
// cfg.batch_size
|
||||||
|
)
|
||||||
|
)
|
||||||
|
LOG.info(
|
||||||
|
f"📝 UPDATE CONFIG WITH: `sample_packing_eff_est: {math.ceil(actual_eff * 100.0) / 100.0}`"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
total_num_steps = int(
|
||||||
|
math.ceil(len(train_dataset) * cfg.num_epochs / cfg.batch_size)
|
||||||
|
)
|
||||||
|
LOG.info(f"total_num_steps: {total_num_steps}")
|
||||||
|
return total_num_steps
|
||||||
|
|
||||||
|
|
||||||
|
def setup_fsdp_envs(cfg):
|
||||||
|
os.environ["ACCELERATE_USE_FSDP"] = "true"
|
||||||
|
if cfg.fsdp_config.fsdp_sync_module_states:
|
||||||
|
os.environ["FSDP_SYNC_MODULE_STATES"] = "true"
|
||||||
|
if cfg.fsdp_config.fsdp_state_dict_type:
|
||||||
|
os.environ["FSDP_STATE_DICT_TYPE"] = cfg.fsdp_config.fsdp_state_dict_type
|
||||||
|
|
||||||
|
|
||||||
|
def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_steps):
|
||||||
|
if cfg.fsdp:
|
||||||
|
setup_fsdp_envs(cfg)
|
||||||
warmup_steps = (
|
warmup_steps = (
|
||||||
cfg.warmup_steps
|
cfg.warmup_steps
|
||||||
if cfg.warmup_steps is not None
|
if cfg.warmup_steps is not None
|
||||||
@@ -189,7 +443,14 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.save_safetensors:
|
if cfg.save_safetensors:
|
||||||
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
|
||||||
|
|
||||||
|
if cfg.sample_packing_eff_est:
|
||||||
|
training_arguments_kwargs[
|
||||||
|
"sample_packing_efficiency"
|
||||||
|
] = cfg.sample_packing_eff_est
|
||||||
|
|
||||||
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
|
||||||
|
# max_steps=total_num_steps, # this is helpful in case we don't actually know total # of steps
|
||||||
|
max_seq_length=cfg.sequence_len,
|
||||||
per_device_train_batch_size=cfg.micro_batch_size,
|
per_device_train_batch_size=cfg.micro_batch_size,
|
||||||
per_device_eval_batch_size=cfg.eval_batch_size
|
per_device_eval_batch_size=cfg.eval_batch_size
|
||||||
if cfg.eval_batch_size is not None
|
if cfg.eval_batch_size is not None
|
||||||
@@ -203,7 +464,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None,
|
||||||
save_steps=cfg.save_steps,
|
save_steps=cfg.save_steps,
|
||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
save_total_limit=3,
|
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||||
load_best_model_at_end=(
|
load_best_model_at_end=(
|
||||||
cfg.load_best_model_at_end is not False
|
cfg.load_best_model_at_end is not False
|
||||||
and cfg.val_set_size > 0
|
and cfg.val_set_size > 0
|
||||||
@@ -221,6 +482,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
if cfg.lr_scheduler and cfg.lr_scheduler not in ("one_cycle", "log_sweep")
|
||||||
else "cosine",
|
else "cosine",
|
||||||
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
weight_decay=cfg.weight_decay if cfg.weight_decay is not None else 0.0,
|
||||||
|
sample_packing=cfg.sample_packing if cfg.sample_packing else False,
|
||||||
|
sample_packing_seq_len_multiplier=cfg.micro_batch_size or 1,
|
||||||
|
train_data_total_num_tokens=cfg.total_num_tokens,
|
||||||
**training_arguments_kwargs,
|
**training_arguments_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -314,11 +578,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
if cfg.collator_pad_to_longest:
|
if cfg.collator_pad_to_longest:
|
||||||
data_collator_kwargs["padding"] = "longest"
|
data_collator_kwargs["padding"] = "longest"
|
||||||
else:
|
else:
|
||||||
data_collator_kwargs["pad_to_multiple_of"] = 8
|
# A100 is best at 64, while others at 8. Let's use the larger so we don't have to check
|
||||||
|
# https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html
|
||||||
|
data_collator_kwargs["pad_to_multiple_of"] = 64
|
||||||
|
|
||||||
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
if cfg.is_llama_derived_model and cfg.landmark_attention:
|
||||||
from functools import partial
|
|
||||||
|
|
||||||
from axolotl.monkeypatch.llama_landmark_attn import (
|
from axolotl.monkeypatch.llama_landmark_attn import (
|
||||||
add_mem_tokens,
|
add_mem_tokens,
|
||||||
get_mem_id,
|
get_mem_id,
|
||||||
@@ -346,7 +610,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
|
|||||||
train_dataset=train_dataset,
|
train_dataset=train_dataset,
|
||||||
eval_dataset=eval_dataset,
|
eval_dataset=eval_dataset,
|
||||||
args=training_args,
|
args=training_args,
|
||||||
data_collator=transformers.DataCollatorForSeq2Seq(
|
data_collator=DataCollatorForSeq2Seq(
|
||||||
tokenizer,
|
tokenizer,
|
||||||
return_tensors="pt",
|
return_tensors="pt",
|
||||||
**data_collator_kwargs,
|
**data_collator_kwargs,
|
||||||
|
|||||||
@@ -8,6 +8,19 @@ LOG = logging.getLogger("axolotl")
|
|||||||
|
|
||||||
|
|
||||||
def validate_config(cfg):
|
def validate_config(cfg):
|
||||||
|
if cfg.max_packed_sequence_len and cfg.sample_packing:
|
||||||
|
raise ValueError(
|
||||||
|
"please set only one of max_packed_sequence_len (deprecated soon) or sample_packing"
|
||||||
|
)
|
||||||
|
if cfg.max_packed_sequence_len:
|
||||||
|
LOG.warning(
|
||||||
|
str(
|
||||||
|
PendingDeprecationWarning(
|
||||||
|
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"please set only one of gradient_accumulation_steps or batch_size"
|
"please set only one of gradient_accumulation_steps or batch_size"
|
||||||
@@ -97,6 +110,17 @@ def validate_config(cfg):
|
|||||||
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.sample_packing and cfg.sdp_attention:
|
||||||
|
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||||
|
raise ValueError(
|
||||||
|
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||||
|
)
|
||||||
|
|
||||||
|
if cfg.sample_packing and cfg.xformers_attention:
|
||||||
|
raise ValueError(
|
||||||
|
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||||
|
)
|
||||||
|
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
30
tests/monkeypatch/test_llama_attn_hijack_flash.py
Normal file
30
tests/monkeypatch/test_llama_attn_hijack_flash.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the monkeypatch utils
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.utils import get_cu_seqlens, get_cu_seqlens_from_pos_ids
|
||||||
|
|
||||||
|
|
||||||
|
class TestMonkeyPatchUtils(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Unit test class for monkeypatch utils
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_get_cu_seqlens_1d(self):
|
||||||
|
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||||
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
||||||
|
self.assertTrue(torch.allclose(get_cu_seqlens(attn_mask)[0], target_res))
|
||||||
|
|
||||||
|
def test_get_cu_seqlens_from_pos_ids_1d(self):
|
||||||
|
position_ids = torch.tensor([[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0]])
|
||||||
|
target_res = torch.tensor([0, 4, 7, 12, 14, 16], dtype=torch.int32)
|
||||||
|
self.assertTrue(
|
||||||
|
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
44
tests/test_expand_mask.py
Normal file
44
tests/test_expand_mask.py
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
"""
|
||||||
|
Unit tests for the monkey patch for expand mask to handle packed sequences
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from axolotl.monkeypatch.llama_expand_mask import _expand_mask
|
||||||
|
|
||||||
|
|
||||||
|
class TestExpandMask(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
Test class for attention mask expansion for packed sequences
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_output(self):
|
||||||
|
mask = torch.tensor([[1, 1, 1, 2], [2, 3, 3, 0]])
|
||||||
|
dtype = torch.float32
|
||||||
|
expected_output = torch.tensor(
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
|
[0.0000e00, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, -3.4028e38, 0.0000e00],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
[
|
||||||
|
[0.0000e00, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
[-3.4028e38, 0.0000e00, -3.4028e38, -3.4028e38],
|
||||||
|
[-3.4028e38, 0.0000e00, 0.0000e00, -3.4028e38],
|
||||||
|
[-3.4028e38, -3.4028e38, -3.4028e38, -3.4028e38],
|
||||||
|
]
|
||||||
|
],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
# Check that the output matches the expected output
|
||||||
|
self.assertTrue(torch.allclose(_expand_mask(mask, dtype), expected_output))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_resets_attention(self):
|
def test_increments_attention(self):
|
||||||
prompter = AlpacaPrompter("chat")
|
prompter = AlpacaPrompter("chat")
|
||||||
strat = AlpacaPromptTokenizingStrategy(
|
strat = AlpacaPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
@@ -55,10 +55,14 @@ class TestPacking(unittest.TestCase):
|
|||||||
# first example doesn't have mask reset
|
# first example doesn't have mask reset
|
||||||
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
assert example["input_ids"][0] == self.tokenizer.bos_token_id
|
||||||
assert example["attention_mask"][0] == 1
|
assert example["attention_mask"][0] == 1
|
||||||
|
assert example["position_ids"][0] == 0
|
||||||
|
assert example["position_ids"][1] == 1
|
||||||
|
|
||||||
# but subsequent one does
|
# but subsequent one does
|
||||||
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
|
||||||
assert example["attention_mask"][next_bos_index] == 0
|
assert example["attention_mask"][next_bos_index] == 2
|
||||||
|
assert example["position_ids"][next_bos_index] == 0
|
||||||
|
assert example["position_ids"][next_bos_index + 1] == 1
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -134,9 +134,15 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
|||||||
"output": "Hi! How can I help?",
|
"output": "Hi! How can I help?",
|
||||||
}
|
}
|
||||||
example = strat.tokenize_prompt(sample)
|
example = strat.tokenize_prompt(sample)
|
||||||
assert example["input_ids"][0:4] == [1, 835, 2184, 29901] # "<s>### System:"
|
assert example["input_ids"][0:5] == [
|
||||||
assert example["input_ids"][5:7] == [1509, 20118] # "use cot"
|
1,
|
||||||
assert example["input_ids"][9] == 11889 # USER
|
28962,
|
||||||
|
1254,
|
||||||
|
12665,
|
||||||
|
29901,
|
||||||
|
] # "<s>SYSTEM:"
|
||||||
|
assert example["input_ids"][5:7] == [671, 20118] # " use cot"
|
||||||
|
assert example["input_ids"][8] == 11889 # USER
|
||||||
|
|
||||||
|
|
||||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
class Llama2ChatTokenizationTest(unittest.TestCase):
|
||||||
|
|||||||
@@ -70,7 +70,7 @@ class AlpacaPrompterTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert "use cot" in res
|
assert "use cot" in res
|
||||||
assert res.startswith("### System:")
|
assert res.startswith("SYSTEM:")
|
||||||
assert "### Instruction:" not in res
|
assert "### Instruction:" not in res
|
||||||
assert "### Input:" not in res
|
assert "### Input:" not in res
|
||||||
assert "alpacas" in res
|
assert "alpacas" in res
|
||||||
|
|||||||
@@ -313,3 +313,27 @@ class ValidationTest(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
validate_config(cfg)
|
validate_config(cfg)
|
||||||
|
|
||||||
|
def test_packing(self):
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"max_packed_sequence_len": 2048,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"max_packed_sequence_len will be deprecated in favor of sample_packing"
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"max_packed_sequence_len": 2048,
|
||||||
|
"sample_packing": True,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
regex_exp = r".*set only one of max_packed_sequence_len \(deprecated soon\) or sample_packing.*"
|
||||||
|
with pytest.raises(ValueError, match=regex_exp):
|
||||||
|
validate_config(cfg)
|
||||||
|
|||||||
Reference in New Issue
Block a user