Release update 20250331 (#2460) [skip ci]
* make torch 2.6.0 the default image * fix tests against upstream main * fix attribute access * use fixture dataset * fix dataset load * correct the fixtures + tests * more fixtures * add accidentally removed shakespeare fixture * fix conversion from unittest to pytest class * nightly main ci caches * build 12.6.3 cuda base image * override for fix from huggingface/transformers#37162 * address PR feedback
This commit is contained in:
6
.github/workflows/base.yml
vendored
6
.github/workflows/base.yml
vendored
@@ -40,6 +40,12 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
|
- cuda: "126"
|
||||||
|
cuda_version: 12.6.3
|
||||||
|
cudnn_version: ""
|
||||||
|
python_version: "3.11"
|
||||||
|
pytorch: 2.6.0
|
||||||
|
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
|
||||||
- cuda: "128"
|
- cuda: "128"
|
||||||
cuda_version: 12.8.1
|
cuda_version: 12.8.1
|
||||||
cudnn_version: ""
|
cudnn_version: ""
|
||||||
|
|||||||
4
.github/workflows/main.yml
vendored
4
.github/workflows/main.yml
vendored
@@ -25,12 +25,12 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras: vllm
|
axolotl_extras: vllm
|
||||||
is_latest: true
|
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
@@ -87,12 +87,12 @@ jobs:
|
|||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.5.1
|
pytorch: 2.5.1
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
is_latest: true
|
|
||||||
- cuda: 124
|
- cuda: 124
|
||||||
cuda_version: 12.4.1
|
cuda_version: 12.4.1
|
||||||
python_version: "3.11"
|
python_version: "3.11"
|
||||||
pytorch: 2.6.0
|
pytorch: 2.6.0
|
||||||
axolotl_extras:
|
axolotl_extras:
|
||||||
|
is_latest: true
|
||||||
runs-on: axolotl-gpu-runner
|
runs-on: axolotl-gpu-runner
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- name: Checkout
|
||||||
|
|||||||
23
.github/workflows/tests-nightly.yml
vendored
23
.github/workflows/tests-nightly.yml
vendored
@@ -33,6 +33,15 @@ jobs:
|
|||||||
- name: Check out repository code
|
- name: Check out repository code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Restore HF cache
|
||||||
|
id: hf-cache-restore
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
/home/runner/.cache/huggingface/hub/datasets--*
|
||||||
|
/home/runner/.cache/huggingface/hub/models--*
|
||||||
|
key: ${{ runner.os }}-hf-hub-cache-v2
|
||||||
|
|
||||||
- name: Setup Python
|
- name: Setup Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@v5
|
||||||
with:
|
with:
|
||||||
@@ -46,7 +55,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install PyTorch
|
- name: Install PyTorch
|
||||||
run: |
|
run: |
|
||||||
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
|
pip3 install torch==${{ matrix.pytorch_version }}
|
||||||
|
|
||||||
- name: Update requirements.txt
|
- name: Update requirements.txt
|
||||||
run: |
|
run: |
|
||||||
@@ -58,8 +67,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip3 install --upgrade pip
|
pip3 show torch
|
||||||
pip3 install --upgrade packaging==23.2
|
|
||||||
pip3 install --no-build-isolation -U -e .
|
pip3 install --no-build-isolation -U -e .
|
||||||
python scripts/unsloth_install.py | sh
|
python scripts/unsloth_install.py | sh
|
||||||
python scripts/cutcrossentropy_install.py | sh
|
python scripts/cutcrossentropy_install.py | sh
|
||||||
@@ -73,10 +81,15 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Pre-Download dataset fixture
|
||||||
|
run: |
|
||||||
|
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
pytest tests/patched/
|
pytest -v tests/patched/
|
||||||
|
pytest -v tests/cli/
|
||||||
|
|
||||||
- name: cleanup pip cache
|
- name: cleanup pip cache
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
4
.github/workflows/tests.yml
vendored
4
.github/workflows/tests.yml
vendored
@@ -96,6 +96,10 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
axolotl --help
|
axolotl --help
|
||||||
|
|
||||||
|
- name: Pre-Download dataset fixture
|
||||||
|
run: |
|
||||||
|
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: |
|
run: |
|
||||||
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ liger-kernel==0.5.5
|
|||||||
packaging==23.2
|
packaging==23.2
|
||||||
|
|
||||||
peft==0.15.0
|
peft==0.15.0
|
||||||
transformers==4.50.0
|
transformers==4.50.3
|
||||||
tokenizers>=0.21.1
|
tokenizers>=0.21.1
|
||||||
accelerate==1.5.2
|
accelerate==1.5.2
|
||||||
datasets==3.5.0
|
datasets==3.5.0
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
from axolotl.core.trainers.mixins import (
|
from axolotl.core.trainers.mixins import (
|
||||||
OptimizerMixin,
|
OptimizerMixin,
|
||||||
|
RngLoaderMixin,
|
||||||
SchedulerMixin,
|
SchedulerMixin,
|
||||||
SequenceParallelMixin,
|
SequenceParallelMixin,
|
||||||
)
|
)
|
||||||
@@ -40,7 +41,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
|
class AxolotlTrainer(
|
||||||
|
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
|
||||||
|
):
|
||||||
"""Extend the base Trainer for axolotl helpers"""
|
"""Extend the base Trainer for axolotl helpers"""
|
||||||
|
|
||||||
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ from transformers import Trainer
|
|||||||
from transformers.utils import is_sagemaker_mp_enabled
|
from transformers.utils import is_sagemaker_mp_enabled
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
|
|
||||||
from axolotl.core.trainers.mixins import SchedulerMixin
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
from axolotl.core.trainers.utils import (
|
from axolotl.core.trainers.utils import (
|
||||||
sanitize_kwargs_for_ds_tagging,
|
sanitize_kwargs_for_ds_tagging,
|
||||||
sanitize_kwargs_for_tagging,
|
sanitize_kwargs_for_tagging,
|
||||||
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
|
|||||||
import smdistributed.modelparallel.torch as smp
|
import smdistributed.modelparallel.torch as smp
|
||||||
|
|
||||||
|
|
||||||
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
|
class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base DPOTrainer for axolotl helpers
|
Extend the base DPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -8,13 +8,13 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
|
|||||||
from trl import GRPOTrainer
|
from trl import GRPOTrainer
|
||||||
from trl.extras.profiling import profiling_decorator
|
from trl.extras.profiling import profiling_decorator
|
||||||
|
|
||||||
from axolotl.core.trainers.base import SchedulerMixin
|
from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
|
||||||
|
|
||||||
if is_deepspeed_available():
|
if is_deepspeed_available():
|
||||||
import deepspeed
|
import deepspeed
|
||||||
|
|
||||||
|
|
||||||
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
|
class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base GRPOTrainer for axolotl helpers
|
Extend the base GRPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -4,5 +4,6 @@
|
|||||||
# flake8: noqa
|
# flake8: noqa
|
||||||
|
|
||||||
from .optimizer import OptimizerMixin
|
from .optimizer import OptimizerMixin
|
||||||
|
from .rng_state_loader import RngLoaderMixin
|
||||||
from .scheduler import SchedulerMixin
|
from .scheduler import SchedulerMixin
|
||||||
from .sequence_parallel import SequenceParallelMixin
|
from .sequence_parallel import SequenceParallelMixin
|
||||||
|
|||||||
67
src/axolotl/core/trainers/mixins/rng_state_loader.py
Normal file
67
src/axolotl/core/trainers/mixins/rng_state_loader.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
"""
|
||||||
|
Temporary fix/override for bug in resume from checkpoint
|
||||||
|
|
||||||
|
See https://github.com/huggingface/transformers/pull/37162
|
||||||
|
|
||||||
|
TODO: Remove when upstream added PR to release
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from transformers import Trainer, is_torch_npu_available
|
||||||
|
from transformers.trainer import safe_globals
|
||||||
|
from transformers.trainer_pt_utils import set_rng_state_for_device
|
||||||
|
from transformers.training_args import ParallelMode
|
||||||
|
|
||||||
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RngLoaderMixin(Trainer):
|
||||||
|
"""
|
||||||
|
mixin for method override to load RNG states from a checkpoint
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _load_rng_state(self, checkpoint):
|
||||||
|
# Load RNG states from `checkpoint`
|
||||||
|
if checkpoint is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if self.args.world_size > 1:
|
||||||
|
process_index = self.args.process_index
|
||||||
|
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
|
||||||
|
if not os.path.isfile(rng_file):
|
||||||
|
LOG.info(
|
||||||
|
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
|
||||||
|
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
rng_file = os.path.join(checkpoint, "rng_state.pth")
|
||||||
|
if not os.path.isfile(rng_file):
|
||||||
|
LOG.info(
|
||||||
|
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
|
||||||
|
"fashion, reproducibility is not guaranteed."
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
|
||||||
|
# which requires allowlisted classes when loading with weights_only=True.
|
||||||
|
with safe_globals():
|
||||||
|
checkpoint_rng_state = torch.load(rng_file) # nosec B614
|
||||||
|
random.setstate(checkpoint_rng_state["python"])
|
||||||
|
np.random.set_state(checkpoint_rng_state["numpy"])
|
||||||
|
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
|
||||||
|
|
||||||
|
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
set_rng_state_for_device(
|
||||||
|
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
|
||||||
|
)
|
||||||
|
if is_torch_npu_available():
|
||||||
|
set_rng_state_for_device(
|
||||||
|
"NPU", torch.npu, checkpoint_rng_state, is_distributed
|
||||||
|
)
|
||||||
@@ -13,6 +13,7 @@ from trl import (
|
|||||||
RewardTrainer,
|
RewardTrainer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from axolotl.core.trainers.mixins import RngLoaderMixin
|
||||||
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
|
||||||
|
|
||||||
|
|
||||||
@@ -74,7 +75,7 @@ class TRLPPOTrainer(PPOTrainer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base ORPOTrainer for axolotl helpers
|
Extend the base ORPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -154,7 +155,7 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
|
|||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base KTOTrainer for axolotl helpers
|
Extend the base KTOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -162,7 +163,7 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
|
|||||||
tag_names = ["axolotl", "kto"]
|
tag_names = ["axolotl", "kto"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base CPOTrainer for axolotl helpers
|
Extend the base CPOTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -244,7 +245,7 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
|
|||||||
return loss, metrics
|
return loss, metrics
|
||||||
|
|
||||||
|
|
||||||
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base RewardTrainer for axolotl helpers
|
Extend the base RewardTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
@@ -252,7 +253,7 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
|
|||||||
tag_names = ["axolotl", "reward"]
|
tag_names = ["axolotl", "reward"]
|
||||||
|
|
||||||
|
|
||||||
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
|
class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
|
||||||
"""
|
"""
|
||||||
Extend the base trl.PRMTrainer for axolotl helpers
|
Extend the base trl.PRMTrainer for axolotl helpers
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -1270,3 +1270,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
|
|||||||
if data["beta"] != data["trl"]["beta"]:
|
if data["beta"] != data["trl"]["beta"]:
|
||||||
raise ValueError("beta and trl.beta must match or one must be removed")
|
raise ValueError("beta and trl.beta must match or one must be removed")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def check_min_torch_version(self):
|
||||||
|
if self.env_capabilities and self.env_capabilities.torch_version:
|
||||||
|
torch_version = self.env_capabilities.torch_version
|
||||||
|
if version.parse(torch_version) < version.parse("2.5.1"):
|
||||||
|
LOG.warning(
|
||||||
|
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
|
||||||
|
)
|
||||||
|
|||||||
@@ -8,11 +8,13 @@ import shutil
|
|||||||
import sys
|
import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import datasets
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
from datasets import load_dataset
|
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
|
from tokenizers import AddedToken
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
|
||||||
@@ -48,6 +50,14 @@ def snapshot_download_w_retry(*args, **kwargs):
|
|||||||
return snapshot_download(*args, **kwargs)
|
return snapshot_download(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
|
def download_ds_fixture_bundle():
|
||||||
|
ds_dir = snapshot_download_w_retry(
|
||||||
|
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
|
||||||
|
)
|
||||||
|
return Path(ds_dir)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def download_smollm2_135m_model():
|
def download_smollm2_135m_model():
|
||||||
# download the model
|
# download the model
|
||||||
@@ -108,43 +118,43 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
# @pytest.fixture(scope="session", autouse=True)
|
||||||
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
|
||||||
# download the dataset
|
# # download the dataset
|
||||||
snapshot_download_w_retry(
|
# snapshot_download_w_retry(
|
||||||
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
# @pytest.fixture(scope="session", autouse=True)
|
||||||
def download_fozzie_alpaca_dpo_dataset():
|
# def download_fozzie_alpaca_dpo_dataset():
|
||||||
# download the dataset
|
# # download the dataset
|
||||||
snapshot_download_w_retry(
|
# snapshot_download_w_retry(
|
||||||
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
|
||||||
)
|
# )
|
||||||
snapshot_download_w_retry(
|
# snapshot_download_w_retry(
|
||||||
"fozziethebeat/alpaca_messages_2k_dpo_test",
|
# "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||||
repo_type="dataset",
|
# repo_type="dataset",
|
||||||
revision="ea82cff",
|
# revision="ea82cff",
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
# @pytest.fixture(scope="session")
|
||||||
@disable_hf_offline
|
# @disable_hf_offline
|
||||||
def dataset_fozzie_alpaca_dpo_dataset(
|
# def dataset_fozzie_alpaca_dpo_dataset(
|
||||||
download_fozzie_alpaca_dpo_dataset,
|
# download_fozzie_alpaca_dpo_dataset,
|
||||||
): # pylint: disable=unused-argument,redefined-outer-name
|
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||||
|
#
|
||||||
|
#
|
||||||
@pytest.fixture(scope="session")
|
# @pytest.fixture(scope="session")
|
||||||
@disable_hf_offline
|
# @disable_hf_offline
|
||||||
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||||
download_fozzie_alpaca_dpo_dataset,
|
# download_fozzie_alpaca_dpo_dataset,
|
||||||
): # pylint: disable=unused-argument,redefined-outer-name
|
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
return load_dataset(
|
# return load_dataset(
|
||||||
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||||
)
|
# )
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
@@ -271,7 +281,7 @@ def download_mlx_mistral_7b_model_fixture():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture
|
||||||
def download_llama2_model_fixture():
|
def download_llama2_model_fixture():
|
||||||
# download the tokenizer only
|
# download the tokenizer only
|
||||||
snapshot_download_w_retry(
|
snapshot_download_w_retry(
|
||||||
@@ -281,7 +291,7 @@ def download_llama2_model_fixture():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def tokenizer_huggyllama(
|
def tokenizer_huggyllama(
|
||||||
download_huggyllama_model_fixture,
|
download_huggyllama_model_fixture,
|
||||||
@@ -292,6 +302,57 @@ def tokenizer_huggyllama(
|
|||||||
return tokenizer
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
@enable_hf_offline
|
||||||
|
def tokenizer_huggyllama_w_special_tokens(
|
||||||
|
tokenizer_huggyllama,
|
||||||
|
): # pylint: disable=redefined-outer-name
|
||||||
|
tokenizer_huggyllama.add_special_tokens(
|
||||||
|
{
|
||||||
|
"bos_token": "<s>",
|
||||||
|
"eos_token": "</s>",
|
||||||
|
"unk_token": "<unk>",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return tokenizer_huggyllama
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
@enable_hf_offline
|
||||||
|
def tokenizer_llama2_7b(
|
||||||
|
download_llama2_model_fixture,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||||
|
|
||||||
|
return tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
@enable_hf_offline
|
||||||
|
def tokenizer_mistral_7b_instruct(
|
||||||
|
download_mlx_mistral_7b_model_fixture,
|
||||||
|
): # pylint: disable=unused-argument,redefined-outer-name
|
||||||
|
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
|
||||||
|
tokenizer_mistral_7b_instruct.add_special_tokens(
|
||||||
|
{
|
||||||
|
"eos_token": AddedToken(
|
||||||
|
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
||||||
|
)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
tokenizer_mistral_7b_instruct.add_tokens(
|
||||||
|
[
|
||||||
|
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return tokenizer_mistral_7b_instruct
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def temp_dir():
|
def temp_dir():
|
||||||
# Create a temporary directory
|
# Create a temporary directory
|
||||||
@@ -357,6 +418,60 @@ def cleanup_monkeypatches():
|
|||||||
globals().pop(module_global, None)
|
globals().pop(module_global, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_winglian_tiny_shakespeare(
|
||||||
|
download_ds_fixture_bundle: Path,
|
||||||
|
): # pylint: disable=redefined-outer-name
|
||||||
|
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
||||||
|
return datasets.load_from_disk(ds_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_tatsu_lab_alpaca(
|
||||||
|
download_ds_fixture_bundle: Path,
|
||||||
|
): # pylint: disable=redefined-outer-name
|
||||||
|
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
||||||
|
return datasets.load_from_disk(ds_path)["train"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_mhenrichsen_alpaca_2k_test(
|
||||||
|
download_ds_fixture_bundle: Path,
|
||||||
|
): # pylint: disable=redefined-outer-name
|
||||||
|
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
||||||
|
return datasets.load_from_disk(ds_path)["train"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||||
|
download_ds_fixture_bundle: Path,
|
||||||
|
): # pylint: disable=redefined-outer-name
|
||||||
|
ds_path = (
|
||||||
|
download_ds_fixture_bundle
|
||||||
|
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
||||||
|
)
|
||||||
|
return datasets.load_from_disk(ds_path)["train"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||||
|
download_ds_fixture_bundle: Path,
|
||||||
|
): # pylint: disable=redefined-outer-name
|
||||||
|
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
||||||
|
return datasets.load_from_disk(ds_path)["train"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||||
|
download_ds_fixture_bundle: Path,
|
||||||
|
): # pylint: disable=redefined-outer-name
|
||||||
|
ds_path = (
|
||||||
|
download_ds_fixture_bundle
|
||||||
|
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
||||||
|
)
|
||||||
|
return datasets.load_from_disk(ds_path)["train"]
|
||||||
|
|
||||||
|
|
||||||
# # pylint: disable=redefined-outer-name,unused-argument
|
# # pylint: disable=redefined-outer-name,unused-argument
|
||||||
# def test_load_fixtures(
|
# def test_load_fixtures(
|
||||||
# download_smollm2_135m_model,
|
# download_smollm2_135m_model,
|
||||||
|
|||||||
@@ -324,7 +324,7 @@ class TestDatasetPreparation:
|
|||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_hub_with_revision_with_dpo(
|
def test_load_hub_with_revision_with_dpo(
|
||||||
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||||
):
|
):
|
||||||
"""Verify that processing dpo data from the hub works with a specific revision"""
|
"""Verify that processing dpo data from the hub works with a specific revision"""
|
||||||
|
|
||||||
@@ -339,12 +339,10 @@ class TestDatasetPreparation:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
with patch(
|
with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
|
||||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
|
||||||
) as mock_load_dataset:
|
|
||||||
# Set up the mock to return different values on successive calls
|
# Set up the mock to return different values on successive calls
|
||||||
mock_load_dataset.return_value = (
|
mock_load_dataset.return_value = (
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||||
)
|
)
|
||||||
|
|
||||||
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
train_dataset, _ = load_prepare_preference_datasets(cfg)
|
||||||
@@ -354,7 +352,9 @@ class TestDatasetPreparation:
|
|||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
@pytest.mark.skip("datasets bug with local datasets when offline")
|
@pytest.mark.skip("datasets bug with local datasets when offline")
|
||||||
def test_load_local_hub_with_revision(self, tokenizer):
|
def test_load_local_hub_with_revision(
|
||||||
|
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
|
||||||
|
):
|
||||||
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
|
||||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||||
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
|
||||||
@@ -386,13 +386,23 @@ class TestDatasetPreparation:
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
|
with patch(
|
||||||
|
"axolotl.utils.data.shared.load_dataset_w_config"
|
||||||
|
) as mock_load_dataset:
|
||||||
|
# Set up the mock to return different values on successive calls
|
||||||
|
mock_load_dataset.return_value = (
|
||||||
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
|
||||||
|
)
|
||||||
|
|
||||||
assert len(dataset) == 2000
|
dataset, _ = load_tokenized_prepared_datasets(
|
||||||
assert "input_ids" in dataset.features
|
tokenizer, cfg, prepared_path
|
||||||
assert "attention_mask" in dataset.features
|
)
|
||||||
assert "labels" in dataset.features
|
|
||||||
shutil.rmtree(tmp_ds_path)
|
assert len(dataset) == 2000
|
||||||
|
assert "input_ids" in dataset.features
|
||||||
|
assert "attention_mask" in dataset.features
|
||||||
|
assert "labels" in dataset.features
|
||||||
|
shutil.rmtree(tmp_ds_path)
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_loading_local_dataset_folder(self, tokenizer):
|
def test_loading_local_dataset_folder(self, tokenizer):
|
||||||
|
|||||||
@@ -238,21 +238,22 @@ class TestDeduplicateRLDataset:
|
|||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_with_deduplication(
|
def test_load_with_deduplication(
|
||||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
self,
|
||||||
|
cfg,
|
||||||
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||||
|
tokenizer_huggyllama,
|
||||||
):
|
):
|
||||||
"""Verify that loading with deduplication removes duplicates."""
|
"""Verify that loading with deduplication removes duplicates."""
|
||||||
|
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
with (
|
with (
|
||||||
patch(
|
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
|
||||||
) as mock_load_dataset,
|
|
||||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||||
):
|
):
|
||||||
# Set up the mock to return different values on successive calls
|
# Set up the mock to return different values on successive calls
|
||||||
mock_load_dataset.side_effect = [
|
mock_load_dataset.side_effect = [
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||||
]
|
]
|
||||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||||
|
|
||||||
@@ -263,19 +264,20 @@ class TestDeduplicateRLDataset:
|
|||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_load_without_deduplication(
|
def test_load_without_deduplication(
|
||||||
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
|
self,
|
||||||
|
cfg,
|
||||||
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||||
|
tokenizer_huggyllama,
|
||||||
):
|
):
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
with (
|
with (
|
||||||
patch(
|
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||||
"axolotl.utils.data.shared.load_dataset_w_config"
|
|
||||||
) as mock_load_dataset,
|
|
||||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||||
):
|
):
|
||||||
# Set up the mock to return different values on successive calls
|
# Set up the mock to return different values on successive calls
|
||||||
mock_load_dataset.side_effect = [
|
mock_load_dataset.side_effect = [
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||||
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
|
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||||
]
|
]
|
||||||
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
mock_load_tokenizer.return_value = tokenizer_huggyllama
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Module for testing streaming dataset sequence packing"""
|
"""Module for testing streaming dataset sequence packing"""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from datasets import concatenate_datasets, load_dataset
|
from datasets import concatenate_datasets
|
||||||
from torch.utils.data import DataLoader, RandomSampler
|
from torch.utils.data import DataLoader, RandomSampler
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
@@ -27,7 +27,6 @@ class TestBatchedSamplerPacking:
|
|||||||
Test class for packing streaming dataset sequences
|
Test class for packing streaming dataset sequences
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"batch_size, num_workers",
|
"batch_size, num_workers",
|
||||||
[
|
[
|
||||||
@@ -41,14 +40,17 @@ class TestBatchedSamplerPacking:
|
|||||||
@pytest.mark.parametrize("sequential", [True, False])
|
@pytest.mark.parametrize("sequential", [True, False])
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def test_packing(
|
def test_packing(
|
||||||
self, batch_size, num_workers, tokenizer, max_seq_length, sequential
|
self,
|
||||||
|
dataset_winglian_tiny_shakespeare,
|
||||||
|
batch_size,
|
||||||
|
num_workers,
|
||||||
|
tokenizer,
|
||||||
|
max_seq_length,
|
||||||
|
sequential,
|
||||||
):
|
):
|
||||||
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
|
||||||
|
|
||||||
dataset = load_dataset(
|
dataset = dataset_winglian_tiny_shakespeare["train"]
|
||||||
"winglian/tiny-shakespeare",
|
|
||||||
split="train",
|
|
||||||
)
|
|
||||||
|
|
||||||
cfg = DictDefault(
|
cfg = DictDefault(
|
||||||
{
|
{
|
||||||
@@ -58,7 +60,7 @@ class TestBatchedSamplerPacking:
|
|||||||
)
|
)
|
||||||
ds_cfg = DictDefault(
|
ds_cfg = DictDefault(
|
||||||
{
|
{
|
||||||
"field": "Text",
|
"field": "text",
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
completion_strategy = load(tokenizer, cfg, ds_cfg)
|
completion_strategy = load(tokenizer, cfg, ds_cfg)
|
||||||
|
|||||||
@@ -2,13 +2,8 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import unittest
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
|
||||||
from datasets import load_dataset
|
|
||||||
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
|
|
||||||
|
|
||||||
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
|
||||||
from axolotl.prompt_strategies.alpaca_w_system import (
|
from axolotl.prompt_strategies.alpaca_w_system import (
|
||||||
InstructionWSystemPromptTokenizingStrategy,
|
InstructionWSystemPromptTokenizingStrategy,
|
||||||
@@ -61,24 +56,13 @@ test_data = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTokenizationStrategies(unittest.TestCase):
|
class TestPromptTokenizationStrategies:
|
||||||
"""
|
"""
|
||||||
Test class for prompt tokenization strategies.
|
Test class for prompt tokenization strategies.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
|
||||||
self.tokenizer.add_special_tokens(
|
|
||||||
{
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_no_sys_prompt(self):
|
|
||||||
"""
|
"""
|
||||||
tests the interface between the user and assistant parts
|
tests the interface between the user and assistant parts
|
||||||
"""
|
"""
|
||||||
@@ -86,7 +70,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
strat = AlpacaPromptTokenizingStrategy(
|
strat = AlpacaPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
self.tokenizer,
|
tokenizer_huggyllama_w_special_tokens,
|
||||||
False,
|
False,
|
||||||
2048,
|
2048,
|
||||||
)
|
)
|
||||||
@@ -99,7 +83,8 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
assert example["labels"][world_idx] == 3186
|
assert example["labels"][world_idx] == 3186
|
||||||
assert example["labels"][world_idx - 1] == -100
|
assert example["labels"][world_idx - 1] == -100
|
||||||
|
|
||||||
def test_alpaca(self):
|
@enable_hf_offline
|
||||||
|
def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
|
||||||
"""
|
"""
|
||||||
tests the interface between the user and assistant parts
|
tests the interface between the user and assistant parts
|
||||||
"""
|
"""
|
||||||
@@ -107,7 +92,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
prompter = AlpacaPrompter()
|
prompter = AlpacaPrompter()
|
||||||
strat = AlpacaPromptTokenizingStrategy(
|
strat = AlpacaPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
self.tokenizer,
|
tokenizer_huggyllama_w_special_tokens,
|
||||||
False,
|
False,
|
||||||
2048,
|
2048,
|
||||||
)
|
)
|
||||||
@@ -118,28 +103,17 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
|
|||||||
assert example["labels"][world_idx - 1] == -100
|
assert example["labels"][world_idx - 1] == -100
|
||||||
|
|
||||||
|
|
||||||
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
class TestInstructionWSystemPromptTokenizingStrategy:
|
||||||
"""
|
"""
|
||||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
|
||||||
self.tokenizer.add_special_tokens(
|
|
||||||
{
|
|
||||||
"bos_token": "<s>",
|
|
||||||
"eos_token": "</s>",
|
|
||||||
"unk_token": "<unk>",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_system_alpaca(self):
|
|
||||||
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
|
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
|
||||||
strat = InstructionWSystemPromptTokenizingStrategy(
|
strat = InstructionWSystemPromptTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
self.tokenizer,
|
tokenizer_huggyllama_w_special_tokens,
|
||||||
False,
|
False,
|
||||||
2048,
|
2048,
|
||||||
)
|
)
|
||||||
@@ -160,18 +134,13 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
|
|||||||
assert example["input_ids"][8] == 11889 # USER
|
assert example["input_ids"][8] == 11889 # USER
|
||||||
|
|
||||||
|
|
||||||
class Llama2ChatTokenizationTest(unittest.TestCase):
|
class Llama2ChatTokenizationTest:
|
||||||
"""
|
"""
|
||||||
Test class for prompt tokenization strategies with sys prompt from the dataset
|
Test class for prompt tokenization strategies with sys prompt from the dataset
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def test_llama2_chat_integration(self, tokenizer_llama2_7b):
|
||||||
# pylint: disable=duplicate-code
|
|
||||||
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
|
||||||
# woraround because official Meta repos are not open
|
|
||||||
|
|
||||||
def test_llama2_chat_integration(self):
|
|
||||||
with open(
|
with open(
|
||||||
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
|
||||||
) as fin:
|
) as fin:
|
||||||
@@ -186,16 +155,18 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
|
|||||||
prompter = Llama2ChatPrompter()
|
prompter = Llama2ChatPrompter()
|
||||||
strat = LLama2ChatTokenizingStrategy(
|
strat = LLama2ChatTokenizingStrategy(
|
||||||
prompter,
|
prompter,
|
||||||
self.tokenizer,
|
tokenizer_llama2_7b,
|
||||||
False,
|
False,
|
||||||
4096,
|
4096,
|
||||||
)
|
)
|
||||||
example = strat.tokenize_prompt(conversation)
|
example = strat.tokenize_prompt(conversation)
|
||||||
for fields in ["input_ids", "attention_mask", "labels"]:
|
for fields in ["input_ids", "attention_mask", "labels"]:
|
||||||
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
|
# pytest assert equals
|
||||||
self.assertEqual(example[fields], tokenized_conversation[fields])
|
|
||||||
|
|
||||||
def compare_with_transformers_integration(self):
|
assert len(example[fields]) == len(tokenized_conversation[fields])
|
||||||
|
assert example[fields] == tokenized_conversation[fields]
|
||||||
|
|
||||||
|
def compare_with_transformers_integration(self, tokenizer_llama2_7b):
|
||||||
# this needs transformers >= v4.31.0
|
# this needs transformers >= v4.31.0
|
||||||
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
|
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
|
||||||
from transformers.pipelines.conversational import Conversation
|
from transformers.pipelines.conversational import Conversation
|
||||||
@@ -234,49 +205,27 @@ If a question does not make any sense, or is not factually coherent, explain why
|
|||||||
generated_responses=answers,
|
generated_responses=answers,
|
||||||
)
|
)
|
||||||
# pylint: disable=W0212
|
# pylint: disable=W0212
|
||||||
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
|
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
|
||||||
|
|
||||||
self.assertEqual(
|
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||||
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class OrpoTokenizationTest(unittest.TestCase):
|
class OrpoTokenizationTest:
|
||||||
"""test case for the ORPO tokenization"""
|
"""test case for the ORPO tokenization"""
|
||||||
|
|
||||||
@enable_hf_offline
|
@enable_hf_offline
|
||||||
def setUp(self) -> None:
|
def test_orpo_integration(
|
||||||
# pylint: disable=duplicate-code
|
self,
|
||||||
tokenizer = LlamaTokenizer.from_pretrained(
|
tokenizer_mistral_7b_instruct_chatml,
|
||||||
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
|
||||||
)
|
):
|
||||||
tokenizer.add_special_tokens(
|
ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
|
||||||
{
|
|
||||||
"eos_token": AddedToken(
|
|
||||||
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
|
|
||||||
)
|
|
||||||
}
|
|
||||||
)
|
|
||||||
tokenizer.add_tokens(
|
|
||||||
[
|
|
||||||
AddedToken(
|
|
||||||
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.dataset = load_dataset(
|
|
||||||
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
|
|
||||||
).select([0])
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
|
|
||||||
def test_orpo_integration(self):
|
|
||||||
strat = load(
|
strat = load(
|
||||||
self.tokenizer,
|
tokenizer_mistral_7b_instruct_chatml,
|
||||||
DictDefault({"train_on_inputs": False}),
|
DictDefault({"train_on_inputs": False}),
|
||||||
DictDefault({"chat_template": "chatml"}),
|
DictDefault({"chat_template": "chatml"}),
|
||||||
)
|
)
|
||||||
res = strat.tokenize_prompt(self.dataset[0])
|
res = strat.tokenize_prompt(ds[0])
|
||||||
assert "rejected_input_ids" in res
|
assert "rejected_input_ids" in res
|
||||||
assert "rejected_labels" in res
|
assert "rejected_labels" in res
|
||||||
assert "input_ids" in res
|
assert "input_ids" in res
|
||||||
@@ -295,7 +244,3 @@ class OrpoTokenizationTest(unittest.TestCase):
|
|||||||
|
|
||||||
assert res["prompt_attention_mask"][0] == 1
|
assert res["prompt_attention_mask"][0] == 1
|
||||||
assert res["prompt_attention_mask"][-1] == 0
|
assert res["prompt_attention_mask"][-1] == 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
unittest.main()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user