Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
d46d7dfe30 wip 2024-02-01 00:28:16 -05:00
Wing Lian
047d9e1d5b helper utils 2024-01-31 12:49:29 -05:00
Wing Lian
88a0c05d2c wip 2024-01-31 12:07:39 -05:00
41 changed files with 641 additions and 1282 deletions

View File

@@ -37,9 +37,6 @@ Features:
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- Advanced Topics
- [Multipack](./docs/multipack.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
@@ -1155,11 +1152,9 @@ Having misalignment between your prompts during training and inference can cause
See [this debugging guide](docs/debugging.md) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
## Need help? 🙋
## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
Need dedicated support? Please contact us at [✉wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
## Badge ❤🏷️

View File

@@ -11,7 +11,6 @@ EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
@@ -19,7 +18,6 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

View File

@@ -1,11 +1,4 @@
# Multipack (Sample Packing)
## Visualization of Multipack with Flash Attention
Because Flash Attention simply drops the attention mask, we do not need to
construct a 4d attention mask. We only need to concatenate the sequences into
a single batch and let flash attention know where each new sequence begins.
# Multipack
4k context, bsz =4,
each character represents 256 tokens
@@ -56,18 +49,3 @@ w packing ( note it's the same effective number of tokens per step, but a true b
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```
cu_seqlens:
[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
## Multipack without Flash Attention
Multipack can still be achieved without Flash attention, but with lower packing
efficiency as we are not able to join multiple batches into a single batch due to
context length limits without flash attention. We can use either Pytorch's Scaled
Dot Product Attention implementation or native Pytorch attention implementation
along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
to pack sequences together and avoid cross attention.
<img src="./images/4d-mask.png" alt="axolotl" width="800">

View File

@@ -12,8 +12,8 @@ feedback. Various methods include, but not limited to:
### RLHF using Axolotl
>[!IMPORTANT]
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
[!IMPORTANT]
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML

View File

@@ -43,7 +43,6 @@
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\""

View File

@@ -12,7 +12,6 @@ max_steps: 200
pretraining_dataset:
path: c4
name: en
type: pretrain
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
transformers==4.37.0
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.26.1

View File

@@ -1,17 +0,0 @@
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
```
cd /workspace
rm -rf /workspace/axolotl
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
cd axolotl
pip install --no-deps -e .
```

View File

@@ -6,9 +6,8 @@ from pathlib import Path
from typing import Tuple
import fire
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -28,7 +27,7 @@ LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)

View File

@@ -6,7 +6,6 @@ import logging
from dataclasses import dataclass, field
from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer

View File

@@ -8,15 +8,17 @@ import importlib
import logging
import math
import sys
import typing
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import wraps
from functools import wraps, partial
from pathlib import Path
from typing import List, Optional, Type, Union
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
import transformers
from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -29,6 +31,7 @@ from transformers.trainer_utils import seed_worker
from trl import DPOTrainer
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -50,12 +53,20 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
)
from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
if typing.TYPE_CHECKING:
# hacky, but recommended per https://github.com/python/mypy/issues/5837
_MixinTrainerBase = Trainer
else:
_MixinTrainerBase = object
LOG = logging.getLogger("axolotl.core.trainer_builder")
@@ -98,10 +109,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
@@ -126,10 +133,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
@@ -161,7 +164,142 @@ class AxolotlTrainingArguments(TrainingArguments):
)
class AxolotlTrainer(Trainer):
class AxolotlMultiPackTrainerMixin(_MixinTrainerBase): # type: ignore
"""Trainer Mixin class for dataloaders and samplers"""
args = None # type: AxolotlTrainingArguments
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
class AxolotlTrainer(AxolotlMultiPackTrainerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
@@ -235,151 +373,6 @@ class AxolotlTrainer(Trainer):
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_train_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
@@ -482,14 +475,10 @@ class ReLoRATrainer(AxolotlTrainer):
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
@@ -498,7 +487,7 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(DPOTrainer):
class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
@@ -515,6 +504,59 @@ class AxolotlDPOTrainer(DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(self, feature, *args, **kwargs) -> Dict:
# check if dataset is already tokenized
if not self.is_encoder_decoder:
keys = [
"chosen_input_ids",
"chosen_attention_mask",
"chosen_labels",
"rejected_input_ids",
"rejected_attention_mask",
"rejected_labels",
]
if all(k in feature.keys() for k in keys):
return feature
else:
keys = [
"chosen_labels",
"rejected_labels",
"prompt_input_ids",
"prompt_attention_mask",
]
if all(k in feature.keys() for k in keys):
return feature
return super().tokenize_row(feature, *args, **kwargs)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
]:
all_logits = model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
position_ids=batch["position_ids"],
).logits
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True)
unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn)
labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True)
unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn)
unpacked_logps = self.get_batch_logps(
unpacked_logits,
unpacked_labels,
average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = unpacked_logps[::2]
rejected_logps = unpacked_logps[1::2]
chosen_logits = unpacked_logits[::2]
rejected_logits = unpacked_logits[1::2]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
class TrainerBuilderBase(abc.ABC):
"""
@@ -888,9 +930,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.flash_attention is not True
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
@@ -901,7 +940,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
@@ -996,11 +1034,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
if use_batch_sampler_collator:
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
else:
@@ -1096,21 +1129,13 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
"use_reentrant": False
}
# set save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
save_strategy="steps",
save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
logging_first_step=True,
@@ -1153,6 +1178,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
setattr(dpo_trainer, "use_dpo_data_collator", True)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)

View File

@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
dataset: IterableDataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,

View File

@@ -1,46 +0,0 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access
import torch
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
from torch.utils.data._utils.worker import _worker_loop
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if isinstance(possibly_batched_index[0], list):
data = [None for i in possibly_batched_index]
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
if self.auto_collation:
if (
hasattr(self.dataset, "__getitems__")
and self.dataset.__getitems__
):
data[i] = self.dataset.__getitems__(possibly_batched_index_)
else:
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
else:
data[i] = self.dataset[possibly_batched_index_]
else:
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
def patch_fetchers():
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
def patched_worker_loop(*args, **kwargs):
patch_fetchers()
return _worker_loop(*args, **kwargs)
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
patch_fetchers()

View File

@@ -0,0 +1,142 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# sdp-attn start
#
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# sdp-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -5,11 +5,38 @@ from typing import Optional
import torch
from axolotl.monkeypatch.utils import mask_2d_to_4d
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(

View File

@@ -1,26 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
import transformers.modeling_attn_mask_utils
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)

View File

@@ -4,16 +4,14 @@ import json
import logging
import os.path
import shutil
from functools import partial
from pathlib import Path
from typing import Dict, List, Sequence, Union
from typing import Dict, List, Sequence
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
@@ -25,50 +23,23 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.relora")
@torch.no_grad()
def magnitude_pruning_(tensor, prune_ratio):
tensor_magnitude = torch.abs(tensor)
threshold = torch.quantile(
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
).to(dtype=tensor.dtype)
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
mask = tensor_magnitude > threshold
tensor.mul_(mask.to(dtype=tensor.dtype))
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
n_zeros = 0
n_total = 0
optimizer_state = optimizer.state
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state
for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
if key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
class ReLoRACallback(TrainerCallback):
@@ -126,25 +97,6 @@ class ReLoRACallback(TrainerCallback):
"relora",
)
if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
lora_params = [
n
for n, p in model.named_parameters()
if p.requires_grad and "lora_" in n
]
model.save_pretrained(
os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad():
merge_and_save(
model,
@@ -155,11 +107,7 @@ class ReLoRACallback(TrainerCallback):
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
reset_optimizer(
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
)
reset_optimizer(optimizer)
if self.quantized:
self.last_full_model = checkpoint_folder
@@ -249,13 +197,11 @@ class ReLoRAScheduler(LRScheduler):
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
@@ -264,20 +210,10 @@ class ReLoRAScheduler(LRScheduler):
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
scale = 1
else:
per_relora_progress = step % self.relora_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.relora_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
@@ -302,11 +238,7 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
adapter: Union[List[str], str] = layer.active_adapter
if isinstance(adapter, list):
if len(adapter) > 1:
raise ValueError("unhandled relora for multiple adapters")
adapter = adapter[0]
adapter = layer.active_adapter
return (
peft.utils.transpose(
layer.lora_B[adapter].weight.detach().to(device)
@@ -316,7 +248,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
* layer.scaling[adapter]
)
raise ValueError("unhandled lora layer type")
return layer.get_delta_weight().to(device)
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
@@ -341,9 +273,9 @@ def update_weights(
):
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name, True)
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name, True)
target.reset_lora_parameters(adapter_name)
if isinstance(target, peft.tuners.lora.Linear4bit):
# This could be faster, but the quantization of Linear4bit weights occurs
@@ -354,9 +286,7 @@ def update_weights(
target.weight.data = new_weight.cpu()
target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight.data = (
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
)
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
else:
target.weight.data = new_weight.to(device)
@@ -374,17 +304,14 @@ def merge_and_save(
if not quantized:
for module_name, target in modules.items():
active_adapter = target.active_adapter
if isinstance(active_adapter, list):
active_adapter = active_adapter[0]
update = target.get_delta_weight(active_adapter).detach()
update = target.get_delta_weight(target.active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name, True)
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name, True)
target.reset_lora_parameters(adapter_name)
return
os.makedirs(model_dst, exist_ok=True)
@@ -436,7 +363,6 @@ def merge_and_save(
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
barrier()
del in_tensors
del out_tensors
torch.cuda.empty_cache()

View File

@@ -1,15 +1,8 @@
"""
Shared utils for the monkeypatches
"""
from typing import Optional
import torch
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
@@ -96,6 +89,7 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
@torch.jit.script
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
@@ -141,18 +135,7 @@ def get_cu_seqlens_from_pos_ids(position_ids):
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
# Find the maximum value across all tensors
max_value = max(t.max() for t in results)
# Find the length of the longest tensor
max_length = max(t.size(0) for t in results)
# Pad each tensor to the same length and collect them in a list
padded_results = [
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
]
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value):
@@ -166,62 +149,3 @@ def set_module_name(model, name, value):
child_name = name
setattr(parent, child_name, value)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)

View File

@@ -1,33 +0,0 @@
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = InstructShareGPTPromptTokenizingStrategy(
# pylint: disable=duplicate-code
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return [
{"from": "human", "value": prompt["instruction"]},
{"from": "gpt", "value": prompt["output"]},
]

View File

@@ -1,58 +0,0 @@
"""pretraining prompt strategies"""
from typing import Generator
from transformers import BatchEncoding
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
class PretrainTokenizer:
"""basic tokenization class for pretraining"""
def build_prompt(self, prompt) -> Generator[str, None, None]:
yield prompt
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
"""handles tokenization for pretraining with strides"""
@property
def supports_batched(self):
return True
def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
res = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
res["input_ids"] = [
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
return res
def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"])
def load(tokenizer, cfg):
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -11,6 +11,7 @@ import torch
import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
@@ -23,11 +24,6 @@ from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
BetterTransformer = None
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
@@ -61,21 +57,6 @@ def train(
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
@@ -98,6 +79,21 @@ def train(
safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)
@@ -128,7 +124,7 @@ def train(
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum and BetterTransformer:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
@@ -153,10 +149,7 @@ def train(
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
@@ -202,7 +195,7 @@ def train(
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)

View File

@@ -19,7 +19,6 @@ def chat_templates(user_choice: str):
"""
templates = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
}

View File

@@ -132,26 +132,24 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
chunked_data = {}
for feature in features[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item[feature])
for item in features
if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
@@ -161,27 +159,28 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
chunked_data = {}
for feature in features[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features)
if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
class BatchSamplerDPODataCollatorWithPadding:
@dataclass
class MambaDataCollator:

View File

@@ -202,20 +202,6 @@ def validate_config(cfg):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
if (
# pylint: disable=too-many-boolean-expressions
not (cfg.bf16 or cfg.bfloat16)
and (cfg.fp16 or cfg.float16)
and not cfg.adapter
and not cfg.flash_attention
and cfg.sample_packing
):
LOG.warning(
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
)
# ValueError: Attempting to unscale FP16 gradients.
# OR
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
if cfg.max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
@@ -322,7 +308,7 @@ def validate_config(cfg):
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bfloat16 is not True:
if cfg.float16 is not True and cfg.bloat16 is not True:
LOG.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
@@ -364,24 +350,17 @@ def validate_config(cfg):
+ "point to its path, and remove model_revision from the config."
)
# if cfg.sample_packing and cfg.sdp_attention:
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
# raise ValueError(
# "sample_packing not compatible with sdp_attention. Use flash_attention"
# )
if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError(
"sample_packing not compatible with sdp_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
@@ -447,11 +426,7 @@ def validate_config(cfg):
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if (
cfg.val_set_size == 0
and (cfg.eval_steps or cfg.evaluation_strategy)
and not cfg.test_datasets
):
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)

View File

@@ -4,7 +4,7 @@ import hashlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import yaml
@@ -88,21 +88,12 @@ def prepare_dataset(cfg, tokenizer):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
train_dataset = load_pretraining_dataset(
path,
tokenizer,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
train_dataset = wrap_pretraining_dataset(
load_dataset(path, streaming=True, split="train", name=name),
tokenizer,
cfg,
ds_wrapper_partial,
name=name,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
@@ -149,7 +140,7 @@ def load_tokenized_prepared_datasets(
+ "|".join(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg_datasets
]
)
@@ -392,9 +383,9 @@ def load_tokenized_prepared_datasets(
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
@@ -505,12 +496,7 @@ def load_prepare_datasets(
def get_dataset_wrapper(
config_dataset,
tokenizer,
cfg,
d_base_type,
dataset,
d_prompt_style=None,
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
):
dataset_wrapper = None
dataset_prompter = None
@@ -521,8 +507,7 @@ def get_dataset_wrapper(
}
if (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
"input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
@@ -780,67 +765,76 @@ def encode_pretraining(
return ret
def wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * batch_size,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=batch_size,
batch_size=cfg.micro_batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
# input_columns="text",
batch_size=10_000,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
desc="Encoding Pretraining",
)
return dataset
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
examples: List[str],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]
tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)
sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=1,
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=get_dataset_lengths(train_dataset),
@@ -848,23 +842,15 @@ def encode_packed_pretraining(
chunked_data = defaultdict(list)
for batch in sampler:
for data in batch:
features = train_dataset[data]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for data in sampler:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data

View File

@@ -8,13 +8,8 @@ import addict
import bitsandbytes as bnb
import torch
import transformers
from peft import (
LoftQConfig,
PeftConfig,
PeftModel,
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from optimum.bettertransformer import BetterTransformer
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AddedToken,
@@ -166,20 +161,15 @@ def load_tokenizer(cfg):
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
for k, val in cfg.special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and (len(tokenizer.encode(val)) > 1)
and cfg.adapter
and (
not cfg.lora_modules_to_save
@@ -223,21 +213,6 @@ def load_tokenizer(cfg):
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
@@ -329,13 +304,13 @@ def load_model(
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
elif cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
hijack_llama_sdp_attention,
)
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
hijack_llama_prepare_4d_mask()
LOG.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
@@ -478,18 +453,6 @@ def load_model(
**bnb_config,
)
if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True
if cfg.load_in_4bit and cfg.adapter is not None:
model_kwargs["load_in_4bit"] = True
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in model_kwargs or cfg.gptq:
if "load_in_8bit" in model_kwargs:
del model_kwargs["load_in_8bit"]
if "load_in_4bit" in model_kwargs:
del model_kwargs["load_in_4bit"]
# sample packing uses custom FA2 patch
if cfg.flash_attention:
if not cfg.sample_packing:
@@ -511,12 +474,6 @@ def load_model(
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif cfg.sdp_attention:
model_kwargs["attn_implementation"] = "sdpa"
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
elif cfg.eager_attention:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = "eager" # pylint: disable=protected-access
try:
if (
@@ -529,6 +486,8 @@ def load_model(
model = LlamaForCausalLM.from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
@@ -596,6 +555,8 @@ def load_model(
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -627,6 +588,8 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -634,9 +597,6 @@ def load_model(
LOG.exception(err)
raise err
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
model = model.merge_and_unload()
embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
@@ -685,7 +645,7 @@ def load_model(
if not cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
for name, module in model.named_modules():
if "norm" in name or name.endswith(".gate"):
if any(m in name for m in ["norm", "gate"]):
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
@@ -698,9 +658,7 @@ def load_model(
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
from deepspeed.utils import set_z3_leaf_modules
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
@@ -763,8 +721,6 @@ def load_model(
model.config.use_cache = False
if cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
@@ -791,7 +747,7 @@ def load_adapter(model, cfg, adapter, inference=False):
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import AdaptionPromptConfig, get_peft_model
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
@@ -837,7 +793,7 @@ def find_all_linear_names(model):
def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])

View File

@@ -117,7 +117,7 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0,
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_size = None
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
@@ -147,13 +147,7 @@ class MultipackBatchSampler(BatchSampler):
n=1,
)
batches = [
[
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
]
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
@@ -195,7 +189,7 @@ class MultipackBatchSampler(BatchSampler):
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// (self.batch_max_len * self.batch_size)
// self.batch_max_len
)
- 1
),

View File

@@ -0,0 +1,61 @@
import torch
import torch.nn.functional as F
def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False):
# pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask)
if index >= nonzero_total:
return False
if pairs and (index // 2) >= (nonzero_total // 2):
return False
if pad_val and (data == pad_val).all(dim=0).all():
return False
return True
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
split_tensors = []
counts = count_nonzero_sequences(cu_seqlens)
# Iterate over each batch
for i in range(tensor.size(0)):
seq_lens = cu_seqlens[i]
start_idx = 0
# Iterate over the cumulative sequence lengths
for j, end_idx in enumerate(seq_lens[1:]):
if end_idx == start_idx:
break
# Extract and pad the current sequence
current_seq = tensor[i, start_idx:end_idx]
keep = True
if keep_fn:
keep = keep_fn(current_seq, index=j, nonzero_total=counts[i])
if not keep:
continue
padding_size = max_seqlen - current_seq.size(0)
padded_seq = F.pad(current_seq, (0, 0) * (current_seq.dim() - 2) + (0, padding_size))
# Append the padded sequence to the list
split_tensors.append(padded_seq)
# Update start index for the next sequence
start_idx = end_idx
# Stack the padded tensors
return torch.stack(split_tensors, dim=0)
def count_nonzero_sequences(cu_seqlens: torch.Tensor) -> torch.LongTensor:
diffs = torch.diff(cu_seqlens, dim=1, prepend=torch.zeros(cu_seqlens.shape[0], 1, dtype=cu_seqlens.dtype))
valid_lengths = diffs != 0
counts = valid_lengths.sum(dim=1).long()
return counts
# Example usage
# Example tensor with dimensions [batch_size, seq_len, other_dimensions...]
# example_tensor = torch.randn(batch_size, seq_len, other_dimensions...)
# cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
# split_padded_tensor = split_and_pad_packed(example_tensor, cu_seqlens, max_seqlen)

View File

@@ -237,17 +237,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
main_process_only=True,
)
else:
if cfg.flash_attention:
batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else:
batch_size = cfg.micro_batch_size
batch_max_len = cfg.sequence_len
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
batch_size=batch_size,
batch_size=cfg.micro_batch_size,
drop_last=True,
batch_max_len=batch_max_len,
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
lengths=get_dataset_lengths(train_dataset),
)
@@ -255,7 +249,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader) // batch_size
data_loader_len = len(data_loader)
actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends

View File

@@ -1,114 +0,0 @@
"""
E2E tests for multipack fft llama using 4d attention masks
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_1_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class Test4dMultipackLlama(unittest.TestCase):
"""
Test case for Llama models using 4d attention with multipack
"""
@require_torch_2_1_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": False,
"sdp_attention": True,
"sample_packing": True,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": False,
"sdp_attention": False,
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -33,7 +33,6 @@ class TestFusedLlama(unittest.TestCase):
{
"base_model": "JackFram/llama-68m",
"flash_attention": True,
"pad_to_sequence_len": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -61,7 +63,6 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
@@ -102,9 +103,12 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -1,68 +0,0 @@
"""
E2E tests for relora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestReLoraLlama(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_relora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora_steps": 25,
"relora_warmup_steps": 5,
"relora_anneal_steps": 5,
"relora_cpu_offload": True,
"val_set_size": 0.0,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"warmup_steps": 15,
"num_epochs": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -4,9 +4,7 @@ helper utils for tests
import os
import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path
@@ -33,15 +31,3 @@ def most_recent_subdir(path):
subdir = max(subdirectories, key=os.path.getctime)
return subdir
def require_torch_2_1_1(test_case):
"""
Decorator marking a test that requires torch >= 2.1.1
"""
def is_min_2_1_1():
torch_version = version("torch")
return torch_version >= "2.1.1"
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)

View File

@@ -30,20 +30,6 @@ class TestMonkeyPatchUtils(unittest.TestCase):
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_cu_seqlens_from_pos_ids_2d(self):
position_ids = torch.tensor(
[
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
]
)
target_res = torch.tensor(
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
)
self.assertTrue(
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_max_seqlen_in_batch(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)

View File

@@ -1,99 +0,0 @@
"""Module for testing streaming dataset sequence packing"""
import pytest
from datasets import concatenate_datasets, load_dataset
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
return 4096
class TestBatchedSamplerPacking:
"""
Test class for packing streaming dataset sequences
"""
@pytest.mark.parametrize(
"batch_size, num_workers",
[
(1, 0),
(2, 0),
(1, 2),
(2, 2),
],
)
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset(
"Trelis/tiny-shakespeare",
split="train",
)
cfg = DictDefault(
{
"train_on_inputs": True,
"sequence_len": max_seq_length,
}
)
ds_cfg = DictDefault(
{
"field": "Text",
}
)
completion_strategy = load(tokenizer, cfg, ds_cfg)
dataset_wrapper = TokenizedPromptDataset(
completion_strategy,
dataset,
)
train_dataset = concatenate_datasets([dataset_wrapper])
batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=max_seq_length,
lengths=get_dataset_lengths(train_dataset),
)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
tokenizer=tokenizer,
padding=True,
pad_to_multiple_of=max_seq_length,
return_tensors="pt",
),
num_workers=num_workers,
)
inputs = next(iter(loader))
assert inputs["input_ids"].shape == (batch_size, max_seq_length)
assert inputs["labels"].shape == (batch_size, max_seq_length)
assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
assert inputs["input_ids"].tolist()[0][0] == 2
assert inputs["labels"].tolist()[0][0] == -100
assert inputs["attention_mask"].tolist()[0][0] == 0
assert inputs["attention_mask"].tolist()[0][-1] > 1
if batch_size >= 2:
assert inputs["input_ids"].tolist()[1][0] == 2
assert inputs["labels"].tolist()[1][0] == -100
assert inputs["attention_mask"].tolist()[1][0] == 0
assert inputs["attention_mask"].tolist()[1][-1] > 1

View File

@@ -1,17 +1,17 @@
"""Module for testing streaming dataset sequence packing"""
import functools
import unittest
from functools import partial
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data import encode_packed_pretraining
class TestPretrainingPacking(unittest.TestCase):
class TestPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""
@@ -20,6 +20,8 @@ class TestPretrainingPacking(unittest.TestCase):
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
self.max_seq_length = 2048
self.batch_size = 2
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
@@ -29,43 +31,30 @@ class TestPretrainingPacking(unittest.TestCase):
streaming=True,
)["train"]
cfg = DictDefault(
{
"pretraining_dataset": [
{
"path": "c4",
"name": "en",
"type": "pretrain",
}
],
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"micro_batch_size": 2,
}
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=self.max_seq_length,
)
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
encode = partial(
encode_packed_pretraining,
self.tokenizer,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
collate_fn,
max_seq_length=self.max_seq_length,
batch_size=self.batch_size,
)
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
dataset,
self.tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=dataset.features.keys(),
)
trainer_loader = DataLoader(
train_dataset,
dataset,
batch_size=1,
collate_fn=None,
drop_last=True,
@@ -75,16 +64,16 @@ class TestPretrainingPacking(unittest.TestCase):
if idx > 10:
break
assert data["input_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
[1, self.batch_size * self.max_seq_length]
)
assert data["position_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
[1, self.batch_size * self.max_seq_length]
)
assert data["labels"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
[1, self.batch_size * self.max_seq_length]
)
assert data["attention_mask"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
[1, self.batch_size * self.max_seq_length]
)
idx += 1

View File

@@ -67,21 +67,6 @@ class TestTokenizers(unittest.TestCase):
)
load_tokenizer(cfg)
def test_add_additional_special_tokens(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"special_tokens": {"additional_special_tokens": ["<|im_start|>"]},
}
)
tokenizer = load_tokenizer(cfg)
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)
# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,98 +0,0 @@
"""
This module is used to launch Axolotl with user defined configurations.
"""
import gradio as gr
import yaml
def config(
base_model,
dataset,
dataset_type,
learn_rate,
gradient_accumulation_steps,
micro_batch_size,
seq_length,
num_epochs,
output_dir,
val_size,
):
"""
This function generates a configuration dictionary and saves it as a yaml file.
"""
config_dict = {
"base_model": base_model,
"datasets": [{"path": dataset, "type": dataset_type}],
"learning_rate": learn_rate,
"gradient_accumulation_steps": gradient_accumulation_steps,
"micro_batch_size": micro_batch_size,
"sequence_len": seq_length,
"num_epochs": num_epochs,
"output_dir": output_dir,
"val_set_size": val_size,
}
with open("config.yml", "w", encoding="utf-8") as file:
yaml.dump(config_dict, file)
print(config_dict)
return yaml.dump(config_dict)
with gr.Blocks(title="Axolotl Launcher") as demo:
gr.Markdown(
"""
# Axolotl Launcher
Fill out the required fields below to create a training run.
"""
)
with gr.Row():
base_model_name = gr.Textbox(
"TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", label="Base model"
)
mode = gr.Radio(
choices=["Full finetune", "QLoRA", "LoRA"],
label="Training mode",
info="FFT = 16 bit, Qlora = 4 bit, Lora = 8 bit",
)
with gr.Row():
dataset_path = gr.Textbox("mhenrichsen/alpaca_2k_test", label="Dataset")
dataset_type_name = gr.Dropdown(
choices=["alpaca", "sharegpt"], label="Dataset type", value="alpaca"
)
with gr.Accordion("Hyperparameters", open=False):
gr.Markdown("Choose hyperparameters")
with gr.Row():
learning_rate = gr.Number(0.000001, label="Learning rate")
gradient_accumulation_steps_count = gr.Number(
1, label="Gradient accumulation steps"
)
val_set_size_count = gr.Number(0, label="Validation size")
with gr.Row():
micro_batch_size_count = gr.Number(1, label="Micro batch size")
sequence_length = gr.Number(1024, label="Sequence length")
num_epochs_count = gr.Number(1, label="Epochs")
output_dir_path = gr.Textbox("./model-out", label="Output directory")
create_config = gr.Button("Create config")
output = gr.TextArea(label="Generated config")
create_config.click(
config,
inputs=[
base_model_name,
dataset_path,
dataset_type_name,
learning_rate,
gradient_accumulation_steps_count,
micro_batch_size_count,
sequence_length,
num_epochs_count,
output_dir_path,
val_set_size_count,
],
outputs=output,
)
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860)