diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml index cf5c1d45d..01b898310 100644 --- a/.github/workflows/base.yml +++ b/.github/workflows/base.yml @@ -40,6 +40,12 @@ jobs: python_version: "3.11" pytorch: 2.6.0 torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" + - cuda: "126" + cuda_version: 12.6.3 + cudnn_version: "" + python_version: "3.11" + pytorch: 2.6.0 + torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX" - cuda: "128" cuda_version: 12.8.1 cudnn_version: "" diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 4a7112041..350b04cca 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -25,12 +25,12 @@ jobs: python_version: "3.11" pytorch: 2.5.1 axolotl_extras: vllm - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.6.0 axolotl_extras: + is_latest: true runs-on: axolotl-gpu-runner steps: - name: Checkout @@ -87,12 +87,12 @@ jobs: python_version: "3.11" pytorch: 2.5.1 axolotl_extras: - is_latest: true - cuda: 124 cuda_version: 12.4.1 python_version: "3.11" pytorch: 2.6.0 axolotl_extras: + is_latest: true runs-on: axolotl-gpu-runner steps: - name: Checkout diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml index 0b91d0c01..9d3b056e1 100644 --- a/.github/workflows/tests-nightly.yml +++ b/.github/workflows/tests-nightly.yml @@ -33,6 +33,15 @@ jobs: - name: Check out repository code uses: actions/checkout@v4 + - name: Restore HF cache + id: hf-cache-restore + uses: actions/cache/restore@v4 + with: + path: | + /home/runner/.cache/huggingface/hub/datasets--* + /home/runner/.cache/huggingface/hub/models--* + key: ${{ runner.os }}-hf-hub-cache-v2 + - name: Setup Python uses: actions/setup-python@v5 with: @@ -46,7 +55,7 @@ jobs: - name: Install PyTorch run: | - pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu + pip3 install torch==${{ matrix.pytorch_version }} - name: Update requirements.txt run: | @@ -58,8 +67,7 @@ jobs: - name: Install dependencies run: | - pip3 install --upgrade pip - pip3 install --upgrade packaging==23.2 + pip3 show torch pip3 install --no-build-isolation -U -e . python scripts/unsloth_install.py | sh python scripts/cutcrossentropy_install.py | sh @@ -73,10 +81,15 @@ jobs: run: | axolotl --help + - name: Pre-Download dataset fixture + run: | + huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures + - name: Run tests run: | - pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/ - pytest tests/patched/ + pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ + pytest -v tests/patched/ + pytest -v tests/cli/ - name: cleanup pip cache run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index a1a7214ec..434803d2c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -96,6 +96,10 @@ jobs: run: | axolotl --help + - name: Pre-Download dataset fixture + run: | + huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures + - name: Run tests run: | pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/ diff --git a/requirements.txt b/requirements.txt index 9aff0ccfe..567b446dd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,7 +12,7 @@ liger-kernel==0.5.5 packaging==23.2 peft==0.15.0 -transformers==4.50.0 +transformers==4.50.3 tokenizers>=0.21.1 accelerate==1.5.2 datasets==3.5.0 diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py index f5679431a..436d3a073 100644 --- a/src/axolotl/core/trainers/base.py +++ b/src/axolotl/core/trainers/base.py @@ -28,6 +28,7 @@ from typing_extensions import override from axolotl.core.trainers.mixins import ( OptimizerMixin, + RngLoaderMixin, SchedulerMixin, SequenceParallelMixin, ) @@ -40,7 +41,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths LOG = logging.getLogger(__name__) -class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer): +class AxolotlTrainer( + SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer +): """Extend the base Trainer for axolotl helpers""" args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined] diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py index 9eb870a3a..89c77dca4 100644 --- a/src/axolotl/core/trainers/dpo/trainer.py +++ b/src/axolotl/core/trainers/dpo/trainer.py @@ -13,7 +13,7 @@ from transformers import Trainer from transformers.utils import is_sagemaker_mp_enabled from trl import DPOTrainer -from axolotl.core.trainers.mixins import SchedulerMixin +from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin from axolotl.core.trainers.utils import ( sanitize_kwargs_for_ds_tagging, sanitize_kwargs_for_tagging, @@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp -class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer): +class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer): """ Extend the base DPOTrainer for axolotl helpers """ diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py index e8a142945..25aafa6a7 100644 --- a/src/axolotl/core/trainers/grpo/trainer.py +++ b/src/axolotl/core/trainers/grpo/trainer.py @@ -8,13 +8,13 @@ from accelerate.utils import is_deepspeed_available, is_peft_model from trl import GRPOTrainer from trl.extras.profiling import profiling_decorator -from axolotl.core.trainers.base import SchedulerMixin +from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin if is_deepspeed_available(): import deepspeed -class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer): +class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer): """ Extend the base GRPOTrainer for axolotl helpers """ diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py index 12c8277fc..44751b465 100644 --- a/src/axolotl/core/trainers/mixins/__init__.py +++ b/src/axolotl/core/trainers/mixins/__init__.py @@ -4,5 +4,6 @@ # flake8: noqa from .optimizer import OptimizerMixin +from .rng_state_loader import RngLoaderMixin from .scheduler import SchedulerMixin from .sequence_parallel import SequenceParallelMixin diff --git a/src/axolotl/core/trainers/mixins/rng_state_loader.py b/src/axolotl/core/trainers/mixins/rng_state_loader.py new file mode 100644 index 000000000..0e101dabb --- /dev/null +++ b/src/axolotl/core/trainers/mixins/rng_state_loader.py @@ -0,0 +1,67 @@ +""" +Temporary fix/override for bug in resume from checkpoint + +See https://github.com/huggingface/transformers/pull/37162 + +TODO: Remove when upstream added PR to release +""" + +import logging +import os +import random + +import numpy as np +import torch +from transformers import Trainer, is_torch_npu_available +from transformers.trainer import safe_globals +from transformers.trainer_pt_utils import set_rng_state_for_device +from transformers.training_args import ParallelMode + +LOG = logging.getLogger(__name__) + + +class RngLoaderMixin(Trainer): + """ + mixin for method override to load RNG states from a checkpoint + """ + + def _load_rng_state(self, checkpoint): + # Load RNG states from `checkpoint` + if checkpoint is None: + return + + if self.args.world_size > 1: + process_index = self.args.process_index + rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth") + if not os.path.isfile(rng_file): + LOG.info( + f"Didn't find an RNG file for process {process_index}, if you are resuming a training that " + "wasn't launched in a distributed fashion, reproducibility is not guaranteed." + ) + return + else: + rng_file = os.path.join(checkpoint, "rng_state.pth") + if not os.path.isfile(rng_file): + LOG.info( + "Didn't find an RNG file, if you are resuming a training that was launched in a distributed " + "fashion, reproducibility is not guaranteed." + ) + return + + # Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+, + # which requires allowlisted classes when loading with weights_only=True. + with safe_globals(): + checkpoint_rng_state = torch.load(rng_file) # nosec B614 + random.setstate(checkpoint_rng_state["python"]) + np.random.set_state(checkpoint_rng_state["numpy"]) + torch.random.set_rng_state(checkpoint_rng_state["cpu"]) + + is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED + if torch.cuda.is_available(): + set_rng_state_for_device( + "CUDA", torch.cuda, checkpoint_rng_state, is_distributed + ) + if is_torch_npu_available(): + set_rng_state_for_device( + "NPU", torch.npu, checkpoint_rng_state, is_distributed + ) diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py index ebe46f11d..b2c5c54ca 100644 --- a/src/axolotl/core/trainers/trl.py +++ b/src/axolotl/core/trainers/trl.py @@ -13,6 +13,7 @@ from trl import ( RewardTrainer, ) +from axolotl.core.trainers.mixins import RngLoaderMixin from axolotl.core.trainers.mixins.scheduler import SchedulerMixin @@ -74,7 +75,7 @@ class TRLPPOTrainer(PPOTrainer): ) -class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): +class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer): """ Extend the base ORPOTrainer for axolotl helpers """ @@ -154,7 +155,7 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer): return loss, metrics -class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): +class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer): """ Extend the base KTOTrainer for axolotl helpers """ @@ -162,7 +163,7 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer): tag_names = ["axolotl", "kto"] -class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): +class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer): """ Extend the base CPOTrainer for axolotl helpers """ @@ -244,7 +245,7 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer): return loss, metrics -class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): +class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer): """ Extend the base RewardTrainer for axolotl helpers """ @@ -252,7 +253,7 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer): tag_names = ["axolotl", "reward"] -class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer): +class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer): """ Extend the base trl.PRMTrainer for axolotl helpers """ diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py index 837d2ca69..3e072b6b4 100644 --- a/src/axolotl/utils/schemas/config.py +++ b/src/axolotl/utils/schemas/config.py @@ -1270,3 +1270,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig): if data["beta"] != data["trl"]["beta"]: raise ValueError("beta and trl.beta must match or one must be removed") return data + + @model_validator(mode="after") + def check_min_torch_version(self): + if self.env_capabilities and self.env_capabilities.torch_version: + torch_version = self.env_capabilities.torch_version + if version.parse(torch_version) < version.parse("2.5.1"): + LOG.warning( + f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1." + ) diff --git a/tests/conftest.py b/tests/conftest.py index b86b714af..4d05d3a26 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,11 +8,13 @@ import shutil import sys import tempfile import time +from pathlib import Path +import datasets import pytest import requests -from datasets import load_dataset from huggingface_hub import snapshot_download +from tokenizers import AddedToken from transformers import AutoTokenizer from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline @@ -48,6 +50,14 @@ def snapshot_download_w_retry(*args, **kwargs): return snapshot_download(*args, **kwargs) +@pytest.fixture(scope="session", autouse=True) +def download_ds_fixture_bundle(): + ds_dir = snapshot_download_w_retry( + "axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset" + ) + return Path(ds_dir) + + @pytest.fixture(scope="session", autouse=True) def download_smollm2_135m_model(): # download the model @@ -108,43 +118,43 @@ def download_argilla_distilabel_intel_orca_dpo_dataset(): ) -@pytest.fixture(scope="session", autouse=True) -def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset(): - # download the dataset - snapshot_download_w_retry( - "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset" - ) +# @pytest.fixture(scope="session", autouse=True) +# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset(): +# # download the dataset +# snapshot_download_w_retry( +# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset" +# ) -@pytest.fixture(scope="session", autouse=True) -def download_fozzie_alpaca_dpo_dataset(): - # download the dataset - snapshot_download_w_retry( - "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset" - ) - snapshot_download_w_retry( - "fozziethebeat/alpaca_messages_2k_dpo_test", - repo_type="dataset", - revision="ea82cff", - ) +# @pytest.fixture(scope="session", autouse=True) +# def download_fozzie_alpaca_dpo_dataset(): +# # download the dataset +# snapshot_download_w_retry( +# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset" +# ) +# snapshot_download_w_retry( +# "fozziethebeat/alpaca_messages_2k_dpo_test", +# repo_type="dataset", +# revision="ea82cff", +# ) -@pytest.fixture(scope="session") -@disable_hf_offline -def dataset_fozzie_alpaca_dpo_dataset( - download_fozzie_alpaca_dpo_dataset, -): # pylint: disable=unused-argument,redefined-outer-name - return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train") - - -@pytest.fixture(scope="session") -@disable_hf_offline -def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff( - download_fozzie_alpaca_dpo_dataset, -): # pylint: disable=unused-argument,redefined-outer-name - return load_dataset( - "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff" - ) +# @pytest.fixture(scope="session") +# @disable_hf_offline +# def dataset_fozzie_alpaca_dpo_dataset( +# download_fozzie_alpaca_dpo_dataset, +# ): # pylint: disable=unused-argument,redefined-outer-name +# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train") +# +# +# @pytest.fixture(scope="session") +# @disable_hf_offline +# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff( +# download_fozzie_alpaca_dpo_dataset, +# ): # pylint: disable=unused-argument,redefined-outer-name +# return load_dataset( +# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff" +# ) @pytest.fixture(scope="session", autouse=True) @@ -271,7 +281,7 @@ def download_mlx_mistral_7b_model_fixture(): ) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture def download_llama2_model_fixture(): # download the tokenizer only snapshot_download_w_retry( @@ -281,7 +291,7 @@ def download_llama2_model_fixture(): ) -@pytest.fixture(scope="session", autouse=True) +@pytest.fixture @enable_hf_offline def tokenizer_huggyllama( download_huggyllama_model_fixture, @@ -292,6 +302,57 @@ def tokenizer_huggyllama( return tokenizer +@pytest.fixture +@enable_hf_offline +def tokenizer_huggyllama_w_special_tokens( + tokenizer_huggyllama, +): # pylint: disable=redefined-outer-name + tokenizer_huggyllama.add_special_tokens( + { + "bos_token": "", + "eos_token": "", + "unk_token": "", + } + ) + + return tokenizer_huggyllama + + +@pytest.fixture +@enable_hf_offline +def tokenizer_llama2_7b( + download_llama2_model_fixture, +): # pylint: disable=unused-argument,redefined-outer-name + tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") + + return tokenizer + + +@pytest.fixture +@enable_hf_offline +def tokenizer_mistral_7b_instruct( + download_mlx_mistral_7b_model_fixture, +): # pylint: disable=unused-argument,redefined-outer-name + return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq") + + +@pytest.fixture +def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct): + tokenizer_mistral_7b_instruct.add_special_tokens( + { + "eos_token": AddedToken( + "<|im_end|>", rstrip=False, lstrip=False, normalized=False + ) + } + ) + tokenizer_mistral_7b_instruct.add_tokens( + [ + AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False), + ] + ) + return tokenizer_mistral_7b_instruct + + @pytest.fixture def temp_dir(): # Create a temporary directory @@ -357,6 +418,60 @@ def cleanup_monkeypatches(): globals().pop(module_global, None) +@pytest.fixture +def dataset_winglian_tiny_shakespeare( + download_ds_fixture_bundle: Path, +): # pylint: disable=redefined-outer-name + ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare" + return datasets.load_from_disk(ds_path) + + +@pytest.fixture +def dataset_tatsu_lab_alpaca( + download_ds_fixture_bundle: Path, +): # pylint: disable=redefined-outer-name + ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca" + return datasets.load_from_disk(ds_path)["train"] + + +@pytest.fixture +def dataset_mhenrichsen_alpaca_2k_test( + download_ds_fixture_bundle: Path, +): # pylint: disable=redefined-outer-name + ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test" + return datasets.load_from_disk(ds_path)["train"] + + +@pytest.fixture +def dataset_argilla_ultrafeedback_binarized_preferences_cleaned( + download_ds_fixture_bundle: Path, +): # pylint: disable=redefined-outer-name + ds_path = ( + download_ds_fixture_bundle + / "argilla__ultrafeedback-binarized-preferences-cleaned" + ) + return datasets.load_from_disk(ds_path)["train"] + + +@pytest.fixture +def dataset_fozziethebeat_alpaca_messages_2k_dpo_test( + download_ds_fixture_bundle: Path, +): # pylint: disable=redefined-outer-name + ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test" + return datasets.load_from_disk(ds_path)["train"] + + +@pytest.fixture +def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff( + download_ds_fixture_bundle: Path, +): # pylint: disable=redefined-outer-name + ds_path = ( + download_ds_fixture_bundle + / "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff" + ) + return datasets.load_from_disk(ds_path)["train"] + + # # pylint: disable=redefined-outer-name,unused-argument # def test_load_fixtures( # download_smollm2_135m_model, diff --git a/tests/test_datasets.py b/tests/test_datasets.py index a82f2f381..ded82869f 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -324,7 +324,7 @@ class TestDatasetPreparation: @enable_hf_offline def test_load_hub_with_revision_with_dpo( - self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff + self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff ): """Verify that processing dpo data from the hub works with a specific revision""" @@ -339,12 +339,10 @@ class TestDatasetPreparation: ) # pylint: disable=duplicate-code - with patch( - "axolotl.utils.data.shared.load_dataset_w_config" - ) as mock_load_dataset: + with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset: # Set up the mock to return different values on successive calls mock_load_dataset.return_value = ( - dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff + dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff ) train_dataset, _ = load_prepare_preference_datasets(cfg) @@ -354,7 +352,9 @@ class TestDatasetPreparation: @enable_hf_offline @pytest.mark.skip("datasets bug with local datasets when offline") - def test_load_local_hub_with_revision(self, tokenizer): + def test_load_local_hub_with_revision( + self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer + ): """Verify that a local copy of a hub dataset can be loaded with a specific revision""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test" @@ -386,13 +386,23 @@ class TestDatasetPreparation: } ) - dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path) + with patch( + "axolotl.utils.data.shared.load_dataset_w_config" + ) as mock_load_dataset: + # Set up the mock to return different values on successive calls + mock_load_dataset.return_value = ( + dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff + ) - assert len(dataset) == 2000 - assert "input_ids" in dataset.features - assert "attention_mask" in dataset.features - assert "labels" in dataset.features - shutil.rmtree(tmp_ds_path) + dataset, _ = load_tokenized_prepared_datasets( + tokenizer, cfg, prepared_path + ) + + assert len(dataset) == 2000 + assert "input_ids" in dataset.features + assert "attention_mask" in dataset.features + assert "labels" in dataset.features + shutil.rmtree(tmp_ds_path) @enable_hf_offline def test_loading_local_dataset_folder(self, tokenizer): diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py index 7430352a4..a75f97f78 100644 --- a/tests/test_exact_deduplication.py +++ b/tests/test_exact_deduplication.py @@ -238,21 +238,22 @@ class TestDeduplicateRLDataset: @enable_hf_offline def test_load_with_deduplication( - self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama + self, + cfg, + dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, + tokenizer_huggyllama, ): """Verify that loading with deduplication removes duplicates.""" # pylint: disable=duplicate-code with ( - patch( - "axolotl.utils.data.shared.load_dataset_w_config" - ) as mock_load_dataset, + patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls mock_load_dataset.side_effect = [ - dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, - dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, + dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, + dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, ] mock_load_tokenizer.return_value = tokenizer_huggyllama @@ -263,19 +264,20 @@ class TestDeduplicateRLDataset: @enable_hf_offline def test_load_without_deduplication( - self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama + self, + cfg, + dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, + tokenizer_huggyllama, ): # pylint: disable=duplicate-code with ( - patch( - "axolotl.utils.data.shared.load_dataset_w_config" - ) as mock_load_dataset, + patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset, patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer, ): # Set up the mock to return different values on successive calls mock_load_dataset.side_effect = [ - dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, - dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, + dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, + dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, ] mock_load_tokenizer.return_value = tokenizer_huggyllama diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py index faba86931..dd0386e58 100644 --- a/tests/test_packed_batch_sampler.py +++ b/tests/test_packed_batch_sampler.py @@ -1,7 +1,7 @@ """Module for testing streaming dataset sequence packing""" import pytest -from datasets import concatenate_datasets, load_dataset +from datasets import concatenate_datasets from torch.utils.data import DataLoader, RandomSampler from transformers import AutoTokenizer @@ -27,7 +27,6 @@ class TestBatchedSamplerPacking: Test class for packing streaming dataset sequences """ - @pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits") @pytest.mark.parametrize( "batch_size, num_workers", [ @@ -41,14 +40,17 @@ class TestBatchedSamplerPacking: @pytest.mark.parametrize("sequential", [True, False]) @enable_hf_offline def test_packing( - self, batch_size, num_workers, tokenizer, max_seq_length, sequential + self, + dataset_winglian_tiny_shakespeare, + batch_size, + num_workers, + tokenizer, + max_seq_length, + sequential, ): import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401 - dataset = load_dataset( - "winglian/tiny-shakespeare", - split="train", - ) + dataset = dataset_winglian_tiny_shakespeare["train"] cfg = DictDefault( { @@ -58,7 +60,7 @@ class TestBatchedSamplerPacking: ) ds_cfg = DictDefault( { - "field": "Text", + "field": "text", } ) completion_strategy = load(tokenizer, cfg, ds_cfg) diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py index 65eee7ddb..3f16bc917 100644 --- a/tests/test_prompt_tokenizers.py +++ b/tests/test_prompt_tokenizers.py @@ -2,13 +2,8 @@ import json import logging -import unittest from pathlib import Path -import pytest -from datasets import load_dataset -from transformers import AddedToken, AutoTokenizer, LlamaTokenizer - from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter from axolotl.prompt_strategies.alpaca_w_system import ( InstructionWSystemPromptTokenizingStrategy, @@ -61,24 +56,13 @@ test_data = { } -class TestPromptTokenizationStrategies(unittest.TestCase): +class TestPromptTokenizationStrategies: """ Test class for prompt tokenization strategies. """ @enable_hf_offline - def setUp(self) -> None: - # pylint: disable=duplicate-code - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - } - ) - - def test_no_sys_prompt(self): + def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens): """ tests the interface between the user and assistant parts """ @@ -86,7 +70,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase): # pylint: disable=duplicate-code strat = AlpacaPromptTokenizingStrategy( prompter, - self.tokenizer, + tokenizer_huggyllama_w_special_tokens, False, 2048, ) @@ -99,7 +83,8 @@ class TestPromptTokenizationStrategies(unittest.TestCase): assert example["labels"][world_idx] == 3186 assert example["labels"][world_idx - 1] == -100 - def test_alpaca(self): + @enable_hf_offline + def test_alpaca(self, tokenizer_huggyllama_w_special_tokens): """ tests the interface between the user and assistant parts """ @@ -107,7 +92,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase): prompter = AlpacaPrompter() strat = AlpacaPromptTokenizingStrategy( prompter, - self.tokenizer, + tokenizer_huggyllama_w_special_tokens, False, 2048, ) @@ -118,28 +103,17 @@ class TestPromptTokenizationStrategies(unittest.TestCase): assert example["labels"][world_idx - 1] == -100 -class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): +class TestInstructionWSystemPromptTokenizingStrategy: """ Test class for prompt tokenization strategies with sys prompt from the dataset """ @enable_hf_offline - def setUp(self) -> None: - # pylint: disable=duplicate-code - self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") - self.tokenizer.add_special_tokens( - { - "bos_token": "", - "eos_token": "", - "unk_token": "", - } - ) - - def test_system_alpaca(self): + def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens): prompter = SystemDataPrompter(PromptStyle.CHAT.value) strat = InstructionWSystemPromptTokenizingStrategy( prompter, - self.tokenizer, + tokenizer_huggyllama_w_special_tokens, False, 2048, ) @@ -160,18 +134,13 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase): assert example["input_ids"][8] == 11889 # USER -class Llama2ChatTokenizationTest(unittest.TestCase): +class Llama2ChatTokenizationTest: """ Test class for prompt tokenization strategies with sys prompt from the dataset """ @enable_hf_offline - def setUp(self) -> None: - # pylint: disable=duplicate-code - self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf") - # woraround because official Meta repos are not open - - def test_llama2_chat_integration(self): + def test_llama2_chat_integration(self, tokenizer_llama2_7b): with open( Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8" ) as fin: @@ -186,16 +155,18 @@ class Llama2ChatTokenizationTest(unittest.TestCase): prompter = Llama2ChatPrompter() strat = LLama2ChatTokenizingStrategy( prompter, - self.tokenizer, + tokenizer_llama2_7b, False, 4096, ) example = strat.tokenize_prompt(conversation) for fields in ["input_ids", "attention_mask", "labels"]: - self.assertEqual(len(example[fields]), len(tokenized_conversation[fields])) - self.assertEqual(example[fields], tokenized_conversation[fields]) + # pytest assert equals - def compare_with_transformers_integration(self): + assert len(example[fields]) == len(tokenized_conversation[fields]) + assert example[fields] == tokenized_conversation[fields] + + def compare_with_transformers_integration(self, tokenizer_llama2_7b): # this needs transformers >= v4.31.0 from transformers.models.llama.tokenization_llama import B_SYS, E_SYS from transformers.pipelines.conversational import Conversation @@ -234,49 +205,27 @@ If a question does not make any sense, or is not factually coherent, explain why generated_responses=answers, ) # pylint: disable=W0212 - hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf) + hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf) - self.assertEqual( - hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)] - ) + assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)] -class OrpoTokenizationTest(unittest.TestCase): +class OrpoTokenizationTest: """test case for the ORPO tokenization""" @enable_hf_offline - def setUp(self) -> None: - # pylint: disable=duplicate-code - tokenizer = LlamaTokenizer.from_pretrained( - "casperhansen/mistral-7b-instruct-v0.1-awq" - ) - tokenizer.add_special_tokens( - { - "eos_token": AddedToken( - "<|im_end|>", rstrip=False, lstrip=False, normalized=False - ) - } - ) - tokenizer.add_tokens( - [ - AddedToken( - "<|im_start|>", rstrip=False, lstrip=False, normalized=False - ), - ] - ) - self.tokenizer = tokenizer - self.dataset = load_dataset( - "argilla/ultrafeedback-binarized-preferences-cleaned", split="train" - ).select([0]) - - @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits") - def test_orpo_integration(self): + def test_orpo_integration( + self, + tokenizer_mistral_7b_instruct_chatml, + dataset_argilla_ultrafeedback_binarized_preferences_cleaned, + ): + ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0]) strat = load( - self.tokenizer, + tokenizer_mistral_7b_instruct_chatml, DictDefault({"train_on_inputs": False}), DictDefault({"chat_template": "chatml"}), ) - res = strat.tokenize_prompt(self.dataset[0]) + res = strat.tokenize_prompt(ds[0]) assert "rejected_input_ids" in res assert "rejected_labels" in res assert "input_ids" in res @@ -295,7 +244,3 @@ class OrpoTokenizationTest(unittest.TestCase): assert res["prompt_attention_mask"][0] == 1 assert res["prompt_attention_mask"][-1] == 0 - - -if __name__ == "__main__": - unittest.main()