Compare commits
3 Commits
streaming-
...
multipack-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d46d7dfe30 | ||
|
|
047d9e1d5b | ||
|
|
88a0c05d2c |
2
.github/workflows/base.yml
vendored
2
.github/workflows/base.yml
vendored
@@ -7,7 +7,7 @@ jobs:
|
||||
build-base:
|
||||
if: github.repository_owner == 'OpenAccess-AI-Collective'
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
runs-on: axolotl-gpu-runner
|
||||
runs-on: self-hosted
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
|
||||
18
.github/workflows/main.yml
vendored
18
.github/workflows/main.yml
vendored
@@ -9,6 +9,7 @@ on:
|
||||
jobs:
|
||||
build-axolotl:
|
||||
if: ${{ ! contains(github.event.commits[0].message, '[skip docker]]') && github.repository_owner == 'OpenAccess-AI-Collective' }}
|
||||
# this job needs to be run on self-hosted GPU runners...
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
@@ -34,7 +35,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
@@ -55,16 +56,27 @@ jobs:
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
load: true
|
||||
build-args: |
|
||||
BASE_TAG=${{ github.ref_name }}-base-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}
|
||||
CUDA=${{ matrix.cuda }}
|
||||
PYTORCH_VERSION=${{ matrix.pytorch }}
|
||||
file: ./docker/Dockerfile
|
||||
push: ${{ github.event_name != 'pull_request' }}
|
||||
tags: |
|
||||
${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
labels: ${{ steps.metadata.outputs.labels }}
|
||||
- name: Unit Tests
|
||||
run: |
|
||||
docker run --rm ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }} pytest --ignore=tests/e2e/ /workspace/axolotl/tests/
|
||||
- name: Push to Docker Hub
|
||||
if: github.event_name != 'pull_request'
|
||||
run: |
|
||||
docker push ${{ steps.metadata.outputs.tags }}-py${{ matrix.python_version }}-cu${{ matrix.cuda }}-${{ matrix.pytorch }}${{ matrix.axolotl_extras != '' && '-' || '' }}${{ matrix.axolotl_extras }}
|
||||
latest_tag=${{ (matrix.is_latest) && format('{0}-latest', steps.metadata.outputs.tags) || '' }}
|
||||
if [ -n "$latest_tag" ]; then
|
||||
docker push "$latest_tag"
|
||||
fi
|
||||
|
||||
build-axolotl-runpod:
|
||||
needs: build-axolotl
|
||||
@@ -94,7 +106,7 @@ jobs:
|
||||
python_version: "3.11"
|
||||
pytorch: 2.1.2
|
||||
axolotl_extras:
|
||||
runs-on: axolotl-gpu-runner
|
||||
runs-on: [self-hosted, gpu, docker]
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -32,9 +32,6 @@ ignore_missing_imports = True
|
||||
[mypy-bitsandbytes]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-requests]
|
||||
ignore_missing_imports = True
|
||||
|
||||
[mypy-datasets]
|
||||
ignore_missing_imports = True
|
||||
|
||||
|
||||
38
README.md
38
README.md
@@ -25,8 +25,8 @@ Features:
|
||||
- [Installation](#installation)
|
||||
- [Docker](#docker)
|
||||
- [Conda/Pip venv](#condapip-venv)
|
||||
- [Cloud GPU](#cloud-gpu) - Latitude.sh, RunPod
|
||||
- [Bare Metal Cloud GPU](#bare-metal-cloud-gpu)
|
||||
- [Cloud GPU](#cloud-gpu) - Runpod, Latitude
|
||||
- [LambdaLabs](#lambdalabs)
|
||||
- [Windows](#windows)
|
||||
- [Launching on public clouds via SkyPilot](#launching-on-public-clouds-via-skypilot)
|
||||
- [Dataset](#dataset)
|
||||
@@ -37,9 +37,6 @@ Features:
|
||||
- [Inference](#inference)
|
||||
- [Merge LORA to Base](#merge-lora-to-base)
|
||||
- [Special Tokens](#special-tokens)
|
||||
- Advanced Topics
|
||||
- [Multipack](./docs/multipack.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||
- [RLHF & DPO](./docs/rlhf.md)<svg width="24" height="24" viewBox="0 0 24 24" xmlns="http://www.w3.org/2000/svg"><path d="M17 13.5v6H5v-12h6m3-3h6v6m0-6-9 9" class="icon_svg-stroke" stroke="#666" stroke-width="1.5" fill="none" fill-rule="evenodd" stroke-linecap="round" stroke-linejoin="round"></path></svg>
|
||||
- [Common Errors](#common-errors-)
|
||||
- [Tokenization Mismatch b/w Training & Inference](#tokenization-mismatch-bw-inference--training)
|
||||
- [Debugging Axolotl](#debugging-axolotl)
|
||||
@@ -121,10 +118,6 @@ accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
# gradio
|
||||
accelerate launch -m axolotl.cli.inference examples/openllama-3b/lora.yml \
|
||||
--lora_model_dir="./lora-out" --gradio
|
||||
|
||||
# remote yaml files - the yaml config can be hosted on a public URL
|
||||
# Note: the yaml config must directly link to the **raw** yaml
|
||||
accelerate launch -m axolotl.cli.train https://raw.githubusercontent.com/OpenAccess-AI-Collective/axolotl/main/examples/openllama-3b/lora.yml
|
||||
```
|
||||
|
||||
## Installation
|
||||
@@ -186,13 +179,9 @@ docker run --privileged --gpus '"all"' --shm-size 10g --rm -it --name axolotl --
|
||||
|
||||
For cloud GPU providers that support docker images, use [`winglian/axolotl-cloud:main-latest`](https://hub.docker.com/r/winglian/axolotl-cloud/tags)
|
||||
|
||||
- on Latitude.sh use this [direct link](https://latitude.sh/blueprint/989e0e79-3bf6-41ea-a46b-1f246e309d5c)
|
||||
- on RunPod use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||
|
||||
#### Bare Metal Cloud GPU
|
||||
|
||||
##### LambdaLabs
|
||||
|
||||
#### LambdaLabs
|
||||
<details>
|
||||
|
||||
<summary>Click to Expand</summary>
|
||||
@@ -472,12 +461,6 @@ See [examples](examples) for quick start. It is recommended to duplicate and mod
|
||||
dataset:
|
||||
- path: s3://path_to_ds # Accepts folder with arrow/parquet or file path like above. Supports s3, gcs.
|
||||
...
|
||||
|
||||
# Loading Data From a Public URL
|
||||
# - The file format is `json` (which includes `jsonl`) by default. For different formats, adjust the `ds_type` option accordingly.
|
||||
dataset:
|
||||
- path: https://some.url.com/yourdata.jsonl # The URL should be a direct link to the file you wish to load. URLs must use HTTPS protocol, not HTTP.
|
||||
ds_type: json # this is the default, see other options below.
|
||||
```
|
||||
|
||||
- loading
|
||||
@@ -990,9 +973,6 @@ Run
|
||||
accelerate launch -m axolotl.cli.train your_config.yml
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> You can also reference a config file that is hosted on a public URL, for example `accelerate launch -m axolotl.cli.train https://yourdomain.com/your_config.yml`
|
||||
|
||||
#### Preprocess dataset
|
||||
|
||||
You can optionally pre-tokenize dataset with the following before finetuning.
|
||||
@@ -1172,11 +1152,9 @@ Having misalignment between your prompts during training and inference can cause
|
||||
|
||||
See [this debugging guide](docs/debugging.md) for tips on debugging Axolotl, along with an example configuration for debugging with VSCode.
|
||||
|
||||
## Need help? 🙋
|
||||
## Need help? 🙋♂️
|
||||
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we our community members can help you.
|
||||
|
||||
Need dedicated support? Please contact us at [✉️wing@openaccessaicollective.org](mailto:wing@openaccessaicollective.org) for dedicated support options.
|
||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||
|
||||
## Badge ❤🏷️
|
||||
|
||||
@@ -1217,12 +1195,6 @@ pre-commit install
|
||||
pytest tests/
|
||||
```
|
||||
|
||||
Thanks to all of our contributors to date. Help drive open source AI progress forward by contributing to Axolotl.
|
||||
|
||||
<a href="https://github.com/openaccess-ai-collective/axolotl/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=openaccess-ai-collective/axolotl" alt="contributor chart by https://contrib.rocks"/>
|
||||
</a>
|
||||
|
||||
## Sponsors 🤝❤
|
||||
|
||||
OpenAccess AI Collective is run by volunteer contributors such as [winglian](https://github.com/winglian),
|
||||
|
||||
@@ -11,7 +11,6 @@ EXPOSE 8888
|
||||
EXPOSE 22
|
||||
|
||||
COPY scripts/cloud-entrypoint.sh /root/cloud-entrypoint.sh
|
||||
COPY scripts/motd /etc/motd
|
||||
|
||||
RUN pip install jupyterlab notebook ipywidgets && \
|
||||
jupyter lab clean
|
||||
@@ -19,7 +18,6 @@ RUN apt install --yes --no-install-recommends openssh-server tmux && \
|
||||
mkdir -p ~/.ssh && \
|
||||
chmod 700 ~/.ssh && \
|
||||
printf "\n[[ -z \"\$TMUX\" ]] && { tmux attach-session -t ssh_tmux || tmux new-session -s ssh_tmux; exit; }\n" >> ~/.bashrc && \
|
||||
printf "[ ! -z \"\$TERM\" -a -r /etc/motd ] && cat /etc/motd\n" >> ~/.bashrc && \
|
||||
chmod +x /workspace/axolotl/scripts/cloud-entrypoint.sh && \
|
||||
chmod +x /root/cloud-entrypoint.sh
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 239 KiB |
@@ -1,11 +1,4 @@
|
||||
# Multipack (Sample Packing)
|
||||
|
||||
## Visualization of Multipack with Flash Attention
|
||||
|
||||
Because Flash Attention simply drops the attention mask, we do not need to
|
||||
construct a 4d attention mask. We only need to concatenate the sequences into
|
||||
a single batch and let flash attention know where each new sequence begins.
|
||||
|
||||
# Multipack
|
||||
|
||||
4k context, bsz =4,
|
||||
each character represents 256 tokens
|
||||
@@ -56,18 +49,3 @@ w packing ( note it's the same effective number of tokens per step, but a true b
|
||||
E E E E F F F F F G G G H H H H
|
||||
I I I J J J J K K K K K L L L X ]]
|
||||
```
|
||||
|
||||
cu_seqlens:
|
||||
[[ 0, 11, 17, 24, 28, 36, 41 44, 48, 51, 55, 60, 64]]
|
||||
|
||||
|
||||
## Multipack without Flash Attention
|
||||
|
||||
Multipack can still be achieved without Flash attention, but with lower packing
|
||||
efficiency as we are not able to join multiple batches into a single batch due to
|
||||
context length limits without flash attention. We can use either Pytorch's Scaled
|
||||
Dot Product Attention implementation or native Pytorch attention implementation
|
||||
along with [4d attention masks](https://github.com/huggingface/transformers/pull/27539)
|
||||
to pack sequences together and avoid cross attention.
|
||||
|
||||
<img src="./images/4d-mask.png" alt="axolotl" width="800">
|
||||
|
||||
@@ -12,8 +12,8 @@ feedback. Various methods include, but not limited to:
|
||||
|
||||
### RLHF using Axolotl
|
||||
|
||||
>[!IMPORTANT]
|
||||
>This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
[!IMPORTANT]
|
||||
This is a BETA feature and many features are not fully implemented. You are encouraged to open new PRs to improve the integration and functionality.
|
||||
|
||||
The various RL training methods are implemented in trl and wrapped via axolotl. Below are various examples with how you can use various preference datasets to train models that use ChatML
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!pip install torch==\"2.1.2\"\n",
|
||||
"!pip install -e git+https://github.com/OpenAccess-AI-Collective/axolotl#egg=axolotl\n",
|
||||
"!pip install flash-attn==\"2.5.0\"\n",
|
||||
"!pip install deepspeed==\"0.13.1\""
|
||||
|
||||
@@ -1,65 +0,0 @@
|
||||
base_model: TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T
|
||||
model_type: LlamaForCausalLM
|
||||
tokenizer_type: LlamaTokenizer
|
||||
is_llama_derived_model: true
|
||||
|
||||
load_in_8bit: true
|
||||
load_in_4bit: false
|
||||
strict: false
|
||||
|
||||
datasets:
|
||||
- path: mhenrichsen/alpaca_2k_test
|
||||
type: alpaca
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0
|
||||
output_dir: ./lora-out
|
||||
|
||||
sequence_len: 4096
|
||||
sample_packing: true
|
||||
pad_to_sequence_len: true
|
||||
eval_sample_packing: false
|
||||
|
||||
adapter: lora
|
||||
lora_model_dir:
|
||||
lora_r: 32
|
||||
lora_alpha: 16
|
||||
lora_dropout: 0.05
|
||||
lora_target_linear: true
|
||||
lora_fan_in_fan_out:
|
||||
|
||||
wandb_project:
|
||||
wandb_entity:
|
||||
wandb_watch:
|
||||
wandb_name:
|
||||
wandb_log_model:
|
||||
|
||||
gradient_accumulation_steps: 4
|
||||
micro_batch_size: 2
|
||||
num_epochs: 4
|
||||
optimizer: adamw_torch
|
||||
lr_scheduler: cosine
|
||||
learning_rate: 0.0002
|
||||
|
||||
train_on_inputs: false
|
||||
group_by_length: false
|
||||
bf16: auto
|
||||
fp16: false
|
||||
tf32: true
|
||||
|
||||
gradient_checkpointing: true
|
||||
early_stopping_patience:
|
||||
resume_from_checkpoint:
|
||||
local_rank:
|
||||
logging_steps: 1
|
||||
xformers_attention:
|
||||
flash_attention: false
|
||||
|
||||
warmup_steps: 10
|
||||
evals_per_epoch: 0
|
||||
saves_per_epoch: 1
|
||||
debug:
|
||||
deepspeed:
|
||||
weight_decay: 0.0
|
||||
fsdp:
|
||||
fsdp_config:
|
||||
special_tokens:
|
||||
@@ -10,9 +10,8 @@ strict: false
|
||||
|
||||
max_steps: 200
|
||||
pretraining_dataset:
|
||||
- path: c4
|
||||
name: en
|
||||
type: pretrain
|
||||
path: c4
|
||||
name: en
|
||||
dataset_prepared_path:
|
||||
val_set_size: 0.0
|
||||
output_dir: ./model-out
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
pre-commit
|
||||
black
|
||||
mypy
|
||||
types-requests
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
--extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/
|
||||
packaging==23.2
|
||||
peft @ git+https://github.com/huggingface/peft.git
|
||||
transformers @ git+https://github.com/huggingface/transformers.git@bebeeee01275c32fccec3fa36d8b148d3813a7dc
|
||||
transformers==4.37.0
|
||||
tokenizers==0.15.0
|
||||
bitsandbytes>=0.41.1
|
||||
accelerate==0.26.1
|
||||
@@ -9,7 +9,6 @@ deepspeed>=0.13.1
|
||||
addict
|
||||
fire
|
||||
PyYAML>=6.0
|
||||
requests
|
||||
datasets>=2.15.0
|
||||
flash-attn==2.3.3
|
||||
sentencepiece
|
||||
|
||||
17
scripts/motd
17
scripts/motd
@@ -1,17 +0,0 @@
|
||||
|
||||
dP dP dP
|
||||
88 88 88
|
||||
.d8888b. dP. .dP .d8888b. 88 .d8888b. d8888P 88
|
||||
88' `88 `8bd8' 88' `88 88 88' `88 88 88
|
||||
88. .88 .d88b. 88. .88 88 88. .88 88 88
|
||||
`88888P8 dP' `dP `88888P' dP `88888P' dP dP
|
||||
|
||||
Welcome to the axolotl cloud image! If the you've mounted a disk to /workspace and the axolotl directory ie empty, run the following commands:
|
||||
|
||||
```
|
||||
cd /workspace
|
||||
rm -rf /workspace/axolotl
|
||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl.git
|
||||
cd axolotl
|
||||
pip install --no-deps -e .
|
||||
```
|
||||
24
setup.py
24
setup.py
@@ -1,7 +1,5 @@
|
||||
"""setup.py for axolotl"""
|
||||
|
||||
import platform
|
||||
import re
|
||||
from importlib.metadata import PackageNotFoundError, version
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
@@ -28,25 +26,11 @@ def parse_requirements():
|
||||
_install_requires.append(line)
|
||||
|
||||
try:
|
||||
if "Darwin" in platform.system():
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
if torch_version.startswith("2.1."):
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
else:
|
||||
torch_version = version("torch")
|
||||
_install_requires.append(f"torch=={torch_version}")
|
||||
|
||||
version_match = re.match(r"^(\d+)\.(\d+)(?:\.(\d+))?", torch_version)
|
||||
if version_match:
|
||||
major, minor, patch = version_match.groups()
|
||||
major, minor = int(major), int(minor)
|
||||
patch = (
|
||||
int(patch) if patch is not None else 0
|
||||
) # Default patch to 0 if not present
|
||||
else:
|
||||
raise ValueError("Invalid version format")
|
||||
|
||||
if (major, minor) >= (2, 1):
|
||||
_install_requires.pop(_install_requires.index("xformers==0.0.22"))
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
_install_requires.append("xformers>=0.0.23")
|
||||
except PackageNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,20 +1,16 @@
|
||||
"""Prepare and train a model on a dataset. Can also infer from a model or merge lora"""
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from threading import Thread
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import gradio as gr
|
||||
import requests
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
@@ -63,52 +59,6 @@ def print_axolotl_text_art(suffix=None):
|
||||
print(ascii_art)
|
||||
|
||||
|
||||
def check_remote_config(config: Union[str, Path]):
|
||||
# Check if the config is a valid HTTPS URL to a .yml or .yaml file
|
||||
if not (isinstance(config, str) and config.startswith("https://")):
|
||||
return config # Return the original value if it's not a valid URL
|
||||
|
||||
filename = os.path.basename(urlparse(config).path)
|
||||
temp_dir = tempfile.mkdtemp()
|
||||
|
||||
try:
|
||||
response = requests.get(config, timeout=30)
|
||||
response.raise_for_status() # Check for HTTP errors
|
||||
|
||||
content = response.content
|
||||
try:
|
||||
# Try parsing as JSON first to catch cases where JSON content is mistakenly considered YAML
|
||||
json.loads(content)
|
||||
# Log a warning but do not raise an error; JSON is technically valid YAML - this can happen when you forget to point to a raw github link
|
||||
LOG.warning(
|
||||
f"Warning: The content of the file at {config} is JSON, which is technically valid YAML but might not be intended."
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
# If it's not valid JSON, verify it's valid YAML
|
||||
try:
|
||||
yaml.safe_load(content)
|
||||
except yaml.YAMLError as err:
|
||||
raise ValueError(
|
||||
f"Failed to parse the content at {config} as YAML: {err}"
|
||||
) from err
|
||||
|
||||
# Write the content to a file if it's valid YAML (or JSON treated as YAML)
|
||||
output_path = Path(temp_dir) / filename
|
||||
with open(output_path, "wb") as file:
|
||||
file.write(content)
|
||||
LOG.info(
|
||||
f"Using the following config obtained from {config}:\n\n{content.decode('utf-8')}\n"
|
||||
)
|
||||
return output_path
|
||||
|
||||
except requests.RequestException as err:
|
||||
# This catches all requests-related exceptions including HTTPError
|
||||
raise RuntimeError(f"Failed to download {config}: {err}") from err
|
||||
except Exception as err:
|
||||
# Catch-all for any other exceptions
|
||||
raise err
|
||||
|
||||
|
||||
def get_multi_line_input() -> Optional[str]:
|
||||
print("Give me an instruction (Ctrl + D to submit): ")
|
||||
instruction = ""
|
||||
@@ -320,10 +270,9 @@ def check_not_in(list1: List[str], list2: Union[Dict[str, Any], List[str]]) -> b
|
||||
return not any(el in list2 for el in list1)
|
||||
|
||||
|
||||
def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
config = check_remote_config(config)
|
||||
def load_cfg(config: Path = Path("examples/"), **kwargs):
|
||||
if Path(config).is_dir():
|
||||
config = choose_config(Path(config))
|
||||
config = choose_config(config)
|
||||
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
|
||||
@@ -3,7 +3,6 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
@@ -24,7 +23,7 @@ from axolotl.prompt_strategies.sharegpt import register_chatml_template
|
||||
LOG = logging.getLogger("axolotl.cli.preprocess")
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
@@ -3,7 +3,6 @@ CLI to shard a trained model into 10GiB chunks
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
|
||||
import fire
|
||||
import transformers
|
||||
@@ -26,7 +25,7 @@ def shard(
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
print_axolotl_text_art()
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
|
||||
@@ -3,12 +3,11 @@ CLI to run training on a model
|
||||
"""
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Tuple, Union
|
||||
from typing import Tuple
|
||||
|
||||
import fire
|
||||
from transformers.hf_argparser import HfArgumentParser
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
import transformers
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
from axolotl.cli import (
|
||||
check_accelerate_default_config,
|
||||
@@ -25,10 +24,10 @@ from axolotl.train import train
|
||||
LOG = logging.getLogger("axolotl.cli.train")
|
||||
|
||||
|
||||
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs):
|
||||
def do_cli(config: Path = Path("examples/"), **kwargs):
|
||||
# pylint: disable=duplicate-code
|
||||
parsed_cfg = load_cfg(config, **kwargs)
|
||||
parser = HfArgumentParser((TrainerCliArgs))
|
||||
parser = transformers.HfArgumentParser((TrainerCliArgs))
|
||||
parsed_cli_args, _ = parser.parse_args_into_dataclasses(
|
||||
return_remaining_strings=True
|
||||
)
|
||||
|
||||
@@ -6,7 +6,6 @@ import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
@@ -8,15 +8,17 @@ import importlib
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import typing
|
||||
from abc import abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from functools import wraps, partial
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Type, Union
|
||||
from typing import Dict, List, Optional, Tuple, Type, Union
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
from datasets import Dataset
|
||||
from torch import nn
|
||||
from torch.optim.lr_scheduler import OneCycleLR
|
||||
from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler
|
||||
from transformers import (
|
||||
@@ -28,8 +30,8 @@ from transformers import (
|
||||
from transformers.trainer_utils import seed_worker
|
||||
from trl import DPOTrainer
|
||||
|
||||
from axolotl.monkeypatch.multipack import SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
from axolotl.monkeypatch.relora import ReLoRACallback, ReLoRAScheduler
|
||||
from axolotl.monkeypatch.utils import get_cu_seqlens_from_pos_ids
|
||||
from axolotl.utils.callbacks import (
|
||||
EvalFirstStepCallback,
|
||||
GPUStatsCallback,
|
||||
@@ -51,12 +53,20 @@ from axolotl.utils.schedulers import (
|
||||
get_cosine_schedule_with_min_lr,
|
||||
get_cosine_schedule_with_quadratic_warmup,
|
||||
)
|
||||
from axolotl.utils.tensors import keep_unpacked_data, split_and_pad_packed
|
||||
|
||||
try:
|
||||
import torch._dynamo # pylint: disable=ungrouped-imports
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
# hacky, but recommended per https://github.com/python/mypy/issues/5837
|
||||
_MixinTrainerBase = Trainer
|
||||
else:
|
||||
_MixinTrainerBase = object
|
||||
|
||||
|
||||
LOG = logging.getLogger("axolotl.core.trainer_builder")
|
||||
|
||||
|
||||
@@ -99,10 +109,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=False,
|
||||
metadata={"help": "Use sample packing for efficient training."},
|
||||
)
|
||||
multipack_real_batches: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use real batches for efficient training."},
|
||||
)
|
||||
eval_sample_packing: Optional[bool] = field(
|
||||
default=None,
|
||||
metadata={"help": "Use sample packing for efficient evals."},
|
||||
@@ -127,10 +133,6 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
relora_anneal_steps: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "how many warmup steps to take after reset for ReLoRA"},
|
||||
)
|
||||
bench_split: Optional[str] = field(
|
||||
default="eval", metadata={"help": "The benchmark split to run on"}
|
||||
)
|
||||
@@ -162,7 +164,142 @@ class AxolotlTrainingArguments(TrainingArguments):
|
||||
)
|
||||
|
||||
|
||||
class AxolotlTrainer(Trainer):
|
||||
class AxolotlMultiPackTrainerMixin(_MixinTrainerBase): # type: ignore
|
||||
"""Trainer Mixin class for dataloaders and samplers"""
|
||||
|
||||
args = None # type: AxolotlTrainingArguments
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
self.args.train_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self._train_batch_size * self.args.max_seq_length,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
self.args.per_device_eval_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=self.args.eval_batch_size * self.args.max_seq_length,
|
||||
lengths=get_dataset_lengths(eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
return dataloader
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> DataLoader:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
|
||||
class AxolotlTrainer(AxolotlMultiPackTrainerMixin, Trainer):
|
||||
"""
|
||||
Extend the base Trainer for axolotl helpers
|
||||
"""
|
||||
@@ -236,151 +373,6 @@ class AxolotlTrainer(Trainer):
|
||||
|
||||
return self.lr_scheduler
|
||||
|
||||
def _get_train_sampler(self) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_train_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_train_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
RandomSampler(self.train_dataset),
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_max_len,
|
||||
lengths=get_dataset_lengths(self.train_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_train_sampler()
|
||||
|
||||
def _get_eval_sampler(
|
||||
self, eval_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
if self.args.multipack_real_batches:
|
||||
batch_size = self.args.per_device_eval_batch_size
|
||||
batch_max_len = self.args.max_seq_length
|
||||
else:
|
||||
batch_size = 1
|
||||
batch_max_len = (
|
||||
self.args.per_device_eval_batch_size * self.args.max_seq_length
|
||||
)
|
||||
return MultipackBatchSampler(
|
||||
SequentialSampler(eval_dataset),
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_max_len,
|
||||
lengths=get_dataset_lengths(eval_dataset),
|
||||
packing_efficiency_estimate=self.args.sample_packing_efficiency,
|
||||
)
|
||||
return super()._get_eval_sampler(eval_dataset)
|
||||
|
||||
def get_train_dataloader(self) -> DataLoader:
|
||||
if self.args.sample_packing and not self.args.pretraining:
|
||||
train_dataset = self.train_dataset
|
||||
if "length" in train_dataset.features.keys():
|
||||
train_dataset = train_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self._train_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
sampler = self._get_train_sampler()
|
||||
if isinstance(sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
dataloader_params["worker_init_fn"] = seed_worker
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(train_dataset, **dataloader_params)
|
||||
)
|
||||
return super().get_train_dataloader()
|
||||
|
||||
def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is False:
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.eval_data_collator
|
||||
)
|
||||
dataloader = super().get_eval_dataloader(eval_dataset)
|
||||
self.data_collator = ( # pylint: disable=attribute-defined-outside-init
|
||||
self.train_data_collator
|
||||
)
|
||||
return dataloader
|
||||
|
||||
if self.args.sample_packing and self.args.eval_sample_packing is not False:
|
||||
eval_dataset = (
|
||||
eval_dataset if eval_dataset is not None else self.eval_dataset
|
||||
)
|
||||
|
||||
eval_sampler = self._get_eval_sampler(eval_dataset)
|
||||
eval_dataset = eval_dataset.remove_columns(["length"])
|
||||
data_collator = self.data_collator
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params[
|
||||
"prefetch_factor"
|
||||
] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if isinstance(eval_sampler, BatchSampler):
|
||||
dataloader_params["batch_sampler"] = eval_sampler
|
||||
del dataloader_params["batch_size"]
|
||||
else:
|
||||
dataloader_params["sampler"] = eval_sampler
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
self.accelerator.even_batches = False
|
||||
return self.accelerator.prepare_data_loader(
|
||||
DataLoader(eval_dataset, **dataloader_params)
|
||||
)
|
||||
|
||||
return super().get_eval_dataloader(eval_dataset)
|
||||
|
||||
def _get_bench_sampler(
|
||||
self, bench_dataset: Dataset
|
||||
) -> Optional[torch.utils.data.Sampler]:
|
||||
if self.args.world_size <= 1:
|
||||
return SequentialSampler(bench_dataset)
|
||||
return None
|
||||
|
||||
def get_bench_dataloader(
|
||||
self,
|
||||
bench_dataset: Dataset,
|
||||
) -> DataLoader:
|
||||
dataloader_params = {
|
||||
"batch_size": self.args.eval_batch_size,
|
||||
"collate_fn": self.bench_data_collator,
|
||||
"num_workers": self.args.dataloader_num_workers,
|
||||
"pin_memory": self.args.dataloader_pin_memory,
|
||||
}
|
||||
if self.args.dataloader_prefetch_factor:
|
||||
dataloader_params["prefetch_factor"] = self.args.dataloader_prefetch_factor
|
||||
|
||||
if not isinstance(bench_dataset, torch.utils.data.IterableDataset):
|
||||
dataloader_params["sampler"] = self._get_bench_sampler(bench_dataset)
|
||||
dataloader_params["drop_last"] = self.args.dataloader_drop_last
|
||||
|
||||
return DataLoader(bench_dataset, **dataloader_params)
|
||||
# return self.accelerator.prepare(DataLoader(bench_dataset, **dataloader_params))
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False):
|
||||
# use one's weighted cross entropy loss calc
|
||||
# if self.args.sample_packing:
|
||||
@@ -483,14 +475,10 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
warmup_steps = (
|
||||
self.args.relora_warmup_steps if self.args.relora_warmup_steps else 10
|
||||
)
|
||||
anneal_steps = (
|
||||
self.args.relora_anneal_steps if self.args.relora_anneal_steps else 1
|
||||
)
|
||||
self.lr_scheduler = ReLoRAScheduler(
|
||||
optimizer,
|
||||
lr_scheduler,
|
||||
self.args.relora_steps,
|
||||
anneal_steps,
|
||||
warmup_steps,
|
||||
)
|
||||
else:
|
||||
@@ -499,7 +487,7 @@ class ReLoRATrainer(AxolotlTrainer):
|
||||
return self.lr_scheduler
|
||||
|
||||
|
||||
class AxolotlDPOTrainer(DPOTrainer):
|
||||
class AxolotlDPOTrainer(AxolotlMultiPackTrainerMixin, DPOTrainer):
|
||||
"""
|
||||
Extend the base DPOTrainer for axolotl helpers
|
||||
"""
|
||||
@@ -516,6 +504,59 @@ class AxolotlDPOTrainer(DPOTrainer):
|
||||
|
||||
return super().push_to_hub(*args, **kwargs)
|
||||
|
||||
def tokenize_row(self, feature, *args, **kwargs) -> Dict:
|
||||
# check if dataset is already tokenized
|
||||
if not self.is_encoder_decoder:
|
||||
keys = [
|
||||
"chosen_input_ids",
|
||||
"chosen_attention_mask",
|
||||
"chosen_labels",
|
||||
"rejected_input_ids",
|
||||
"rejected_attention_mask",
|
||||
"rejected_labels",
|
||||
]
|
||||
if all(k in feature.keys() for k in keys):
|
||||
return feature
|
||||
else:
|
||||
keys = [
|
||||
"chosen_labels",
|
||||
"rejected_labels",
|
||||
"prompt_input_ids",
|
||||
"prompt_attention_mask",
|
||||
]
|
||||
if all(k in feature.keys() for k in keys):
|
||||
return feature
|
||||
return super().tokenize_row(feature, *args, **kwargs)
|
||||
|
||||
def concatenated_forward(
|
||||
self, model: nn.Module, batch: Dict[str, Union[List, torch.LongTensor]]
|
||||
) -> Tuple[
|
||||
torch.FloatTensor, torch.FloatTensor, torch.FloatTensor, torch.FloatTensor
|
||||
]:
|
||||
all_logits = model(
|
||||
batch["input_ids"],
|
||||
attention_mask=batch["attention_mask"],
|
||||
position_ids=batch["position_ids"],
|
||||
).logits
|
||||
cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
|
||||
logits_keep_fn = partial(keep_unpacked_data, pad_val=None, pairs=True)
|
||||
unpacked_logits = split_and_pad_packed(all_logits, cu_seqlens, max_seqlen, logits_keep_fn)
|
||||
labels_keep_fn = partial(keep_unpacked_data, pad_val=-100, pairs=True)
|
||||
unpacked_labels = split_and_pad_packed(batch["labels"], cu_seqlens, max_seqlen, labels_keep_fn)
|
||||
unpacked_logps = self.get_batch_logps(
|
||||
unpacked_logits,
|
||||
unpacked_labels,
|
||||
average_log_prob=self.loss_type == "ipo",
|
||||
is_encoder_decoder=self.is_encoder_decoder,
|
||||
label_pad_token_id=self.label_pad_token_id,
|
||||
)
|
||||
chosen_logps = unpacked_logps[::2]
|
||||
rejected_logps = unpacked_logps[1::2]
|
||||
chosen_logits = unpacked_logits[::2]
|
||||
rejected_logits = unpacked_logits[1::2]
|
||||
|
||||
return (chosen_logps, rejected_logps, chosen_logits, rejected_logits)
|
||||
|
||||
|
||||
class TrainerBuilderBase(abc.ABC):
|
||||
"""
|
||||
@@ -889,9 +930,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
training_arguments_kwargs["sample_packing"] = (
|
||||
self.cfg.sample_packing if self.cfg.sample_packing else False
|
||||
)
|
||||
training_arguments_kwargs["multipack_real_batches"] = (
|
||||
self.cfg.flash_attention is not True
|
||||
)
|
||||
training_arguments_kwargs["eval_sample_packing"] = (
|
||||
self.cfg.sample_packing
|
||||
if self.cfg.eval_sample_packing is not False
|
||||
@@ -902,7 +940,6 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
] = self.cfg.micro_batch_size
|
||||
training_arguments_kwargs["relora_steps"] = self.cfg.relora_steps
|
||||
training_arguments_kwargs["relora_warmup_steps"] = self.cfg.relora_warmup_steps
|
||||
training_arguments_kwargs["relora_anneal_steps"] = self.cfg.relora_anneal_steps
|
||||
training_arguments_kwargs = self.hook_pre_create_training_args(
|
||||
training_arguments_kwargs
|
||||
)
|
||||
@@ -995,12 +1032,7 @@ class HFCausalTrainerBuilder(TrainerBuilderBase):
|
||||
]
|
||||
]
|
||||
if use_batch_sampler_collator:
|
||||
if self.cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
elif (
|
||||
self.cfg.model_config_type in ["llama"]
|
||||
and self.cfg.flash_attention is not True
|
||||
):
|
||||
if self.cfg.model_config_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
collator = V2BatchSamplerDataCollatorForSeq2Seq
|
||||
else:
|
||||
collator = BatchSamplerDataCollatorForSeq2Seq
|
||||
@@ -1097,21 +1129,13 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
"use_reentrant": False
|
||||
}
|
||||
|
||||
# set save_strategy and save_steps
|
||||
if self.cfg.save_steps:
|
||||
training_args_kwargs["save_strategy"] = "steps"
|
||||
training_args_kwargs["save_steps"] = self.cfg.save_steps
|
||||
elif self.cfg.save_strategy:
|
||||
training_args_kwargs["save_strategy"] = self.cfg.save_strategy
|
||||
else:
|
||||
# default to saving each epoch if not defined
|
||||
training_args_kwargs["save_strategy"] = "epoch"
|
||||
|
||||
training_args = TrainingArguments(
|
||||
per_device_train_batch_size=self.cfg.micro_batch_size,
|
||||
max_steps=self.cfg.max_steps or total_num_steps,
|
||||
gradient_accumulation_steps=self.cfg.gradient_accumulation_steps,
|
||||
learning_rate=self.cfg.learning_rate,
|
||||
save_strategy="steps",
|
||||
save_steps=self.cfg.save_steps,
|
||||
output_dir=self.cfg.output_dir,
|
||||
warmup_steps=self.cfg.warmup_steps,
|
||||
logging_first_step=True,
|
||||
@@ -1154,6 +1178,7 @@ class HFDPOTrainerBuilder(TrainerBuilderBase):
|
||||
callbacks=self.get_callbacks(),
|
||||
**dpo_trainer_kwargs,
|
||||
)
|
||||
setattr(dpo_trainer, "use_dpo_data_collator", True)
|
||||
dpo_trainer = self.hook_post_create_trainer(dpo_trainer)
|
||||
for callback in self.get_post_trainer_create_callbacks(dpo_trainer):
|
||||
dpo_trainer.add_callback(callback)
|
||||
|
||||
@@ -31,7 +31,7 @@ class TokenizedPromptDataset(Dataset):
|
||||
def __init__( # pylint: disable=super-init-not-called
|
||||
self,
|
||||
prompt_tokenizer: PromptTokenizingStrategy,
|
||||
dataset: Dataset,
|
||||
dataset: IterableDataset,
|
||||
process_count: Optional[int] = None,
|
||||
keep_in_memory: Optional[bool] = False,
|
||||
**kwargs,
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
"""monkey patches for the dataset fetcher to handle batches of packed indexes"""
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import torch
|
||||
from torch.utils.data._utils.fetch import _BaseDatasetFetcher
|
||||
from torch.utils.data._utils.worker import _worker_loop
|
||||
|
||||
|
||||
class _MapDatasetFetcher(_BaseDatasetFetcher):
|
||||
def fetch(self, possibly_batched_index):
|
||||
if isinstance(possibly_batched_index[0], list):
|
||||
data = [None for i in possibly_batched_index]
|
||||
for i, possibly_batched_index_ in enumerate(possibly_batched_index):
|
||||
if self.auto_collation:
|
||||
if (
|
||||
hasattr(self.dataset, "__getitems__")
|
||||
and self.dataset.__getitems__
|
||||
):
|
||||
data[i] = self.dataset.__getitems__(possibly_batched_index_)
|
||||
else:
|
||||
data[i] = [self.dataset[idx] for idx in possibly_batched_index_]
|
||||
else:
|
||||
data[i] = self.dataset[possibly_batched_index_]
|
||||
else:
|
||||
if self.auto_collation:
|
||||
if hasattr(self.dataset, "__getitems__") and self.dataset.__getitems__:
|
||||
data = self.dataset.__getitems__(possibly_batched_index)
|
||||
else:
|
||||
data = [self.dataset[idx] for idx in possibly_batched_index]
|
||||
else:
|
||||
data = self.dataset[possibly_batched_index]
|
||||
return self.collate_fn(data)
|
||||
|
||||
|
||||
def patch_fetchers():
|
||||
torch.utils.data._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
torch.utils.data.dataloader._utils.fetch._MapDatasetFetcher = _MapDatasetFetcher
|
||||
|
||||
|
||||
def patched_worker_loop(*args, **kwargs):
|
||||
patch_fetchers()
|
||||
return _worker_loop(*args, **kwargs)
|
||||
|
||||
|
||||
torch.utils.data._utils.worker._worker_loop = patched_worker_loop
|
||||
patch_fetchers()
|
||||
12
src/axolotl/monkeypatch/falcon/__init__.py
Normal file
12
src/axolotl/monkeypatch/falcon/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Patches to support multipack for falcon
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_falcon_attn_with_multipack_flash_attn():
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
142
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Normal file
142
src/axolotl/monkeypatch/llama_attn_hijack_sdp.py
Normal file
@@ -0,0 +1,142 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers.models.llama.modeling_llama
|
||||
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, repeat_kv
|
||||
|
||||
|
||||
def hijack_llama_sdp_attention():
|
||||
transformers.models.llama.modeling_llama.LlamaAttention.forward = (
|
||||
sdp_attention_forward
|
||||
)
|
||||
|
||||
|
||||
def sdp_attention_forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
position_ids: Optional[torch.LongTensor] = None,
|
||||
past_key_value: Optional[Tuple[torch.Tensor]] = None,
|
||||
output_attentions: bool = False,
|
||||
use_cache: bool = False,
|
||||
padding_mask: Optional[torch.LongTensor] = None, # pylint: disable=unused-argument
|
||||
**kwargs, # pylint: disable=unused-argument
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
|
||||
# pylint: disable=duplicate-code
|
||||
bsz, q_len, _ = hidden_states.size()
|
||||
|
||||
if not hasattr(self, "pretraining_tp"):
|
||||
self.pretraining_tp = 1
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
key_value_slicing = (
|
||||
self.num_key_value_heads * self.head_dim
|
||||
) // self.pretraining_tp
|
||||
query_slices = self.q_proj.weight.split(
|
||||
(self.num_heads * self.head_dim) // self.pretraining_tp, dim=0
|
||||
)
|
||||
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
|
||||
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
|
||||
|
||||
query_states = [
|
||||
F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
query_states = torch.cat(query_states, dim=-1)
|
||||
|
||||
key_states = [
|
||||
F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
key_states = torch.cat(key_states, dim=-1)
|
||||
|
||||
value_states = [
|
||||
F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)
|
||||
]
|
||||
value_states = torch.cat(value_states, dim=-1)
|
||||
|
||||
else:
|
||||
query_states = self.q_proj(hidden_states)
|
||||
key_states = self.k_proj(hidden_states)
|
||||
value_states = self.v_proj(hidden_states)
|
||||
|
||||
query_states = query_states.view(
|
||||
bsz, q_len, self.num_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
key_states = key_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
value_states = value_states.view(
|
||||
bsz, q_len, self.num_key_value_heads, self.head_dim
|
||||
).transpose(1, 2)
|
||||
# [bsz, q_len, nh, hd]
|
||||
# [bsz, nh, q_len, hd]
|
||||
|
||||
kv_seq_len = key_states.shape[-2]
|
||||
if past_key_value is not None:
|
||||
kv_seq_len += past_key_value[0].shape[-2]
|
||||
|
||||
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
|
||||
query_states, key_states = apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, position_ids
|
||||
)
|
||||
# [bsz, nh, t, hd]
|
||||
|
||||
if past_key_value is not None:
|
||||
# reuse k, v, self_attention
|
||||
key_states = torch.cat([past_key_value[0], key_states], dim=2)
|
||||
value_states = torch.cat([past_key_value[1], value_states], dim=2)
|
||||
|
||||
past_key_value = (key_states, value_states) if use_cache else None
|
||||
|
||||
# repeat k/v heads if n_kv_heads < n_heads
|
||||
key_states = repeat_kv(key_states, self.num_key_value_groups)
|
||||
value_states = repeat_kv(value_states, self.num_key_value_groups)
|
||||
|
||||
if output_attentions:
|
||||
warnings.warn(
|
||||
"Output attentions is not supported for patched `LlamaAttention`, returning `None` instead."
|
||||
)
|
||||
|
||||
#
|
||||
# sdp-attn start
|
||||
#
|
||||
|
||||
with torch.backends.cuda.sdp_kernel():
|
||||
attn_output = torch.nn.functional.scaled_dot_product_attention(
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attn_mask=attention_mask,
|
||||
is_causal=False,
|
||||
)
|
||||
|
||||
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
|
||||
raise ValueError(
|
||||
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
|
||||
f" {attn_output.size()}"
|
||||
)
|
||||
attn_output = attn_output.transpose(1, 2)
|
||||
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
|
||||
|
||||
#
|
||||
# sdp-attn end
|
||||
#
|
||||
|
||||
if self.pretraining_tp > 1:
|
||||
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
|
||||
o_proj_slices = self.o_proj.weight.split(
|
||||
self.hidden_size // self.pretraining_tp, dim=1
|
||||
)
|
||||
attn_output = sum(
|
||||
F.linear(attn_output[i], o_proj_slices[i])
|
||||
for i in range(self.pretraining_tp)
|
||||
)
|
||||
else:
|
||||
attn_output = self.o_proj(attn_output)
|
||||
|
||||
return attn_output, None, past_key_value
|
||||
@@ -5,11 +5,38 @@ from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from axolotl.monkeypatch.utils import mask_2d_to_4d
|
||||
|
||||
|
||||
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
|
||||
masked_zero_one_mask = mask_2d_to_4d(mask, dtype, tgt_len)
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1).to(dtype),
|
||||
torch.tensor(0).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
inverted_mask = 1.0 - masked_zero_one_mask
|
||||
|
||||
return inverted_mask.masked_fill(
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
"""
|
||||
Patched LlamaAttention to use torch.nn.functional.scaled_dot_product_attention
|
||||
"""
|
||||
|
||||
from axolotl.monkeypatch.utils import (
|
||||
patched_prepare_4d_causal_attention_mask,
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
|
||||
|
||||
def hijack_llama_prepare_4d_mask():
|
||||
import transformers.modeling_attn_mask_utils
|
||||
import transformers.models.llama.modeling_llama
|
||||
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask_for_sdpa = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask_for_sdpa
|
||||
)
|
||||
transformers.models.llama.modeling_llama._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
transformers.modeling_attn_mask_utils._prepare_4d_causal_attention_mask = ( # pylint: disable=protected-access
|
||||
patched_prepare_4d_causal_attention_mask
|
||||
)
|
||||
@@ -2,6 +2,9 @@
|
||||
Patches to support multipack for mixtral
|
||||
"""
|
||||
import torch
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def patch_mixtral_moe_forward_zero3() -> None:
|
||||
@@ -48,3 +51,11 @@ def patch_mixtral_moe_forward_zero3() -> None:
|
||||
|
||||
MixtralBLockSparseTop2MLP.forward = mlp_forward
|
||||
MixtralSparseMoeBlock.forward = moe_forward
|
||||
|
||||
|
||||
def replace_mixtral_attn_with_multipack_flash_attn(for_zero3=False):
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if for_zero3:
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
"""multipack patching for v2 of sample packing"""
|
||||
|
||||
import transformers
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.monkeypatch.mixtral import patch_mixtral_moe_forward_zero3
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES = ["mixtral", "qwen2", "falcon", "phi"]
|
||||
|
||||
|
||||
def patch_for_multipack(model_type):
|
||||
if model_type == "mixtral":
|
||||
transformers.models.mixtral.modeling_mixtral._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
if is_deepspeed_zero3_enabled():
|
||||
patch_mixtral_moe_forward_zero3()
|
||||
elif model_type == "qwen2":
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "falcon":
|
||||
transformers.models.falcon.modeling_falcon._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
elif model_type == "phi":
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
12
src/axolotl/monkeypatch/phi/__init__.py
Normal file
12
src/axolotl/monkeypatch/phi/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Patches to support multipack for phi2
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_phi_attn_with_multipack_flash_attn():
|
||||
transformers.models.phi.modeling_phi._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
12
src/axolotl/monkeypatch/qwen2/__init__.py
Normal file
12
src/axolotl/monkeypatch/qwen2/__init__.py
Normal file
@@ -0,0 +1,12 @@
|
||||
"""
|
||||
Patches to support multipack for qwen2
|
||||
"""
|
||||
import transformers
|
||||
|
||||
from axolotl.monkeypatch.utils import get_unpad_data
|
||||
|
||||
|
||||
def replace_qwen2_attn_with_multipack_flash_attn():
|
||||
transformers.models.qwen2.modeling_qwen2._get_unpad_data = ( # pylint: disable=protected-access
|
||||
get_unpad_data
|
||||
)
|
||||
@@ -4,16 +4,14 @@ import json
|
||||
import logging
|
||||
import os.path
|
||||
import shutil
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Sequence, Union
|
||||
from typing import Dict, List, Sequence
|
||||
|
||||
import bitsandbytes as bnb
|
||||
import peft
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.optim.optimizer import Optimizer
|
||||
from transformers import (
|
||||
@@ -25,50 +23,23 @@ from transformers import (
|
||||
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR
|
||||
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.distributed import barrier, is_main_process
|
||||
from axolotl.utils.distributed import is_main_process
|
||||
|
||||
LOG = logging.getLogger("axolotl.relora")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def magnitude_pruning_(tensor, prune_ratio):
|
||||
tensor_magnitude = torch.abs(tensor)
|
||||
threshold = torch.quantile(
|
||||
tensor_magnitude.flatten().to(dtype=torch.float32), prune_ratio
|
||||
).to(dtype=tensor.dtype)
|
||||
def reset_optimizer(optimizer: torch.optim.Optimizer):
|
||||
for group in optimizer.param_groups:
|
||||
for param in group["params"]:
|
||||
param_state = optimizer.state[param]
|
||||
for key in param_state:
|
||||
if "qmap" in key:
|
||||
continue
|
||||
|
||||
mask = tensor_magnitude > threshold
|
||||
tensor.mul_(mask.to(dtype=tensor.dtype))
|
||||
|
||||
|
||||
def reset_optimizer(
|
||||
optimizer: torch.optim.Optimizer,
|
||||
*,
|
||||
reset_params: list[str], # where str is the key to a torch.nn.Parameter
|
||||
optimizer_state_keys: list[str],
|
||||
):
|
||||
pruning_fn = partial(magnitude_pruning_, prune_ratio=0.9)
|
||||
n_zeros = 0
|
||||
n_total = 0
|
||||
|
||||
optimizer_state = optimizer.state
|
||||
if isinstance(optimizer, ZeroRedundancyOptimizer):
|
||||
optimizer_state = optimizer.optim.state
|
||||
|
||||
for param in reset_params:
|
||||
param_state = optimizer_state[param]
|
||||
if len(param_state) == 0: # no state for this param, happens for ZeRo optimizer
|
||||
continue
|
||||
for key in optimizer_state_keys:
|
||||
pruning_fn(
|
||||
param_state[key]
|
||||
) # pruning fn has to be inplace to keep the same keys in the dict
|
||||
n_total += param_state[key].numel()
|
||||
n_zeros += torch.sum(param_state[key] == 0).item()
|
||||
|
||||
_zeroed = n_zeros / (1e-7 + n_total) * 100
|
||||
LOG.info(f"Percent of optimizer states zeroed: {_zeroed:.2f}")
|
||||
LOG.info(f"absolute n of optimizer states zeroed: {n_zeros}")
|
||||
if key == "step" and isinstance(param_state[key], int):
|
||||
param_state[key] = 0
|
||||
else:
|
||||
param_state[key] = torch.zeros_like(param_state[key])
|
||||
|
||||
|
||||
class ReLoRACallback(TrainerCallback):
|
||||
@@ -126,25 +97,6 @@ class ReLoRACallback(TrainerCallback):
|
||||
"relora",
|
||||
)
|
||||
|
||||
if "adam" in args.optim.lower():
|
||||
optimizer_state_keys = ["exp_avg", "exp_avg_sq"]
|
||||
else:
|
||||
raise ValueError(f"Optimizer {args.optim} not supported with ReLoRA")
|
||||
|
||||
lora_params = [
|
||||
n
|
||||
for n, p in model.named_parameters()
|
||||
if p.requires_grad and "lora_" in n
|
||||
]
|
||||
|
||||
model.save_pretrained(
|
||||
os.path.join(
|
||||
args.output_dir,
|
||||
f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}",
|
||||
"adapter",
|
||||
),
|
||||
safe_serialization=True,
|
||||
)
|
||||
with torch.no_grad():
|
||||
merge_and_save(
|
||||
model,
|
||||
@@ -155,11 +107,7 @@ class ReLoRACallback(TrainerCallback):
|
||||
actually_save=is_main_process(),
|
||||
cpu_offload=self.cpu_offload,
|
||||
)
|
||||
reset_optimizer(
|
||||
optimizer,
|
||||
reset_params=lora_params,
|
||||
optimizer_state_keys=optimizer_state_keys,
|
||||
)
|
||||
reset_optimizer(optimizer)
|
||||
|
||||
if self.quantized:
|
||||
self.last_full_model = checkpoint_folder
|
||||
@@ -249,13 +197,11 @@ class ReLoRAScheduler(LRScheduler):
|
||||
inner_schedule: LRScheduler,
|
||||
relora_steps: int,
|
||||
warmup_steps: int,
|
||||
anneal_steps: int = 1,
|
||||
min_lr_scale: float = 0.001,
|
||||
) -> None:
|
||||
self.inner_schedule = inner_schedule
|
||||
self.relora_steps = relora_steps
|
||||
self.warmup_steps = warmup_steps
|
||||
self.anneal_steps = anneal_steps
|
||||
self.min_lr_scale = min_lr_scale
|
||||
super().__init__(optimizer, inner_schedule.last_epoch, inner_schedule.verbose)
|
||||
|
||||
@@ -264,20 +210,10 @@ class ReLoRAScheduler(LRScheduler):
|
||||
|
||||
original = self.inner_schedule.get_lr()
|
||||
step = self.last_epoch
|
||||
|
||||
if step < self.relora_steps:
|
||||
scale = 1
|
||||
else:
|
||||
per_relora_progress = step % self.relora_steps
|
||||
if per_relora_progress < self.warmup_steps:
|
||||
cycle_t = min(1.0, (per_relora_progress) / self.warmup_steps)
|
||||
elif per_relora_progress > (self.relora_steps - self.anneal_steps):
|
||||
cycle_t = min(
|
||||
1.0,
|
||||
(self.relora_steps - per_relora_progress) / self.anneal_steps,
|
||||
)
|
||||
else:
|
||||
cycle_t = 1
|
||||
cycle_t = min(1.0, (step % self.relora_steps) / self.warmup_steps)
|
||||
scale = cycle_t * (1 - self.min_lr_scale) + self.min_lr_scale
|
||||
|
||||
if isinstance(original, Sequence):
|
||||
@@ -302,11 +238,7 @@ def sharded_paths(path: str, module_names: List[str]) -> Dict[str, str]:
|
||||
|
||||
def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor:
|
||||
if isinstance(layer, (peft.tuners.lora.Linear8bitLt, peft.tuners.lora.Linear4bit)):
|
||||
adapter: Union[List[str], str] = layer.active_adapter
|
||||
if isinstance(adapter, list):
|
||||
if len(adapter) > 1:
|
||||
raise ValueError("unhandled relora for multiple adapters")
|
||||
adapter = adapter[0]
|
||||
adapter = layer.active_adapter
|
||||
return (
|
||||
peft.utils.transpose(
|
||||
layer.lora_B[adapter].weight.detach().to(device)
|
||||
@@ -316,7 +248,7 @@ def lora_delta_weight(layer: peft.tuners.lora.LoraLayer, device) -> torch.Tensor
|
||||
* layer.scaling[adapter]
|
||||
)
|
||||
|
||||
raise ValueError("unhandled lora layer type")
|
||||
return layer.get_delta_weight().to(device)
|
||||
|
||||
|
||||
def find_lora_modules(model: peft.LoraModel) -> Dict[str, peft.tuners.lora.LoraLayer]:
|
||||
@@ -341,9 +273,9 @@ def update_weights(
|
||||
):
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
|
||||
if isinstance(target, peft.tuners.lora.Linear4bit):
|
||||
# This could be faster, but the quantization of Linear4bit weights occurs
|
||||
@@ -354,9 +286,7 @@ def update_weights(
|
||||
target.weight.data = new_weight.cpu()
|
||||
target.to(device)
|
||||
elif isinstance(target, peft.tuners.lora.Linear8bitLt):
|
||||
target.weight.data = (
|
||||
bnb.nn.Int8Params(new_weight, requires_grad=False).to(device).data
|
||||
)
|
||||
target.weight = bnb.nn.Int8Params(new_weight, requires_grad=False).to(device)
|
||||
else:
|
||||
target.weight.data = new_weight.to(device)
|
||||
|
||||
@@ -374,17 +304,14 @@ def merge_and_save(
|
||||
|
||||
if not quantized:
|
||||
for module_name, target in modules.items():
|
||||
active_adapter = target.active_adapter
|
||||
if isinstance(active_adapter, list):
|
||||
active_adapter = active_adapter[0]
|
||||
update = target.get_delta_weight(active_adapter).detach()
|
||||
update = target.get_delta_weight(target.active_adapter).detach()
|
||||
target.weight.data += update
|
||||
|
||||
if reinit:
|
||||
for adapter_name in target.lora_A:
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
for adapter_name in target.lora_embedding_A:
|
||||
target.reset_lora_parameters(adapter_name, True)
|
||||
target.reset_lora_parameters(adapter_name)
|
||||
return
|
||||
|
||||
os.makedirs(model_dst, exist_ok=True)
|
||||
@@ -436,7 +363,6 @@ def merge_and_save(
|
||||
LOG.info(f"saving tensors to {shard_fn}")
|
||||
st.save_file(out_tensors, shard_fn, metadata={"format": "pt"})
|
||||
|
||||
barrier()
|
||||
del in_tensors
|
||||
del out_tensors
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -1,15 +1,8 @@
|
||||
"""
|
||||
Shared utils for the monkeypatches
|
||||
"""
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from transformers.modeling_attn_mask_utils import (
|
||||
_prepare_4d_causal_attention_mask,
|
||||
_prepare_4d_causal_attention_mask_for_sdpa,
|
||||
)
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
@@ -96,6 +89,7 @@ def get_cu_seqlens(attn_mask):
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
"""generate a cumulative sequence length mask for flash attention using pos ids"""
|
||||
if len(position_ids.shape) == 1:
|
||||
@@ -141,18 +135,7 @@ def get_cu_seqlens_from_pos_ids(position_ids):
|
||||
results.append(cu_seqlens)
|
||||
max_seq_lens.append(max_seq_len)
|
||||
|
||||
# Find the maximum value across all tensors
|
||||
max_value = max(t.max() for t in results)
|
||||
|
||||
# Find the length of the longest tensor
|
||||
max_length = max(t.size(0) for t in results)
|
||||
|
||||
# Pad each tensor to the same length and collect them in a list
|
||||
padded_results = [
|
||||
F.pad(t, (0, max_length - t.size(0)), "constant", max_value) for t in results
|
||||
]
|
||||
|
||||
return torch.stack(padded_results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
return torch.stack(results).to(dtype=torch.int32), torch.stack(max_seq_lens)
|
||||
|
||||
|
||||
def set_module_name(model, name, value):
|
||||
@@ -166,62 +149,3 @@ def set_module_name(model, name, value):
|
||||
child_name = name
|
||||
|
||||
setattr(parent, child_name, value)
|
||||
|
||||
|
||||
def mask_2d_to_4d(
|
||||
mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None
|
||||
):
|
||||
"""
|
||||
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
|
||||
This expansion handles packed sequences so that sequences share the same attention mask integer value
|
||||
when they attend to each other within that sequence.
|
||||
This expansion transforms the mask to lower triangular form to prevent future peeking.
|
||||
"""
|
||||
bsz, src_len = mask.size()
|
||||
tgt_len = tgt_len if tgt_len is not None else src_len
|
||||
|
||||
mask = mask.unsqueeze(1).unsqueeze(2)
|
||||
mask = mask.expand(bsz, 1, tgt_len, src_len)
|
||||
|
||||
# Create a binary mask from the original mask where zeros remain zeros and all other values are set to one
|
||||
binary_mask = torch.where(
|
||||
mask != 0,
|
||||
torch.tensor(1, device=mask.device).to(dtype),
|
||||
torch.tensor(0, device=mask.device).to(dtype),
|
||||
)
|
||||
|
||||
# Create a block-diagonal mask.
|
||||
# we multiply by the binary mask so that 0's in the original mask are correctly excluded
|
||||
zero_one_mask = torch.eq(mask, mask.transpose(-1, -2)).int() * binary_mask
|
||||
|
||||
# Now let's create a lower triangular mask of ones that will zero out the upper triangular part
|
||||
lower_triangular_ones = torch.tril(torch.ones((tgt_len, src_len), dtype=dtype)).to(
|
||||
mask.device
|
||||
)
|
||||
|
||||
# Use the lower triangular mask to zero out the upper triangular part of the zero_one_mask
|
||||
masked_zero_one_mask = zero_one_mask * lower_triangular_ones
|
||||
|
||||
return masked_zero_one_mask
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
|
||||
def patched_prepare_4d_causal_attention_mask_for_sdpa(
|
||||
attention_mask: Optional[torch.Tensor],
|
||||
*args,
|
||||
):
|
||||
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float32
|
||||
return _prepare_4d_causal_attention_mask_for_sdpa(
|
||||
mask_2d_to_4d(attention_mask, dtype=dtype),
|
||||
*args,
|
||||
)
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
import os
|
||||
from typing import Callable, Generator, Tuple
|
||||
|
||||
import psycopg
|
||||
import psycopg.conninfo
|
||||
|
||||
|
||||
def pgsql(pgsql_table=None, id_field="id", **kwargs) -> Callable:
|
||||
pgsql_conn = os.environ.get("PGSQL_CONN", None)
|
||||
if not pgsql_conn:
|
||||
raise ValueError("missing PGSQL_CONN environment variable")
|
||||
conn_dict = psycopg.conninfo.conninfo_to_dict(pgsql_conn)
|
||||
|
||||
def data_generator() -> Generator[Tuple, None, None]:
|
||||
with psycopg.connect(**conn_dict) as conn:
|
||||
with conn.cursor() as cur:
|
||||
page_size = 10
|
||||
last_id = None
|
||||
while True:
|
||||
if last_id:
|
||||
where_clause = f" WHERE {id_field} > {last_id}"
|
||||
cur.execute(
|
||||
f"SELECT * FROM {pgsql_table}{where_clause} ORDER BY {id_field} ASC LIMIT {page_size}"
|
||||
)
|
||||
for row in cur.fetchall():
|
||||
yield row[id_field], dict(row)
|
||||
|
||||
return data_generator
|
||||
@@ -1,33 +0,0 @@
|
||||
"""Module containing the InstructShareGPTPromptTokenizingStrategy class"""
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from axolotl.prompt_tokenizers import ShareGPTPromptTokenizingStrategy
|
||||
from axolotl.prompters import ShareGPTPrompterV2
|
||||
|
||||
|
||||
def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
|
||||
conversation = (
|
||||
ds_cfg["conversation"] if ds_cfg and "conversation" in ds_cfg else None
|
||||
)
|
||||
strategy = InstructShareGPTPromptTokenizingStrategy(
|
||||
# pylint: disable=duplicate-code
|
||||
ShareGPTPrompterV2(
|
||||
conversation=conversation,
|
||||
),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
)
|
||||
return strategy
|
||||
|
||||
|
||||
class InstructShareGPTPromptTokenizingStrategy(ShareGPTPromptTokenizingStrategy):
|
||||
"""
|
||||
basic sharegpt strategy to grab conversations from the sample row
|
||||
"""
|
||||
|
||||
def get_conversation_thread(self, prompt):
|
||||
return [
|
||||
{"from": "human", "value": prompt["instruction"]},
|
||||
{"from": "gpt", "value": prompt["output"]},
|
||||
]
|
||||
@@ -1,58 +0,0 @@
|
||||
"""pretraining prompt strategies"""
|
||||
from typing import Generator
|
||||
|
||||
from transformers import BatchEncoding
|
||||
|
||||
from axolotl.prompt_tokenizers import PromptTokenizingStrategy
|
||||
|
||||
|
||||
class PretrainTokenizer:
|
||||
"""basic tokenization class for pretraining"""
|
||||
|
||||
def build_prompt(self, prompt) -> Generator[str, None, None]:
|
||||
yield prompt
|
||||
|
||||
|
||||
class PretrainTokenizationStrategy(PromptTokenizingStrategy):
|
||||
"""handles tokenization for pretraining with strides"""
|
||||
|
||||
@property
|
||||
def supports_batched(self):
|
||||
return True
|
||||
|
||||
def __init__(self, *args, max_length=None, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if max_length:
|
||||
self.max_length = max_length
|
||||
|
||||
def _tokenize(
|
||||
self, prompt: str, add_eos_token: bool = True, strip_bos_token: bool = False
|
||||
) -> BatchEncoding:
|
||||
res = self.tokenizer(
|
||||
prompt,
|
||||
truncation=True,
|
||||
max_length=self.max_length - 1,
|
||||
add_special_tokens=True,
|
||||
return_overflowing_tokens=True,
|
||||
stride=256,
|
||||
)
|
||||
res["input_ids"] = [
|
||||
seq + [self.tokenizer.eos_token_id] for seq in res["input_ids"]
|
||||
]
|
||||
res["attention_mask"] = [seq + [1] for seq in res["attention_mask"]]
|
||||
|
||||
return res
|
||||
|
||||
def tokenize_prompt(self, prompt):
|
||||
return self._tokenize(prompt["text"])
|
||||
|
||||
|
||||
def load(tokenizer, cfg):
|
||||
strat = PretrainTokenizationStrategy(
|
||||
PretrainTokenizer(),
|
||||
tokenizer,
|
||||
cfg.train_on_inputs,
|
||||
cfg.sequence_len,
|
||||
max_length=cfg.sequence_len * 64,
|
||||
)
|
||||
return strat
|
||||
@@ -11,6 +11,7 @@ import torch
|
||||
import transformers.modelcard
|
||||
from accelerate.logging import get_logger
|
||||
from datasets import Dataset
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import PeftModel
|
||||
from pkg_resources import get_distribution # type: ignore
|
||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||
@@ -23,11 +24,6 @@ from axolotl.utils.freeze import freeze_parameters_except
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.trainer import setup_trainer
|
||||
|
||||
try:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
except ImportError:
|
||||
BetterTransformer = None
|
||||
|
||||
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
src_dir = os.path.join(project_root, "src")
|
||||
sys.path.insert(0, src_dir)
|
||||
@@ -61,21 +57,6 @@ def train(
|
||||
eval_dataset = dataset_meta.eval_dataset
|
||||
total_num_steps = dataset_meta.total_num_steps
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
# Load the model and tokenizer
|
||||
msg = "loading model"
|
||||
if cfg.adapter:
|
||||
@@ -98,6 +79,21 @@ def train(
|
||||
|
||||
safe_serialization = cfg.save_safetensors is True
|
||||
|
||||
if cfg.resume_from_checkpoint is None and cfg.auto_resume_from_checkpoints:
|
||||
possible_checkpoints = [
|
||||
str(cp) for cp in Path(cfg.output_dir).glob("checkpoint-*")
|
||||
]
|
||||
if len(possible_checkpoints) > 0:
|
||||
sorted_paths = sorted(
|
||||
possible_checkpoints,
|
||||
key=lambda path: int(path.split("-")[-1]),
|
||||
)
|
||||
cfg.resume_from_checkpoint = sorted_paths[-1]
|
||||
LOG.info(
|
||||
f"Using Auto-resume functionality to start with checkpoint at {cfg.resume_from_checkpoint}"
|
||||
)
|
||||
resume_from_checkpoint = cfg.resume_from_checkpoint
|
||||
|
||||
if cfg.unfrozen_parameters:
|
||||
freeze_parameters_except(model, cfg.unfrozen_parameters)
|
||||
|
||||
@@ -128,7 +124,7 @@ def train(
|
||||
if cfg.local_rank == 0:
|
||||
|
||||
def terminate_handler(_, __, model):
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
sys.exit(0)
|
||||
@@ -153,10 +149,7 @@ def train(
|
||||
pretrain_hooks(cfg, trainer)
|
||||
if cfg.flash_optimum:
|
||||
with torch.backends.cuda.sdp_kernel(
|
||||
# TODO configure these from the YAML w/ sdp_kernel_kwargs: ...
|
||||
enable_flash=True,
|
||||
enable_math=True,
|
||||
enable_mem_efficient=True,
|
||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||
):
|
||||
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
|
||||
else:
|
||||
@@ -202,16 +195,13 @@ def train(
|
||||
state_dict=trainer.accelerator.get_state_dict(trainer.model_wrapped),
|
||||
)
|
||||
elif cfg.local_rank == 0:
|
||||
if cfg.flash_optimum and BetterTransformer:
|
||||
if cfg.flash_optimum:
|
||||
model = BetterTransformer.reverse(model)
|
||||
|
||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||
|
||||
if not cfg.hub_model_id:
|
||||
try:
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
except AttributeError:
|
||||
pass
|
||||
trainer.create_model_card(model_name=cfg.output_dir.lstrip("./"))
|
||||
elif cfg.hub_model_id:
|
||||
# defensively push to the hub to ensure the model card is updated
|
||||
trainer.push_to_hub()
|
||||
|
||||
@@ -47,12 +47,6 @@ def gpu_memory_usage_all(device=0):
|
||||
return usage, reserved - usage, max(0, smi - reserved)
|
||||
|
||||
|
||||
def mps_memory_usage_all():
|
||||
usage = torch.mps.current_allocated_memory() / 1024.0**3
|
||||
reserved = torch.mps.driver_allocated_memory() / 1024.0**3
|
||||
return usage, reserved - usage, 0
|
||||
|
||||
|
||||
@check_cuda_device(0.0)
|
||||
def gpu_memory_usage_smi(device=0):
|
||||
if isinstance(device, torch.device):
|
||||
@@ -69,10 +63,7 @@ def gpu_memory_usage_smi(device=0):
|
||||
|
||||
|
||||
def log_gpu_memory_usage(log, msg, device):
|
||||
if torch.backends.mps.is_available():
|
||||
usage, cache, misc = mps_memory_usage_all()
|
||||
else:
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
usage, cache, misc = gpu_memory_usage_all(device)
|
||||
extras = []
|
||||
if cache > 0:
|
||||
extras.append(f"+{cache:.03f}GB cache")
|
||||
|
||||
@@ -19,7 +19,6 @@ def chat_templates(user_choice: str):
|
||||
"""
|
||||
|
||||
templates = {
|
||||
"alpaca": "{% for message in messages %}{% if message['role'] == 'user' %}{{ '### Instruction: ' + message['content'] + '\n\n' }}{% elif message['role'] == 'assistant' %}{{ '### Response: ' + message['content'] + eos_token}}{% endif %}{% endfor %}",
|
||||
"inst": "{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token}}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}", # I don't know what this one is called. Used by Mistral/Mixtral.
|
||||
"chatml": "{% if messages[0]['role'] == 'system' %}{% set loop_messages = messages[1:] %}{% set system_message = messages[0]['content'] %}{% else %}{% set loop_messages = messages %}{% set system_message = 'You are a helpful assistant.' %}{% endif %}{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in loop_messages %}{% if loop.index0 == 0 %}{{'<|im_start|>system\n' + system_message + '<|im_end|>\n'}}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}",
|
||||
}
|
||||
|
||||
@@ -132,26 +132,24 @@ class BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(1) * np.array(item[feature])
|
||||
for item in features
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -161,27 +159,28 @@ class V2BatchSamplerDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
||||
"""
|
||||
|
||||
def __call__(self, features, return_tensors=None):
|
||||
if not isinstance(features[0], list):
|
||||
features = [features]
|
||||
out_features = [{} for _ in features]
|
||||
for i, features_ in enumerate(features):
|
||||
for feature in features_[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features_)
|
||||
if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features_ if feature in item
|
||||
]
|
||||
out_features[i][feature] = np.concatenate(arrays)
|
||||
return super().__call__(out_features, return_tensors=return_tensors)
|
||||
chunked_data = {}
|
||||
for feature in features[0].keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
if feature == "attention_mask":
|
||||
arrays = [
|
||||
(i + 1) * np.array(item[feature])
|
||||
for i, item in enumerate(features)
|
||||
if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
else:
|
||||
arrays = [
|
||||
np.array(item[feature]) for item in features if feature in item
|
||||
]
|
||||
chunked_data[feature] = np.concatenate(arrays)
|
||||
features = [chunked_data]
|
||||
return super().__call__(features, return_tensors=return_tensors)
|
||||
|
||||
@dataclass
|
||||
class BatchSamplerDPODataCollatorWithPadding:
|
||||
|
||||
|
||||
@dataclass
|
||||
class MambaDataCollator:
|
||||
|
||||
@@ -202,20 +202,6 @@ def validate_config(cfg):
|
||||
raise ValueError(
|
||||
"bf16 requested, but AMP is not supported on this GPU. Requires Ampere series or above."
|
||||
)
|
||||
if (
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
not (cfg.bf16 or cfg.bfloat16)
|
||||
and (cfg.fp16 or cfg.float16)
|
||||
and not cfg.adapter
|
||||
and not cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
LOG.warning(
|
||||
"Full fine tune w/o FA2 w/ sample packing and fp16/float16 is likely to raise errors. Try LoRA."
|
||||
)
|
||||
# ValueError: Attempting to unscale FP16 gradients.
|
||||
# OR
|
||||
# RuntimeError: expected mat1 and mat2 to have the same dtype, but got: float != c10::Half
|
||||
if cfg.max_packed_sequence_len:
|
||||
raise DeprecationWarning("`max_packed_sequence_len` is no longer supported")
|
||||
|
||||
@@ -322,7 +308,7 @@ def validate_config(cfg):
|
||||
LOG.warning("BetterTransformers probably doesn't work with PEFT adapters")
|
||||
if cfg.fp16 or cfg.bf16:
|
||||
raise ValueError("AMP is not supported with BetterTransformer")
|
||||
if cfg.float16 is not True and cfg.bfloat16 is not True:
|
||||
if cfg.float16 is not True and cfg.bloat16 is not True:
|
||||
LOG.warning(
|
||||
"You should probably set bfloat16 or float16 to true to "
|
||||
"load the model in float16 for BetterTransformers"
|
||||
@@ -364,24 +350,17 @@ def validate_config(cfg):
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
# # incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
# raise ValueError(
|
||||
# "sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
# )
|
||||
if cfg.sample_packing and cfg.sdp_attention:
|
||||
# incompatible due to bug w/ accelerate causing 0.0 loss when using llama2
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with sdp_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.xformers_attention:
|
||||
raise ValueError(
|
||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||
)
|
||||
|
||||
if cfg.sample_packing and cfg.sdp_attention and (cfg.bfloat16 or cfg.bf16):
|
||||
# https://github.com/pytorch/pytorch/blob/1b03423526536b5f3d35bdfa95ccc6197556cf9b/test/test_transformers.py#L2440-L2450
|
||||
LOG.warning(
|
||||
"sample_packing & torch sdpa with bf16 is unsupported may results in 0.0 loss. "
|
||||
"This may work on H100s."
|
||||
)
|
||||
|
||||
if cfg.early_stopping_patience:
|
||||
if not cfg.save_steps or not cfg.eval_steps:
|
||||
raise ValueError(
|
||||
@@ -447,11 +426,7 @@ def validate_config(cfg):
|
||||
"evaluation_strategy and eval_steps mismatch. Please set evaluation_strategy to 'steps' or remove eval_steps."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.val_set_size == 0
|
||||
and (cfg.eval_steps or cfg.evaluation_strategy)
|
||||
and not cfg.test_datasets
|
||||
):
|
||||
if cfg.val_set_size == 0 and (cfg.eval_steps or cfg.evaluation_strategy):
|
||||
raise ValueError(
|
||||
"eval_steps and evaluation_strategy are not supported with val_set_size == 0"
|
||||
)
|
||||
|
||||
@@ -1,23 +1,20 @@
|
||||
"""Module containing data utilities"""
|
||||
import functools
|
||||
import hashlib
|
||||
import importlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
from datasets import (
|
||||
Dataset,
|
||||
DatasetDict,
|
||||
IterableDataset,
|
||||
concatenate_datasets,
|
||||
load_dataset,
|
||||
load_from_disk,
|
||||
)
|
||||
from datasets.iterable_dataset import ExamplesIterable
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HFValidationError
|
||||
from torch.utils.data import RandomSampler
|
||||
@@ -67,25 +64,6 @@ def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||
|
||||
|
||||
def get_streaming_dataset(ds_cfg):
|
||||
path = ds_cfg["path"]
|
||||
func = None
|
||||
try:
|
||||
load_fn = path.split(".")[-1]
|
||||
module_name = ".".join(load_fn.split(".")[:-1])
|
||||
mod = importlib.import_module(f".{module_name}", "axolotl")
|
||||
func = getattr(mod, load_fn)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if func:
|
||||
data_producer = func(**ds_cfg)
|
||||
return IterableDataset(ExamplesIterable(data_producer, {}))
|
||||
else:
|
||||
split = ds_cfg["split"] or "train"
|
||||
return load_dataset(path, streaming=True, split=split, name=ds_cfg["name"])
|
||||
|
||||
|
||||
def prepare_dataset(cfg, tokenizer):
|
||||
prompters = []
|
||||
if not cfg.pretraining_dataset:
|
||||
@@ -102,21 +80,20 @@ def prepare_dataset(cfg, tokenizer):
|
||||
tokenizer, cfg, DEFAULT_DATASET_PREPARED_PATH
|
||||
)
|
||||
else:
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
tokenizer,
|
||||
cfg,
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
)
|
||||
path = cfg.pretraining_dataset
|
||||
name = None
|
||||
if isinstance(cfg.pretraining_dataset, list) and isinstance(
|
||||
cfg.pretraining_dataset[0], dict
|
||||
):
|
||||
path = cfg.pretraining_dataset[0]["path"]
|
||||
name = cfg.pretraining_dataset[0]["name"]
|
||||
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
get_streaming_dataset(cfg.pretraining_dataset[0]),
|
||||
train_dataset = load_pretraining_dataset(
|
||||
path,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
name=name,
|
||||
max_tokens=cfg.sequence_len,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seed=cfg.seed or 42,
|
||||
)
|
||||
# https://discuss.huggingface.co/t/how-to-use-huggingface-trainer-streaming-datasets-without-wrapping-it-with-torchdatas-iterablewrapper/25230
|
||||
@@ -163,7 +140,7 @@ def load_tokenized_prepared_datasets(
|
||||
+ "|".join(
|
||||
sorted(
|
||||
[
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}{d.split}"
|
||||
f"{d.path}:{d.type}:{d.shards}:{d.conversation}"
|
||||
for d in cfg_datasets
|
||||
]
|
||||
)
|
||||
@@ -350,16 +327,6 @@ def load_tokenized_prepared_datasets(
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
elif config_dataset.path.startswith("https://"):
|
||||
ds_type = get_ds_type(config_dataset)
|
||||
ds = load_dataset(
|
||||
ds_type,
|
||||
name=config_dataset.name,
|
||||
data_files=config_dataset.path,
|
||||
streaming=False,
|
||||
split=None,
|
||||
storage_options=storage_options,
|
||||
)
|
||||
else:
|
||||
if isinstance(config_dataset.data_files, str):
|
||||
fp = hf_hub_download(
|
||||
@@ -416,9 +383,9 @@ def load_tokenized_prepared_datasets(
|
||||
|
||||
dataset_wrapper, dataset_prompter = get_dataset_wrapper(
|
||||
config_dataset=config_dataset,
|
||||
dataset=ds,
|
||||
tokenizer=tokenizer,
|
||||
cfg=cfg,
|
||||
dataset=ds,
|
||||
d_base_type=d_base_type,
|
||||
d_prompt_style=d_prompt_style,
|
||||
)
|
||||
@@ -529,12 +496,7 @@ def load_prepare_datasets(
|
||||
|
||||
|
||||
def get_dataset_wrapper(
|
||||
config_dataset,
|
||||
tokenizer,
|
||||
cfg,
|
||||
d_base_type,
|
||||
dataset,
|
||||
d_prompt_style=None,
|
||||
config_dataset, dataset, tokenizer, cfg, d_base_type, d_prompt_style
|
||||
):
|
||||
dataset_wrapper = None
|
||||
dataset_prompter = None
|
||||
@@ -545,8 +507,7 @@ def get_dataset_wrapper(
|
||||
}
|
||||
|
||||
if (
|
||||
isinstance(dataset, Dataset)
|
||||
and "input_ids" in dataset.features
|
||||
"input_ids" in dataset.features
|
||||
and "attention_mask" in dataset.features
|
||||
and "labels" in dataset.features
|
||||
):
|
||||
@@ -804,67 +765,76 @@ def encode_pretraining(
|
||||
return ret
|
||||
|
||||
|
||||
def wrap_pretraining_dataset(
|
||||
dataset,
|
||||
tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_fn,
|
||||
max_tokens=2048,
|
||||
batch_size=1,
|
||||
seed=42,
|
||||
buffer_size=10_000,
|
||||
):
|
||||
def load_pretraining_dataset(path, tokenizer, cfg, name=None, max_tokens=2048, seed=42):
|
||||
if cfg.sample_packing:
|
||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||
tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=max_tokens * batch_size,
|
||||
pad_to_multiple_of=max_tokens * cfg.micro_batch_size,
|
||||
)
|
||||
encode = functools.partial(
|
||||
encode_packed_pretraining,
|
||||
tokenizer,
|
||||
collate_fn,
|
||||
ds_wrapper_fn,
|
||||
max_seq_length=max_tokens,
|
||||
batch_size=batch_size,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
)
|
||||
# set this to 1 so downstream data_loader doesn't try to increase the batch again
|
||||
cfg.micro_batch_size = 1
|
||||
else:
|
||||
encode = functools.partial(encode_pretraining, tokenizer, max_tokens)
|
||||
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=buffer_size)
|
||||
dataset = load_dataset(path, streaming=True, split="train", name=name)
|
||||
dataset = dataset.shuffle(seed=seed, buffer_size=10_000)
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
batch_size=buffer_size,
|
||||
# input_columns="text",
|
||||
batch_size=10_000,
|
||||
input_columns="text",
|
||||
# remove all the existing columns after mapping since they end up having
|
||||
# a different length than the encoded/tokenized column
|
||||
remove_columns=dataset.features.keys(),
|
||||
desc="Encoding Pretraining",
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def encode_packed_pretraining(
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
collate_fn,
|
||||
ds_wrapper: Callable,
|
||||
examples: Dict[str, List],
|
||||
examples: List[str],
|
||||
max_seq_length: int = 2048,
|
||||
batch_size: int = 4,
|
||||
) -> Dict[str, List]:
|
||||
# pylint: disable=duplicate-code
|
||||
# tokenize all the examples
|
||||
# rows get split with stride (overlap)
|
||||
train_dataset = ds_wrapper(Dataset.from_dict(examples))[0]
|
||||
res = tokenizer(
|
||||
examples,
|
||||
truncation=True,
|
||||
max_length=max_seq_length - 1,
|
||||
add_special_tokens=True,
|
||||
return_overflowing_tokens=True,
|
||||
stride=256,
|
||||
)
|
||||
|
||||
input_ids = [seq + [tokenizer.eos_token_id] for seq in res["input_ids"]]
|
||||
attention_mask = [seq + [1] for seq in res["attention_mask"]]
|
||||
|
||||
tokenized_examples = {
|
||||
"input_ids": input_ids,
|
||||
"attention_mask": attention_mask,
|
||||
}
|
||||
|
||||
train_dataset = Dataset.from_dict(tokenized_examples)
|
||||
train_dataset = process_pretraining_datasets_for_packing(
|
||||
train_dataset, max_seq_length
|
||||
)
|
||||
|
||||
sampler = MultipackBatchSampler(
|
||||
RandomSampler(train_dataset),
|
||||
batch_size=1,
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_size * max_seq_length,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
@@ -872,23 +842,15 @@ def encode_packed_pretraining(
|
||||
|
||||
chunked_data = defaultdict(list)
|
||||
|
||||
for batch in sampler:
|
||||
for data in batch:
|
||||
features = train_dataset[data]
|
||||
if "num_truncated_tokens" in features:
|
||||
del features["num_truncated_tokens"]
|
||||
if "num_truncated_tokens" in features:
|
||||
del features["num_truncated_tokens"]
|
||||
if "overflow_to_sample_mapping" in features:
|
||||
del features["overflow_to_sample_mapping"]
|
||||
if "labels" not in features:
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
for data in sampler:
|
||||
features = train_dataset[data]
|
||||
features["labels"] = features["input_ids"].copy()
|
||||
collated_features = collate_fn(features)
|
||||
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
for feature in features.keys():
|
||||
if feature == "length":
|
||||
continue
|
||||
chunked_data[feature].append(collated_features[feature].squeeze(0))
|
||||
|
||||
return chunked_data
|
||||
|
||||
|
||||
@@ -8,13 +8,8 @@ import addict
|
||||
import bitsandbytes as bnb
|
||||
import torch
|
||||
import transformers
|
||||
from peft import (
|
||||
LoftQConfig,
|
||||
PeftConfig,
|
||||
PeftModel,
|
||||
PeftModelForCausalLM,
|
||||
prepare_model_for_kbit_training,
|
||||
)
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
from peft import LoftQConfig, PeftConfig, prepare_model_for_kbit_training
|
||||
from peft.tuners.lora import QuantLinear
|
||||
from transformers import ( # noqa: F401
|
||||
AddedToken,
|
||||
@@ -29,10 +24,6 @@ from transformers import ( # noqa: F401
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
|
||||
from axolotl.models.mamba import fix_mamba_attn_for_loss
|
||||
from axolotl.monkeypatch.multipack import (
|
||||
SUPPORTED_MULTIPACK_MODEL_TYPES,
|
||||
patch_for_multipack,
|
||||
)
|
||||
from axolotl.prompt_tokenizers import LLAMA_DEFAULT_EOS_TOKEN
|
||||
from axolotl.utils.bench import log_gpu_memory_usage
|
||||
from axolotl.utils.chat_templates import chat_templates
|
||||
@@ -170,20 +161,15 @@ def load_tokenizer(cfg):
|
||||
if getattr(tokenizer, attr_name) is None:
|
||||
setattr(tokenizer, attr_name, "<|endoftext|>")
|
||||
|
||||
additional_special_tokens = None
|
||||
if cfg.special_tokens:
|
||||
special_tokens = cfg.special_tokens.to_dict()
|
||||
additional_special_tokens = special_tokens.pop(
|
||||
"additional_special_tokens", None
|
||||
)
|
||||
lora_modules_to_save = get_linear_embedding_layers(model_config.model_type)
|
||||
for k, val in special_tokens.items():
|
||||
for k, val in cfg.special_tokens.items():
|
||||
# check if new special token is not already in tokenizer and
|
||||
# is adapter training to make sure lora_modules_to_save is set
|
||||
# pylint: disable=too-many-boolean-expressions
|
||||
if (
|
||||
(getattr(tokenizer, k) is None or getattr(tokenizer, k) != val)
|
||||
and (len(tokenizer.encode(val, add_special_tokens=False)) > 2)
|
||||
and (len(tokenizer.encode(val)) > 1)
|
||||
and cfg.adapter
|
||||
and (
|
||||
not cfg.lora_modules_to_save
|
||||
@@ -227,21 +213,6 @@ def load_tokenizer(cfg):
|
||||
]
|
||||
)
|
||||
|
||||
# Additional special tokens are a List, and need to be treated differently than regular special
|
||||
# tokens. We add them after we have called `add_tokens` in case these additional special tokens
|
||||
# are new tokens.
|
||||
#
|
||||
# Usage:
|
||||
#
|
||||
# ```py
|
||||
# special_tokens:
|
||||
# additional_special_tokens: ["<|im_start|>", "<|im_end|>"]
|
||||
# ```
|
||||
if additional_special_tokens is not None:
|
||||
tokenizer.add_special_tokens(
|
||||
{"additional_special_tokens": additional_special_tokens}
|
||||
)
|
||||
|
||||
LOG.debug(f"EOS: {tokenizer.eos_token_id} / {tokenizer.eos_token}")
|
||||
LOG.debug(f"BOS: {tokenizer.bos_token_id} / {tokenizer.bos_token}")
|
||||
LOG.debug(f"PAD: {tokenizer.pad_token_id} / {tokenizer.pad_token}")
|
||||
@@ -303,15 +274,8 @@ def load_model(
|
||||
shifted-sparse attention does not currently support sample packing."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.model_config_type in SUPPORTED_MULTIPACK_MODEL_TYPES
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
patch_for_multipack(cfg.model_config_type)
|
||||
elif cfg.is_llama_derived_model:
|
||||
# Modify all llama derived models in one block
|
||||
|
||||
# Modify all llama derived models in one block
|
||||
if cfg.is_llama_derived_model:
|
||||
if cfg.flash_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_flash import (
|
||||
replace_llama_attn_with_flash_attn,
|
||||
@@ -340,13 +304,13 @@ def load_model(
|
||||
|
||||
LOG.info("patching with xformers attention")
|
||||
hijack_llama_attention()
|
||||
elif cfg.sample_packing:
|
||||
from axolotl.monkeypatch.llama_patch_multipack import (
|
||||
hijack_llama_prepare_4d_mask,
|
||||
elif cfg.sdp_attention:
|
||||
from axolotl.monkeypatch.llama_attn_hijack_sdp import (
|
||||
hijack_llama_sdp_attention,
|
||||
)
|
||||
|
||||
LOG.info("patching llama _prepare_4d_causal_attention_mask*")
|
||||
hijack_llama_prepare_4d_mask()
|
||||
LOG.info("patching with sdp attention")
|
||||
hijack_llama_sdp_attention()
|
||||
elif cfg.s2_attention:
|
||||
raise NotImplementedError(
|
||||
"Shifted-sparse attention not currently implemented without flash attention."
|
||||
@@ -365,6 +329,43 @@ def load_model(
|
||||
LOG.info("patching mistral with flash attention")
|
||||
replace_mistral_attn_with_flash_attn(packed=cfg.sample_packing)
|
||||
|
||||
if (
|
||||
cfg.model_config_type == "mixtral"
|
||||
and cfg.flash_attention
|
||||
and cfg.sample_packing
|
||||
):
|
||||
from axolotl.monkeypatch.mixtral import (
|
||||
replace_mixtral_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching mixtral with flash attention")
|
||||
mixtral_patch_kwargs = {}
|
||||
if is_deepspeed_zero3_enabled():
|
||||
mixtral_patch_kwargs["for_zero3"] = True
|
||||
replace_mixtral_attn_with_multipack_flash_attn(**mixtral_patch_kwargs)
|
||||
|
||||
if cfg.model_config_type == "falcon" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.falcon import (
|
||||
replace_falcon_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching falcon with flash attention")
|
||||
replace_falcon_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "phi" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.phi import replace_phi_attn_with_multipack_flash_attn
|
||||
|
||||
LOG.info("patching phi with flash attention")
|
||||
replace_phi_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.model_config_type == "qwen2" and cfg.flash_attention and cfg.sample_packing:
|
||||
from axolotl.monkeypatch.qwen2 import (
|
||||
replace_qwen2_attn_with_multipack_flash_attn,
|
||||
)
|
||||
|
||||
LOG.info("patching qwen2 with flash attention")
|
||||
replace_qwen2_attn_with_multipack_flash_attn()
|
||||
|
||||
if cfg.is_llama_derived_model and cfg.sample_packing and not inference:
|
||||
from axolotl.monkeypatch.llama_expand_mask import hijack_expand_mask
|
||||
|
||||
@@ -374,7 +375,7 @@ def load_model(
|
||||
model_kwargs: Dict[str, Any] = {}
|
||||
|
||||
if cfg.model_kwargs:
|
||||
for key, val in cfg.model_kwargs.items():
|
||||
for key, val in model_kwargs.items():
|
||||
model_kwargs[key] = val
|
||||
|
||||
max_memory = cfg.max_memory
|
||||
@@ -409,10 +410,6 @@ def load_model(
|
||||
|
||||
model_kwargs["device_map"] = device_map
|
||||
model_kwargs["torch_dtype"] = cfg.torch_dtype
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
model_kwargs["device_map"] = "mps:0"
|
||||
|
||||
# TODO can we put the reference model on it's own gpu? I think we have to move logits around to calculate loss
|
||||
# if cfg.rl:
|
||||
# if torch.cuda.device_count() > 1:
|
||||
@@ -456,18 +453,6 @@ def load_model(
|
||||
**bnb_config,
|
||||
)
|
||||
|
||||
if cfg.load_in_8bit and cfg.adapter is not None:
|
||||
model_kwargs["load_in_8bit"] = True
|
||||
if cfg.load_in_4bit and cfg.adapter is not None:
|
||||
model_kwargs["load_in_4bit"] = True
|
||||
|
||||
# no longer needed per https://github.com/huggingface/transformers/pull/26610
|
||||
if "quantization_config" in model_kwargs or cfg.gptq:
|
||||
if "load_in_8bit" in model_kwargs:
|
||||
del model_kwargs["load_in_8bit"]
|
||||
if "load_in_4bit" in model_kwargs:
|
||||
del model_kwargs["load_in_4bit"]
|
||||
|
||||
# sample packing uses custom FA2 patch
|
||||
if cfg.flash_attention:
|
||||
if not cfg.sample_packing:
|
||||
@@ -479,7 +464,7 @@ def load_model(
|
||||
"flash_attention_2"
|
||||
)
|
||||
else:
|
||||
if model_config.model_type in SUPPORTED_MULTIPACK_MODEL_TYPES:
|
||||
if model_config.model_type in ["mixtral", "qwen2", "falcon", "phi"]:
|
||||
model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"flash_attention_2"
|
||||
@@ -489,12 +474,6 @@ def load_model(
|
||||
model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
"eager"
|
||||
)
|
||||
elif cfg.sdp_attention:
|
||||
model_kwargs["attn_implementation"] = "sdpa"
|
||||
model_config._attn_implementation = "sdpa" # pylint: disable=protected-access
|
||||
elif cfg.eager_attention:
|
||||
model_kwargs["attn_implementation"] = "eager"
|
||||
model_config._attn_implementation = "eager" # pylint: disable=protected-access
|
||||
|
||||
try:
|
||||
if (
|
||||
@@ -507,6 +486,8 @@ def load_model(
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
**model_kwargs,
|
||||
)
|
||||
|
||||
@@ -574,6 +555,8 @@ def load_model(
|
||||
model = getattr(transformers, model_type).from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -605,6 +588,8 @@ def load_model(
|
||||
model = AutoModelForCausalLM.from_pretrained(
|
||||
base_model,
|
||||
config=model_config,
|
||||
load_in_8bit=cfg.load_in_8bit and cfg.adapter is not None,
|
||||
load_in_4bit=cfg.load_in_4bit and cfg.adapter is not None,
|
||||
trust_remote_code=cfg.trust_remote_code or False,
|
||||
**model_kwargs,
|
||||
)
|
||||
@@ -612,9 +597,6 @@ def load_model(
|
||||
LOG.exception(err)
|
||||
raise err
|
||||
|
||||
if isinstance(model, (PeftModel, PeftModelForCausalLM)):
|
||||
model = model.merge_and_unload()
|
||||
|
||||
embeddings_len = (
|
||||
math.ceil(len(tokenizer) / 32) * 32
|
||||
if cfg.resize_token_embeddings_to_32x
|
||||
@@ -655,7 +637,7 @@ def load_model(
|
||||
):
|
||||
model.config.eos_token_id = tokenizer.eos_token_id
|
||||
|
||||
if hasattr(model, "device") and model.device.type in ("cuda", "mps"):
|
||||
if hasattr(model, "device") and model.device.type == "cuda":
|
||||
log_gpu_memory_usage(LOG, "after model load", model.device)
|
||||
|
||||
# make sure these are fp32 per Ramesh et al. (2021)
|
||||
@@ -663,7 +645,7 @@ def load_model(
|
||||
if not cfg.fsdp:
|
||||
# FSDP doesn't like mixed Float and BFloat16
|
||||
for name, module in model.named_modules():
|
||||
if "norm" in name or name.endswith(".gate"):
|
||||
if any(m in name for m in ["norm", "gate"]):
|
||||
module.to(torch.float32)
|
||||
if model_config.model_type == "btlm":
|
||||
# don't upcast lm_head for btlm
|
||||
@@ -676,9 +658,7 @@ def load_model(
|
||||
skip_prepare_model_for_kbit_training = False
|
||||
|
||||
if cfg.model_config_type == "mixtral" and is_deepspeed_zero3_enabled():
|
||||
from deepspeed.utils import ( # pylint: disable=no-name-in-module
|
||||
set_z3_leaf_modules,
|
||||
)
|
||||
from deepspeed.utils import set_z3_leaf_modules
|
||||
from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
|
||||
|
||||
set_z3_leaf_modules(model, [MixtralSparseMoeBlock])
|
||||
@@ -741,8 +721,6 @@ def load_model(
|
||||
model.config.use_cache = False
|
||||
|
||||
if cfg.flash_optimum:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
|
||||
model = BetterTransformer.transform(model)
|
||||
|
||||
if cfg.adapter is not None:
|
||||
@@ -769,7 +747,7 @@ def load_adapter(model, cfg, adapter, inference=False):
|
||||
|
||||
def load_llama_adapter(model, cfg):
|
||||
# type: (PreTrainedModel, DictDefault) -> Tuple[PreTrainedModel, Optional[PeftConfig]]
|
||||
from peft import AdaptionPromptConfig, get_peft_model
|
||||
from peft import AdaptionPromptConfig, PeftModel, get_peft_model
|
||||
|
||||
peft_config = AdaptionPromptConfig(
|
||||
adapter_layers=cfg.peft_adapter.layers, # layers (L)
|
||||
@@ -815,7 +793,7 @@ def find_all_linear_names(model):
|
||||
def load_lora(model, cfg, inference=False, config_only=False):
|
||||
# type: (PreTrainedModel, DictDefault, bool, bool) -> Tuple[Optional[PreTrainedModel], Optional[PeftConfig]]
|
||||
|
||||
from peft import LoraConfig, get_peft_model
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
lora_target_modules = list(cfg.lora_target_modules or [])
|
||||
|
||||
|
||||
@@ -117,7 +117,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
packing_efficiency_estimate: float = 1.0,
|
||||
):
|
||||
super().__init__(sampler, batch_size, drop_last)
|
||||
self.batch_size = batch_size
|
||||
self.batch_size = None
|
||||
self.batch_max_len = batch_max_len
|
||||
self.lengths: np.ndarray = lengths
|
||||
self.packing_efficiency_estimate = packing_efficiency_estimate or 1.0
|
||||
@@ -147,13 +147,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
n=1,
|
||||
)
|
||||
|
||||
batches = [
|
||||
[
|
||||
[indices[b_idx] for b_idx in batch]
|
||||
for batch in batches[i : i + self.batch_size]
|
||||
]
|
||||
for i in range(0, len(batches), self.batch_size)
|
||||
]
|
||||
batches = [[indices[b_idx] for b_idx in batch] for batch in batches]
|
||||
|
||||
# statistics
|
||||
if set_stats:
|
||||
@@ -195,7 +189,7 @@ class MultipackBatchSampler(BatchSampler):
|
||||
0.99
|
||||
* lengths_sum_per_device
|
||||
/ self.packing_efficiency_estimate
|
||||
// (self.batch_max_len * self.batch_size)
|
||||
// self.batch_max_len
|
||||
)
|
||||
- 1
|
||||
),
|
||||
|
||||
61
src/axolotl/utils/tensors.py
Normal file
61
src/axolotl/utils/tensors.py
Normal file
@@ -0,0 +1,61 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def keep_unpacked_data(data: torch.Tensor, index=None, nonzero_total=None, pad_val= None, pairs=False):
|
||||
# pad val could be padding token (input_ids), -100 (labels), or 0 (attention_mask)
|
||||
if index >= nonzero_total:
|
||||
return False
|
||||
if pairs and (index // 2) >= (nonzero_total // 2):
|
||||
return False
|
||||
if pad_val and (data == pad_val).all(dim=0).all():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def split_and_pad_packed(tensor, cu_seqlens, max_seqlen, keep_fn=None):
|
||||
split_tensors = []
|
||||
|
||||
counts = count_nonzero_sequences(cu_seqlens)
|
||||
# Iterate over each batch
|
||||
for i in range(tensor.size(0)):
|
||||
seq_lens = cu_seqlens[i]
|
||||
start_idx = 0
|
||||
|
||||
# Iterate over the cumulative sequence lengths
|
||||
for j, end_idx in enumerate(seq_lens[1:]):
|
||||
if end_idx == start_idx:
|
||||
break
|
||||
# Extract and pad the current sequence
|
||||
current_seq = tensor[i, start_idx:end_idx]
|
||||
keep = True
|
||||
if keep_fn:
|
||||
keep = keep_fn(current_seq, index=j, nonzero_total=counts[i])
|
||||
if not keep:
|
||||
continue
|
||||
padding_size = max_seqlen - current_seq.size(0)
|
||||
padded_seq = F.pad(current_seq, (0, 0) * (current_seq.dim() - 2) + (0, padding_size))
|
||||
|
||||
# Append the padded sequence to the list
|
||||
split_tensors.append(padded_seq)
|
||||
|
||||
# Update start index for the next sequence
|
||||
start_idx = end_idx
|
||||
|
||||
# Stack the padded tensors
|
||||
return torch.stack(split_tensors, dim=0)
|
||||
|
||||
|
||||
def count_nonzero_sequences(cu_seqlens: torch.Tensor) -> torch.LongTensor:
|
||||
diffs = torch.diff(cu_seqlens, dim=1, prepend=torch.zeros(cu_seqlens.shape[0], 1, dtype=cu_seqlens.dtype))
|
||||
valid_lengths = diffs != 0
|
||||
counts = valid_lengths.sum(dim=1).long()
|
||||
|
||||
return counts
|
||||
|
||||
|
||||
# Example usage
|
||||
# Example tensor with dimensions [batch_size, seq_len, other_dimensions...]
|
||||
# example_tensor = torch.randn(batch_size, seq_len, other_dimensions...)
|
||||
# cu_seqlens, max_seqlen = get_cu_seqlens_from_pos_ids(batch["position_ids"])
|
||||
# split_padded_tensor = split_and_pad_packed(example_tensor, cu_seqlens, max_seqlen)
|
||||
@@ -237,17 +237,11 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
main_process_only=True,
|
||||
)
|
||||
else:
|
||||
if cfg.flash_attention:
|
||||
batch_size = 1
|
||||
batch_max_len = cfg.micro_batch_size * cfg.sequence_len
|
||||
else:
|
||||
batch_size = cfg.micro_batch_size
|
||||
batch_max_len = cfg.sequence_len
|
||||
sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=batch_size,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=batch_max_len,
|
||||
batch_max_len=cfg.micro_batch_size * cfg.sequence_len,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
)
|
||||
|
||||
@@ -255,7 +249,7 @@ def calculate_total_num_steps(cfg, train_dataset, update=True):
|
||||
train_dataset.remove_columns(["length"]),
|
||||
batch_sampler=sampler,
|
||||
)
|
||||
data_loader_len = len(data_loader) // batch_size
|
||||
data_loader_len = len(data_loader)
|
||||
actual_eff = sampler.efficiency()
|
||||
LOG.debug(f"data_loader_len: {data_loader_len}", main_process_only=True)
|
||||
# FIXME: is there a bug here somewhere? the total num steps depends
|
||||
|
||||
@@ -1,114 +0,0 @@
|
||||
"""
|
||||
E2E tests for multipack fft llama using 4d attention masks
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import require_torch_2_1_1, with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class Test4dMultipackLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using 4d attention with multipack
|
||||
"""
|
||||
|
||||
@require_torch_2_1_1
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": False,
|
||||
"sdp_attention": True,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"sequence_len": 1024,
|
||||
"val_set_size": 0.1,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
|
||||
@with_temp_dir
|
||||
def test_torch_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": False,
|
||||
"sdp_attention": False,
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_linear": True,
|
||||
"val_set_size": 0.1,
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 2,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"fp16": True,
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "adapter_model.bin").exists()
|
||||
@@ -33,7 +33,6 @@ class TestFusedLlama(unittest.TestCase):
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"flash_attention": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"flash_attn_fuse_qkv": True,
|
||||
"flash_attn_fuse_mlp": True,
|
||||
"sample_packing": True,
|
||||
|
||||
@@ -7,6 +7,8 @@ import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
@@ -61,7 +63,6 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
@@ -102,9 +103,12 @@ class TestMistral(unittest.TestCase):
|
||||
"max_steps": 20,
|
||||
"save_steps": 10,
|
||||
"eval_steps": 10,
|
||||
"bf16": "auto",
|
||||
}
|
||||
)
|
||||
if is_torch_bf16_gpu_available():
|
||||
cfg.bf16 = True
|
||||
else:
|
||||
cfg.fp16 = True
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
"""
|
||||
E2E tests for relora llama
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
|
||||
from axolotl.cli import load_datasets
|
||||
from axolotl.common.cli import TrainerCliArgs
|
||||
from axolotl.train import train
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import with_temp_dir
|
||||
|
||||
LOG = logging.getLogger("axolotl.tests.e2e")
|
||||
os.environ["WANDB_DISABLED"] = "true"
|
||||
|
||||
|
||||
class TestReLoraLlama(unittest.TestCase):
|
||||
"""
|
||||
Test case for Llama models using LoRA
|
||||
"""
|
||||
|
||||
@with_temp_dir
|
||||
def test_relora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
"tokenizer_type": "LlamaTokenizer",
|
||||
"sequence_len": 1024,
|
||||
"load_in_8bit": True,
|
||||
"adapter": "lora",
|
||||
"lora_r": 32,
|
||||
"lora_alpha": 16,
|
||||
"lora_dropout": 0.05,
|
||||
"lora_target_modules": ["q_proj", "v_proj"],
|
||||
"relora_steps": 25,
|
||||
"relora_warmup_steps": 5,
|
||||
"relora_anneal_steps": 5,
|
||||
"relora_cpu_offload": True,
|
||||
"val_set_size": 0.0,
|
||||
"special_tokens": {},
|
||||
"datasets": [
|
||||
{
|
||||
"path": "mhenrichsen/alpaca_2k_test",
|
||||
"type": "alpaca",
|
||||
},
|
||||
],
|
||||
"warmup_steps": 15,
|
||||
"num_epochs": 2,
|
||||
"micro_batch_size": 4,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"output_dir": temp_dir,
|
||||
"learning_rate": 0.00001,
|
||||
"optimizer": "adamw_torch",
|
||||
"lr_scheduler": "cosine",
|
||||
}
|
||||
)
|
||||
normalize_config(cfg)
|
||||
cli_args = TrainerCliArgs()
|
||||
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
|
||||
|
||||
train(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
|
||||
assert (Path(temp_dir) / "model.safetensors").exists()
|
||||
@@ -4,9 +4,7 @@ helper utils for tests
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from functools import wraps
|
||||
from importlib.metadata import version
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@@ -33,15 +31,3 @@ def most_recent_subdir(path):
|
||||
subdir = max(subdirectories, key=os.path.getctime)
|
||||
|
||||
return subdir
|
||||
|
||||
|
||||
def require_torch_2_1_1(test_case):
|
||||
"""
|
||||
Decorator marking a test that requires torch >= 2.1.1
|
||||
"""
|
||||
|
||||
def is_min_2_1_1():
|
||||
torch_version = version("torch")
|
||||
return torch_version >= "2.1.1"
|
||||
|
||||
return unittest.skipUnless(is_min_2_1_1(), "test torch 2.1.1")(test_case)
|
||||
|
||||
@@ -30,20 +30,6 @@ class TestMonkeyPatchUtils(unittest.TestCase):
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_cu_seqlens_from_pos_ids_2d(self):
|
||||
position_ids = torch.tensor(
|
||||
[
|
||||
[0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 3, 4, 0, 1, 0, 0],
|
||||
[0, 1, 2, 3, 4, 0, 1, 2, 0, 1, 2, 3, 4, 5, 6, 0],
|
||||
]
|
||||
)
|
||||
target_res = torch.tensor(
|
||||
[[0, 4, 7, 12, 14, 16], [0, 5, 8, 15, 16, 16]], dtype=torch.int32
|
||||
)
|
||||
self.assertTrue(
|
||||
torch.allclose(get_cu_seqlens_from_pos_ids(position_ids)[0], target_res)
|
||||
)
|
||||
|
||||
def test_get_max_seqlen_in_batch(self):
|
||||
attn_mask = torch.tensor([[1, 1, 1, 1, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 0, 0]])
|
||||
target_res = torch.tensor([4, 3, 5, 2], dtype=torch.int32)
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
import pytest
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.datasets import TokenizedPromptDataset
|
||||
from axolotl.prompt_strategies.completion import load
|
||||
from axolotl.utils.collators import V2BatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
||||
|
||||
|
||||
@pytest.fixture(name="tokenizer")
|
||||
def fixture_tokenizer():
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.pad_token = "</s>"
|
||||
return tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="max_seq_length")
|
||||
def fixture_max_seq_length():
|
||||
return 4096
|
||||
|
||||
|
||||
class TestBatchedSamplerPacking:
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size, num_workers",
|
||||
[
|
||||
(1, 0),
|
||||
(2, 0),
|
||||
(1, 2),
|
||||
(2, 2),
|
||||
],
|
||||
)
|
||||
def test_packing(self, batch_size, num_workers, tokenizer, max_seq_length):
|
||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
dataset = load_dataset(
|
||||
"Trelis/tiny-shakespeare",
|
||||
split="train",
|
||||
)
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"train_on_inputs": True,
|
||||
"sequence_len": max_seq_length,
|
||||
}
|
||||
)
|
||||
ds_cfg = DictDefault(
|
||||
{
|
||||
"field": "Text",
|
||||
}
|
||||
)
|
||||
completion_strategy = load(tokenizer, cfg, ds_cfg)
|
||||
dataset_wrapper = TokenizedPromptDataset(
|
||||
completion_strategy,
|
||||
dataset,
|
||||
)
|
||||
train_dataset = concatenate_datasets([dataset_wrapper])
|
||||
batch_sampler = MultipackBatchSampler(
|
||||
sampler=RandomSampler(train_dataset),
|
||||
batch_size=batch_size,
|
||||
drop_last=True,
|
||||
batch_max_len=max_seq_length,
|
||||
lengths=get_dataset_lengths(train_dataset),
|
||||
)
|
||||
|
||||
loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
|
||||
tokenizer=tokenizer,
|
||||
padding=True,
|
||||
pad_to_multiple_of=max_seq_length,
|
||||
return_tensors="pt",
|
||||
),
|
||||
num_workers=num_workers,
|
||||
)
|
||||
inputs = next(iter(loader))
|
||||
|
||||
assert inputs["input_ids"].shape == (batch_size, max_seq_length)
|
||||
assert inputs["labels"].shape == (batch_size, max_seq_length)
|
||||
assert inputs["attention_mask"].shape == (batch_size, max_seq_length)
|
||||
|
||||
assert inputs["input_ids"].tolist()[0][0] == 2
|
||||
assert inputs["labels"].tolist()[0][0] == -100
|
||||
assert inputs["attention_mask"].tolist()[0][0] == 0
|
||||
assert inputs["attention_mask"].tolist()[0][-1] > 1
|
||||
|
||||
if batch_size >= 2:
|
||||
assert inputs["input_ids"].tolist()[1][0] == 2
|
||||
assert inputs["labels"].tolist()[1][0] == -100
|
||||
assert inputs["attention_mask"].tolist()[1][0] == 0
|
||||
assert inputs["attention_mask"].tolist()[1][-1] > 1
|
||||
@@ -1,17 +1,17 @@
|
||||
"""Module for testing streaming dataset sequence packing"""
|
||||
import functools
|
||||
import unittest
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
from axolotl.utils.data import get_dataset_wrapper, wrap_pretraining_dataset
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.collators import PretrainingBatchSamplerDataCollatorForSeq2Seq
|
||||
from axolotl.utils.data import encode_packed_pretraining
|
||||
|
||||
|
||||
class TestPretrainingPacking(unittest.TestCase):
|
||||
class TestPacking(unittest.TestCase):
|
||||
"""
|
||||
Test class for packing streaming dataset sequences
|
||||
"""
|
||||
@@ -20,6 +20,8 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.pad_token = "</s>"
|
||||
self.max_seq_length = 2048
|
||||
self.batch_size = 2
|
||||
|
||||
def test_packing_stream_dataset(self):
|
||||
# pylint: disable=duplicate-code
|
||||
@@ -29,43 +31,30 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
streaming=True,
|
||||
)["train"]
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"pretraining_dataset": [
|
||||
{
|
||||
"path": "c4",
|
||||
"name": "en",
|
||||
"type": "pretrain",
|
||||
}
|
||||
],
|
||||
"sample_packing": True,
|
||||
"pad_to_sequence_len": True,
|
||||
"sequence_len": 2048,
|
||||
"micro_batch_size": 2,
|
||||
}
|
||||
collate_fn = PretrainingBatchSamplerDataCollatorForSeq2Seq(
|
||||
self.tokenizer,
|
||||
return_tensors="pt",
|
||||
padding=True,
|
||||
pad_to_multiple_of=self.max_seq_length,
|
||||
)
|
||||
|
||||
ds_wrapper_partial = functools.partial(
|
||||
get_dataset_wrapper,
|
||||
cfg.pretraining_dataset[0],
|
||||
encode = partial(
|
||||
encode_packed_pretraining,
|
||||
self.tokenizer,
|
||||
cfg,
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
collate_fn,
|
||||
max_seq_length=self.max_seq_length,
|
||||
batch_size=self.batch_size,
|
||||
)
|
||||
|
||||
original_bsz = cfg.micro_batch_size
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
dataset,
|
||||
self.tokenizer,
|
||||
cfg,
|
||||
ds_wrapper_partial,
|
||||
max_tokens=cfg.sequence_len,
|
||||
batch_size=cfg.micro_batch_size,
|
||||
seed=cfg.seed or 42,
|
||||
dataset = dataset.map(
|
||||
encode,
|
||||
batched=True,
|
||||
input_columns="text",
|
||||
remove_columns=dataset.features.keys(),
|
||||
)
|
||||
|
||||
trainer_loader = DataLoader(
|
||||
train_dataset,
|
||||
dataset,
|
||||
batch_size=1,
|
||||
collate_fn=None,
|
||||
drop_last=True,
|
||||
@@ -75,16 +64,16 @@ class TestPretrainingPacking(unittest.TestCase):
|
||||
if idx > 10:
|
||||
break
|
||||
assert data["input_ids"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["position_ids"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["labels"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
assert data["attention_mask"].shape == torch.Size(
|
||||
[1, original_bsz * cfg.sequence_len]
|
||||
[1, self.batch_size * self.max_seq_length]
|
||||
)
|
||||
idx += 1
|
||||
|
||||
|
||||
@@ -67,21 +67,6 @@ class TestTokenizers(unittest.TestCase):
|
||||
)
|
||||
load_tokenizer(cfg)
|
||||
|
||||
def test_add_additional_special_tokens(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"tokenizer_config": "huggyllama/llama-7b",
|
||||
"special_tokens": {"additional_special_tokens": ["<|im_start|>"]},
|
||||
}
|
||||
)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(tokenizer("<|im_start|>user")["input_ids"], [1, 32000, 1404])
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
|
||||
# ensure reloading the tokenizer again from cfg results in same vocab length
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
self.assertEqual(len(tokenizer), 32001)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user