Compare commits

..

3 Commits

Author SHA1 Message Date
Wing Lian
d46d7dfe30 wip 2024-02-01 00:28:16 -05:00
Wing Lian
047d9e1d5b helper utils 2024-01-31 12:49:29 -05:00
Wing Lian
88a0c05d2c wip 2024-01-31 12:07:39 -05:00
58 changed files with 771 additions and 1490 deletions

View File

@@ -7,7 +7,7 @@ jobs:
build-base:
if: github.repository_owner == 'OpenAccess-AI-Collective'
# this job needs to be run on self-hosted GPU runners...
runs-on: axolotl-gpu-runner
runs-on: self-hosted
strategy:
fail-fast: false
matrix:

View File

@@ -9,6 +9,7 @@ on:
jobs:
build-axolotl:
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
# this job needs to be run on self-hosted GPU runners...
strategy:
fail-fast: false
matrix:
@@ -34,7 +35,7 @@ jobs:
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras:
runs-on: axolotl-gpu-runner
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
uses: actions/checkout@v4
@@ -55,16 +56,27 @@ jobs:
uses: docker/build-push-action@v5
with:
context: .
load: true
build-args: |
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
CUDA=${{ matrix.cuda }}
PYTORCH_VERSION=${{ matrix.pytorch }}
file: ./docker/Dockerfile
push: ${{ github.event_name != 'pull_request' }}
tags: |
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
labels: ${{ steps.metadata.outputs.labels }}
- name: Unit Tests
run: |
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
- name: Push to Docker Hub
if: github.event_name != 'pull_request'
run: |
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
if [ -n "$latest_tag" ]; then
docker push "$latest_tag"
fi
build-axolotl-runpod:
needs: build-axolotl
@@ -94,7 +106,7 @@ jobs:
python_version: "3.11"
pytorch: 2.1.2
axolotl_extras:
runs-on: axolotl-gpu-runner
runs-on: [self-hosted, gpu, docker]
steps:
- name: Checkout
uses: actions/checkout@v4

View File

@@ -32,9 +32,6 @@ ignore_missing_imports = True
[mypy-bitsandbytes]
ignore_missing_imports = True
[mypy-requests]
ignore_missing_imports = True
[mypy-datasets]
ignore_missing_imports = True

View File

