Release update 20250331 (#2460) [skip ci]

* make torch 2.6.0 the default image

* fix tests against upstream main

* fix attribute access

* use fixture dataset

* fix dataset load

* correct the fixtures + tests

* more fixtures

* add accidentally removed shakespeare fixture

* fix conversion from unittest to pytest class

* nightly main ci caches

* build 12.6.3 cuda base image

* override for fix from huggingface/transformers#37162

* address PR feedback
This commit is contained in:
Wing Lian
2025-04-01 08:47:50 -04:00
committed by GitHub
parent 328d598114
commit e0aba74dd0
17 changed files with 347 additions and 169 deletions

View File

@@ -40,6 +40,12 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "126"
cuda_version: 12.6.3
cudnn_version: ""
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128" - cuda: "128"
cuda_version: 12.8.1 cuda_version: 12.8.1
cudnn_version: "" cudnn_version: ""

View File

@@ -25,12 +25,12 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: vllm axolotl_extras: vllm
is_latest: true
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout
@@ -87,12 +87,12 @@ jobs:
python_version: "3.11" python_version: "3.11"
pytorch: 2.5.1 pytorch: 2.5.1
axolotl_extras: axolotl_extras:
is_latest: true
- cuda: 124 - cuda: 124
cuda_version: 12.4.1 cuda_version: 12.4.1
python_version: "3.11" python_version: "3.11"
pytorch: 2.6.0 pytorch: 2.6.0
axolotl_extras: axolotl_extras:
is_latest: true
runs-on: axolotl-gpu-runner runs-on: axolotl-gpu-runner
steps: steps:
- name: Checkout - name: Checkout

View File

@@ -33,6 +33,15 @@ jobs:
- name: Check out repository code - name: Check out repository code
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Restore HF cache
id: hf-cache-restore
uses: actions/cache/restore@v4
with:
path: |
/home/runner/.cache/huggingface/hub/datasets--*
/home/runner/.cache/huggingface/hub/models--*
key: ${{ runner.os }}-hf-hub-cache-v2
- name: Setup Python - name: Setup Python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
with: with:
@@ -46,7 +55,7 @@ jobs:
- name: Install PyTorch - name: Install PyTorch
run: | run: |
pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu pip3 install torch==${{ matrix.pytorch_version }}
- name: Update requirements.txt - name: Update requirements.txt
run: | run: |
@@ -58,8 +67,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
run: | run: |
pip3 install --upgrade pip pip3 show torch
pip3 install --upgrade packaging==23.2
pip3 install --no-build-isolation -U -e . pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh python scripts/cutcrossentropy_install.py | sh
@@ -73,10 +81,15 @@ jobs:
run: | run: |
axolotl --help axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests - name: Run tests
run: | run: |
pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
pytest tests/patched/ pytest -v tests/patched/
pytest -v tests/cli/
- name: cleanup pip cache - name: cleanup pip cache
run: | run: |

View File

@@ -96,6 +96,10 @@ jobs:
run: | run: |
axolotl --help axolotl --help
- name: Pre-Download dataset fixture
run: |
huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
- name: Run tests - name: Run tests
run: | run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/

View File

@@ -12,7 +12,7 @@ liger-kernel==0.5.5
packaging==23.2 packaging==23.2
peft==0.15.0 peft==0.15.0
transformers==4.50.0 transformers==4.50.3
tokenizers>=0.21.1 tokenizers>=0.21.1
accelerate==1.5.2 accelerate==1.5.2
datasets==3.5.0 datasets==3.5.0

View File

@@ -28,6 +28,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import ( from axolotl.core.trainers.mixins import (
OptimizerMixin, OptimizerMixin,
RngLoaderMixin,
SchedulerMixin, SchedulerMixin,
SequenceParallelMixin, SequenceParallelMixin,
) )
@@ -40,7 +41,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer): class AxolotlTrainer(
SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
):
"""Extend the base Trainer for axolotl helpers""" """Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]

View File

@@ -13,7 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer from trl import DPOTrainer
from axolotl.core.trainers.mixins import SchedulerMixin from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import ( from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging, sanitize_kwargs_for_tagging,
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp import smdistributed.modelparallel.torch as smp
class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
""" """
Extend the base DPOTrainer for axolotl helpers Extend the base DPOTrainer for axolotl helpers
""" """

View File

@@ -8,13 +8,13 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator from trl.extras.profiling import profiling_decorator
from axolotl.core.trainers.base import SchedulerMixin from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
if is_deepspeed_available(): if is_deepspeed_available():
import deepspeed import deepspeed
class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
""" """
Extend the base GRPOTrainer for axolotl helpers Extend the base GRPOTrainer for axolotl helpers
""" """

View File

@@ -4,5 +4,6 @@
# flake8: noqa # flake8: noqa
from .optimizer import OptimizerMixin from .optimizer import OptimizerMixin
from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin from .sequence_parallel import SequenceParallelMixin

View File

@@ -0,0 +1,67 @@
"""
Temporary fix/override for bug in resume from checkpoint
See https://github.com/huggingface/transformers/pull/37162
TODO: Remove when upstream added PR to release
"""
import logging
import os
import random
import numpy as np
import torch
from transformers import Trainer, is_torch_npu_available
from transformers.trainer import safe_globals
from transformers.trainer_pt_utils import set_rng_state_for_device
from transformers.training_args import ParallelMode
LOG = logging.getLogger(__name__)
class RngLoaderMixin(Trainer):
"""
mixin for method override to load RNG states from a checkpoint
"""
def _load_rng_state(self, checkpoint):
# Load RNG states from `checkpoint`
if checkpoint is None:
return
if self.args.world_size > 1:
process_index = self.args.process_index
rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
if not os.path.isfile(rng_file):
LOG.info(
f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
"wasn't launched in a distributed fashion, reproducibility is not guaranteed."
)
return
else:
rng_file = os.path.join(checkpoint, "rng_state.pth")
if not os.path.isfile(rng_file):
LOG.info(
"Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
"fashion, reproducibility is not guaranteed."
)
return
# Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
# which requires allowlisted classes when loading with weights_only=True.
with safe_globals():
checkpoint_rng_state = torch.load(rng_file) # nosec B614
random.setstate(checkpoint_rng_state["python"])
np.random.set_state(checkpoint_rng_state["numpy"])
torch.random.set_rng_state(checkpoint_rng_state["cpu"])
is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
if torch.cuda.is_available():
set_rng_state_for_device(
"CUDA", torch.cuda, checkpoint_rng_state, is_distributed
)
if is_torch_npu_available():
set_rng_state_for_device(
"NPU", torch.npu, checkpoint_rng_state, is_distributed
)

View File

@@ -13,6 +13,7 @@ from trl import (
RewardTrainer, RewardTrainer,
) )
from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -74,7 +75,7 @@ class TRLPPOTrainer(PPOTrainer):
) )
class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
""" """
Extend the base ORPOTrainer for axolotl helpers Extend the base ORPOTrainer for axolotl helpers
""" """
@@ -154,7 +155,7 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
return loss, metrics return loss, metrics
class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
""" """
Extend the base KTOTrainer for axolotl helpers Extend the base KTOTrainer for axolotl helpers
""" """
@@ -162,7 +163,7 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"] tag_names = ["axolotl", "kto"]
class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
""" """
Extend the base CPOTrainer for axolotl helpers Extend the base CPOTrainer for axolotl helpers
""" """
@@ -244,7 +245,7 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
return loss, metrics return loss, metrics
class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
""" """
Extend the base RewardTrainer for axolotl helpers Extend the base RewardTrainer for axolotl helpers
""" """
@@ -252,7 +253,7 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"] tag_names = ["axolotl", "reward"]
class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
""" """
Extend the base trl.PRMTrainer for axolotl helpers Extend the base trl.PRMTrainer for axolotl helpers
""" """

View File

@@ -1270,3 +1270,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data["beta"] != data["trl"]["beta"]: if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed") raise ValueError("beta and trl.beta must match or one must be removed")
return data return data
@model_validator(mode="after")
def check_min_torch_version(self):
if self.env_capabilities and self.env_capabilities.torch_version:
torch_version = self.env_capabilities.torch_version
if version.parse(torch_version) < version.parse("2.5.1"):
LOG.warning(
f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
)

View File

@@ -8,11 +8,13 @@ import shutil
import sys import sys
import tempfile import tempfile
import time import time
from pathlib import Path
import datasets
import pytest import pytest
import requests import requests
from datasets import load_dataset
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from tokenizers import AddedToken
from transformers import AutoTokenizer from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
@@ -48,6 +50,14 @@ def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs) return snapshot_download(*args, **kwargs)
@pytest.fixture(scope="session", autouse=True)
def download_ds_fixture_bundle():
ds_dir = snapshot_download_w_retry(
"axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
)
return Path(ds_dir)
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model(): def download_smollm2_135m_model():
# download the model # download the model
@@ -108,43 +118,43 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
) )
@pytest.fixture(scope="session", autouse=True) # @pytest.fixture(scope="session", autouse=True)
def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset(): # def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
# download the dataset # # download the dataset
snapshot_download_w_retry( # snapshot_download_w_retry(
"argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset" # "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
) # )
@pytest.fixture(scope="session", autouse=True) # @pytest.fixture(scope="session", autouse=True)
def download_fozzie_alpaca_dpo_dataset(): # def download_fozzie_alpaca_dpo_dataset():
# download the dataset # # download the dataset
snapshot_download_w_retry( # snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset" # "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
) # )
snapshot_download_w_retry( # snapshot_download_w_retry(
"fozziethebeat/alpaca_messages_2k_dpo_test", # "fozziethebeat/alpaca_messages_2k_dpo_test",
repo_type="dataset", # repo_type="dataset",
revision="ea82cff", # revision="ea82cff",
) # )
@pytest.fixture(scope="session") # @pytest.fixture(scope="session")
@disable_hf_offline # @disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset( # def dataset_fozzie_alpaca_dpo_dataset(
download_fozzie_alpaca_dpo_dataset, # download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name # ): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train") # return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
#
#
@pytest.fixture(scope="session") # @pytest.fixture(scope="session")
@disable_hf_offline # @disable_hf_offline
def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff( # def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
download_fozzie_alpaca_dpo_dataset, # download_fozzie_alpaca_dpo_dataset,
): # pylint: disable=unused-argument,redefined-outer-name # ): # pylint: disable=unused-argument,redefined-outer-name
return load_dataset( # return load_dataset(
"fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff" # "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
) # )
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
@@ -271,7 +281,7 @@ def download_mlx_mistral_7b_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True) @pytest.fixture
def download_llama2_model_fixture(): def download_llama2_model_fixture():
# download the tokenizer only # download the tokenizer only
snapshot_download_w_retry( snapshot_download_w_retry(
@@ -281,7 +291,7 @@ def download_llama2_model_fixture():
) )
@pytest.fixture(scope="session", autouse=True) @pytest.fixture
@enable_hf_offline @enable_hf_offline
def tokenizer_huggyllama( def tokenizer_huggyllama(
download_huggyllama_model_fixture, download_huggyllama_model_fixture,
@@ -292,6 +302,57 @@ def tokenizer_huggyllama(
return tokenizer return tokenizer
@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama_w_special_tokens(
tokenizer_huggyllama,
): # pylint: disable=redefined-outer-name
tokenizer_huggyllama.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
return tokenizer_huggyllama
@pytest.fixture
@enable_hf_offline
def tokenizer_llama2_7b(
download_llama2_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
return tokenizer
@pytest.fixture
@enable_hf_offline
def tokenizer_mistral_7b_instruct(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
@pytest.fixture
def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
tokenizer_mistral_7b_instruct.add_special_tokens(
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer_mistral_7b_instruct.add_tokens(
[
AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
]
)
return tokenizer_mistral_7b_instruct
@pytest.fixture @pytest.fixture
def temp_dir(): def temp_dir():
# Create a temporary directory # Create a temporary directory
@@ -357,6 +418,60 @@ def cleanup_monkeypatches():
globals().pop(module_global, None) globals().pop(module_global, None)
@pytest.fixture
def dataset_winglian_tiny_shakespeare(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
return datasets.load_from_disk(ds_path)
@pytest.fixture
def dataset_tatsu_lab_alpaca(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_mhenrichsen_alpaca_2k_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = (
download_ds_fixture_bundle
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
)
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
return datasets.load_from_disk(ds_path)["train"]
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
ds_path = (
download_ds_fixture_bundle
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
)
return datasets.load_from_disk(ds_path)["train"]
# # pylint: disable=redefined-outer-name,unused-argument # # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures( # def test_load_fixtures(
# download_smollm2_135m_model, # download_smollm2_135m_model,

View File

@@ -324,7 +324,7 @@ class TestDatasetPreparation:
@enable_hf_offline @enable_hf_offline
def test_load_hub_with_revision_with_dpo( def test_load_hub_with_revision_with_dpo(
self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
): ):
"""Verify that processing dpo data from the hub works with a specific revision""" """Verify that processing dpo data from the hub works with a specific revision"""
@@ -339,12 +339,10 @@ class TestDatasetPreparation:
) )
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
with patch( with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls # Set up the mock to return different values on successive calls
mock_load_dataset.return_value = ( mock_load_dataset.return_value = (
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
) )
train_dataset, _ = load_prepare_preference_datasets(cfg) train_dataset, _ = load_prepare_preference_datasets(cfg)
@@ -354,7 +352,9 @@ class TestDatasetPreparation:
@enable_hf_offline @enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline") @pytest.mark.skip("datasets bug with local datasets when offline")
def test_load_local_hub_with_revision(self, tokenizer): def test_load_local_hub_with_revision(
self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision""" """Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir: with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
@@ -386,13 +386,23 @@ class TestDatasetPreparation:
} }
) )
dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) with patch(
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
assert len(dataset) == 2000 dataset, _ = load_tokenized_prepared_datasets(
assert "input_ids" in dataset.features tokenizer, cfg, prepared_path
assert "attention_mask" in dataset.features )
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path) assert len(dataset) == 2000
assert "input_ids" in dataset.features
assert "attention_mask" in dataset.features
assert "labels" in dataset.features
shutil.rmtree(tmp_ds_path)
@enable_hf_offline @enable_hf_offline
def test_loading_local_dataset_folder(self, tokenizer): def test_loading_local_dataset_folder(self, tokenizer):

View File

@@ -238,21 +238,22 @@ class TestDeduplicateRLDataset:
@enable_hf_offline @enable_hf_offline
def test_load_with_deduplication( def test_load_with_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama self,
cfg,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
tokenizer_huggyllama,
): ):
"""Verify that loading with deduplication removes duplicates.""" """Verify that loading with deduplication removes duplicates."""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
with ( with (
patch( patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
): ):
# Set up the mock to return different values on successive calls # Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [ mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
] ]
mock_load_tokenizer.return_value = tokenizer_huggyllama mock_load_tokenizer.return_value = tokenizer_huggyllama
@@ -263,19 +264,20 @@ class TestDeduplicateRLDataset:
@enable_hf_offline @enable_hf_offline
def test_load_without_deduplication( def test_load_without_deduplication(
self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama self,
cfg,
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
tokenizer_huggyllama,
): ):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
with ( with (
patch( patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
"axolotl.utils.data.shared.load_dataset_w_config"
) as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
): ):
# Set up the mock to return different values on successive calls # Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [ mock_load_dataset.side_effect = [
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
] ]
mock_load_tokenizer.return_value = tokenizer_huggyllama mock_load_tokenizer.return_value = tokenizer_huggyllama

View File

@@ -1,7 +1,7 @@
"""Module for testing streaming dataset sequence packing""" """Module for testing streaming dataset sequence packing"""
import pytest import pytest
from datasets import concatenate_datasets, load_dataset from datasets import concatenate_datasets
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer from transformers import AutoTokenizer
@@ -27,7 +27,6 @@ class TestBatchedSamplerPacking:
Test class for packing streaming dataset sequences Test class for packing streaming dataset sequences
""" """
@pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"batch_size, num_workers", "batch_size, num_workers",
[ [
@@ -41,14 +40,17 @@ class TestBatchedSamplerPacking:
@pytest.mark.parametrize("sequential", [True, False]) @pytest.mark.parametrize("sequential", [True, False])
@enable_hf_offline @enable_hf_offline
def test_packing( def test_packing(
self, batch_size, num_workers, tokenizer, max_seq_length, sequential self,
dataset_winglian_tiny_shakespeare,
batch_size,
num_workers,
tokenizer,
max_seq_length,
sequential,
): ):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
dataset = load_dataset( dataset = dataset_winglian_tiny_shakespeare["train"]
"winglian/tiny-shakespeare",
split="train",
)
cfg = DictDefault( cfg = DictDefault(
{ {
@@ -58,7 +60,7 @@ class TestBatchedSamplerPacking:
) )
ds_cfg = DictDefault( ds_cfg = DictDefault(
{ {
"field": "Text", "field": "text",
} }
) )
completion_strategy = load(tokenizer, cfg, ds_cfg) completion_strategy = load(tokenizer, cfg, ds_cfg)

View File

@@ -2,13 +2,8 @@
import json import json
import logging import logging
import unittest
from pathlib import Path from pathlib import Path
import pytest
from datasets import load_dataset
from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import ( from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy, InstructionWSystemPromptTokenizingStrategy,
@@ -61,24 +56,13 @@ test_data = {
} }
class TestPromptTokenizationStrategies(unittest.TestCase): class TestPromptTokenizationStrategies:
""" """
Test class for prompt tokenization strategies. Test class for prompt tokenization strategies.
""" """
@enable_hf_offline @enable_hf_offline
def setUp(self) -> None: def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_no_sys_prompt(self):
""" """
tests the interface between the user and assistant parts tests the interface between the user and assistant parts
""" """
@@ -86,7 +70,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy( strat = AlpacaPromptTokenizingStrategy(
prompter, prompter,
self.tokenizer, tokenizer_huggyllama_w_special_tokens,
False, False,
2048, 2048,
) )
@@ -99,7 +83,8 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx] == 3186 assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100 assert example["labels"][world_idx - 1] == -100
def test_alpaca(self): @enable_hf_offline
def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
""" """
tests the interface between the user and assistant parts tests the interface between the user and assistant parts
""" """
@@ -107,7 +92,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
prompter = AlpacaPrompter() prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy( strat = AlpacaPromptTokenizingStrategy(
prompter, prompter,
self.tokenizer, tokenizer_huggyllama_w_special_tokens,
False, False,
2048, 2048,
) )
@@ -118,28 +103,17 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx - 1] == -100 assert example["labels"][world_idx - 1] == -100
class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): class TestInstructionWSystemPromptTokenizingStrategy:
""" """
Test class for prompt tokenization strategies with sys prompt from the dataset Test class for prompt tokenization strategies with sys prompt from the dataset
""" """
@enable_hf_offline @enable_hf_offline
def setUp(self) -> None: def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
}
)
def test_system_alpaca(self):
prompter = SystemDataPrompter(PromptStyle.CHAT.value) prompter = SystemDataPrompter(PromptStyle.CHAT.value)
strat = InstructionWSystemPromptTokenizingStrategy( strat = InstructionWSystemPromptTokenizingStrategy(
prompter, prompter,
self.tokenizer, tokenizer_huggyllama_w_special_tokens,
False, False,
2048, 2048,
) )
@@ -160,18 +134,13 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
assert example["input_ids"][8] == 11889 # USER assert example["input_ids"][8] == 11889 # USER
class Llama2ChatTokenizationTest(unittest.TestCase): class Llama2ChatTokenizationTest:
""" """
Test class for prompt tokenization strategies with sys prompt from the dataset Test class for prompt tokenization strategies with sys prompt from the dataset
""" """
@enable_hf_offline @enable_hf_offline
def setUp(self) -> None: def test_llama2_chat_integration(self, tokenizer_llama2_7b):
# pylint: disable=duplicate-code
self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
# woraround because official Meta repos are not open
def test_llama2_chat_integration(self):
with open( with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin: ) as fin:
@@ -186,16 +155,18 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
prompter = Llama2ChatPrompter() prompter = Llama2ChatPrompter()
strat = LLama2ChatTokenizingStrategy( strat = LLama2ChatTokenizingStrategy(
prompter, prompter,
self.tokenizer, tokenizer_llama2_7b,
False, False,
4096, 4096,
) )
example = strat.tokenize_prompt(conversation) example = strat.tokenize_prompt(conversation)
for fields in ["input_ids", "attention_mask", "labels"]: for fields in ["input_ids", "attention_mask", "labels"]:
self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) # pytest assert equals
self.assertEqual(example[fields], tokenized_conversation[fields])
def compare_with_transformers_integration(self): assert len(example[fields]) == len(tokenized_conversation[fields])
assert example[fields] == tokenized_conversation[fields]
def compare_with_transformers_integration(self, tokenizer_llama2_7b):
# this needs transformers >= v4.31.0 # this needs transformers >= v4.31.0
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
from transformers.pipelines.conversational import Conversation from transformers.pipelines.conversational import Conversation
@@ -234,49 +205,27 @@ If a question does not make any sense, or is not factually coherent, explain why
generated_responses=answers, generated_responses=answers,
) )
# pylint: disable=W0212 # pylint: disable=W0212
hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf) hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
self.assertEqual( assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
)
class OrpoTokenizationTest(unittest.TestCase): class OrpoTokenizationTest:
"""test case for the ORPO tokenization""" """test case for the ORPO tokenization"""
@enable_hf_offline @enable_hf_offline
def setUp(self) -> None: def test_orpo_integration(
# pylint: disable=duplicate-code self,
tokenizer = LlamaTokenizer.from_pretrained( tokenizer_mistral_7b_instruct_chatml,
"casperhansen/mistral-7b-instruct-v0.1-awq" dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
) ):
tokenizer.add_special_tokens( ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
{
"eos_token": AddedToken(
"<|im_end|>", rstrip=False, lstrip=False, normalized=False
)
}
)
tokenizer.add_tokens(
[
AddedToken(
"<|im_start|>", rstrip=False, lstrip=False, normalized=False
),
]
)
self.tokenizer = tokenizer
self.dataset = load_dataset(
"argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
).select([0])
@pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
def test_orpo_integration(self):
strat = load( strat = load(
self.tokenizer, tokenizer_mistral_7b_instruct_chatml,
DictDefault({"train_on_inputs": False}), DictDefault({"train_on_inputs": False}),
DictDefault({"chat_template": "chatml"}), DictDefault({"chat_template": "chatml"}),
) )
res = strat.tokenize_prompt(self.dataset[0]) res = strat.tokenize_prompt(ds[0])
assert "rejected_input_ids" in res assert "rejected_input_ids" in res
assert "rejected_labels" in res assert "rejected_labels" in res
assert "input_ids" in res assert "input_ids" in res
@@ -295,7 +244,3 @@ class OrpoTokenizationTest(unittest.TestCase):
assert res["prompt_attention_mask"][0] == 1 assert res["prompt_attention_mask"][0] == 1
assert res["prompt_attention_mask"][-1] == 0 assert res["prompt_attention_mask"][-1] == 0
if __name__ == "__main__":
unittest.main()