Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
8028652b8f fix attetion mask with packing
Some checks failed
pre-commit / pre-commit (push) Has been cancelled
PyTest / test (3.10) (push) Has been cancelled
PyTest / test (3.9) (push) Has been cancelled
2023-07-15 10:38:01 -04:00
Wing Lian
33814cc94e make sure we eval for openorca 2023-07-02 17:59:10 -04:00
Wing Lian
50254a7ccc handle orca splits 2023-07-01 07:20:23 -04:00
12 changed files with 86 additions and 490 deletions

View File

@@ -237,7 +237,7 @@ Have dataset(s) in one of the following format (JSONL recommended):
#### How to add custom prompts #### How to add custom prompts
1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example. 1. Add your method to a file in [prompt_strategies](src/axolotl/prompt_strategies). Please see other files as example.
2. Use your custom file name as the dataset type `<prompt_strategies_file>.load_<load_fn>`. 2. Use your custom file name as the dataset type.
Optionally, download some datasets, see [data/README.md](data/README.md) Optionally, download some datasets, see [data/README.md](data/README.md)
@@ -255,18 +255,10 @@ See sample configs in [configs](configs) folder or [examples](examples) for quic
- dataset - dataset
```yaml ```yaml
sequence_len: 2048 # max token length for prompt
# huggingface repo
datasets: datasets:
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4 # local or huggingface repo
type: alpaca # format from earlier
# local
datasets:
- path: json
data_files: data.jsonl # or json
type: alpaca # format from earlier type: alpaca # format from earlier
sequence_len: 2048 # max token length / prompt
``` ```
- loading - loading
@@ -305,8 +297,6 @@ base_model_ignore_patterns:
# if the base_model repo on hf hub doesn't include configuration .json files, # if the base_model repo on hf hub doesn't include configuration .json files,
# you can set that here, or leave this empty to default to base_model # you can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# you can specify to choose a specific model revision from huggingface hub
model_revision:
# Optional tokenizer configuration override in case you want to use a different tokenizer # Optional tokenizer configuration override in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
tokenizer_config: tokenizer_config:
@@ -338,10 +328,10 @@ tf32: true # require >=ampere
# a list of one or more datasets to finetune the model with # a list of one or more datasets to finetune the model with
datasets: datasets:
# hf dataset repo | "json" for local dataset, make sure to fill data_files # this can be either a hf dataset, or relative path
- path: vicgalle/alpaca-gpt4 - path: vicgalle/alpaca-gpt4
# The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection] # The type of prompt to use for training. [alpaca, sharegpt, gpteacher, oasst, reflection]
type: alpaca # format | format:<prompt_style> (chat/instruct) | <prompt_strategies>.load_<load_fn> type: alpaca # format OR format:prompt_style (chat/instruct)
data_files: # path to source data files data_files: # path to source data files
shards: # number of shards to split data into shards: # number of shards to split data into
@@ -351,7 +341,7 @@ dataset_prepared_path: data/last_run_prepared
# push prepared dataset to hub # push prepared dataset to hub
push_dataset_to_hub: # repo path push_dataset_to_hub: # repo path
# push checkpoints to hub # push checkpoints to hub
hub_model_id: # repo path push_to_hub_model_id: # repo path
# whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets # whether to use hf `use_auth_token` for loading datasets. Useful for fetching private datasets
# required to be true when used in combination with `push_dataset_to_hub` # required to be true when used in combination with `push_dataset_to_hub`
hf_use_auth_token: # boolean hf_use_auth_token: # boolean
@@ -413,9 +403,6 @@ logging_steps:
save_steps: save_steps:
eval_steps: eval_steps:
# save model as safetensors (require safetensors package)
save_safetensors:
# whether to mask out or include the human's prompt from the training labels # whether to mask out or include the human's prompt from the training labels
train_on_inputs: false train_on_inputs: false
# don't use this, leads to wonky training (according to someone on the internet) # don't use this, leads to wonky training (according to someone on the internet)

View File