@@ -25,8 +25,8 @@ Features:
- [Installation](#installation)
- [Docker](#docker)
- [Conda/Pip venv](#condapip-venv)
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
- [Cloud GPU](#cloud-gpu) - Runpod, Latitude
- [LambdaLabs](#lambdalabs)
- [Windows](#windows)
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
- [Dataset](#dataset)
@@ -37,9 +37,6 @@ Features:
- [Inference](#inference)
- [Merge LORA to Base](#merge-lora-to-base)
- [Special Tokens](#special-tokens)
- Advanced Topics
- [Multipack](./docs/multipack.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [RLHF & DPO](./docs/rlhf.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
- [Common Errors](#common-errors-)
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
- [Debugging Axolotl](#debugging-axolotl)
@@ -121,10 +118,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
# gradio
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
--lora_model_dir="./lora-out" --gradio
# remote yaml files - the yaml config can be hosted on a public URL
# Note: the yaml config must directly link to the **raw** yaml
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
```
## Installation
@@ -186,13 +179,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
#### Bare Metal Cloud GPU
##### LambdaLabs
#### LambdaLabs
<details>
<summary>Click to Expand</summary>
@@ -472,12 +461,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
dataset:
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
...
# Loading Data From a Public URL
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
dataset:
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
ds_type: json # this is the default, see other options below.
```
- loading
@@ -990,9 +973,6 @@ Run
accelerate launch -m axolotl.cli.train your_config.yml
```
> [!TIP]
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
#### Preprocess dataset
You can optionally pre-tokenize dataset with the following before finetuning.
@@ -1172,11 +1152,9 @@ Having misalignment between your prompts during training and inference can cause
See [this debugging guide](docs/debugging.md) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
## Need help? 🙋
## Need help? 🙋♂️
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
Need dedicated support? Please contact us at [✉wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
## Badge ❤🏷️
@@ -1217,12 +1195,6 @@ pre-commit install
pytest tests/
```
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
</a>
## Sponsors 🤝❤
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),

View File

@@ -11,7 +11,6 @@ EXPOSE 8888
EXPOSE 22
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
COPY scripts/motd /etc/motd
RUN pip install jupyterlab notebook ipywidgets && \
jupyter lab clean
@@ -19,7 +18,6 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
mkdir -p ~/.ssh && \
chmod 700 ~/.ssh && \
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
chmod +x /root/cloud-entrypoint.sh

Binary file not shown.

Before

Width:  |  Height:  |  Size: 239 KiB

View File

@@ -1,11 +1,4 @@
# Multipack (Sample Packing)
## Visualization of Multipack with Flash Attention
Because Flash Attention simply drops the attention mask, we do not need to
construct a 4d attention mask. We only need to concatenate the sequences into
a single batch and let flash attention know where each new sequence begins.
# Multipack
4k context, bsz =4,
each character represents 256 tokens
@@ -56,18 +49,3 @@ w packing ( note it's the same effective number of tokens per step, but a true b
E E E E F F F F F G G G H H H H
I I I J J J J K K K K K L L L X ]]
```
cu_seqlens:
[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
## Multipack without Flash Attention
Multipack can still be achieved without Flash attention, but with lower packing
efficiency as we are not able to join multiple batches into a single batch due to
context length limits without flash attention. We can use either Pytorch's Scaled
Dot Product Attention implementation or native Pytorch attention implementation
along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
to pack sequences together and avoid cross attention.
<img src="./images/4d-mask.png" alt="axolotl" width="800">

View File

@@ -12,8 +12,8 @@ feedback. Various methods include, but not limited to:
### RLHF using Axolotl
>[!IMPORTANT]
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
[!IMPORTANT]
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML

View File

@@ -43,7 +43,6 @@
},
"outputs": [],
"source": [
"!pip install torch==\"2.1.2\"\n",
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
"!pip install flash-attn==\"2.5.0\"\n",
"!pip install deepspeed==\"0.13.1\""

View File

@@ -1,65 +0,0 @@
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
model_type: LlamaForCausalLM
tokenizer_type: LlamaTokenizer
is_llama_derived_model: true
load_in_8bit: true
load_in_4bit: false
strict: false
datasets:
- path: mhenrichsen/alpaca_2k_test
type: alpaca
dataset_prepared_path:
val_set_size: 0
output_dir: ./lora-out
sequence_len: 4096
sample_packing: true
pad_to_sequence_len: true
eval_sample_packing: false
adapter: lora
lora_model_dir:
lora_r: 32
lora_alpha: 16
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
wandb_project:
wandb_entity:
wandb_watch:
wandb_name:
wandb_log_model:
gradient_accumulation_steps: 4
micro_batch_size: 2
num_epochs: 4
optimizer: adamw_torch
lr_scheduler: cosine
learning_rate: 0.0002
train_on_inputs: false
group_by_length: false
bf16: auto
fp16: false
tf32: true
gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint:
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: false
warmup_steps: 10
evals_per_epoch: 0
saves_per_epoch: 1
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:

View File

@@ -10,9 +10,8 @@ strict: false
max_steps: 200
pretraining_dataset:
- path: c4
name: en
type: pretrain
path: c4
name: en
dataset_prepared_path:
val_set_size: 0.0
output_dir: ./model-out

View File

@@ -1,4 +1,3 @@
pre-commit
black
mypy
types-requests

View File

@@ -1,7 +1,7 @@
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
packaging==23.2
peft @ git+https://github.com/huggingface/peft.git
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
transformers==4.37.0
tokenizers==0.15.0
bitsandbytes>=0.41.1
accelerate==0.26.1
@@ -9,7 +9,6 @@ deepspeed>=0.13.1
addict
fire
PyYAML>=6.0
requests
datasets>=2.15.0
flash-attn==2.3.3
sentencepiece

View File

@@ -1,17 +0,0 @@
dP dP dP
88 88 88
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
88' `88 `8bd8' 88' `88 88 88' `88 88 88
88. .88 .d88b. 88. .88 88 88. .88 88 88
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
```
cd /workspace
rm -rf /workspace/axolotl
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
cd axolotl
pip install --no-deps -e .
```

View File

@@ -1,7 +1,5 @@
"""setup.py for axolotl"""
import platform
import re
from importlib.metadata import PackageNotFoundError, version
from setuptools import find_packages, setup
@@ -28,25 +26,11 @@ def parse_requirements():
_install_requires.append(line)
try:
if "Darwin" in platform.system():
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
if torch_version.startswith("2.1."):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
else:
torch_version = version("torch")
_install_requires.append(f"torch=={torch_version}")
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
if version_match:
major, minor, patch = version_match.groups()
major, minor = int(major), int(minor)
patch = (
int(patch) if patch is not None else 0
) # Default patch to 0 if not present
else:
raise ValueError("Invalid version format")
if (major, minor) >= (2, 1):
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
_install_requires.append("xformers>=0.0.23")
_install_requires.append("xformers>=0.0.23")
except PackageNotFoundError:
pass

View File

@@ -1,20 +1,16 @@
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
import importlib
import json
import logging
import math
import os
import random
import sys
import tempfile
from pathlib import Path
from threading import Thread
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
import gradio as gr
import requests
import torch
import yaml
@@ -63,52 +59,6 @@ def print_axolotl_text_art(suffix=None):
print(ascii_art)
def check_remote_config(config: Union[str, Path]):
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
if not (isinstance(config, str) and config.startswith("https://")):
return config # Return the original value if it's not a valid URL
filename = os.path.basename(urlparse(config).path)
temp_dir = tempfile.mkdtemp()
try:
response = requests.get(config, timeout=30)
response.raise_for_status() # Check for HTTP errors
content = response.content
try:
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
json.loads(content)
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
LOG.warning(
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
)
except json.JSONDecodeError:
# If it's not valid JSON, verify it's valid YAML
try:
yaml.safe_load(content)
except yaml.YAMLError as err:
raise ValueError(
f"Failed to parse the content at {config} as YAML: {err}"
) from err
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
output_path = Path(temp_dir) / filename
with open(output_path, "wb") as file:
file.write(content)
LOG.info(
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
)
return output_path
except requests.RequestException as err:
# This catches all requests-related exceptions including HTTPError
raise RuntimeError(f"Failed to download {config}: {err}") from err
except Exception as err:
# Catch-all for any other exceptions
raise err
def get_multi_line_input() -> Optional[str]:
print("Give me an instruction (Ctrl + D to submit): ")
instruction = ""
@@ -320,10 +270,9 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
return not any(el in list2 for el in list1)
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
config = check_remote_config(config)
def load_cfg(config: Path = Path("examples/"), **kwargs):
if Path(config).is_dir():
config = choose_config(Path(config))
config = choose_config(config)
# load the config from the yaml file
with open(config, encoding="utf-8") as file:

View File

@@ -3,7 +3,6 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -24,7 +23,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
LOG = logging.getLogger("axolotl.cli.preprocess")
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)

View File

@@ -3,7 +3,6 @@ CLI to shard a trained model into 10GiB chunks
"""
import logging
from pathlib import Path
from typing import Union
import fire
import transformers
@@ -26,7 +25,7 @@ def shard(
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
print_axolotl_text_art()
parsed_cfg = load_cfg(config, **kwargs)

View File

@@ -3,12 +3,11 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Tuple, Union
from typing import Tuple
import fire
from transformers.hf_argparser import HfArgumentParser
from transformers.modeling_utils import PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
import transformers
from transformers import PreTrainedModel, PreTrainedTokenizer
from axolotl.cli import (
check_accelerate_default_config,
@@ -25,10 +24,10 @@ from axolotl.train import train
LOG = logging.getLogger("axolotl.cli.train")
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
def do_cli(config: Path = Path("examples/"), **kwargs):
# pylint: disable=duplicate-code
parsed_cfg = load_cfg(config, **kwargs)
parser = HfArgumentParser((TrainerCliArgs))
parser = transformers.HfArgumentParser((TrainerCliArgs))
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
return_remaining_strings=True
)

View File

@@ -6,7 +6,6 @@ import logging
from dataclasses import dataclass, field
from typing import Optional
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
from axolotl.logging_config import configure_logging
from axolotl.utils.dict import DictDefault
from axolotl.utils.models import load_model, load_tokenizer

View File

@@ -8,15 +8,17 @@ import importlib
import logging
import math
import sys
import typing
from abc import abstractmethod
from dataclasses import dataclass, field
from functools import wraps
from functools import wraps, partial
from pathlib import Path
from typing import List, Optional, Type, Union
from typing import Dict, List, Optional, Tuple, Type, Union
import torch
import transformers
from datasets import Dataset
from torch import nn
from torch.optim.lr_scheduler import OneCycleLR
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
from transformers import (
@@ -28,8 +30,8 @@ from transformers import (
from transformers.trainer_utils import seed_worker
from trl import DPOTrainer
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
from axolotl.utils.callbacks import (
EvalFirstStepCallback,
GPUStatsCallback,
@@ -51,12 +53,20 @@ from axolotl.utils.schedulers import (
get_cosine_schedule_with_min_lr,
get_cosine_schedule_with_quadratic_warmup,
)
from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed
try:
import torch._dynamo # pylint: disable=ungrouped-imports
except ImportError:
pass
if typing.TYPE_CHECKING:
# hacky, but recommended per https://github.com/python/mypy/issues/5837
_MixinTrainerBase = Trainer
else:
_MixinTrainerBase = object
LOG = logging.getLogger("axolotl.core.trainer_builder")
@@ -99,10 +109,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=False,
metadata={"help": "Use sample packing for efficient training."},
)
multipack_real_batches: bool = field(
default=False,
metadata={"help": "Use real batches for efficient training."},
)
eval_sample_packing: Optional[bool] = field(
default=None,
metadata={"help": "Use sample packing for efficient evals."},
@@ -127,10 +133,6 @@ class AxolotlTrainingArguments(TrainingArguments):
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
relora_anneal_steps: Optional[int] = field(
default=None,
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
)
bench_split: Optional[str] = field(
default="eval", metadata={"help": "The benchmark split to run on"}
)
@@ -162,7 +164,142 @@ class AxolotlTrainingArguments(TrainingArguments):
)
class AxolotlTrainer(Trainer):
class AxolotlMultiPackTrainerMixin(_MixinTrainerBase): # type: ignore
"""Trainer Mixin class for dataloaders and samplers"""
args = None # type: AxolotlTrainingArguments
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
self.args.train_batch_size,
drop_last=True,
batch_max_len=self._train_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
self.args.per_device_eval_batch_size,
drop_last=True,
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
class AxolotlTrainer(AxolotlMultiPackTrainerMixin, Trainer):
"""
Extend the base Trainer for axolotl helpers
"""
@@ -236,151 +373,6 @@ class AxolotlTrainer(Trainer):
return self.lr_scheduler
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and not self.args.pretraining:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_train_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_train_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
RandomSampler(self.train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(self.train_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_train_sampler()
def _get_eval_sampler(
self, eval_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.sample_packing and self.args.eval_sample_packing is not False:
if self.args.multipack_real_batches:
batch_size = self.args.per_device_eval_batch_size
batch_max_len = self.args.max_seq_length
else:
batch_size = 1
batch_max_len = (
self.args.per_device_eval_batch_size * self.args.max_seq_length
)
return MultipackBatchSampler(
SequentialSampler(eval_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_max_len,
lengths=get_dataset_lengths(eval_dataset),
packing_efficiency_estimate=self.args.sample_packing_efficiency,
)
return super()._get_eval_sampler(eval_dataset)
def get_train_dataloader(self) -> DataLoader:
if self.args.sample_packing and not self.args.pretraining:
train_dataset = self.train_dataset
if "length" in train_dataset.features.keys():
train_dataset = train_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self._train_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
sampler = self._get_train_sampler()
if isinstance(sampler, BatchSampler):
dataloader_params["batch_sampler"] = sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
dataloader_params["worker_init_fn"] = seed_worker
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(train_dataset, **dataloader_params)
)
return super().get_train_dataloader()
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
if self.args.sample_packing and self.args.eval_sample_packing is False:
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.eval_data_collator
)
dataloader = super().get_eval_dataloader(eval_dataset)
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
self.train_data_collator
)
return dataloader
if self.args.sample_packing and self.args.eval_sample_packing is not False:
eval_dataset = (
eval_dataset if eval_dataset is not None else self.eval_dataset
)
eval_sampler = self._get_eval_sampler(eval_dataset)
eval_dataset = eval_dataset.remove_columns(["length"])
data_collator = self.data_collator
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params[
"prefetch_factor"
] = self.args.dataloader_prefetch_factor
if isinstance(eval_sampler, BatchSampler):
dataloader_params["batch_sampler"] = eval_sampler
del dataloader_params["batch_size"]
else:
dataloader_params["sampler"] = eval_sampler
dataloader_params["drop_last"] = self.args.dataloader_drop_last
self.accelerator.even_batches = False
return self.accelerator.prepare_data_loader(
DataLoader(eval_dataset, **dataloader_params)
)
return super().get_eval_dataloader(eval_dataset)
def _get_bench_sampler(
self, bench_dataset: Dataset
) -> Optional[torch.utils.data.Sampler]:
if self.args.world_size <= 1:
return SequentialSampler(bench_dataset)
return None
def get_bench_dataloader(
self,
bench_dataset: Dataset,
) -> DataLoader:
dataloader_params = {
"batch_size": self.args.eval_batch_size,
"collate_fn": self.bench_data_collator,
"num_workers": self.args.dataloader_num_workers,
"pin_memory": self.args.dataloader_pin_memory,
}
if self.args.dataloader_prefetch_factor:
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
dataloader_params["drop_last"] = self.args.dataloader_drop_last
return DataLoader(bench_dataset, **dataloader_params)
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
def compute_loss(self, model, inputs, return_outputs=False):
# use one's weighted cross entropy loss calc
# if self.args.sample_packing:
@@ -483,14 +475,10 @@ class ReLoRATrainer(AxolotlTrainer):
warmup_steps = (
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
)
anneal_steps = (
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
)
self.lr_scheduler = ReLoRAScheduler(
optimizer,
lr_scheduler,
self.args.relora_steps,
anneal_steps,
warmup_steps,
)
else:
@@ -499,7 +487,7 @@ class ReLoRATrainer(AxolotlTrainer):
return self.lr_scheduler
class AxolotlDPOTrainer(DPOTrainer):
class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
@@ -516,6 +504,59 @@ class AxolotlDPOTrainer(DPOTrainer):
return super().push_to_hub(*args, **kwargs)
def tokenize_row(self, feature, *args, **kwargs) -> Dict:
# check if dataset is already tokenized
if not self.is_encoder_decoder:
keys = [
"chosen_input_ids",
"chosen_attention_mask",
"chosen_labels",
"rejected_input_ids",
"rejected_attention_mask",
"rejected_labels",
]
if all(k in feature.keys() for k in keys):
return feature
else:
keys = [
"chosen_labels",
"rejected_labels",
"prompt_input_ids",
"prompt_attention_mask",
]
if all(k in feature.keys() for k in keys):
return feature
return super().tokenize_row(feature, *args, **kwargs)
def concatenated_forward(
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
) -> Tuple[
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
]:
all_logits = model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
position_ids=batch["position_ids"],
).logits
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True)
unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn)
labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True)
unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn)
unpacked_logps = self.get_batch_logps(
unpacked_logits,
unpacked_labels,
average_log_prob=self.loss_type == "ipo",
is_encoder_decoder=self.is_encoder_decoder,
label_pad_token_id=self.label_pad_token_id,
)
chosen_logps = unpacked_logps[::2]
rejected_logps = unpacked_logps[1::2]
chosen_logits = unpacked_logits[::2]
rejected_logits = unpacked_logits[1::2]
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
class TrainerBuilderBase(abc.ABC):
"""
@@ -889,9 +930,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
training_arguments_kwargs["sample_packing"] = (
self.cfg.sample_packing if self.cfg.sample_packing else False
)
training_arguments_kwargs["multipack_real_batches"] = (
self.cfg.flash_attention is not True
)
training_arguments_kwargs["eval_sample_packing"] = (
self.cfg.sample_packing
if self.cfg.eval_sample_packing is not False
@@ -902,7 +940,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
] = self.cfg.micro_batch_size
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
training_arguments_kwargs = self.hook_pre_create_training_args(
training_arguments_kwargs
)
@@ -995,12 +1032,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
]
]
if use_batch_sampler_collator:
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
collator = V2BatchSamplerDataCollatorForSeq2Seq
elif (
self.cfg.model_config_type in ["llama"]
and self.cfg.flash_attention is not True
):
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
collator = V2BatchSamplerDataCollatorForSeq2Seq
else:
collator = BatchSamplerDataCollatorForSeq2Seq
@@ -1097,21 +1129,13 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
"use_reentrant": False
}
# set save_strategy and save_steps
if self.cfg.save_steps:
training_args_kwargs["save_strategy"] = "steps"
training_args_kwargs["save_steps"] = self.cfg.save_steps
elif self.cfg.save_strategy:
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
else:
# default to saving each epoch if not defined
training_args_kwargs["save_strategy"] = "epoch"
training_args = TrainingArguments(
per_device_train_batch_size=self.cfg.micro_batch_size,
max_steps=self.cfg.max_steps or total_num_steps,
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
learning_rate=self.cfg.learning_rate,
save_strategy="steps",
save_steps=self.cfg.save_steps,
output_dir=self.cfg.output_dir,
warmup_steps=self.cfg.warmup_steps,
logging_first_step=True,
@@ -1154,6 +1178,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
callbacks=self.get_callbacks(),
**dpo_trainer_kwargs,
)
setattr(dpo_trainer, "use_dpo_data_collator", True)
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
dpo_trainer.add_callback(callback)

View File

@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
def __init__( # pylint: disable=super-init-not-called
self,
prompt_tokenizer: PromptTokenizingStrategy,
dataset: Dataset,
dataset: IterableDataset,
process_count: Optional[int] = None,
keep_in_memory: Optional[bool] = False,
**kwargs,

View File

@@ -1,46 +0,0 @@
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
# pylint: disable=protected-access
import torch
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
from torch.utils.data._utils.worker import _worker_loop
class _MapDatasetFetcher(_BaseDatasetFetcher):
def fetch(self, possibly_batched_index):
if isinstance(possibly_batched_index[0], list):
data = [None for i in possibly_batched_index]
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
if self.auto_collation:
if (
hasattr(self.dataset, "__getitems__")
and self.dataset.__getitems__
):
data[i] = self.dataset.__getitems__(possibly_batched_index_)
else:
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
else:
data[i] = self.dataset[possibly_batched_index_]
else:
if self.auto_collation:
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
data = self.dataset.__getitems__(possibly_batched_index)
else:
data = [self.dataset[idx] for idx in possibly_batched_index]
else:
data = self.dataset[possibly_batched_index]
return self.collate_fn(data)
def patch_fetchers():
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
def patched_worker_loop(*args, **kwargs):
patch_fetchers()
return _worker_loop(*args, **kwargs)
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
patch_fetchers()

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for falcon
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_falcon_attn_with_multipack_flash_attn():
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -0,0 +1,142 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
import warnings
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import transformers.models.llama.modeling_llama
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
def hijack_llama_sdp_attention():
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
sdp_attention_forward
)
def sdp_attention_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
**kwargs, # pylint: disable=unused-argument
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
if not hasattr(self, "pretraining_tp"):
self.pretraining_tp = 1
if self.pretraining_tp > 1:
key_value_slicing = (
self.num_key_value_heads * self.head_dim
) // self.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)
key_states = [
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)
value_states = [
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(
bsz, q_len, self.num_heads, self.head_dim
).transpose(1, 2)
key_states = key_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
value_states = value_states.view(
bsz, q_len, self.num_key_value_heads, self.head_dim
).transpose(1, 2)
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids
)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if output_attentions:
warnings.warn(
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
)
#
# sdp-attn start
#
with torch.backends.cuda.sdp_kernel():
attn_output = torch.nn.functional.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
is_causal=False,
)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
#
# sdp-attn end
#
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(
self.hidden_size // self.pretraining_tp, dim=1
)
attn_output = sum(
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.pretraining_tp)
)
else:
attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value

View File

@@ -5,11 +5,38 @@ from typing import Optional
import torch
from axolotl.monkeypatch.utils import mask_2d_to_4d
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1).to(dtype),
torch.tensor(0).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
inverted_mask = 1.0 - masked_zero_one_mask
return inverted_mask.masked_fill(

View File

@@ -1,26 +0,0 @@
"""
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
"""
from axolotl.monkeypatch.utils import (
patched_prepare_4d_causal_attention_mask,
patched_prepare_4d_causal_attention_mask_for_sdpa,
)
def hijack_llama_prepare_4d_mask():
import transformers.modeling_attn_mask_utils
import transformers.models.llama.modeling_llama
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask_for_sdpa
)
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
patched_prepare_4d_causal_attention_mask
)

View File

@@ -2,6 +2,9 @@
Patches to support multipack for mixtral
"""
import torch
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def patch_mixtral_moe_forward_zero3() -> None:
@@ -48,3 +51,11 @@ def patch_mixtral_moe_forward_zero3() -> None:
MixtralBLockSparseTop2MLP.forward = mlp_forward
MixtralSparseMoeBlock.forward = moe_forward
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if for_zero3:
patch_mixtral_moe_forward_zero3()

View File

@@ -1,30 +0,0 @@
"""multipack patching for v2 of sample packing"""
import transformers
from transformers.integrations import is_deepspeed_zero3_enabled
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
from axolotl.monkeypatch.utils import get_unpad_data
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
def patch_for_multipack(model_type):
if model_type == "mixtral":
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
if is_deepspeed_zero3_enabled():
patch_mixtral_moe_forward_zero3()
elif model_type == "qwen2":
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "falcon":
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)
elif model_type == "phi":
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for phi2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_phi_attn_with_multipack_flash_attn():
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -0,0 +1,12 @@
"""
Patches to support multipack for qwen2
"""
import transformers
from axolotl.monkeypatch.utils import get_unpad_data
def replace_qwen2_attn_with_multipack_flash_attn():
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
get_unpad_data
)

View File

@@ -4,16 +4,14 @@ import json
import logging
import os.path
import shutil
from functools import partial
from pathlib import Path
from typing import Dict, List, Sequence, Union
from typing import Dict, List, Sequence
import bitsandbytes as bnb
import peft
import safetensors.torch as st
import torch
from huggingface_hub import snapshot_download
from torch.distributed.optim import ZeroRedundancyOptimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.optim.optimizer import Optimizer
from transformers import (
@@ -25,50 +23,23 @@ from transformers import (
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
from axolotl.utils.dict import DictDefault
from axolotl.utils.distributed import barrier, is_main_process
from axolotl.utils.distributed import is_main_process
LOG = logging.getLogger("axolotl.relora")
@torch.no_grad()
def magnitude_pruning_(tensor, prune_ratio):
tensor_magnitude = torch.abs(tensor)
threshold = torch.quantile(
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
).to(dtype=tensor.dtype)
def reset_optimizer(optimizer: torch.optim.Optimizer):
for group in optimizer.param_groups:
for param in group["params"]:
param_state = optimizer.state[param]
for key in param_state:
if "qmap" in key:
continue
mask = tensor_magnitude > threshold
tensor.mul_(mask.to(dtype=tensor.dtype))
def reset_optimizer(
optimizer: torch.optim.Optimizer,
*,
reset_params: list[str], # where str is the key to a torch.nn.Parameter
optimizer_state_keys: list[str],
):
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
n_zeros = 0
n_total = 0
optimizer_state = optimizer.state
if isinstance(optimizer, ZeroRedundancyOptimizer):
optimizer_state = optimizer.optim.state
for param in reset_params:
param_state = optimizer_state[param]
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
continue
for key in optimizer_state_keys:
pruning_fn(
param_state[key]
) # pruning fn has to be inplace to keep the same keys in the dict
n_total += param_state[key].numel()
n_zeros += torch.sum(param_state[key] == 0).item()
_zeroed = n_zeros / (1e-7 + n_total) * 100
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
if key == "step" and isinstance(param_state[key], int):
param_state[key] = 0
else:
param_state[key] = torch.zeros_like(param_state[key])
class ReLoRACallback(TrainerCallback):
@@ -126,25 +97,6 @@ class ReLoRACallback(TrainerCallback):
"relora",
)
if "adam" in args.optim.lower():
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
else:
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
lora_params = [
n
for n, p in model.named_parameters()
if p.requires_grad and "lora_" in n
]
model.save_pretrained(
os.path.join(
args.output_dir,
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
"adapter",
),
safe_serialization=True,
)
with torch.no_grad():
merge_and_save(
model,
@@ -155,11 +107,7 @@ class ReLoRACallback(TrainerCallback):
actually_save=is_main_process(),
cpu_offload=self.cpu_offload,
)
reset_optimizer(
optimizer,
reset_params=lora_params,
optimizer_state_keys=optimizer_state_keys,
)
reset_optimizer(optimizer)
if self.quantized:
self.last_full_model = checkpoint_folder
@@ -249,13 +197,11 @@ class ReLoRAScheduler(LRScheduler):
inner_schedule: LRScheduler,
relora_steps: int,
warmup_steps: int,
anneal_steps: int = 1,
min_lr_scale: float = 0.001,
) -> None:
self.inner_schedule = inner_schedule
self.relora_steps = relora_steps
self.warmup_steps = warmup_steps
self.anneal_steps = anneal_steps
self.min_lr_scale = min_lr_scale
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
@@ -264,20 +210,10 @@ class ReLoRAScheduler(LRScheduler):
original = self.inner_schedule.get_lr()
step = self.last_epoch
if step < self.relora_steps:
scale = 1
else:
per_relora_progress = step % self.relora_steps
if per_relora_progress < self.warmup_steps:
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
cycle_t = min(
1.0,
(self.relora_steps - per_relora_progress) / self.anneal_steps,
)
else:
cycle_t = 1
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
if isinstance(original, Sequence):
@@ -302,11 +238,7 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
adapter: Union[List[str], str] = layer.active_adapter
if isinstance(adapter, list):
if len(adapter) > 1:
raise ValueError("unhandled relora for multiple adapters")
adapter = adapter[0]
adapter = layer.active_adapter
return (
peft.utils.transpose(
layer.lora_B[adapter].weight.detach().to(device)
@@ -316,7 +248,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
* layer.scaling[adapter]
)
raise ValueError("unhandled lora layer type")
return layer.get_delta_weight().to(device)
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
@@ -341,9 +273,9 @@ def update_weights(
):
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name, True)
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name, True)
target.reset_lora_parameters(adapter_name)
if isinstance(target, peft.tuners.lora.Linear4bit):
# This could be faster, but the quantization of Linear4bit weights occurs
@@ -354,9 +286,7 @@ def update_weights(
target.weight.data = new_weight.cpu()
target.to(device)
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
target.weight.data = (
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
)
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
else:
target.weight.data = new_weight.to(device)
@@ -374,17 +304,14 @@ def merge_and_save(
if not quantized:
for module_name, target in modules.items():
active_adapter = target.active_adapter
if isinstance(active_adapter, list):
active_adapter = active_adapter[0]
update = target.get_delta_weight(active_adapter).detach()
update = target.get_delta_weight(target.active_adapter).detach()
target.weight.data += update
if reinit:
for adapter_name in target.lora_A:
target.reset_lora_parameters(adapter_name, True)
target.reset_lora_parameters(adapter_name)
for adapter_name in target.lora_embedding_A:
target.reset_lora_parameters(adapter_name, True)
target.reset_lora_parameters(adapter_name)
return
os.makedirs(model_dst, exist_ok=True)
@@ -436,7 +363,6 @@ def merge_and_save(
LOG.info(f"saving tensors to {shard_fn}")
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
barrier()
del in_tensors
del out_tensors
torch.cuda.empty_cache()

View File

@@ -1,15 +1,8 @@
"""
Shared utils for the monkeypatches
"""
from typing import Optional
import torch
import torch.nn.functional as F
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask,
_prepare_4d_causal_attention_mask_for_sdpa,
)
from transformers.utils import is_torch_bf16_gpu_available
@torch.jit.script
@@ -96,6 +89,7 @@ def get_cu_seqlens(attn_mask):
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
@torch.jit.script
def get_cu_seqlens_from_pos_ids(position_ids):
"""generate a cumulative sequence length mask for flash attention using pos ids"""
if len(position_ids.shape) == 1:
@@ -141,18 +135,7 @@ def get_cu_seqlens_from_pos_ids(position_ids):
results.append(cu_seqlens)
max_seq_lens.append(max_seq_len)
# Find the maximum value across all tensors
max_value = max(t.max() for t in results)
# Find the length of the longest tensor
max_length = max(t.size(0) for t in results)
# Pad each tensor to the same length and collect them in a list
padded_results = [
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
]
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
def set_module_name(model, name, value):
@@ -166,62 +149,3 @@ def set_module_name(model, name, value):
child_name = name
setattr(parent, child_name, value)
def mask_2d_to_4d(
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
This expansion handles packed sequences so that sequences share the same attention mask integer value
when they attend to each other within that sequence.
This expansion transforms the mask to lower triangular form to prevent future peeking.
"""
bsz, src_len = mask.size()
tgt_len = tgt_len if tgt_len is not None else src_len
mask = mask.unsqueeze(1).unsqueeze(2)
mask = mask.expand(bsz, 1, tgt_len, src_len)
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
binary_mask = torch.where(
mask != 0,
torch.tensor(1, device=mask.device).to(dtype),
torch.tensor(0, device=mask.device).to(dtype),
)
# Create a block-diagonal mask.
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
mask.device
)
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
return masked_zero_one_mask
def patched_prepare_4d_causal_attention_mask(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)
def patched_prepare_4d_causal_attention_mask_for_sdpa(
attention_mask: Optional[torch.Tensor],
*args,
):
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
return _prepare_4d_causal_attention_mask_for_sdpa(
mask_2d_to_4d(attention_mask, dtype=dtype),
*args,
)

View File

@@ -1,28 +0,0 @@
import os
from typing import Callable, Generator, Tuple
import psycopg
import psycopg.conninfo
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
pgsql_conn = os.environ.get("PGSQL_CONN", None)
if not pgsql_conn:
raise ValueError("missing PGSQL_CONN environment variable")
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
def data_generator() -> Generator[Tuple, None, None]:
with psycopg.connect(**conn_dict) as conn:
with conn.cursor() as cur:
page_size = 10
last_id = None
while True:
if last_id:
where_clause = f" WHERE {id_field} > {last_id}"
cur.execute(
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
)
for row in cur.fetchall():
yield row[id_field], dict(row)
return data_generator

View File

@@ -1,33 +0,0 @@
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
from typing import Any, Dict, Optional
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
from axolotl.prompters import ShareGPTPrompterV2
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
conversation = (
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
)
strategy = InstructShareGPTPromptTokenizingStrategy(
# pylint: disable=duplicate-code
ShareGPTPrompterV2(
conversation=conversation,
),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
)
return strategy
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
"""
basic sharegpt strategy to grab conversations from the sample row
"""
def get_conversation_thread(self, prompt):
return [
{"from": "human", "value": prompt["instruction"]},
{"from": "gpt", "value": prompt["output"]},
]

View File

@@ -1,58 +0,0 @@
"""pretraining prompt strategies"""
from typing import Generator
from transformers import BatchEncoding
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
class PretrainTokenizer:
"""basic tokenization class for pretraining"""
def build_prompt(self, prompt) -> Generator[str, None, None]:
yield prompt
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
"""handles tokenization for pretraining with strides"""
@property
def supports_batched(self):
return True
def __init__(self, *args, max_length=None, **kwargs):
super().__init__(*args, **kwargs)
if max_length:
self.max_length = max_length
def _tokenize(
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
) -> BatchEncoding:
res = self.tokenizer(
prompt,
truncation=True,
max_length=self.max_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
res["input_ids"] = [
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
]
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
return res
def tokenize_prompt(self, prompt):
return self._tokenize(prompt["text"])
def load(tokenizer, cfg):
strat = PretrainTokenizationStrategy(
PretrainTokenizer(),
tokenizer,
cfg.train_on_inputs,
cfg.sequence_len,
max_length=cfg.sequence_len * 64,
)
return strat

View File

@@ -11,6 +11,7 @@ import torch
import transformers.modelcard
from accelerate.logging import get_logger
from datasets import Dataset
from optimum.bettertransformer import BetterTransformer
from peft import PeftModel
from pkg_resources import get_distribution # type: ignore
from transformers import PreTrainedModel, PreTrainedTokenizer
@@ -23,11 +24,6 @@ from axolotl.utils.freeze import freeze_parameters_except
from axolotl.utils.models import load_model, load_tokenizer
from axolotl.utils.trainer import setup_trainer
try:
from optimum.bettertransformer import BetterTransformer
except ImportError:
BetterTransformer = None
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
src_dir = os.path.join(project_root, "src")
sys.path.insert(0, src_dir)
@@ -61,21 +57,6 @@ def train(
eval_dataset = dataset_meta.eval_dataset
total_num_steps = dataset_meta.total_num_steps
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
# Load the model and tokenizer
msg = "loading model"
if cfg.adapter:
@@ -98,6 +79,21 @@ def train(
safe_serialization = cfg.save_safetensors is True
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
possible_checkpoints = [
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
]
if len(possible_checkpoints) > 0:
sorted_paths = sorted(
possible_checkpoints,
key=lambda path: int(path.split("-")[-1]),
)
cfg.resume_from_checkpoint = sorted_paths[-1]
LOG.info(
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
)
resume_from_checkpoint = cfg.resume_from_checkpoint
if cfg.unfrozen_parameters:
freeze_parameters_except(model, cfg.unfrozen_parameters)
@@ -128,7 +124,7 @@ def train(
if cfg.local_rank == 0:
def terminate_handler(_, __, model):
if cfg.flash_optimum and BetterTransformer:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
sys.exit(0)
@@ -153,10 +149,7 @@ def train(
pretrain_hooks(cfg, trainer)
if cfg.flash_optimum:
with torch.backends.cuda.sdp_kernel(
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
enable_flash=True,
enable_math=True,
enable_mem_efficient=True,
enable_flash=True, enable_math=True, enable_mem_efficient=True
):
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
else:
@@ -202,16 +195,13 @@ def train(
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
)
elif cfg.local_rank == 0:
if cfg.flash_optimum and BetterTransformer:
if cfg.flash_optimum:
model = BetterTransformer.reverse(model)
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
if not cfg.hub_model_id:
try:
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
except AttributeError:
pass
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
elif cfg.hub_model_id:
# defensively push to the hub to ensure the model card is updated
trainer.push_to_hub()

View File

@@ -47,12 +47,6 @@ def gpu_memory_usage_all(device=0):
return usage, reserved - usage, max(0, smi - reserved)
def mps_memory_usage_all():
usage = torch.mps.current_allocated_memory() / 1024.0**3
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
return usage, reserved - usage, 0
@check_cuda_device(0.0)
def gpu_memory_usage_smi(device=0):
if isinstance(device, torch.device):
@@ -69,10 +63,7 @@ def gpu_memory_usage_smi(device=0):
def log_gpu_memory_usage(log, msg, device):
if torch.backends.mps.is_available():
usage, cache, misc = mps_memory_usage_all()
else:
usage, cache, misc = gpu_memory_usage_all(device)
usage, cache, misc = gpu_memory_usage_all(device)
extras = []
if cache > 0:
extras.append(f"+{cache:.03f}GB cache")

View File

@@ -19,7 +19,6 @@ def chat_templates(user_choice: str):
"""
templates = {
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
}

View File

@@ -132,26 +132,24 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
chunked_data = {}
for feature in features[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(1) * np.array(item[feature])
for item in features
if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
@@ -161,27 +159,28 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
"""
def __call__(self, features, return_tensors=None):
if not isinstance(features[0], list):
features = [features]
out_features = [{} for _ in features]
for i, features_ in enumerate(features):
for feature in features_[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features_)
if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features_ if feature in item
]
out_features[i][feature] = np.concatenate(arrays)
return super().__call__(out_features, return_tensors=return_tensors)
chunked_data = {}
for feature in features[0].keys():
if feature == "length":
continue
if feature == "attention_mask":
arrays = [
(i + 1) * np.array(item[feature])
for i, item in enumerate(features)
if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
else:
arrays = [
np.array(item[feature]) for item in features if feature in item
]
chunked_data[feature] = np.concatenate(arrays)
features = [chunked_data]
return super().__call__(features, return_tensors=return_tensors)
@dataclass
class BatchSamplerDPODataCollatorWithPadding:
@dataclass
class MambaDataCollator:

View File

@@ -202,20 +202,6 @@ def validate_config(cfg):
raise ValueError(
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
)
if (
# pylint: disable=too-many-boolean-expressions
not (cfg.bf16 or cfg.bfloat16)
and (cfg.fp16 or cfg.float16)
and not cfg.adapter
and not cfg.flash_attention
and cfg.sample_packing
):
LOG.warning(
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
)
# ValueError: Attempting to unscale FP16 gradients.
# OR
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
if cfg.max_packed_sequence_len:
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
@@ -322,7 +308,7 @@ def validate_config(cfg):
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
if cfg.fp16 or cfg.bf16:
raise ValueError("AMP is not supported with BetterTransformer")
if cfg.float16 is not True and cfg.bfloat16 is not True:
if cfg.float16 is not True and cfg.bloat16 is not True:
LOG.warning(
"You should probably set bfloat16 or float16 to true to "
"load the model in float16 for BetterTransformers"
@@ -364,24 +350,17 @@ def validate_config(cfg):
+ "point to its path, and remove model_revision from the config."
)
# if cfg.sample_packing and cfg.sdp_attention:
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
# raise ValueError(
# "sample_packing not compatible with sdp_attention. Use flash_attention"
# )
if cfg.sample_packing and cfg.sdp_attention:
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
raise ValueError(
"sample_packing not compatible with sdp_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.xformers_attention:
raise ValueError(
"sample_packing not compatible with xformers_attention. Use flash_attention"
)
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
LOG.warning(
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
"This may work on H100s."
)
if cfg.early_stopping_patience:
if not cfg.save_steps or not cfg.eval_steps:
raise ValueError(
@@ -447,11 +426,7 @@ def validate_config(cfg):
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
)
if (
cfg.val_set_size == 0
and (cfg.eval_steps or cfg.evaluation_strategy)
and not cfg.test_datasets
):
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
raise ValueError(
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
)

View File

@@ -1,23 +1,20 @@
"""Module containing data utilities"""
import functools
import hashlib
import importlib
import logging
from collections import defaultdict
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import yaml
from datasets import (
Dataset,
DatasetDict,
IterableDataset,
concatenate_datasets,
load_dataset,
load_from_disk,
)
from datasets.iterable_dataset import ExamplesIterable
from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HFValidationError
from torch.utils.data import RandomSampler
@@ -67,25 +64,6 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
def get_streaming_dataset(ds_cfg):
path = ds_cfg["path"]
func = None
try:
load_fn = path.split(".")[-1]
module_name = ".".join(load_fn.split(".")[:-1])
mod = importlib.import_module(f".{module_name}", "axolotl")
func = getattr(mod, load_fn)
except Exception:
pass
if func:
data_producer = func(**ds_cfg)
return IterableDataset(ExamplesIterable(data_producer, {}))
else:
split = ds_cfg["split"] or "train"
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
@@ -102,21 +80,20 @@ def prepare_dataset(cfg, tokenizer):
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
)
else:
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
tokenizer,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
path = cfg.pretraining_dataset
name = None
if isinstance(cfg.pretraining_dataset, list) and isinstance(
cfg.pretraining_dataset[0], dict
):
path = cfg.pretraining_dataset[0]["path"]
name = cfg.pretraining_dataset[0]["name"]
train_dataset = wrap_pretraining_dataset(
get_streaming_dataset(cfg.pretraining_dataset[0]),
train_dataset = load_pretraining_dataset(
path,
tokenizer,
cfg,
ds_wrapper_partial,
name=name,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
)
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
@@ -163,7 +140,7 @@ def load_tokenized_prepared_datasets(
+ "|".join(
sorted(
[
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
for d in cfg_datasets
]
)
@@ -350,16 +327,6 @@ def load_tokenized_prepared_datasets(
split=None,
storage_options=storage_options,
)
elif config_dataset.path.startswith("https://"):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
ds_type,
name=config_dataset.name,
data_files=config_dataset.path,
streaming=False,
split=None,
storage_options=storage_options,
)
else:
if isinstance(config_dataset.data_files, str):
fp = hf_hub_download(
@@ -416,9 +383,9 @@ def load_tokenized_prepared_datasets(
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
config_dataset=config_dataset,
dataset=ds,
tokenizer=tokenizer,
cfg=cfg,
dataset=ds,
d_base_type=d_base_type,
d_prompt_style=d_prompt_style,
)
@@ -529,12 +496,7 @@ def load_prepare_datasets(
def get_dataset_wrapper(
config_dataset,
tokenizer,
cfg,
d_base_type,
dataset,
d_prompt_style=None,
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
):
dataset_wrapper = None
dataset_prompter = None
@@ -545,8 +507,7 @@ def get_dataset_wrapper(
}
if (
isinstance(dataset, Dataset)
and "input_ids" in dataset.features
"input_ids" in dataset.features
and "attention_mask" in dataset.features
and "labels" in dataset.features
):
@@ -804,67 +765,76 @@ def encode_pretraining(
return ret
def wrap_pretraining_dataset(
dataset,
tokenizer,
cfg,
ds_wrapper_fn,
max_tokens=2048,
batch_size=1,
seed=42,
buffer_size=10_000,
):
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
if cfg.sample_packing:
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=max_tokens * batch_size,
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
)
encode = functools.partial(
encode_packed_pretraining,
tokenizer,
collate_fn,
ds_wrapper_fn,
max_seq_length=max_tokens,
batch_size=batch_size,
batch_size=cfg.micro_batch_size,
)
# set this to 1 so downstream data_loader doesn't try to increase the batch again
cfg.micro_batch_size = 1
else:
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
dataset = load_dataset(path, streaming=True, split="train", name=name)
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
dataset = dataset.map(
encode,
batched=True,
batch_size=buffer_size,
# input_columns="text",
batch_size=10_000,
input_columns="text",
# remove all the existing columns after mapping since they end up having
# a different length than the encoded/tokenized column
remove_columns=dataset.features.keys(),
desc="Encoding Pretraining",
)
return dataset
def encode_packed_pretraining(
tokenizer: PreTrainedTokenizerBase,
collate_fn,
ds_wrapper: Callable,
examples: Dict[str, List],
examples: List[str],
max_seq_length: int = 2048,
batch_size: int = 4,
) -> Dict[str, List]:
# pylint: disable=duplicate-code
# tokenize all the examples
# rows get split with stride (overlap)
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
res = tokenizer(
examples,
truncation=True,
max_length=max_seq_length - 1,
add_special_tokens=True,
return_overflowing_tokens=True,
stride=256,
)
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
attention_mask = [seq + [1] for seq in res["attention_mask"]]
tokenized_examples = {
"input_ids": input_ids,
"attention_mask": attention_mask,
}
train_dataset = Dataset.from_dict(tokenized_examples)
train_dataset = process_pretraining_datasets_for_packing(
train_dataset, max_seq_length
)
sampler = MultipackBatchSampler(
RandomSampler(train_dataset),
batch_size=1,
batch_size=batch_size,
drop_last=True,
batch_max_len=batch_size * max_seq_length,
lengths=get_dataset_lengths(train_dataset),
@@ -872,23 +842,15 @@ def encode_packed_pretraining(
chunked_data = defaultdict(list)
for batch in sampler:
for data in batch:
features = train_dataset[data]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "num_truncated_tokens" in features:
del features["num_truncated_tokens"]
if "overflow_to_sample_mapping" in features:
del features["overflow_to_sample_mapping"]
if "labels" not in features:
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for data in sampler:
features = train_dataset[data]
features["labels"] = features["input_ids"].copy()
collated_features = collate_fn(features)
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
for feature in features.keys():
if feature == "length":
continue
chunked_data[feature].append(collated_features[feature].squeeze(0))
return chunked_data

View File

@@ -8,13 +8,8 @@ import addict
import bitsandbytes as bnb
import torch
import transformers
from peft import (
LoftQConfig,
PeftConfig,
PeftModel,
PeftModelForCausalLM,
prepare_model_for_kbit_training,
)
from optimum.bettertransformer import BetterTransformer
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
from peft.tuners.lora import QuantLinear
from transformers import ( # noqa: F401
AddedToken,
@@ -29,10 +24,6 @@ from transformers import ( # noqa: F401
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
from axolotl.models.mamba import fix_mamba_attn_for_loss
from axolotl.monkeypatch.multipack import (
SUPPORTED_MULTIPACK_MODEL_TYPES,
patch_for_multipack,
)
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
from axolotl.utils.bench import log_gpu_memory_usage
from axolotl.utils.chat_templates import chat_templates
@@ -170,20 +161,15 @@ def load_tokenizer(cfg):
if getattr(tokenizer, attr_name) is None:
setattr(tokenizer, attr_name, "<|endoftext|>")
additional_special_tokens = None
if cfg.special_tokens:
special_tokens = cfg.special_tokens.to_dict()
additional_special_tokens = special_tokens.pop(
"additional_special_tokens", None
)
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
for k, val in special_tokens.items():
for k, val in cfg.special_tokens.items():
# check if new special token is not already in tokenizer and
# is adapter training to make sure lora_modules_to_save is set
# pylint: disable=too-many-boolean-expressions
if (
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
and (len(tokenizer.encode(val)) > 1)
and cfg.adapter
and (
not cfg.lora_modules_to_save
@@ -227,21 +213,6 @@ def load_tokenizer(cfg):
]
)
# Additional special tokens are a List, and need to be treated differently than regular special
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
# are new tokens.
#
# Usage:
#
# ```py
# special_tokens:
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
# ```
if additional_special_tokens is not None:
tokenizer.add_special_tokens(
{"additional_special_tokens": additional_special_tokens}
)
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
@@ -303,15 +274,8 @@ def load_model(
shifted-sparse attention does not currently support sample packing."
)
if (
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
and cfg.flash_attention
and cfg.sample_packing
):
patch_for_multipack(cfg.model_config_type)
elif cfg.is_llama_derived_model:
# Modify all llama derived models in one block
# Modify all llama derived models in one block
if cfg.is_llama_derived_model:
if cfg.flash_attention:
from axolotl.monkeypatch.llama_attn_hijack_flash import (
replace_llama_attn_with_flash_attn,
@@ -340,13 +304,13 @@ def load_model(
LOG.info("patching with xformers attention")
hijack_llama_attention()
elif cfg.sample_packing:
from axolotl.monkeypatch.llama_patch_multipack import (
hijack_llama_prepare_4d_mask,
elif cfg.sdp_attention:
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
hijack_llama_sdp_attention,
)
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
hijack_llama_prepare_4d_mask()
LOG.info("patching with sdp attention")
hijack_llama_sdp_attention()
elif cfg.s2_attention:
raise NotImplementedError(
"Shifted-sparse attention not currently implemented without flash attention."
@@ -365,6 +329,43 @@ def load_model(
LOG.info("patching mistral with flash attention")
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
if (
cfg.model_config_type == "mixtral"
and cfg.flash_attention
and cfg.sample_packing
):
from axolotl.monkeypatch.mixtral import (
replace_mixtral_attn_with_multipack_flash_attn,
)
LOG.info("patching mixtral with flash attention")
mixtral_patch_kwargs = {}
if is_deepspeed_zero3_enabled():
mixtral_patch_kwargs["for_zero3"] = True
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.falcon import (
replace_falcon_attn_with_multipack_flash_attn,
)
LOG.info("patching falcon with flash attention")
replace_falcon_attn_with_multipack_flash_attn()
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
LOG.info("patching phi with flash attention")
replace_phi_attn_with_multipack_flash_attn()
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
from axolotl.monkeypatch.qwen2 import (
replace_qwen2_attn_with_multipack_flash_attn,
)
LOG.info("patching qwen2 with flash attention")
replace_qwen2_attn_with_multipack_flash_attn()
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
@@ -374,7 +375,7 @@ def load_model(
model_kwargs: Dict[str, Any] = {}
if cfg.model_kwargs:
for key, val in cfg.model_kwargs.items():
for key, val in model_kwargs.items():
model_kwargs[key] = val
max_memory = cfg.max_memory
@@ -409,10 +410,6 @@ def load_model(
model_kwargs["device_map"] = device_map
model_kwargs["torch_dtype"] = cfg.torch_dtype
if torch.backends.mps.is_available():
model_kwargs["device_map"] = "mps:0"
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
# if cfg.rl:
# if torch.cuda.device_count() > 1:
@@ -456,18 +453,6 @@ def load_model(
**bnb_config,
)
if cfg.load_in_8bit and cfg.adapter is not None:
model_kwargs["load_in_8bit"] = True
if cfg.load_in_4bit and cfg.adapter is not None:
model_kwargs["load_in_4bit"] = True
# no longer needed per https://github.com/huggingface/transformers/pull/26610
if "quantization_config" in model_kwargs or cfg.gptq:
if "load_in_8bit" in model_kwargs:
del model_kwargs["load_in_8bit"]
if "load_in_4bit" in model_kwargs:
del model_kwargs["load_in_4bit"]
# sample packing uses custom FA2 patch
if cfg.flash_attention:
if not cfg.sample_packing:
@@ -479,7 +464,7 @@ def load_model(
"flash_attention_2"
)
else:
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
model_kwargs["attn_implementation"] = "flash_attention_2"
model_config._attn_implementation = ( # pylint: disable=protected-access
"flash_attention_2"
@@ -489,12 +474,6 @@ def load_model(
model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif cfg.sdp_attention:
model_kwargs["attn_implementation"] = "sdpa"
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
elif cfg.eager_attention:
model_kwargs["attn_implementation"] = "eager"
model_config._attn_implementation = "eager" # pylint: disable=protected-access
try:
if (
@@ -507,6 +486,8 @@ def load_model(
model = LlamaForCausalLM.from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
**model_kwargs,
)
@@ -574,6 +555,8 @@ def load_model(
model = getattr(transformers, model_type).from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -605,6 +588,8 @@ def load_model(
model = AutoModelForCausalLM.from_pretrained(
base_model,
config=model_config,
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
trust_remote_code=cfg.trust_remote_code or False,
**model_kwargs,
)
@@ -612,9 +597,6 @@ def load_model(
LOG.exception(err)
raise err
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
model = model.merge_and_unload()
embeddings_len = (
math.ceil(len(tokenizer) / 32) * 32
if cfg.resize_token_embeddings_to_32x
@@ -655,7 +637,7 @@ def load_model(
):
model.config.eos_token_id = tokenizer.eos_token_id
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
if hasattr(model, "device") and model.device.type == "cuda":
log_gpu_memory_usage(LOG, "after model load", model.device)
# make sure these are fp32 per Ramesh et al. (2021)
@@ -663,7 +645,7 @@ def load_model(
if not cfg.fsdp:
# FSDP doesn't like mixed Float and BFloat16
for name, module in model.named_modules():
if "norm" in name or name.endswith(".gate"):
if any(m in name for m in ["norm", "gate"]):
module.to(torch.float32)
if model_config.model_type == "btlm":
# don't upcast lm_head for btlm
@@ -676,9 +658,7 @@ def load_model(
skip_prepare_model_for_kbit_training = False
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
from deepspeed.utils import ( # pylint: disable=no-name-in-module
set_z3_leaf_modules,
)
from deepspeed.utils import set_z3_leaf_modules
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
@@ -741,8 +721,6 @@ def load_model(
model.config.use_cache = False
if cfg.flash_optimum:
from optimum.bettertransformer import BetterTransformer
model = BetterTransformer.transform(model)
if cfg.adapter is not None:
@@ -769,7 +747,7 @@ def load_adapter(model, cfg, adapter, inference=False):
def load_llama_adapter(model, cfg):
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
from peft import AdaptionPromptConfig, get_peft_model
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
peft_config = AdaptionPromptConfig(
adapter_layers=cfg.peft_adapter.layers, # layers (L)
@@ -815,7 +793,7 @@ def find_all_linear_names(model):
def load_lora(model, cfg, inference=False, config_only=False):
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
from peft import LoraConfig, get_peft_model
from peft import LoraConfig, PeftModel, get_peft_model
lora_target_modules = list(cfg.lora_target_modules or [])

View File

@@ -117,7 +117,7 @@ class MultipackBatchSampler(BatchSampler):
packing_efficiency_estimate: float = 1.0,
):
super().__init__(sampler, batch_size, drop_last)
self.batch_size = batch_size
self.batch_size = None
self.batch_max_len = batch_max_len
self.lengths: np.ndarray = lengths
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
@@ -147,13 +147,7 @@ class MultipackBatchSampler(BatchSampler):
n=1,
)
batches = [
[
[indices[b_idx] for b_idx in batch]
for batch in batches[i : i + self.batch_size]
]
for i in range(0, len(batches), self.batch_size)
]
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
# statistics
if set_stats:
@@ -195,7 +189,7 @@ class MultipackBatchSampler(BatchSampler):
0.99
* lengths_sum_per_device
/ self.packing_efficiency_estimate
// (self.batch_max_len * self.batch_size)
// self.batch_max_len
)
- 1
),

View File

@@ -0,0 +1,61 @@
import torch
import torch.nn.functional as F
def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False):
# pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask)
if index >= nonzero_total:
return False
if pairs and (index // 2) >= (nonzero_total // 2):
return False
if pad_val and (data == pad_val).all(dim=0).all():
return False
return True
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
split_tensors = []
counts = count_nonzero_sequences(cu_seqlens)
# Iterate over each batch
for i in range(tensor.size(0)):
seq_lens = cu_seqlens[i]
start_idx = 0
# Iterate over the cumulative sequence lengths
for j, end_idx in enumerate(seq_lens[1:]):
if end_idx == start_idx:
break
# Extract and pad the current sequence
current_seq = tensor[i, start_idx:end_idx]
keep = True
if keep_fn:
keep = keep_fn(current_seq, index=j, nonzero_total=counts[i])
if not keep:
continue
padding_size = max_seqlen - current_seq.size(0)
padded_seq = F.pad(current_seq, (0, 0) * (current_seq.dim() - 2) + (0, padding_size))
# Append the padded sequence to the list
split_tensors.append(padded_seq)
# Update start index for the next sequence
start_idx = end_idx
# Stack the padded tensors
return torch.stack(split_tensors, dim=0)
def count_nonzero_sequences(cu_seqlens: torch.Tensor) -> torch.LongTensor:
diffs = torch.diff(cu_seqlens, dim=1, prepend=torch.zeros(cu_seqlens.shape[0], 1, dtype=cu_seqlens.dtype))
valid_lengths = diffs != 0
counts = valid_lengths.sum(dim=1).long()
return counts
# Example usage
# Example tensor with dimensions [batch_size, seq_len, other_dimensions...]
# example_tensor = torch.randn(batch_size, seq_len, other_dimensions...)
# cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
# split_padded_tensor = split_and_pad_packed(example_tensor, cu_seqlens, max_seqlen)

View File

@@ -237,17 +237,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
main_process_only=True,
)
else:
if cfg.flash_attention:
batch_size = 1
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
else:
batch_size = cfg.micro_batch_size
batch_max_len = cfg.sequence_len
sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
batch_size=batch_size,
batch_size=cfg.micro_batch_size,
drop_last=True,
batch_max_len=batch_max_len,
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
lengths=get_dataset_lengths(train_dataset),
)
@@ -255,7 +249,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
train_dataset.remove_columns(["length"]),
batch_sampler=sampler,
)
data_loader_len = len(data_loader) // batch_size
data_loader_len = len(data_loader)
actual_eff = sampler.efficiency()
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
# FIXME: is there a bug here somewhere? the total num steps depends

View File

@@ -1,114 +0,0 @@
"""
E2E tests for multipack fft llama using 4d attention masks
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from ..utils import require_torch_2_1_1, with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class Test4dMultipackLlama(unittest.TestCase):
"""
Test case for Llama models using 4d attention with multipack
"""
@require_torch_2_1_1
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": False,
"sdp_attention": True,
"sample_packing": True,
"pad_to_sequence_len": True,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"sequence_len": 1024,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()
@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"flash_attention": False,
"sdp_attention": False,
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_linear": True,
"val_set_size": 0.1,
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"num_epochs": 2,
"micro_batch_size": 2,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"fp16": True,
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "adapter_model.bin").exists()

View File

@@ -33,7 +33,6 @@ class TestFusedLlama(unittest.TestCase):
{
"base_model": "JackFram/llama-68m",
"flash_attention": True,
"pad_to_sequence_len": True,
"flash_attn_fuse_qkv": True,
"flash_attn_fuse_mlp": True,
"sample_packing": True,

View File

@@ -7,6 +7,8 @@ import os
import unittest
from pathlib import Path
from transformers.utils import is_torch_bf16_gpu_available
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
@@ -61,7 +63,6 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
normalize_config(cfg)
@@ -102,9 +103,12 @@ class TestMistral(unittest.TestCase):
"max_steps": 20,
"save_steps": 10,
"eval_steps": 10,
"bf16": "auto",
}
)
if is_torch_bf16_gpu_available():
cfg.bf16 = True
else:
cfg.fp16 = True
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)

View File

@@ -1,68 +0,0 @@
"""
E2E tests for relora llama
"""
import logging
import os
import unittest
from pathlib import Path
from axolotl.cli import load_datasets
from axolotl.common.cli import TrainerCliArgs
from axolotl.train import train
from axolotl.utils.config import normalize_config
from axolotl.utils.dict import DictDefault
from .utils import with_temp_dir
LOG = logging.getLogger("axolotl.tests.e2e")
os.environ["WANDB_DISABLED"] = "true"
class TestReLoraLlama(unittest.TestCase):
"""
Test case for Llama models using LoRA
"""
@with_temp_dir
def test_relora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",
"tokenizer_type": "LlamaTokenizer",
"sequence_len": 1024,
"load_in_8bit": True,
"adapter": "lora",
"lora_r": 32,
"lora_alpha": 16,
"lora_dropout": 0.05,
"lora_target_modules": ["q_proj", "v_proj"],
"relora_steps": 25,
"relora_warmup_steps": 5,
"relora_anneal_steps": 5,
"relora_cpu_offload": True,
"val_set_size": 0.0,
"special_tokens": {},
"datasets": [
{
"path": "mhenrichsen/alpaca_2k_test",
"type": "alpaca",
},
],
"warmup_steps": 15,
"num_epochs": 2,
"micro_batch_size": 4,
"gradient_accumulation_steps": 1,
"output_dir": temp_dir,
"learning_rate": 0.00001,
"optimizer": "adamw_torch",
"lr_scheduler": "cosine",
}
)
normalize_config(cfg)
cli_args = TrainerCliArgs()
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
assert (Path(temp_dir) / "model.safetensors").exists()

View File

@@ -4,9 +4,7 @@ helper utils for tests
import os
import shutil
import tempfile
import unittest
from functools import wraps
from importlib.metadata import version
from pathlib import Path
@@ -33,15 +31,3 @@ def most_recent_subdir(path):
subdir = max(subdirectories, key=os.path.getctime)
return subdir
def require_torch_2_1_1(test_case):
"""
Decorator marking a test that requires torch >= 2.1.1
"""
def is_min_2_1_1():
torch_version = version("torch")
return torch_version >= "2.1.1"
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)

View File

@@ -30,20 +30,6 @@ class TestMonkeyPatchUtils(unittest.TestCase):
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_cu_seqlens_from_pos_ids_2d(self):
position_ids = torch.tensor(
[
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
]
)
target_res = torch.tensor(
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
)
self.assertTrue(
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
)
def test_get_max_seqlen_in_batch(self):
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)

View File

@@ -1,99 +0,0 @@
"""Module for testing streaming dataset sequence packing"""
import pytest
from datasets import concatenate_datasets, load_dataset
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer
from axolotl.datasets import TokenizedPromptDataset
from axolotl.prompt_strategies.completion import load
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.dict import DictDefault
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
@pytest.fixture(name="tokenizer")
def fixture_tokenizer():
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
return tokenizer
@pytest.fixture(name="max_seq_length")
def fixture_max_seq_length():
return 4096
class TestBatchedSamplerPacking:
"""
Test class for packing streaming dataset sequences
"""
@pytest.mark.parametrize(
"batch_size, num_workers",
[
(1, 0),
(2, 0),
(1, 2),
(2, 2),
],
)
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset(
"Trelis/tiny-shakespeare",
split="train",
)
cfg = DictDefault(
{
"train_on_inputs": True,
"sequence_len": max_seq_length,
}
)
ds_cfg = DictDefault(
{
"field": "Text",
}
)
completion_strategy = load(tokenizer, cfg, ds_cfg)
dataset_wrapper = TokenizedPromptDataset(
completion_strategy,
dataset,
)
train_dataset = concatenate_datasets([dataset_wrapper])
batch_sampler = MultipackBatchSampler(
sampler=RandomSampler(train_dataset),
batch_size=batch_size,
drop_last=True,
batch_max_len=max_seq_length,
lengths=get_dataset_lengths(train_dataset),
)
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
tokenizer=tokenizer,
padding=True,
pad_to_multiple_of=max_seq_length,
return_tensors="pt",
),
num_workers=num_workers,
)
inputs = next(iter(loader))
assert inputs["input_ids"].shape == (batch_size, max_seq_length)
assert inputs["labels"].shape == (batch_size, max_seq_length)
assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
assert inputs["input_ids"].tolist()[0][0] == 2
assert inputs["labels"].tolist()[0][0] == -100
assert inputs["attention_mask"].tolist()[0][0] == 0
assert inputs["attention_mask"].tolist()[0][-1] > 1
if batch_size >= 2:
assert inputs["input_ids"].tolist()[1][0] == 2
assert inputs["labels"].tolist()[1][0] == -100
assert inputs["attention_mask"].tolist()[1][0] == 0
assert inputs["attention_mask"].tolist()[1][-1] > 1

View File

@@ -1,17 +1,17 @@
"""Module for testing streaming dataset sequence packing"""
import functools
import unittest
from functools import partial
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
from transformers import AutoTokenizer
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
from axolotl.utils.dict import DictDefault
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
from axolotl.utils.data import encode_packed_pretraining
class TestPretrainingPacking(unittest.TestCase):
class TestPacking(unittest.TestCase):
"""
Test class for packing streaming dataset sequences
"""
@@ -20,6 +20,8 @@ class TestPretrainingPacking(unittest.TestCase):
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.pad_token = "</s>"
self.max_seq_length = 2048
self.batch_size = 2
def test_packing_stream_dataset(self):
# pylint: disable=duplicate-code
@@ -29,43 +31,30 @@ class TestPretrainingPacking(unittest.TestCase):
streaming=True,
)["train"]
cfg = DictDefault(
{
"pretraining_dataset": [
{
"path": "c4",
"name": "en",
"type": "pretrain",
}
],
"sample_packing": True,
"pad_to_sequence_len": True,
"sequence_len": 2048,
"micro_batch_size": 2,
}
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
self.tokenizer,
return_tensors="pt",
padding=True,
pad_to_multiple_of=self.max_seq_length,
)
ds_wrapper_partial = functools.partial(
get_dataset_wrapper,
cfg.pretraining_dataset[0],
encode = partial(
encode_packed_pretraining,
self.tokenizer,
cfg,
cfg.pretraining_dataset[0]["type"] or "pretrain",
collate_fn,
max_seq_length=self.max_seq_length,
batch_size=self.batch_size,
)
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
dataset,
self.tokenizer,
cfg,
ds_wrapper_partial,
max_tokens=cfg.sequence_len,
batch_size=cfg.micro_batch_size,
seed=cfg.seed or 42,
dataset = dataset.map(
encode,
batched=True,
input_columns="text",
remove_columns=dataset.features.keys(),
)
trainer_loader = DataLoader(
train_dataset,
dataset,
batch_size=1,
collate_fn=None,
drop_last=True,
@@ -75,16 +64,16 @@ class TestPretrainingPacking(unittest.TestCase):
if idx > 10:
break
assert data["input_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
[1, self.batch_size * self.max_seq_length]
)
assert data["position_ids"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
[1, self.batch_size * self.max_seq_length]
)
assert data["labels"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
[1, self.batch_size * self.max_seq_length]
)
assert data["attention_mask"].shape == torch.Size(
[1, original_bsz * cfg.sequence_len]
[1, self.batch_size * self.max_seq_length]
)
idx += 1

View File

@@ -67,21 +67,6 @@ class TestTokenizers(unittest.TestCase):
)
load_tokenizer(cfg)
def test_add_additional_special_tokens(self):
cfg = DictDefault(
{
"tokenizer_config": "huggyllama/llama-7b",
"special_tokens": {"additional_special_tokens": ["<|im_start|>"]},
}
)
tokenizer = load_tokenizer(cfg)
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
self.assertEqual(len(tokenizer), 32001)
# ensure reloading the tokenizer again from cfg results in same vocab length
tokenizer = load_tokenizer(cfg)
self.assertEqual(len(tokenizer), 32001)
if __name__ == "__main__":
unittest.main()