Compare commits
17 Commits
datasets-r
...
v0.3.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
772cd870d4 | ||
|
|
6c5fbe6223 | ||
|
|
bcbc9597e9 | ||
|
|
6d57f2f0f0 | ||
|
|
20ed4c1f9e | ||
|
|
c5dedb17ad | ||
|
|
b56503d423 | ||
|
|
a94f9cb99e | ||
|
|
c1921c9acb | ||
|
|
0b4cf5bc8c | ||
|
|
78ee2cdab2 | ||
|
|
34c0a86a11 | ||
|
|
5e2d8a42d9 | ||
|
|
e30f1e3cf7 | ||
|
|
343714972b | ||
|
|
245c5c41e2 | ||
|
|
a546ca2813 |
45
.github/workflows/pypi.yml
vendored
Normal file
45
.github/workflows/pypi.yml
vendored
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
name: publish pypi
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- '*'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
pypi-publish:
|
||||||
|
name: Upload release to PyPI
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
environment:
|
||||||
|
name: pypi
|
||||||
|
url: https://pypi.org/p/axolotl
|
||||||
|
permissions:
|
||||||
|
id-token: write # IMPORTANT: this permission is mandatory for trusted publishing
|
||||||
|
steps:
|
||||||
|
- name: Check out repository code
|
||||||
|
uses: actions/checkout@v3
|
||||||
|
|
||||||
|
- name: Setup Python
|
||||||
|
uses: actions/setup-python@v4
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip3 install wheel
|
||||||
|
pip3 install -e .
|
||||||
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
|
- name: Extract tag name
|
||||||
|
id: tag
|
||||||
|
run: echo ::set-output name=TAG_NAME::$(echo $GITHUB_REF | cut -d / -f 3)
|
||||||
|
|
||||||
|
- name: Update version in setup.py
|
||||||
|
run: >-
|
||||||
|
sed -i -E 's/version="([0-9.]+)",/version="${{ steps.tag.outputs.TAG_NAME }}",/g' setup.py
|
||||||
|
|
||||||
|
- name: Build a binary wheel
|
||||||
|
run: >-
|
||||||
|
python setup.py sdist bdist_wheel
|
||||||
|
|
||||||
|
- name: Publish package distributions to PyPI
|
||||||
|
uses: pypa/gh-action-pypi-publish@release/v1
|
||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -24,8 +24,8 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install -e .
|
pip3 install -e .
|
||||||
pip install -r requirements-tests.txt
|
pip3 install -r requirements-tests.txt
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
48
README.md
48
README.md
@@ -90,8 +90,7 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
|||||||
```bash
|
```bash
|
||||||
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
docker run --gpus '"all"' --rm -it winglian/axolotl:main-py3.10-cu118-2.0.1
|
||||||
```
|
```
|
||||||
- `winglian/axolotl-runpod:main-py3.10-cu118-2.0.1`: for runpod
|
- `winglian/axolotl-runpod:main-latest`: for runpod or use this [direct link](https://runpod.io/gsc?template=v2ickqhz9s&ref=6i7fkpdz)
|
||||||
- `winglian/axolotl-runpod:main-py3.9-cu118-2.0.1-gptq`: for gptq
|
|
||||||
|
|
||||||
Or run on the current files for development:
|
Or run on the current files for development:
|
||||||
|
|
||||||
@@ -104,19 +103,9 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
|||||||
|
|
||||||
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
2. Install pytorch stable https://pytorch.org/get-started/locally/
|
||||||
|
|
||||||
3. Install python dependencies with ONE of the following:
|
3. Install axolotl along with python dependencies
|
||||||
- Recommended, supports QLoRA, NO gptq/int4 support
|
|
||||||
```bash
|
```bash
|
||||||
pip3 install -e .
|
pip3 install -e .[flash-attn]
|
||||||
pip3 install -U git+https://github.com/huggingface/peft.git
|
|
||||||
```
|
|
||||||
- gptq/int4 support, NO QLoRA
|
|
||||||
```bash
|
|
||||||
pip3 install -e .[gptq]
|
|
||||||
```
|
|
||||||
- same as above but not recommended
|
|
||||||
```bash
|
|
||||||
pip3 install -e .[gptq_triton]
|
|
||||||
```
|
```
|
||||||
|
|
||||||
- LambdaLabs
|
- LambdaLabs
|
||||||
@@ -151,10 +140,9 @@ accelerate launch scripts/finetune.py examples/openllama-3b/lora.yml \
|
|||||||
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
git clone https://github.com/OpenAccess-AI-Collective/axolotl
|
||||||
cd axolotl
|
cd axolotl
|
||||||
|
|
||||||
pip3 install -e . # change depend on needs
|
pip3 install -e .
|
||||||
pip3 install protobuf==3.20.3
|
pip3 install protobuf==3.20.3
|
||||||
pip3 install -U --ignore-installed requests Pillow psutil scipy
|
pip3 install -U --ignore-installed requests Pillow psutil scipy
|
||||||
pip3 install git+https://github.com/huggingface/peft.git # not for gptq
|
|
||||||
```
|
```
|
||||||
|
|
||||||
5. Set path
|
5. Set path
|
||||||
@@ -572,6 +560,30 @@ log_sweep_min_lr:
|
|||||||
log_sweep_max_lr:
|
log_sweep_max_lr:
|
||||||
|
|
||||||
# specify optimizer
|
# specify optimizer
|
||||||
|
# Valid values are driven by the Transformers OptimizerNames class, see:
|
||||||
|
# https://github.com/huggingface/transformers/blob/95b374952dc27d8511541d6f5a4e22c9ec11fb24/src/transformers/training_args.py#L134
|
||||||
|
#
|
||||||
|
# Note that not all optimizers may be available in your environment, ex: 'adamw_anyprecision' is part of
|
||||||
|
# torchdistx, 'adamw_bnb_8bit' is part of bnb.optim.Adam8bit, etc. When in doubt, it is recommended to start with the optimizer used
|
||||||
|
# in the examples/ for your model and fine-tuning use case.
|
||||||
|
#
|
||||||
|
# Valid values for 'optimizer' include:
|
||||||
|
# - adamw_hf
|
||||||
|
# - adamw_torch
|
||||||
|
# - adamw_torch_fused
|
||||||
|
# - adamw_torch_xla
|
||||||
|
# - adamw_apex_fused
|
||||||
|
# - adafactor
|
||||||
|
# - adamw_anyprecision
|
||||||
|
# - sgd
|
||||||
|
# - adagrad
|
||||||
|
# - adamw_bnb_8bit
|
||||||
|
# - lion_8bit
|
||||||
|
# - lion_32bit
|
||||||
|
# - paged_adamw_32bit
|
||||||
|
# - paged_adamw_8bit
|
||||||
|
# - paged_lion_32bit
|
||||||
|
# - paged_lion_8bit
|
||||||
optimizer:
|
optimizer:
|
||||||
# specify weight decay
|
# specify weight decay
|
||||||
weight_decay:
|
weight_decay:
|
||||||
@@ -752,6 +764,10 @@ Try to turn off xformers.
|
|||||||
|
|
||||||
It's safe to ignore it.
|
It's safe to ignore it.
|
||||||
|
|
||||||
|
> NCCL Timeouts during training
|
||||||
|
|
||||||
|
See the [NCCL](docs/nccl.md) guide.
|
||||||
|
|
||||||
## Need help? 🙋♂️
|
## Need help? 🙋♂️
|
||||||
|
|
||||||
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
Join our [Discord server](https://discord.gg/HhrNrHJPRb) where we can help you
|
||||||
|
|||||||
@@ -9,6 +9,11 @@ services:
|
|||||||
- ~/.cache/huggingface/:/root/.cache/huggingface/
|
- ~/.cache/huggingface/:/root/.cache/huggingface/
|
||||||
# set environment variables
|
# set environment variables
|
||||||
environment:
|
environment:
|
||||||
|
# Set environment variables
|
||||||
|
- GIT_AUTHOR_NAME=${GIT_AUTHOR_NAME}
|
||||||
|
- GIT_AUTHOR_EMAIL=${GIT_AUTHOR_EMAIL}
|
||||||
|
- GIT_COMMITTER_NAME=${GIT_COMMITTER_NAME}
|
||||||
|
- GIT_COMMITTER_EMAIL=${GIT_COMMITTER_EMAIL}
|
||||||
- WANDB_API_KEY=${WANDB_API_KEY}
|
- WANDB_API_KEY=${WANDB_API_KEY}
|
||||||
deploy:
|
deploy:
|
||||||
resources:
|
resources:
|
||||||
|
|||||||
@@ -15,9 +15,9 @@ RUN git clone --depth=1 https://github.com/OpenAccess-AI-Collective/axolotl.git
|
|||||||
# If AXOLOTL_EXTRAS is set, append it in brackets
|
# If AXOLOTL_EXTRAS is set, append it in brackets
|
||||||
RUN cd axolotl && \
|
RUN cd axolotl && \
|
||||||
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
if [ "$AXOLOTL_EXTRAS" != "" ] ; then \
|
||||||
pip install -e .[flash-attn,gptq,$AXOLOTL_EXTRAS]; \
|
pip install -e .[flash-attn,$AXOLOTL_EXTRAS]; \
|
||||||
else \
|
else \
|
||||||
pip install -e .[flash-attn,gptq]; \
|
pip install -e .[flash-attn]; \
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# fix so that git fetch/pull from remote works
|
# fix so that git fetch/pull from remote works
|
||||||
|
|||||||
46
docs/nccl.md
Normal file
46
docs/nccl.md
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
# NCCL
|
||||||
|
|
||||||
|
NVIDIA NCCL is a library to facilitate and optimize multi-GPU communication operations, such as broadcast, all-gather, reduce, all-reduce, etc. Broadly, NCCL configuration is highly environment-specific and is configured via several [environment variables](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html). A common NCCL-related problem occurs when a long-running operation times out causing the training process to abort:
|
||||||
|
|
||||||
|
```text
|
||||||
|
Watchdog caught collective operation timeout: WorkNCCL(SeqNum=42, OpType=ALLGATHER, Timeout(ms)=1800000) ran for 1806948 milliseconds before timing out.
|
||||||
|
```
|
||||||
|
|
||||||
|
Often, this timeout will happen after 30 minutes (the default setting) and is accompanied by below-average power consumption with near 100% GPU utilization before the error is raised. Nvidia recommends [disabling PCI access control services (ACS)](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/troubleshooting.html#pci-access-control-services-acs) as a possible solution if this is available to you.
|
||||||
|
|
||||||
|
Forcing cross-GPU communication via [NVLink](https://en.wikipedia.org/wiki/NVLink) may help without increasing timeouts. To verify that your configuration is leveraging NVLink run the following command:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
nvidia-smi nvlink --status
|
||||||
|
```
|
||||||
|
|
||||||
|
To force NCCL to use NVLink, simply set this in the environment:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export NCCL_P2P_LEVEL=NVL
|
||||||
|
```
|
||||||
|
|
||||||
|
If NVLink is not available in your environment there are other options for ``NCCL_P2P_LEVEL`` in the table below:
|
||||||
|
|
||||||
|
| NCCL_P2P_LEVEL | Description |
|
||||||
|
| -------------- | ----------- |
|
||||||
|
| PIX | P2P data transfers through no more than a single PCIe bridge. Faster data transfer rates vs to paths involving multiple bridges, but slower compared to direct GPU-to-GPU communication. |
|
||||||
|
| PXB | P2P data transfers through multiple PCIe bridges but not going through the PCIe Host Bridge; this path involves a complex routing process, potentially incurring a moderate level of latency. |
|
||||||
|
| PHB | P2P data transfers occur over the PCIe and through a PCIe Host Bridge, typically involving the CPU, which can facilitate direct memory access but might introduce additional latency compared to more direct paths (ex PIX, NVL) |
|
||||||
|
|
||||||
|
To validate that acceptable data transfer speeds exist for your training job, running [NCCL Tests](https://github.com/NVIDIA/nccl-tests/blob/master/README.md) can help pinpoint bottlenecks, for example:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
./build/all_reduce_perf -b 8 -e 128M -f 2 -g 3
|
||||||
|
```
|
||||||
|
|
||||||
|
It can be useful when debugging NCCL communication timeouts to activate additional logging in both PyTorch and NCCL:
|
||||||
|
|
||||||
|
```shell
|
||||||
|
export NCCL_DEBUG=INFO
|
||||||
|
export NCCL_DEBUG_SUBSYS=ALL
|
||||||
|
export TORCH_DISTRIBUTED_DEBUG=INFO
|
||||||
|
export TORCHELASTIC_ERROR_FILE=/PATH/TO/torcherror.log
|
||||||
|
```
|
||||||
|
|
||||||
|
Finally, if you believe your training job needs more time you can increase the timeout past 30 minutes by setting the ``ddp_timeout`` value in the Axolotl configuration. See [PyTorch init_process_group](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for documentation on this value.
|
||||||
@@ -17,6 +17,7 @@ output_dir: ./lora-out
|
|||||||
|
|
||||||
sequence_len: 100000
|
sequence_len: 100000
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 100000
|
sequence_len: 100000
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ output_dir: ./lora-out
|
|||||||
|
|
||||||
sequence_len: 100000
|
sequence_len: 100000
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 100000
|
sequence_len: 100000
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ output_dir: ./lora-out
|
|||||||
|
|
||||||
sequence_len: 100000
|
sequence_len: 100000
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 100000
|
sequence_len: 100000
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -17,6 +17,7 @@ output_dir: ./lora-out
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
adapter: lora
|
adapter: lora
|
||||||
lora_model_dir:
|
lora_model_dir:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 32
|
lora_r: 32
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ lora_model_dir:
|
|||||||
|
|
||||||
sequence_len: 4096
|
sequence_len: 4096
|
||||||
sample_packing: true
|
sample_packing: true
|
||||||
|
pad_to_sequence_len: true
|
||||||
|
|
||||||
lora_r: 8
|
lora_r: 8
|
||||||
lora_alpha: 16
|
lora_alpha: 16
|
||||||
|
|||||||
@@ -6,13 +6,13 @@ packaging
|
|||||||
peft @ git+https://github.com/huggingface/peft.git
|
peft @ git+https://github.com/huggingface/peft.git
|
||||||
transformers @ git+https://github.com/huggingface/transformers.git
|
transformers @ git+https://github.com/huggingface/transformers.git
|
||||||
bitsandbytes>=0.41.1
|
bitsandbytes>=0.41.1
|
||||||
accelerate @ git+https://github.com/huggingface/accelerate@2a289f6108e77a77a4efffb3f6316bc98538413b
|
accelerate @ git+https://github.com/huggingface/accelerate
|
||||||
addict
|
addict
|
||||||
evaluate
|
evaluate
|
||||||
fire
|
fire
|
||||||
PyYAML>=6.0
|
PyYAML>=6.0
|
||||||
datasets
|
datasets
|
||||||
flash-attn>=2.0.8
|
flash-attn>=2.2.1
|
||||||
sentencepiece
|
sentencepiece
|
||||||
wandb
|
wandb
|
||||||
einops
|
einops
|
||||||
|
|||||||
14
setup.py
14
setup.py
@@ -7,9 +7,7 @@ def parse_requirements():
|
|||||||
_install_requires = []
|
_install_requires = []
|
||||||
_dependency_links = []
|
_dependency_links = []
|
||||||
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
with open("./requirements.txt", encoding="utf-8") as requirements_file:
|
||||||
lines = [
|
lines = [r.strip() for r in requirements_file.readlines()]
|
||||||
r.strip() for r in requirements_file.readlines() if "auto-gptq" not in r
|
|
||||||
]
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line.startswith("--extra-index-url"):
|
if line.startswith("--extra-index-url"):
|
||||||
# Handle custom index URLs
|
# Handle custom index URLs
|
||||||
@@ -26,18 +24,16 @@ install_requires, dependency_links = parse_requirements()
|
|||||||
|
|
||||||
setup(
|
setup(
|
||||||
name="axolotl",
|
name="axolotl",
|
||||||
version="0.1",
|
version="0.3.0",
|
||||||
description="You know you're going to axolotl questions",
|
description="LLM Trainer",
|
||||||
|
long_description="Axolotl is a tool designed to streamline the fine-tuning of various AI models, offering support for multiple configurations and architectures.",
|
||||||
package_dir={"": "src"},
|
package_dir={"": "src"},
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
install_requires=install_requires,
|
install_requires=install_requires,
|
||||||
dependency_links=dependency_links,
|
dependency_links=dependency_links,
|
||||||
extras_require={
|
extras_require={
|
||||||
"gptq": [
|
|
||||||
"auto-gptq",
|
|
||||||
],
|
|
||||||
"flash-attn": [
|
"flash-attn": [
|
||||||
"flash-attn==2.0.8",
|
"flash-attn>=2.2.1",
|
||||||
],
|
],
|
||||||
"extras": [
|
"extras": [
|
||||||
"deepspeed",
|
"deepspeed",
|
||||||
|
|||||||
@@ -1,144 +0,0 @@
|
|||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Dict, Generator, List, Optional, Union
|
|
||||||
|
|
||||||
from datasets import Dataset as Dataset_ds
|
|
||||||
from datasets import DatasetDict, IterableDataset, load_dataset, load_from_disk
|
|
||||||
from huggingface_hub import hf_hub_download
|
|
||||||
|
|
||||||
logger = logging.getLogger("axolotl")
|
|
||||||
|
|
||||||
|
|
||||||
class DsType(Enum):
|
|
||||||
JSON = "json"
|
|
||||||
ARROW = "arrow"
|
|
||||||
PARQUET = "parquet"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class DatasetConfiguration:
|
|
||||||
path: str
|
|
||||||
type: str
|
|
||||||
name: Optional[str] = field(
|
|
||||||
default=None,
|
|
||||||
metadata={"help": "the name of the dataset configuration to load."},
|
|
||||||
)
|
|
||||||
ds_type: Optional[DsType] = None
|
|
||||||
data_files: Optional[Union[str, List[str]]] = None
|
|
||||||
shards: Optional[int] = None
|
|
||||||
test_size: Optional[float] = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_dict(d: Dict[str, Any]) -> Generator["DatasetConfiguration", None, None]:
|
|
||||||
if "name" in d and isinstance(d["name"], list):
|
|
||||||
name = d.pop("name")
|
|
||||||
for n in name:
|
|
||||||
yield DatasetConfiguration(
|
|
||||||
**d,
|
|
||||||
name=n,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset_from_local(config: DatasetConfiguration) -> Optional[Dataset_ds]:
|
|
||||||
local_path = Path(config.path)
|
|
||||||
if not local_path.exists():
|
|
||||||
return None
|
|
||||||
ds = None
|
|
||||||
if local_path.is_dir():
|
|
||||||
if config.ds_type:
|
|
||||||
# TODO dirs with arrow or parquet files could be loaded with `load_from_disk`
|
|
||||||
ds = load_from_disk(config.path)
|
|
||||||
else:
|
|
||||||
ds = load_dataset(
|
|
||||||
config.path,
|
|
||||||
name=config.name,
|
|
||||||
data_files=config.data_files,
|
|
||||||
streaming=False,
|
|
||||||
split=None,
|
|
||||||
)
|
|
||||||
elif local_path.is_file():
|
|
||||||
ds_type = "json"
|
|
||||||
if config.ds_type:
|
|
||||||
ds_type = config.ds_type.value
|
|
||||||
elif "parquet" in config.path:
|
|
||||||
ds_type = "parquet"
|
|
||||||
elif "arrow" in config.path:
|
|
||||||
ds_type = "arrow"
|
|
||||||
ds = load_dataset(
|
|
||||||
ds_type,
|
|
||||||
name=config.name,
|
|
||||||
data_files=config.path,
|
|
||||||
streaming=False,
|
|
||||||
split=None, # is this correct?
|
|
||||||
)
|
|
||||||
if not ds:
|
|
||||||
raise ValueError(
|
|
||||||
"unhandled dataset load: local path exists, but is neither a directory or a file"
|
|
||||||
)
|
|
||||||
return ds
|
|
||||||
|
|
||||||
|
|
||||||
# TODO should this be a DatasetDict?
|
|
||||||
class Dataset(Dataset_ds):
|
|
||||||
_config: DatasetConfiguration
|
|
||||||
|
|
||||||
def __init__(self, *args, config: DatasetConfiguration = None, **kwargs):
|
|
||||||
self._config = config
|
|
||||||
super().__init__(*args, **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_config(
|
|
||||||
config: DatasetConfiguration,
|
|
||||||
token: bool = False,
|
|
||||||
default_test_size: float = 0.1,
|
|
||||||
):
|
|
||||||
ds = load_dataset_from_local(config)
|
|
||||||
if not ds:
|
|
||||||
try:
|
|
||||||
ds = load_dataset(
|
|
||||||
config.path,
|
|
||||||
name=config.name,
|
|
||||||
data_files=config.data_files,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
except FileNotFoundError:
|
|
||||||
pass
|
|
||||||
if not ds:
|
|
||||||
fp = hf_hub_download(
|
|
||||||
repo_id=config.path,
|
|
||||||
repo_type="dataset",
|
|
||||||
filename=config.data_files,
|
|
||||||
token=token,
|
|
||||||
)
|
|
||||||
ds = load_dataset(
|
|
||||||
"json", name=config.name, data_files=fp, streaming=False, split=None
|
|
||||||
)
|
|
||||||
if not ds:
|
|
||||||
raise ValueError("unhandled dataset load")
|
|
||||||
test_size = config.test_size if config.test_size else default_test_size
|
|
||||||
# determine if the dataset is pre-tokenized
|
|
||||||
check_ds = ds["train"] if isinstance(ds, DatasetDict) and "train" in ds else ds
|
|
||||||
is_ds_tokenized = False
|
|
||||||
if "input_ids" in check_ds.features:
|
|
||||||
is_ds_tokenized = True
|
|
||||||
if "attention_mask" not in check_ds.features:
|
|
||||||
logger.warning("`attention_mask` missing from pre-tokenized dataset")
|
|
||||||
if "labels" not in check_ds.features:
|
|
||||||
logger.warning("`labels` missing from pre-tokenized dataset")
|
|
||||||
if test_size and (not isinstance(ds, DatasetDict) or "test" not in ds):
|
|
||||||
ds.train_test_split(test_size=test_size, shuffle=False)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class DatasetCollection:
|
|
||||||
datasets: List[Dataset] = []
|
|
||||||
|
|
||||||
def __init__(self, datasets: Union[Dataset, List[Dataset]]):
|
|
||||||
self.datasets = datasets if isinstance(datasets, list) else [datasets]
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
for ds in self.datasets:
|
|
||||||
for d in ds:
|
|
||||||
yield d
|
|
||||||
@@ -23,6 +23,7 @@ class ColorfulFormatter(Formatter):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def format(self, record):
|
def format(self, record):
|
||||||
|
record.rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
log_message = super().format(record)
|
log_message = super().format(record)
|
||||||
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
return self.COLORS.get(record.levelname, "") + log_message + Fore.RESET
|
||||||
|
|
||||||
@@ -35,7 +36,7 @@ DEFAULT_LOGGING_CONFIG: Dict[str, Any] = {
|
|||||||
},
|
},
|
||||||
"colorful": {
|
"colorful": {
|
||||||
"()": ColorfulFormatter,
|
"()": ColorfulFormatter,
|
||||||
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] %(message)s",
|
"format": "[%(asctime)s] [%(levelname)s] [%(name)s.%(funcName)s:%(lineno)d] [PID:%(process)d] [RANK:%(rank)d] %(message)s",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
"filters": {},
|
"filters": {},
|
||||||
|
|||||||
@@ -88,6 +88,11 @@ def train(
|
|||||||
if peft_config:
|
if peft_config:
|
||||||
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
LOG.info(f"Pre-saving adapter config to {cfg.output_dir}")
|
||||||
peft_config.save_pretrained(cfg.output_dir)
|
peft_config.save_pretrained(cfg.output_dir)
|
||||||
|
# additionally presave the tokenizer and model configs
|
||||||
|
if not Path(cfg.output_dir).is_dir():
|
||||||
|
os.makedirs(cfg.output_dir, exist_ok=True)
|
||||||
|
tokenizer.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
|
model.config.save_pretrained(str(Path(cfg.output_dir)))
|
||||||
|
|
||||||
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
# In case we want to stop early with ctrl+c, this is a nice to have to save the pretrained model
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
@@ -106,9 +111,6 @@ def train(
|
|||||||
if cfg.group_by_length:
|
if cfg.group_by_length:
|
||||||
LOG.info("hang tight... sorting dataset for group_by_length")
|
LOG.info("hang tight... sorting dataset for group_by_length")
|
||||||
|
|
||||||
if not Path(cfg.output_dir).is_dir():
|
|
||||||
os.makedirs(cfg.output_dir, exist_ok=True)
|
|
||||||
tokenizer.save_pretrained(cfg.output_dir)
|
|
||||||
if cfg.flash_optimum:
|
if cfg.flash_optimum:
|
||||||
with torch.backends.cuda.sdp_kernel(
|
with torch.backends.cuda.sdp_kernel(
|
||||||
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
enable_flash=True, enable_math=True, enable_mem_efficient=True
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, IntervalStrategy
|
|||||||
from axolotl.utils.bench import log_gpu_memory_usage
|
from axolotl.utils.bench import log_gpu_memory_usage
|
||||||
from axolotl.utils.distributed import (
|
from axolotl.utils.distributed import (
|
||||||
barrier,
|
barrier,
|
||||||
|
broadcast_dict,
|
||||||
gather_scalar_from_all_ranks,
|
gather_scalar_from_all_ranks,
|
||||||
get_world_size,
|
get_world_size,
|
||||||
is_distributed,
|
is_distributed,
|
||||||
@@ -271,6 +272,7 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
lambda: len(data_loader), get_world_size()
|
lambda: len(data_loader), get_world_size()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
results = {}
|
||||||
if is_distributed() and not is_main_process():
|
if is_distributed() and not is_main_process():
|
||||||
dist.gather_object(local_bench_names, dst=0)
|
dist.gather_object(local_bench_names, dst=0)
|
||||||
else:
|
else:
|
||||||
@@ -316,4 +318,8 @@ def bench_eval_callback_factory(trainer, tokenizer):
|
|||||||
)["accuracy"]
|
)["accuracy"]
|
||||||
trainer.log(results)
|
trainer.log(results)
|
||||||
|
|
||||||
|
results = broadcast_dict(results)
|
||||||
|
for key, val in results.items():
|
||||||
|
metrics[key] = val
|
||||||
|
|
||||||
return BenchEvalCallback
|
return BenchEvalCallback
|
||||||
|
|||||||
@@ -97,6 +97,11 @@ def validate_config(cfg):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.sample_packing and not cfg.pad_to_sequence_len:
|
||||||
|
LOG.warning(
|
||||||
|
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||||
|
)
|
||||||
|
|
||||||
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
if cfg.gradient_accumulation_steps and cfg.batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"please set only one of gradient_accumulation_steps or batch_size"
|
"please set only one of gradient_accumulation_steps or batch_size"
|
||||||
@@ -215,6 +220,15 @@ def validate_config(cfg):
|
|||||||
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
"sample_packing not compatible with xformers_attention. Use flash_attention"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if cfg.early_stopping_patience:
|
||||||
|
if not cfg.save_steps or not cfg.eval_steps:
|
||||||
|
raise ValueError(
|
||||||
|
"`early_stopping_patience` requires save_steps and eval_steps to be set. eval_steps should evenly divide save_steps."
|
||||||
|
)
|
||||||
|
if cfg.save_steps % cfg.eval_steps != 0:
|
||||||
|
raise ValueError(
|
||||||
|
"`early_stopping_patience` requires that eval_steps should evenly divide save_steps."
|
||||||
|
)
|
||||||
# TODO
|
# TODO
|
||||||
# MPT 7b
|
# MPT 7b
|
||||||
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
# https://github.com/facebookresearch/bitsandbytes/issues/25
|
||||||
|
|||||||
@@ -2,7 +2,6 @@
|
|||||||
import functools
|
import functools
|
||||||
import hashlib
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from hashlib import md5
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union
|
from typing import Tuple, Union
|
||||||
|
|
||||||
@@ -52,6 +51,13 @@ LOG = logging.getLogger("axolotl")
|
|||||||
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
DEFAULT_DATASET_PREPARED_PATH = "last_run_prepared"
|
||||||
|
|
||||||
|
|
||||||
|
def md5(to_hash: str, encoding: str = "utf-8") -> str:
|
||||||
|
try:
|
||||||
|
return hashlib.md5(to_hash.encode(encoding), usedforsecurity=False).hexdigest()
|
||||||
|
except TypeError:
|
||||||
|
return hashlib.md5(to_hash.encode(encoding)).hexdigest() # nosec
|
||||||
|
|
||||||
|
|
||||||
def prepare_dataset(cfg, tokenizer):
|
def prepare_dataset(cfg, tokenizer):
|
||||||
if not cfg.pretraining_dataset:
|
if not cfg.pretraining_dataset:
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
@@ -88,7 +94,7 @@ def load_tokenized_prepared_datasets(
|
|||||||
) -> DatasetDict:
|
) -> DatasetDict:
|
||||||
tokenizer_name = tokenizer.__class__.__name__
|
tokenizer_name = tokenizer.__class__.__name__
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5( # nosec
|
md5(
|
||||||
(
|
(
|
||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@"
|
||||||
@@ -97,8 +103,8 @@ def load_tokenized_prepared_datasets(
|
|||||||
)
|
)
|
||||||
+ "|"
|
+ "|"
|
||||||
+ tokenizer_name
|
+ tokenizer_name
|
||||||
).encode("utf-8")
|
)
|
||||||
).hexdigest()
|
)
|
||||||
)
|
)
|
||||||
prepared_ds_path = (
|
prepared_ds_path = (
|
||||||
Path(cfg.dataset_prepared_path) / ds_hash
|
Path(cfg.dataset_prepared_path) / ds_hash
|
||||||
@@ -374,7 +380,7 @@ def load_prepare_datasets(
|
|||||||
# see if we can go ahead and load the stacked dataset
|
# see if we can go ahead and load the stacked dataset
|
||||||
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
seed = f"@{str(cfg.seed)}" if cfg.seed else ""
|
||||||
ds_hash = str(
|
ds_hash = str(
|
||||||
md5( # nosec
|
md5(
|
||||||
(
|
(
|
||||||
str(cfg.sequence_len)
|
str(cfg.sequence_len)
|
||||||
+ "@"
|
+ "@"
|
||||||
@@ -385,8 +391,8 @@ def load_prepare_datasets(
|
|||||||
)
|
)
|
||||||
+ "|"
|
+ "|"
|
||||||
+ tokenizer_name
|
+ tokenizer_name
|
||||||
).encode("utf-8")
|
)
|
||||||
).hexdigest()
|
)
|
||||||
)
|
)
|
||||||
prepared_ds_path = (
|
prepared_ds_path = (
|
||||||
Path(cfg.dataset_prepared_path) / ds_hash
|
Path(cfg.dataset_prepared_path) / ds_hash
|
||||||
@@ -500,12 +506,8 @@ def load_prepare_datasets(
|
|||||||
+ "|"
|
+ "|"
|
||||||
+ str(cfg.seed or 42)
|
+ str(cfg.seed or 42)
|
||||||
)
|
)
|
||||||
train_fingerprint = hashlib.md5(
|
train_fingerprint = md5(to_hash_train)
|
||||||
to_hash_train.encode(), usedforsecurity=False
|
test_fingerprint = md5(to_hash_test)
|
||||||
).hexdigest()
|
|
||||||
test_fingerprint = hashlib.md5(
|
|
||||||
to_hash_test.encode(), usedforsecurity=False
|
|
||||||
).hexdigest()
|
|
||||||
|
|
||||||
with zero_first(is_main_process()):
|
with zero_first(is_main_process()):
|
||||||
dataset = dataset.train_test_split(
|
dataset = dataset.train_test_split(
|
||||||
|
|||||||
@@ -2,6 +2,7 @@
|
|||||||
utility helpers for distributed checks
|
utility helpers for distributed checks
|
||||||
"""
|
"""
|
||||||
import os
|
import os
|
||||||
|
import pickle # nosec
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -93,3 +94,30 @@ def gather_scalar_from_all_ranks(fn, world_size=1): # pylint: disable=invalid-n
|
|||||||
gathered_values.append(float(tensor.item()))
|
gathered_values.append(float(tensor.item()))
|
||||||
return gathered_values
|
return gathered_values
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def broadcast_dict(vals: dict):
|
||||||
|
if not is_distributed():
|
||||||
|
return vals
|
||||||
|
|
||||||
|
if is_main_process():
|
||||||
|
data_byte = pickle.dumps(vals)
|
||||||
|
data_tensor = torch.ByteTensor(list(data_byte)).to("cuda")
|
||||||
|
data_size = torch.IntTensor([len(data_byte)]).to("cuda")
|
||||||
|
else:
|
||||||
|
data_tensor = torch.empty([1024], dtype=torch.uint8, device="cuda")
|
||||||
|
data_size = torch.IntTensor([0]).to("cuda")
|
||||||
|
|
||||||
|
dist.broadcast(data_size, 0)
|
||||||
|
if not is_main_process():
|
||||||
|
# resize
|
||||||
|
data_tensor = data_tensor.new_empty([data_size.item()])
|
||||||
|
|
||||||
|
dist.broadcast(data_tensor, 0)
|
||||||
|
|
||||||
|
if not is_main_process():
|
||||||
|
data_list = data_tensor.cpu().tolist()
|
||||||
|
data_byte = bytes(data_list[: data_size.item()])
|
||||||
|
vals = pickle.loads(data_byte) # nosec
|
||||||
|
|
||||||
|
return vals
|
||||||
|
|||||||
@@ -160,7 +160,7 @@ def load_model(
|
|||||||
model_kwargs["revision"] = cfg.model_revision
|
model_kwargs["revision"] = cfg.model_revision
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
if hasattr(model_config, "quantization_config"):
|
if not hasattr(model_config, "quantization_config"):
|
||||||
LOG.warning("model config does not contain quantization_config information")
|
LOG.warning("model config does not contain quantization_config information")
|
||||||
else:
|
else:
|
||||||
model_kwargs["quantization_config"] = GPTQConfig(
|
model_kwargs["quantization_config"] = GPTQConfig(
|
||||||
|
|||||||
@@ -33,6 +33,7 @@ from axolotl.utils.callbacks import (
|
|||||||
)
|
)
|
||||||
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
from axolotl.utils.collators import DataCollatorForSeq2Seq
|
||||||
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
from axolotl.utils.dataloader import MultipackDistributedDataloader
|
||||||
|
from axolotl.utils.distributed import is_main_process, zero_first
|
||||||
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
from axolotl.utils.schedulers import get_cosine_schedule_with_quadratic_warmup
|
||||||
|
|
||||||
LOG = logging.getLogger("axolotl")
|
LOG = logging.getLogger("axolotl")
|
||||||
@@ -375,14 +376,17 @@ def disable_datasets_caching():
|
|||||||
|
|
||||||
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
def process_datasets_for_packing(cfg, train_dataset, eval_dataset):
|
||||||
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
drop_long = partial(drop_long_seq, sequence_len=cfg.sequence_len)
|
||||||
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
|
with zero_first(is_main_process()):
|
||||||
if eval_dataset:
|
train_dataset = train_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||||
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
|
||||||
|
|
||||||
if cfg.sample_packing:
|
|
||||||
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
|
||||||
if eval_dataset:
|
if eval_dataset:
|
||||||
eval_dataset = eval_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
eval_dataset = eval_dataset.filter(drop_long, num_proc=os.cpu_count())
|
||||||
|
|
||||||
|
if cfg.sample_packing:
|
||||||
|
train_dataset = train_dataset.map(add_position_ids, num_proc=os.cpu_count())
|
||||||
|
if eval_dataset:
|
||||||
|
eval_dataset = eval_dataset.map(
|
||||||
|
add_position_ids, num_proc=os.cpu_count()
|
||||||
|
)
|
||||||
return train_dataset, eval_dataset
|
return train_dataset, eval_dataset
|
||||||
|
|
||||||
|
|
||||||
@@ -572,6 +576,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
training_arguments_kwargs["do_bench_eval"] = cfg.do_bench_eval
|
||||||
if cfg.bench_dataset:
|
if cfg.bench_dataset:
|
||||||
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
training_arguments_kwargs["bench_dataset"] = cfg.bench_dataset
|
||||||
|
if cfg.metric_for_best_model:
|
||||||
|
training_arguments_kwargs["metric_for_best_model"] = cfg.metric_for_best_model
|
||||||
|
if cfg.greater_is_better:
|
||||||
|
training_arguments_kwargs["greater_is_better"] = cfg.greater_is_better
|
||||||
|
|
||||||
# DDP Config
|
# DDP Config
|
||||||
if cfg.ddp_timeout:
|
if cfg.ddp_timeout:
|
||||||
@@ -597,11 +605,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
output_dir=cfg.output_dir,
|
output_dir=cfg.output_dir,
|
||||||
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
save_total_limit=cfg.save_total_limit if cfg.save_total_limit else 4,
|
||||||
load_best_model_at_end=(
|
load_best_model_at_end=(
|
||||||
cfg.load_best_model_at_end is not False
|
(cfg.load_best_model_at_end is not False or cfg.early_stopping_patience)
|
||||||
and cfg.val_set_size > 0
|
and cfg.val_set_size > 0
|
||||||
and cfg.save_steps
|
and cfg.save_steps
|
||||||
and cfg.save_steps % cfg.eval_steps == 0
|
and cfg.save_steps % cfg.eval_steps == 0
|
||||||
and cfg.load_in_8bit is not True
|
|
||||||
)
|
)
|
||||||
or False,
|
or False,
|
||||||
ddp_find_unused_parameters=False if cfg.ddp else None,
|
ddp_find_unused_parameters=False if cfg.ddp else None,
|
||||||
@@ -633,13 +640,6 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
if cfg.relora_steps:
|
if cfg.relora_steps:
|
||||||
callbacks.append(ReLoRACallback(cfg))
|
callbacks.append(ReLoRACallback(cfg))
|
||||||
|
|
||||||
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
|
||||||
if cfg.early_stopping_patience:
|
|
||||||
early_stop_cb = EarlyStoppingCallback(
|
|
||||||
cfg.early_stopping_patience,
|
|
||||||
)
|
|
||||||
callbacks.append(early_stop_cb)
|
|
||||||
|
|
||||||
if cfg.local_rank == 0 and cfg.adapter in [
|
if cfg.local_rank == 0 and cfg.adapter in [
|
||||||
"lora",
|
"lora",
|
||||||
"qlora",
|
"qlora",
|
||||||
@@ -706,4 +706,11 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer, total_num_
|
|||||||
if cfg.do_bench_eval:
|
if cfg.do_bench_eval:
|
||||||
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
trainer.add_callback(bench_eval_callback_factory(trainer, tokenizer))
|
||||||
|
|
||||||
|
# TODO on_save callback to sync checkpoints to GCP/AWS in background
|
||||||
|
if cfg.early_stopping_patience:
|
||||||
|
early_stop_cb = EarlyStoppingCallback(
|
||||||
|
cfg.early_stopping_patience,
|
||||||
|
)
|
||||||
|
trainer.add_callback(early_stop_cb)
|
||||||
|
|
||||||
return trainer
|
return trainer
|
||||||
|
|||||||
64
tests/test_data.py
Normal file
64
tests/test_data.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
"""
|
||||||
|
test module for the axolotl.utis.data module
|
||||||
|
"""
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from transformers import LlamaTokenizer
|
||||||
|
|
||||||
|
from axolotl.utils.data import encode_pretraining, md5
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncodePretraining(unittest.TestCase):
|
||||||
|
"""
|
||||||
|
test class for encode pretraining and md5 helper
|
||||||
|
"""
|
||||||
|
|
||||||
|
def setUp(self):
|
||||||
|
self.tokenizer = LlamaTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||||
|
self.tokenizer.add_special_tokens(
|
||||||
|
{
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
"pad_token": "<pad>",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
self.max_tokens = 15 # set a small number for easy inspection
|
||||||
|
|
||||||
|
def test_encode_pretraining(self):
|
||||||
|
examples = {
|
||||||
|
"text": [
|
||||||
|
"Hello, world!",
|
||||||
|
"Nice to meet you.",
|
||||||
|
"lorem ipsum dolor sit amet.",
|
||||||
|
"Nice to meet you again!.",
|
||||||
|
"hello, hello",
|
||||||
|
]
|
||||||
|
}
|
||||||
|
result = encode_pretraining(self.tokenizer, self.max_tokens, examples)
|
||||||
|
|
||||||
|
self.assertEqual(len(result["input_ids"]), 3)
|
||||||
|
|
||||||
|
# Assert the length of input_ids and attention_mask is correct
|
||||||
|
self.assertEqual(len(result["input_ids"][0]), self.max_tokens)
|
||||||
|
self.assertEqual(len(result["attention_mask"][0]), self.max_tokens)
|
||||||
|
|
||||||
|
# Assert EOS and PAD tokens are correctly added
|
||||||
|
# hello world! is 4 tokens
|
||||||
|
self.assertEqual(result["input_ids"][0][0], self.tokenizer.bos_token_id)
|
||||||
|
self.assertEqual(result["input_ids"][0][5], self.tokenizer.eos_token_id)
|
||||||
|
self.assertEqual(result["input_ids"][0][6], self.tokenizer.pad_token_id)
|
||||||
|
# second part, 5 tokens
|
||||||
|
self.assertEqual(result["input_ids"][0][7], self.tokenizer.bos_token_id)
|
||||||
|
self.assertEqual(result["input_ids"][0][13], self.tokenizer.eos_token_id)
|
||||||
|
self.assertEqual(result["input_ids"][0][14], self.tokenizer.pad_token_id)
|
||||||
|
|
||||||
|
def test_md5(self):
|
||||||
|
self.assertEqual(md5("hello world"), "5eb63bbbe01eeed093cb22bb8f5acdc3")
|
||||||
|
self.assertEqual(
|
||||||
|
md5("hello world", "utf-8"), "5eb63bbbe01eeed093cb22bb8f5acdc3"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
@@ -328,6 +328,20 @@ class ValidationTest(unittest.TestCase):
|
|||||||
for record in self._caplog.records
|
for record in self._caplog.records
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cfg = DictDefault(
|
||||||
|
{
|
||||||
|
"sample_packing": True,
|
||||||
|
"pad_to_sequence_len": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
with self._caplog.at_level(logging.WARNING):
|
||||||
|
validate_config(cfg)
|
||||||
|
assert any(
|
||||||
|
"`pad_to_sequence_len: true` is recommended when using sample_packing"
|
||||||
|
in record.message
|
||||||
|
for record in self._caplog.records
|
||||||
|
)
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"max_packed_sequence_len": 2048,
|
"max_packed_sequence_len": 2048,
|
||||||
|
|||||||
Reference in New Issue
Block a user