@@ -97,4 +97,4 @@ RUN cd /workspace/builds/bitsandbytes && python3 setup.py install
RUN git lfs install --skip-repo RUN git lfs install --skip-repo
RUN pip3 install awscli && \ RUN pip3 install awscli && \
# The base image ships with `pydantic==1.8.2` which is not working # The base image ships with `pydantic==1.8.2` which is not working
pip3 install -U --no-cache-dir pydantic==1.10.10 pip3 install -U --no-cache-dir pydantic

View File

@@ -1,6 +1,7 @@
peft @ git+https://github.com/huggingface/peft.git peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git transformers @ git+https://github.com/huggingface/transformers.git
bitsandbytes>=0.39.0 bitsandbytes>=0.39.0
accelerate
addict addict
fire fire
PyYAML==6.0 PyYAML==6.0
@@ -17,4 +18,3 @@ evaluate==0.4.0
rouge-score==0.1.2 rouge-score==0.1.2
scipy scipy
scikit-learn==1.2.2 scikit-learn==1.2.2
numba

View File

@@ -79,11 +79,13 @@ class ConstantLengthDataset(IterableDataset):
buffer = {"input_ids": [], "attention_mask": [], "labels": []} buffer = {"input_ids": [], "attention_mask": [], "labels": []}
buffer_len = 0 buffer_len = 0
for dataset in self.datasets: for dataset in self.datasets:
idx = 0
iterator = iter(dataset) iterator = iter(dataset)
more_examples = True more_examples = True
while more_examples: while more_examples:
try: try:
example = next(iterator) example = next(iterator)
idx += 1
except StopIteration: except StopIteration:
more_examples = False more_examples = False
example = None example = None
@@ -124,6 +126,7 @@ class ConstantLengthDataset(IterableDataset):
"labels": [], "labels": [],
} }
buffer_len = 0 buffer_len = 0
idx = 1
if example: if example:
# FIXME # FIXME
@@ -132,11 +135,6 @@ class ConstantLengthDataset(IterableDataset):
input_ids = example["input_ids"] input_ids = example["input_ids"]
attention_mask = example["attention_mask"] attention_mask = example["attention_mask"]
labels = example["labels"] labels = example["labels"]
if (
buffer["input_ids"]
and input_ids[0] == self.tokenizer.bos_token_id
):
attention_mask[0] = 0
if add_concat_token: if add_concat_token:
input_ids.append(self.concat_token_id) input_ids.append(self.concat_token_id)
@@ -147,7 +145,7 @@ class ConstantLengthDataset(IterableDataset):
input_ids, dtype=self.tokens_dtype input_ids, dtype=self.tokens_dtype
) )
attention_mask_with_concat = torch.tensor( attention_mask_with_concat = torch.tensor(
attention_mask, dtype=self.tokens_dtype [idx * m for m in attention_mask], dtype=torch.int16
) )
labels_with_concat = torch.tensor( labels_with_concat = torch.tensor(
labels, dtype=self.tokens_dtype labels, dtype=self.tokens_dtype

View File

@@ -37,7 +37,7 @@ from axolotl.prompters import (
def load_tokenized_prepared_datasets( def load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path split, tokenizer, cfg, default_dataset_prepared_path
) -> DatasetDict: ) -> DatasetDict:
tokenizer_name = tokenizer.__class__.__name__ tokenizer_name = tokenizer.__class__.__name__
ds_hash = str( ds_hash = str(
@@ -49,6 +49,8 @@ def load_tokenized_prepared_datasets(
sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets]) sorted([f"{d.path}:{d.type}:{d.shards}" for d in cfg.datasets])
) )
+ "|" + "|"
+ split
+ "|"
+ tokenizer_name + tokenizer_name
).encode("utf-8") ).encode("utf-8")
).hexdigest() ).hexdigest()
@@ -66,7 +68,7 @@ def load_tokenized_prepared_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
dataset = dataset["train"] dataset = dataset[split]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
pass pass
@@ -102,26 +104,13 @@ def load_tokenized_prepared_datasets(
pass pass
# prefer local dataset, even if hub exists # prefer local dataset, even if hub exists
local_path = Path(d.path) if Path(d.path).exists():
if local_path.exists(): ds = load_dataset(
if local_path.is_dir(): "json",
ds = load_dataset( data_files=d.path,
d.path, streaming=False,
data_files=d.data_files, split=None,
streaming=False, )
split=None,
)
elif local_path.is_file():
ds = load_dataset(
"json",
data_files=d.path,
streaming=False,
split=None,
)
else:
raise ValueError(
"unhandled dataset load: local path exists, but is neither a directory or a file"
)
elif ds_from_hub: elif ds_from_hub:
if d.data_files: if d.data_files:
ds = load_dataset( ds = load_dataset(
@@ -147,8 +136,8 @@ def load_tokenized_prepared_datasets(
raise ValueError("unhandled dataset load") raise ValueError("unhandled dataset load")
# support for using a subset of the data # support for using a subset of the data
if d.shards: if d.shards:
if "train" in ds: if split in ds:
ds = ds.shuffle(seed=seed)["train"].shard( ds = ds.shuffle(seed=seed)[split].shard(
num_shards=d.shards, index=0 num_shards=d.shards, index=0
) )
else: else:
@@ -157,8 +146,8 @@ def load_tokenized_prepared_datasets(
d_type_split = d_type.split(":") d_type_split = d_type.split(":")
d_base_type = d_type_split[0] d_base_type = d_type_split[0]
d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None d_prompt_style = d_type_split[1] if len(d_type_split) > 1 else None
if "train" in ds: if split in ds:
ds = ds["train"] ds = ds[split]
if ds_strategy := load(d.type, tokenizer, cfg): if ds_strategy := load(d.type, tokenizer, cfg):
ds_wrapper = TokenizedPromptDataset(ds_strategy, ds) ds_wrapper = TokenizedPromptDataset(ds_strategy, ds)
datasets.append(ds_wrapper) datasets.append(ds_wrapper)
@@ -332,7 +321,6 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", f"{cfg.push_dataset_to_hub}/{ds_hash}",
use_auth_token=use_auth_token, use_auth_token=use_auth_token,
) )
dataset = dataset["train"]
except Exception: # pylint: disable=broad-except # nosec except Exception: # pylint: disable=broad-except # nosec
pass pass
@@ -352,28 +340,37 @@ def load_prepare_datasets(
f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True f"{cfg.push_dataset_to_hub}/{ds_hash}", private=True
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset_train = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path "train", tokenizer, cfg, default_dataset_prepared_path
) )
dataset_test = load_tokenized_prepared_datasets(
"test", tokenizer, cfg, default_dataset_prepared_path
)
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.seed: if cfg.seed:
dataset = dataset.shuffle(seed=cfg.seed) dataset = dataset.shuffle(seed=cfg.seed)
constant_len_dataset = ConstantLengthDataset( constant_len_dataset_train = ConstantLengthDataset(
tokenizer, tokenizer,
[dataset], [dataset["train"]],
seq_length=max_packed_sequence_len,
)
constant_len_dataset_test = ConstantLengthDataset(
tokenizer,
[dataset["test"]],
seq_length=max_packed_sequence_len, seq_length=max_packed_sequence_len,
) )
logging.info( logging.info(
f"packing master dataset to len: {cfg.max_packed_sequence_len}" f"packing master dataset to len: {cfg.max_packed_sequence_len}"
) )
dataset = Dataset.from_list(list(constant_len_dataset)) dataset_train = Dataset.from_list(list(constant_len_dataset_train))
dataset_test = Dataset.from_list(list(constant_len_dataset_test))
# filter out bad data # filter out bad data
dataset = Dataset.from_list( dataset_train = Dataset.from_list(
[ [
d d
for d in dataset for d in dataset_train
if len(d["input_ids"]) < cfg.sequence_len if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0 and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"]) and len(d["input_ids"]) == len(d["attention_mask"])
@@ -381,6 +378,19 @@ def load_prepare_datasets(
] ]
) )
# filter out bad data
dataset_test = Dataset.from_list(
[
d
for d in dataset_test
if len(d["input_ids"]) < cfg.sequence_len
and len(d["input_ids"]) > 0
and len(d["input_ids"]) == len(d["attention_mask"])
and len(d["input_ids"]) == len(d["labels"])
]
)
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.local_rank == 0: if cfg.local_rank == 0:
logging.info( logging.info(
f"Saving packed prepared dataset to disk... {prepared_ds_path}" f"Saving packed prepared dataset to disk... {prepared_ds_path}"
@@ -395,9 +405,13 @@ def load_prepare_datasets(
private=True, private=True,
) )
else: else:
dataset = load_tokenized_prepared_datasets( dataset_train = load_tokenized_prepared_datasets(
tokenizer, cfg, default_dataset_prepared_path "train", tokenizer, cfg, default_dataset_prepared_path
) )
dataset_test = load_tokenized_prepared_datasets(
"test", tokenizer, cfg, default_dataset_prepared_path
)
dataset = DatasetDict({"train": dataset_train, "test": dataset_test})
if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None: if cfg.dataset_shard_num and cfg.dataset_shard_idx is not None:
logging.info( logging.info(
@@ -412,6 +426,9 @@ def load_prepare_datasets(
dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False) dataset = dataset.train_test_split(test_size=cfg.val_set_size, shuffle=False)
train_dataset = dataset["train"] train_dataset = dataset["train"]
eval_dataset = dataset["test"] eval_dataset = dataset["test"]
elif "train" in dataset:
train_dataset = dataset["train"]
eval_dataset = dataset["test"]
else: else:
train_dataset = dataset train_dataset = dataset
eval_dataset = None eval_dataset = None

View File

@@ -154,8 +154,6 @@ def load_model(
) )
model_kwargs = {} model_kwargs = {}
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.adapter == "qlora" and cfg.load_in_4bit: if cfg.adapter == "qlora" and cfg.load_in_4bit:
model_kwargs["quantization_config"] = BitsAndBytesConfig( model_kwargs["quantization_config"] = BitsAndBytesConfig(
load_in_4bit=True, load_in_4bit=True,
@@ -204,7 +202,7 @@ def load_model(
else True, else True,
) )
load_in_8bit = False load_in_8bit = False
elif cfg.is_llama_derived_model and not cfg.trust_remote_code: elif cfg.is_llama_derived_model:
from transformers import LlamaForCausalLM from transformers import LlamaForCausalLM
config = LlamaConfig.from_pretrained(base_model_config) config = LlamaConfig.from_pretrained(base_model_config)
@@ -243,7 +241,7 @@ def load_model(
# device=cfg.device, # device=cfg.device,
# ) # )
# model.train() # sets to train instead of eval mode # model.train() # sets to train instead of eval mode
elif model_type and not cfg.trust_remote_code: elif model_type:
model = getattr(transformers, model_type).from_pretrained( model = getattr(transformers, model_type).from_pretrained(
base_model, base_model,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None, load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,

View File

@@ -1,173 +0,0 @@
# pylint: skip-file
from typing import Any, List, Optional
import numba
import numpy as np
import torch.distributed as dist
from torch.utils.data import Sampler
@numba.njit
def ffd_check(a: np.ndarray, c: int, n: int):
# First-fit-decreasing bin packing
# Check if a[] could fit in n bins with capacity c
# https://en.wikipedia.org/wiki/First-fit-decreasing_bin_packing
a = np.sort(a)[::-1]
bins = np.full((n,), c, dtype=a.dtype)
for size in a:
not_found = True
for idx in range(n):
if bins[idx] >= size:
bins[idx] -= size
not_found = False
break
if not_found:
return False
return True
@numba.njit
def ffd_with_result(a: np.ndarray, c: int, start_index: int):
# First-fit-decreasing bin packing (with result return)
indices = np.argsort(a)[::-1]
a = a[indices]
bins: List[int] = []
bins_result: List[Any] = []
for a_id, size in enumerate(a):
add_new = True
for idx in range(len(bins)):
if bins[idx] >= size:
bins[idx] -= size
bins_result[idx].append(indices[a_id] + start_index)
add_new = False
break
if add_new:
bins.append(c - size)
bins_result.append([indices[a_id] + start_index])
return bins_result
@numba.njit
def allocate(
lengths: np.ndarray, lengths_cumsum: np.ndarray, rank: int, c: int, n: int
):
# Dynamic batch allocator, similar to Multifit
# https://en.wikipedia.org/wiki/Multifit_algorithm
# ~99.5% efficiency on OpenChat training set (12 * 2048 ctx len)
s = 0
start_index = 0
result = []
while True:
# binary search [l, r)
left = 1
right = 1 + np.searchsorted(lengths_cumsum[start_index:], s + c * n, "right")
while right - left > 1:
m = (left + right) // 2
if ffd_check(lengths[start_index : start_index + m], c, n):
left = m
else:
right = m
# use length l
batch = ffd_with_result(
lengths[start_index : start_index + left], c, start_index
)
assert len(batch) <= n
if len(batch) < n:
break
start_index += left
s = lengths_cumsum[start_index - 1]
# add local rank
result.append(batch[rank])
return result, s, len(result) * c * n
class MultipackDistributedBatchSampler(Sampler):
"""Unpadded length sampling using Multipack.
Approximate (at most ~1.22x) the optimal solution of the identical-machines scheduling problem, which is NP-hard.
"""
def __init__(
self,
batch_max_length: int,
lengths: List[int],
num_replicas: Optional[int] = None,
rank: Optional[int] = None,
seed: int = 0,
):
# Get rank
if num_replicas is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
num_replicas = dist.get_world_size()
if rank is None:
if not dist.is_available():
raise RuntimeError("Requires distributed package to be available")
rank = dist.get_rank()
self.num_replicas = num_replicas
self.rank = rank
self.seed = seed
self.batch_max_length = batch_max_length
self.lengths = lengths
assert isinstance(self.lengths, np.ndarray)
self.epoch = 0
# statistics
self.eff_total_used = 0
self.eff_total_slots = 0
def set_epoch(self, epoch: int):
self.epoch = epoch
def generate_batches(self, set_stats=False):
indices = np.random.default_rng(seed=self.seed + self.epoch).permutation(
len(self.lengths)
)
lengths = self.lengths[indices]
lengths_cumsum = np.cumsum(lengths)
batches, total_used, total_slots = allocate(
lengths=lengths,
lengths_cumsum=lengths_cumsum,
rank=self.rank,
c=self.batch_max_length,
n=self.num_replicas,
)
batches = [indices[batch] for batch in batches]
# statistics
if set_stats:
self.eff_total_used += total_used
self.eff_total_slots += total_slots
return batches
def __iter__(self):
batches = self.generate_batches(set_stats=True)
return iter(batches)
def num_batches(self):
batches = self.generate_batches()
return len(batches)
def efficiency(self):
return self.eff_total_used / self.eff_total_slots

View File

@@ -1,9 +1,6 @@
"""Module for custom LRScheduler class""" """Module for custom LRScheduler class"""
import math
from functools import partial
from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler
from torch.optim.lr_scheduler import LambdaLR, LRScheduler
class InterpolatingLogScheduler(LRScheduler): class InterpolatingLogScheduler(LRScheduler):
@@ -45,58 +42,3 @@ class InterpolatingLogScheduler(LRScheduler):
lrs = [self.max_lr for base_lr in self.base_lrs] lrs = [self.max_lr for base_lr in self.base_lrs]
return lrs return lrs
def _get_cosine_schedule_with_quadratic_warmup_lr_lambda(
current_step: int,
*,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float
):
if current_step < num_warmup_steps:
return (float(current_step) / float(max(1, num_warmup_steps))) ** 2
progress = float(current_step - num_warmup_steps) / float(
max(1, num_training_steps - num_warmup_steps)
)
return max(
0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress))
)
def get_cosine_schedule_with_quadratic_warmup(
optimizer: Optimizer,
num_warmup_steps: int,
num_training_steps: int,
num_cycles: float = 0.5,
last_epoch: int = -1,
):
"""
Create a schedule with a learning rate that decreases following the values of the cosine function between the
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
initial lr set in the optimizer.
Args:
optimizer ([`~torch.optim.Optimizer`]):
The optimizer for which to schedule the learning rate.
num_warmup_steps (`int`):
The number of steps for the warmup phase.
num_training_steps (`int`):
The total number of training steps.
num_cycles (`float`, *optional*, defaults to 0.5):
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
following a half-cosine).
last_epoch (`int`, *optional*, defaults to -1):
The index of the last epoch when resuming training.
Return:
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
"""
lr_lambda = partial(
_get_cosine_schedule_with_quadratic_warmup_lr_lambda,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
num_cycles=num_cycles,
)
return LambdaLR(optimizer, lr_lambda, last_epoch)

View File

@@ -5,185 +5,25 @@ import logging
import math import math
import os import os
import sys import sys
from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
import bitsandbytes as bnb import bitsandbytes as bnb
import numpy as np
import torch.cuda import torch.cuda
import transformers import transformers
from torch import nn from torch import nn
from torch.optim.lr_scheduler import OneCycleLR from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import Dataset from transformers import EarlyStoppingCallback, Trainer
from transformers import EarlyStoppingCallback, Trainer, TrainingArguments
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from axolotl.utils.callbacks import ( from axolotl.utils.callbacks import (
SaveBetterTransformerModelCallback, SaveBetterTransformerModelCallback,
SavePeftModelCallback, SavePeftModelCallback,
) )
from axolotl.utils.sampler import MultipackDistributedBatchSampler from axolotl.utils.schedulers import InterpolatingLogScheduler
from axolotl.utils.schedulers import (
InterpolatingLogScheduler,
get_cosine_schedule_with_quadratic_warmup,
)
IGNORE_LABEL_ID = -100
def _find_multiple(val1, val2): class OneCycleLRSchedulerTrainer(Trainer):
return (-(val1 // -val2)) * val2
def batch_to_tensor(batch, pad_id=0, dtype=torch.long, loss_dtype=torch.bfloat16):
# Pad an unused item to reach multiple of 64, for faster GEMM
pad_cur_len = sum(list(batch["length"]))
pad_len = _find_multiple(pad_cur_len, 64) - pad_cur_len
if pad_len > 0:
assert pad_len < 64
batch["input_ids"].append([pad_id] * pad_len)
batch["labels"].append([pad_id] * pad_len)
batch["attention_mask"].append([0] * pad_len)
batch["length"].append(pad_len)
# seqlen
batch_lengths = torch.tensor(list(batch["length"]), dtype=torch.int32, device="cpu")
max_seqlen = torch.max(batch_lengths)
cu_seqlens = torch.nn.functional.pad(
batch_lengths.cumsum(-1, dtype=torch.int32), (1, 0)
)
# nz elements
nz_num = cu_seqlens[-1]
nz_input_ids = torch.zeros((nz_num,), dtype=dtype, pin_memory=True, device="cpu")
nz_position_ids = torch.zeros((nz_num,), dtype=dtype, pin_memory=True, device="cpu")
nz_shifted_label_ids = torch.zeros(
(nz_num,), dtype=dtype, pin_memory=True, device="cpu"
)
nz_shifted_loss_weights = torch.zeros(
(nz_num,), dtype=loss_dtype, pin_memory=True, device="cpu"
)
index = 0
for token_list, length, labels_list in zip(
batch["input_ids"], batch["length"], batch["labels"]
):
tokens = torch.tensor(token_list, dtype=dtype, device="cpu")
position_ids = torch.arange(length, dtype=dtype, device="cpu")
# Input IDs & shifted labels
# shifted_label_ids = torch.where(masks, tokens, IGNORE_LABEL_ID)
shifted_label_ids = labels_list
shifted_label_ids = torch.nn.functional.pad(
shifted_label_ids[1:], (0, 1), "constant", IGNORE_LABEL_ID
)
nz_input_ids[index : index + length] = tokens
nz_position_ids[index : index + length] = position_ids
nz_shifted_label_ids[index : index + length] = shifted_label_ids
# Loss weights
mask_count = sum(1 for label in labels_list[1:] if label != IGNORE_LABEL_ID)
loss_weight = (
1 / mask_count if mask_count > 0 else 0
) # Avoid division by zero for paddings
nz_shifted_loss_weights[index : index + length] = loss_weight
index += length
# inputs
return {
"max_seqlen": max_seqlen,
"cu_seqlens": cu_seqlens,
"nz_input_ids": nz_input_ids,
"nz_position_ids": nz_position_ids,
"nz_shifted_label_ids": nz_shifted_label_ids,
"nz_shifted_loss_weights": nz_shifted_loss_weights,
}
@dataclass
class AxolotlTrainingArguments(TrainingArguments):
"""
Extend the base TrainingArguments for axolotl helpers
"""
lr_quadratic_warmup: bool = field(
default=False,
metadata={"help": "Use quadratic warmup for cosine scheduling."},
)
sample_packing: bool = field(
default=True,
metadata={"help": "Use sample packing for efficient training."},
)
max_seq_length: int = field(
default=2048,
metadata={"help": "The maximum sequence length the model can handle"},
)
class AxolotlTrainer(Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
args = None # type: AxolotlTrainingArguments
def create_scheduler(
self, num_training_steps: int, optimizer: torch.optim.Optimizer = None
):
"""
Setup the scheduler. The optimizer of the trainer must have been set up either before this method is called or
passed as an argument.
Args:
num_training_steps (int): The number of training steps to do.
optimizer (torch.optim.Optimizer): The training optimizer
"""
# fmt: off
if self.lr_scheduler is None: # type: ignore # pylint: disable=access-member-before-definition
# fmt: on
if (
self.args.lr_scheduler_type == "cosine"
and self.args.lr_quadratic_warmup is True
):
self.lr_scheduler = get_cosine_schedule_with_quadratic_warmup( # pylint: disable=attribute-defined-outside-init
optimizer,
num_warmup_steps=self.args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps,
)
else:
return super().create_scheduler(num_training_steps, optimizer)
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
lengths = np.array([len(sample["input_ids"]) for sample in self.train_dataset])
return MultipackDistributedBatchSampler(
batch_max_length=self.args.per_device_train_batch_size
* self.args.max_seq_length,
lengths=lengths,
seed=self.args.seed,
)
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
lengths = np.array([len(sample["input_ids"]) for sample in eval_dataset])
return MultipackDistributedBatchSampler(
batch_max_length=self.args.per_device_eval_batch_size
* self.args.max_seq_length,
lengths=lengths,
seed=self.args.seed,
)
class OneCycleLRSchedulerTrainer(AxolotlTrainer):
""" """
Trainer subclass that uses the OneCycleLR scheduler Trainer subclass that uses the OneCycleLR scheduler
""" """
@@ -263,9 +103,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.fsdp_config: if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config) training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)
if cfg.lr_quadratic_warmup is not None:
training_arguments_kwargs["lr_quadratic_warmup"] = cfg.lr_quadratic_warmup
# deepspeed # deepspeed
if ( if (
os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true" os.environ.get("ACCELERATE_USE_DEEPSPEED") == "true"
@@ -287,16 +124,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
if cfg.max_grad_norm: if cfg.max_grad_norm:
training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm training_arguments_kwargs["max_grad_norm"] = cfg.max_grad_norm
if cfg.hub_model_id: if cfg.push_to_hub_model_id:
training_arguments_kwargs["hub_model_id"] = cfg.hub_model_id training_arguments_kwargs["push_to_hub_model_id"] = cfg.push_to_hub_model_id
training_arguments_kwargs["push_to_hub"] = True training_arguments_kwargs["push_to_hub"] = True
training_arguments_kwargs["hub_private_repo"] = True
if cfg.save_safetensors: training_args = transformers.TrainingArguments(
training_arguments_kwargs["save_safetensors"] = cfg.save_safetensors
training_args = AxolotlTrainingArguments( # pylint: disable=unexpected-keyword-arg
max_steps=total_num_steps * cfg.num_epochs,
per_device_train_batch_size=cfg.micro_batch_size, per_device_train_batch_size=cfg.micro_batch_size,
per_device_eval_batch_size=cfg.eval_batch_size per_device_eval_batch_size=cfg.eval_batch_size
if cfg.eval_batch_size is not None if cfg.eval_batch_size is not None
@@ -305,9 +137,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
eval_accumulation_steps=cfg.gradient_accumulation_steps, eval_accumulation_steps=cfg.gradient_accumulation_steps,
num_train_epochs=cfg.num_epochs, num_train_epochs=cfg.num_epochs,
learning_rate=cfg.learning_rate, learning_rate=cfg.learning_rate,
evaluation_strategy="steps" if cfg.val_set_size > 0 else "no", evaluation_strategy="steps",
save_strategy="steps" if cfg.save_steps else "epoch", save_strategy="steps" if cfg.save_steps else "epoch",
eval_steps=cfg.eval_steps if cfg.val_set_size > 0 else None, eval_steps=cfg.eval_steps,
save_steps=cfg.save_steps, save_steps=cfg.save_steps,
output_dir=cfg.output_dir, output_dir=cfg.output_dir,
save_total_limit=3, save_total_limit=3,
@@ -446,7 +278,7 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
trainer_cls = ( trainer_cls = (
OneCycleLRSchedulerTrainer OneCycleLRSchedulerTrainer
if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora") if cfg.lr_scheduler == "one_cycle" and (cfg.fsdp or cfg.adapter == "qlora")
else AxolotlTrainer else transformers.Trainer
) )
trainer = trainer_cls( trainer = trainer_cls(
model=model, model=model,

View File

@@ -87,16 +87,11 @@ def validate_config(cfg):
"You probably want to disable group_by_length as it will force a streamed dataset to download completely." "You probably want to disable group_by_length as it will force a streamed dataset to download completely."
) )
if any([cfg.adam_beta1, cfg.adam_beta2, cfg.adam_epsilon]) and ( if any([cfg.adamw_beta1, cfg.adamw_beta2, cfg.adamw_epsilon]) and (
not cfg.optimizer or "adamw" not in cfg.optimizer not cfg.optimizer or "adamw" not in cfg.optimizer
): ):
logging.warning("adamw hyperparameters found, but no adamw optimizer set") logging.warning("adamw hyperparameters found, but no adamw optimizer set")
if cfg.push_to_hub_model_id:
raise ValueError(
"push_to_hub_model_id is deprecated. Please use hub_model_id instead."
)
# TODO # TODO
# MPT 7b # MPT 7b
# https://github.com/facebookresearch/bitsandbytes/issues/25 # https://github.com/facebookresearch/bitsandbytes/issues/25

View File

@@ -27,7 +27,7 @@ class TestPacking(unittest.TestCase):
} }
) )
def test_resets_attention(self): def test_increments_attention(self):
prompter = AlpacaPrompter("chat") prompter = AlpacaPrompter("chat")
strat = AlpacaPromptTokenizingStrategy( strat = AlpacaPromptTokenizingStrategy(
prompter, prompter,
@@ -58,7 +58,7 @@ class TestPacking(unittest.TestCase):
# but subsequent one does # but subsequent one does
assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id assert example["input_ids"][next_bos_index] == self.tokenizer.bos_token_id
assert example["attention_mask"][next_bos_index] == 0 assert example["attention_mask"][next_bos_index] == 2
if __name__ == "__main__": if __name__ == "__main__":

View File

@@ -268,7 +268,7 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": None, "optimizer": None,
"adam_epsilon": 0.0001, "adamw_epsilon": 0.0001,
} }
) )
@@ -283,7 +283,7 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": "adafactor", "optimizer": "adafactor",
"adam_beta1": 0.0001, "adamw_beta1": 0.0001,
} }
) )
@@ -298,9 +298,9 @@ class ValidationTest(unittest.TestCase):
cfg = DictDefault( cfg = DictDefault(
{ {
"optimizer": "adamw_bnb_8bit", "optimizer": "adamw_bnb_8bit",
"adam_beta1": 0.9, "adamw_beta1": 0.0001,
"adam_beta2": 0.99, "adamw_beta2": 0.0001,
"adam_epsilon": 0.0001, "adamw_epsilon": 0.0001,
} }
) )