diff --git a/.github/workflows/base.yml b/.github/workflows/base.yml
index cf5c1d45d..01b898310 100644
--- a/.github/workflows/base.yml
+++ b/.github/workflows/base.yml
@@ -40,6 +40,12 @@ jobs:
python_version: "3.11"
pytorch: 2.6.0
torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
+ - cuda: "126"
+ cuda_version: 12.6.3
+ cudnn_version: ""
+ python_version: "3.11"
+ pytorch: 2.6.0
+ torch_cuda_arch_list: "7.0 7.5 8.0 8.6 8.7 8.9 9.0+PTX"
- cuda: "128"
cuda_version: 12.8.1
cudnn_version: ""
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 4a7112041..350b04cca 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -25,12 +25,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras: vllm
- is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
+ is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
@@ -87,12 +87,12 @@ jobs:
python_version: "3.11"
pytorch: 2.5.1
axolotl_extras:
- is_latest: true
- cuda: 124
cuda_version: 12.4.1
python_version: "3.11"
pytorch: 2.6.0
axolotl_extras:
+ is_latest: true
runs-on: axolotl-gpu-runner
steps:
- name: Checkout
diff --git a/.github/workflows/tests-nightly.yml b/.github/workflows/tests-nightly.yml
index 0b91d0c01..9d3b056e1 100644
--- a/.github/workflows/tests-nightly.yml
+++ b/.github/workflows/tests-nightly.yml
@@ -33,6 +33,15 @@ jobs:
- name: Check out repository code
uses: actions/checkout@v4
+ - name: Restore HF cache
+ id: hf-cache-restore
+ uses: actions/cache/restore@v4
+ with:
+ path: |
+ /home/runner/.cache/huggingface/hub/datasets--*
+ /home/runner/.cache/huggingface/hub/models--*
+ key: ${{ runner.os }}-hf-hub-cache-v2
+
- name: Setup Python
uses: actions/setup-python@v5
with:
@@ -46,7 +55,7 @@ jobs:
- name: Install PyTorch
run: |
- pip3 install torch==${{ matrix.pytorch_version }} --index-url https://download.pytorch.org/whl/cpu
+ pip3 install torch==${{ matrix.pytorch_version }}
- name: Update requirements.txt
run: |
@@ -58,8 +67,7 @@ jobs:
- name: Install dependencies
run: |
- pip3 install --upgrade pip
- pip3 install --upgrade packaging==23.2
+ pip3 show torch
pip3 install --no-build-isolation -U -e .
python scripts/unsloth_install.py | sh
python scripts/cutcrossentropy_install.py | sh
@@ -73,10 +81,15 @@ jobs:
run: |
axolotl --help
+ - name: Pre-Download dataset fixture
+ run: |
+ huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
+
- name: Run tests
run: |
- pytest -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ tests/
- pytest tests/patched/
+ pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
+ pytest -v tests/patched/
+ pytest -v tests/cli/
- name: cleanup pip cache
run: |
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index a1a7214ec..434803d2c 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -96,6 +96,10 @@ jobs:
run: |
axolotl --help
+ - name: Pre-Download dataset fixture
+ run: |
+ huggingface-cli download --repo-type=dataset axolotl-ai-internal/axolotl-oss-dataset-fixtures
+
- name: Run tests
run: |
pytest -v -n8 --dist loadfile --ignore=tests/e2e/ --ignore=tests/patched/ --ignore=tests/cli/ tests/
diff --git a/requirements.txt b/requirements.txt
index 9aff0ccfe..567b446dd 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -12,7 +12,7 @@ liger-kernel==0.5.5
packaging==23.2
peft==0.15.0
-transformers==4.50.0
+transformers==4.50.3
tokenizers>=0.21.1
accelerate==1.5.2
datasets==3.5.0
diff --git a/src/axolotl/core/trainers/base.py b/src/axolotl/core/trainers/base.py
index f5679431a..436d3a073 100644
--- a/src/axolotl/core/trainers/base.py
+++ b/src/axolotl/core/trainers/base.py
@@ -28,6 +28,7 @@ from typing_extensions import override
from axolotl.core.trainers.mixins import (
OptimizerMixin,
+ RngLoaderMixin,
SchedulerMixin,
SequenceParallelMixin,
)
@@ -40,7 +41,9 @@ from axolotl.utils.samplers import MultipackBatchSampler, get_dataset_lengths
LOG = logging.getLogger(__name__)
-class AxolotlTrainer(SchedulerMixin, OptimizerMixin, SequenceParallelMixin, Trainer):
+class AxolotlTrainer(
+ SchedulerMixin, OptimizerMixin, RngLoaderMixin, SequenceParallelMixin, Trainer
+):
"""Extend the base Trainer for axolotl helpers"""
args = None # type: "AxolotlTrainingArguments" # type: ignore[name-defined]
diff --git a/src/axolotl/core/trainers/dpo/trainer.py b/src/axolotl/core/trainers/dpo/trainer.py
index 9eb870a3a..89c77dca4 100644
--- a/src/axolotl/core/trainers/dpo/trainer.py
+++ b/src/axolotl/core/trainers/dpo/trainer.py
@@ -13,7 +13,7 @@ from transformers import Trainer
from transformers.utils import is_sagemaker_mp_enabled
from trl import DPOTrainer
-from axolotl.core.trainers.mixins import SchedulerMixin
+from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
from axolotl.core.trainers.utils import (
sanitize_kwargs_for_ds_tagging,
sanitize_kwargs_for_tagging,
@@ -23,7 +23,7 @@ if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
-class AxolotlDPOTrainer(SchedulerMixin, DPOTrainer):
+class AxolotlDPOTrainer(RngLoaderMixin, SchedulerMixin, DPOTrainer):
"""
Extend the base DPOTrainer for axolotl helpers
"""
diff --git a/src/axolotl/core/trainers/grpo/trainer.py b/src/axolotl/core/trainers/grpo/trainer.py
index e8a142945..25aafa6a7 100644
--- a/src/axolotl/core/trainers/grpo/trainer.py
+++ b/src/axolotl/core/trainers/grpo/trainer.py
@@ -8,13 +8,13 @@ from accelerate.utils import is_deepspeed_available, is_peft_model
from trl import GRPOTrainer
from trl.extras.profiling import profiling_decorator
-from axolotl.core.trainers.base import SchedulerMixin
+from axolotl.core.trainers.mixins import RngLoaderMixin, SchedulerMixin
if is_deepspeed_available():
import deepspeed
-class AxolotlGRPOTrainer(SchedulerMixin, GRPOTrainer):
+class AxolotlGRPOTrainer(RngLoaderMixin, SchedulerMixin, GRPOTrainer):
"""
Extend the base GRPOTrainer for axolotl helpers
"""
diff --git a/src/axolotl/core/trainers/mixins/__init__.py b/src/axolotl/core/trainers/mixins/__init__.py
index 12c8277fc..44751b465 100644
--- a/src/axolotl/core/trainers/mixins/__init__.py
+++ b/src/axolotl/core/trainers/mixins/__init__.py
@@ -4,5 +4,6 @@
# flake8: noqa
from .optimizer import OptimizerMixin
+from .rng_state_loader import RngLoaderMixin
from .scheduler import SchedulerMixin
from .sequence_parallel import SequenceParallelMixin
diff --git a/src/axolotl/core/trainers/mixins/rng_state_loader.py b/src/axolotl/core/trainers/mixins/rng_state_loader.py
new file mode 100644
index 000000000..0e101dabb
--- /dev/null
+++ b/src/axolotl/core/trainers/mixins/rng_state_loader.py
@@ -0,0 +1,67 @@
+"""
+Temporary fix/override for bug in resume from checkpoint
+
+See https://github.com/huggingface/transformers/pull/37162
+
+TODO: Remove when upstream added PR to release
+"""
+
+import logging
+import os
+import random
+
+import numpy as np
+import torch
+from transformers import Trainer, is_torch_npu_available
+from transformers.trainer import safe_globals
+from transformers.trainer_pt_utils import set_rng_state_for_device
+from transformers.training_args import ParallelMode
+
+LOG = logging.getLogger(__name__)
+
+
+class RngLoaderMixin(Trainer):
+ """
+ mixin for method override to load RNG states from a checkpoint
+ """
+
+ def _load_rng_state(self, checkpoint):
+ # Load RNG states from `checkpoint`
+ if checkpoint is None:
+ return
+
+ if self.args.world_size > 1:
+ process_index = self.args.process_index
+ rng_file = os.path.join(checkpoint, f"rng_state_{process_index}.pth")
+ if not os.path.isfile(rng_file):
+ LOG.info(
+ f"Didn't find an RNG file for process {process_index}, if you are resuming a training that "
+ "wasn't launched in a distributed fashion, reproducibility is not guaranteed."
+ )
+ return
+ else:
+ rng_file = os.path.join(checkpoint, "rng_state.pth")
+ if not os.path.isfile(rng_file):
+ LOG.info(
+ "Didn't find an RNG file, if you are resuming a training that was launched in a distributed "
+ "fashion, reproducibility is not guaranteed."
+ )
+ return
+
+ # Use safe_globals to ensure numpy RNG states can be deserialized safely under PyTorch 2.6+,
+ # which requires allowlisted classes when loading with weights_only=True.
+ with safe_globals():
+ checkpoint_rng_state = torch.load(rng_file) # nosec B614
+ random.setstate(checkpoint_rng_state["python"])
+ np.random.set_state(checkpoint_rng_state["numpy"])
+ torch.random.set_rng_state(checkpoint_rng_state["cpu"])
+
+ is_distributed = self.args.parallel_mode == ParallelMode.DISTRIBUTED
+ if torch.cuda.is_available():
+ set_rng_state_for_device(
+ "CUDA", torch.cuda, checkpoint_rng_state, is_distributed
+ )
+ if is_torch_npu_available():
+ set_rng_state_for_device(
+ "NPU", torch.npu, checkpoint_rng_state, is_distributed
+ )
diff --git a/src/axolotl/core/trainers/trl.py b/src/axolotl/core/trainers/trl.py
index ebe46f11d..b2c5c54ca 100644
--- a/src/axolotl/core/trainers/trl.py
+++ b/src/axolotl/core/trainers/trl.py
@@ -13,6 +13,7 @@ from trl import (
RewardTrainer,
)
+from axolotl.core.trainers.mixins import RngLoaderMixin
from axolotl.core.trainers.mixins.scheduler import SchedulerMixin
@@ -74,7 +75,7 @@ class TRLPPOTrainer(PPOTrainer):
)
-class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
+class AxolotlORPOTrainer(RngLoaderMixin, SchedulerMixin, ORPOTrainer):
"""
Extend the base ORPOTrainer for axolotl helpers
"""
@@ -154,7 +155,7 @@ class AxolotlORPOTrainer(SchedulerMixin, ORPOTrainer):
return loss, metrics
-class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
+class AxolotlKTOTrainer(RngLoaderMixin, SchedulerMixin, KTOTrainer):
"""
Extend the base KTOTrainer for axolotl helpers
"""
@@ -162,7 +163,7 @@ class AxolotlKTOTrainer(SchedulerMixin, KTOTrainer):
tag_names = ["axolotl", "kto"]
-class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
+class AxolotlCPOTrainer(RngLoaderMixin, SchedulerMixin, CPOTrainer):
"""
Extend the base CPOTrainer for axolotl helpers
"""
@@ -244,7 +245,7 @@ class AxolotlCPOTrainer(SchedulerMixin, CPOTrainer):
return loss, metrics
-class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
+class AxolotlRewardTrainer(RngLoaderMixin, SchedulerMixin, RewardTrainer):
"""
Extend the base RewardTrainer for axolotl helpers
"""
@@ -252,7 +253,7 @@ class AxolotlRewardTrainer(SchedulerMixin, RewardTrainer):
tag_names = ["axolotl", "reward"]
-class AxolotlPRMTrainer(SchedulerMixin, PRMTrainer):
+class AxolotlPRMTrainer(RngLoaderMixin, SchedulerMixin, PRMTrainer):
"""
Extend the base trl.PRMTrainer for axolotl helpers
"""
diff --git a/src/axolotl/utils/schemas/config.py b/src/axolotl/utils/schemas/config.py
index 837d2ca69..3e072b6b4 100644
--- a/src/axolotl/utils/schemas/config.py
+++ b/src/axolotl/utils/schemas/config.py
@@ -1270,3 +1270,12 @@ class AxolotlConfigWCapabilities(AxolotlInputConfig):
if data["beta"] != data["trl"]["beta"]:
raise ValueError("beta and trl.beta must match or one must be removed")
return data
+
+ @model_validator(mode="after")
+ def check_min_torch_version(self):
+ if self.env_capabilities and self.env_capabilities.torch_version:
+ torch_version = self.env_capabilities.torch_version
+ if version.parse(torch_version) < version.parse("2.5.1"):
+ LOG.warning(
+ f"torch=={torch_version} may not be supported in future versions. Please consider upgrading to torch>=2.5.1."
+ )
diff --git a/tests/conftest.py b/tests/conftest.py
index b86b714af..4d05d3a26 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -8,11 +8,13 @@ import shutil
import sys
import tempfile
import time
+from pathlib import Path
+import datasets
import pytest
import requests
-from datasets import load_dataset
from huggingface_hub import snapshot_download
+from tokenizers import AddedToken
from transformers import AutoTokenizer
from tests.hf_offline_utils import disable_hf_offline, enable_hf_offline
@@ -48,6 +50,14 @@ def snapshot_download_w_retry(*args, **kwargs):
return snapshot_download(*args, **kwargs)
+@pytest.fixture(scope="session", autouse=True)
+def download_ds_fixture_bundle():
+ ds_dir = snapshot_download_w_retry(
+ "axolotl-ai-internal/axolotl-oss-dataset-fixtures", repo_type="dataset"
+ )
+ return Path(ds_dir)
+
+
@pytest.fixture(scope="session", autouse=True)
def download_smollm2_135m_model():
# download the model
@@ -108,43 +118,43 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
)
-@pytest.fixture(scope="session", autouse=True)
-def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
- # download the dataset
- snapshot_download_w_retry(
- "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
- )
+# @pytest.fixture(scope="session", autouse=True)
+# def download_argilla_ultrafeedback_binarized_preferences_cleaned_dataset():
+# # download the dataset
+# snapshot_download_w_retry(
+# "argilla/ultrafeedback-binarized-preferences-cleaned", repo_type="dataset"
+# )
-@pytest.fixture(scope="session", autouse=True)
-def download_fozzie_alpaca_dpo_dataset():
- # download the dataset
- snapshot_download_w_retry(
- "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
- )
- snapshot_download_w_retry(
- "fozziethebeat/alpaca_messages_2k_dpo_test",
- repo_type="dataset",
- revision="ea82cff",
- )
+# @pytest.fixture(scope="session", autouse=True)
+# def download_fozzie_alpaca_dpo_dataset():
+# # download the dataset
+# snapshot_download_w_retry(
+# "fozziethebeat/alpaca_messages_2k_dpo_test", repo_type="dataset"
+# )
+# snapshot_download_w_retry(
+# "fozziethebeat/alpaca_messages_2k_dpo_test",
+# repo_type="dataset",
+# revision="ea82cff",
+# )
-@pytest.fixture(scope="session")
-@disable_hf_offline
-def dataset_fozzie_alpaca_dpo_dataset(
- download_fozzie_alpaca_dpo_dataset,
-): # pylint: disable=unused-argument,redefined-outer-name
- return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
-
-
-@pytest.fixture(scope="session")
-@disable_hf_offline
-def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
- download_fozzie_alpaca_dpo_dataset,
-): # pylint: disable=unused-argument,redefined-outer-name
- return load_dataset(
- "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
- )
+# @pytest.fixture(scope="session")
+# @disable_hf_offline
+# def dataset_fozzie_alpaca_dpo_dataset(
+# download_fozzie_alpaca_dpo_dataset,
+# ): # pylint: disable=unused-argument,redefined-outer-name
+# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
+#
+#
+# @pytest.fixture(scope="session")
+# @disable_hf_offline
+# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
+# download_fozzie_alpaca_dpo_dataset,
+# ): # pylint: disable=unused-argument,redefined-outer-name
+# return load_dataset(
+# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
+# )
@pytest.fixture(scope="session", autouse=True)
@@ -271,7 +281,7 @@ def download_mlx_mistral_7b_model_fixture():
)
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture
def download_llama2_model_fixture():
# download the tokenizer only
snapshot_download_w_retry(
@@ -281,7 +291,7 @@ def download_llama2_model_fixture():
)
-@pytest.fixture(scope="session", autouse=True)
+@pytest.fixture
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
@@ -292,6 +302,57 @@ def tokenizer_huggyllama(
return tokenizer
+@pytest.fixture
+@enable_hf_offline
+def tokenizer_huggyllama_w_special_tokens(
+ tokenizer_huggyllama,
+): # pylint: disable=redefined-outer-name
+ tokenizer_huggyllama.add_special_tokens(
+ {
+ "bos_token": "",
+ "eos_token": "",
+ "unk_token": "",
+ }
+ )
+
+ return tokenizer_huggyllama
+
+
+@pytest.fixture
+@enable_hf_offline
+def tokenizer_llama2_7b(
+ download_llama2_model_fixture,
+): # pylint: disable=unused-argument,redefined-outer-name
+ tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
+
+ return tokenizer
+
+
+@pytest.fixture
+@enable_hf_offline
+def tokenizer_mistral_7b_instruct(
+ download_mlx_mistral_7b_model_fixture,
+): # pylint: disable=unused-argument,redefined-outer-name
+ return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
+
+
+@pytest.fixture
+def tokenizer_mistral_7b_instruct_chatml(tokenizer_mistral_7b_instruct):
+ tokenizer_mistral_7b_instruct.add_special_tokens(
+ {
+ "eos_token": AddedToken(
+ "<|im_end|>", rstrip=False, lstrip=False, normalized=False
+ )
+ }
+ )
+ tokenizer_mistral_7b_instruct.add_tokens(
+ [
+ AddedToken("<|im_start|>", rstrip=False, lstrip=False, normalized=False),
+ ]
+ )
+ return tokenizer_mistral_7b_instruct
+
+
@pytest.fixture
def temp_dir():
# Create a temporary directory
@@ -357,6 +418,60 @@ def cleanup_monkeypatches():
globals().pop(module_global, None)
+@pytest.fixture
+def dataset_winglian_tiny_shakespeare(
+ download_ds_fixture_bundle: Path,
+): # pylint: disable=redefined-outer-name
+ ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
+ return datasets.load_from_disk(ds_path)
+
+
+@pytest.fixture
+def dataset_tatsu_lab_alpaca(
+ download_ds_fixture_bundle: Path,
+): # pylint: disable=redefined-outer-name
+ ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
+ return datasets.load_from_disk(ds_path)["train"]
+
+
+@pytest.fixture
+def dataset_mhenrichsen_alpaca_2k_test(
+ download_ds_fixture_bundle: Path,
+): # pylint: disable=redefined-outer-name
+ ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
+ return datasets.load_from_disk(ds_path)["train"]
+
+
+@pytest.fixture
+def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
+ download_ds_fixture_bundle: Path,
+): # pylint: disable=redefined-outer-name
+ ds_path = (
+ download_ds_fixture_bundle
+ / "argilla__ultrafeedback-binarized-preferences-cleaned"
+ )
+ return datasets.load_from_disk(ds_path)["train"]
+
+
+@pytest.fixture
+def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
+ download_ds_fixture_bundle: Path,
+): # pylint: disable=redefined-outer-name
+ ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
+ return datasets.load_from_disk(ds_path)["train"]
+
+
+@pytest.fixture
+def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
+ download_ds_fixture_bundle: Path,
+): # pylint: disable=redefined-outer-name
+ ds_path = (
+ download_ds_fixture_bundle
+ / "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
+ )
+ return datasets.load_from_disk(ds_path)["train"]
+
+
# # pylint: disable=redefined-outer-name,unused-argument
# def test_load_fixtures(
# download_smollm2_135m_model,
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index a82f2f381..ded82869f 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -324,7 +324,7 @@ class TestDatasetPreparation:
@enable_hf_offline
def test_load_hub_with_revision_with_dpo(
- self, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
+ self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
):
"""Verify that processing dpo data from the hub works with a specific revision"""
@@ -339,12 +339,10 @@ class TestDatasetPreparation:
)
# pylint: disable=duplicate-code
- with patch(
- "axolotl.utils.data.shared.load_dataset_w_config"
- ) as mock_load_dataset:
+ with patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset:
# Set up the mock to return different values on successive calls
mock_load_dataset.return_value = (
- dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff
+ dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
)
train_dataset, _ = load_prepare_preference_datasets(cfg)
@@ -354,7 +352,9 @@ class TestDatasetPreparation:
@enable_hf_offline
@pytest.mark.skip("datasets bug with local datasets when offline")
- def test_load_local_hub_with_revision(self, tokenizer):
+ def test_load_local_hub_with_revision(
+ self, dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff, tokenizer
+ ):
"""Verify that a local copy of a hub dataset can be loaded with a specific revision"""
with tempfile.TemporaryDirectory() as tmp_dir:
tmp_ds_path = Path(tmp_dir) / "mhenrichsen/alpaca_2k_test"
@@ -386,13 +386,23 @@ class TestDatasetPreparation:
}
)
- dataset, _ = load_tokenized_prepared_datasets(tokenizer, cfg, prepared_path)
+ with patch(
+ "axolotl.utils.data.shared.load_dataset_w_config"
+ ) as mock_load_dataset:
+ # Set up the mock to return different values on successive calls
+ mock_load_dataset.return_value = (
+ dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff
+ )
- assert len(dataset) == 2000
- assert "input_ids" in dataset.features
- assert "attention_mask" in dataset.features
- assert "labels" in dataset.features
- shutil.rmtree(tmp_ds_path)
+ dataset, _ = load_tokenized_prepared_datasets(
+ tokenizer, cfg, prepared_path
+ )
+
+ assert len(dataset) == 2000
+ assert "input_ids" in dataset.features
+ assert "attention_mask" in dataset.features
+ assert "labels" in dataset.features
+ shutil.rmtree(tmp_ds_path)
@enable_hf_offline
def test_loading_local_dataset_folder(self, tokenizer):
diff --git a/tests/test_exact_deduplication.py b/tests/test_exact_deduplication.py
index 7430352a4..a75f97f78 100644
--- a/tests/test_exact_deduplication.py
+++ b/tests/test_exact_deduplication.py
@@ -238,21 +238,22 @@ class TestDeduplicateRLDataset:
@enable_hf_offline
def test_load_with_deduplication(
- self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
+ self,
+ cfg,
+ dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
+ tokenizer_huggyllama,
):
"""Verify that loading with deduplication removes duplicates."""
# pylint: disable=duplicate-code
with (
- patch(
- "axolotl.utils.data.shared.load_dataset_w_config"
- ) as mock_load_dataset,
+ patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
- dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
- dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
+ dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
+ dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
@@ -263,19 +264,20 @@ class TestDeduplicateRLDataset:
@enable_hf_offline
def test_load_without_deduplication(
- self, cfg, dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff, tokenizer_huggyllama
+ self,
+ cfg,
+ dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
+ tokenizer_huggyllama,
):
# pylint: disable=duplicate-code
with (
- patch(
- "axolotl.utils.data.shared.load_dataset_w_config"
- ) as mock_load_dataset,
+ patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
):
# Set up the mock to return different values on successive calls
mock_load_dataset.side_effect = [
- dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
- dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff,
+ dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
+ dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
]
mock_load_tokenizer.return_value = tokenizer_huggyllama
diff --git a/tests/test_packed_batch_sampler.py b/tests/test_packed_batch_sampler.py
index faba86931..dd0386e58 100644
--- a/tests/test_packed_batch_sampler.py
+++ b/tests/test_packed_batch_sampler.py
@@ -1,7 +1,7 @@
"""Module for testing streaming dataset sequence packing"""
import pytest
-from datasets import concatenate_datasets, load_dataset
+from datasets import concatenate_datasets
from torch.utils.data import DataLoader, RandomSampler
from transformers import AutoTokenizer
@@ -27,7 +27,6 @@ class TestBatchedSamplerPacking:
Test class for packing streaming dataset sequences
"""
- @pytest.mark.skip(reason="TODO: fix hf offline mode for CI rate limits")
@pytest.mark.parametrize(
"batch_size, num_workers",
[
@@ -41,14 +40,17 @@ class TestBatchedSamplerPacking:
@pytest.mark.parametrize("sequential", [True, False])
@enable_hf_offline
def test_packing(
- self, batch_size, num_workers, tokenizer, max_seq_length, sequential
+ self,
+ dataset_winglian_tiny_shakespeare,
+ batch_size,
+ num_workers,
+ tokenizer,
+ max_seq_length,
+ sequential,
):
import axolotl.monkeypatch.data.batch_dataset_fetcher # pylint: disable=unused-import # noqa: F401
- dataset = load_dataset(
- "winglian/tiny-shakespeare",
- split="train",
- )
+ dataset = dataset_winglian_tiny_shakespeare["train"]
cfg = DictDefault(
{
@@ -58,7 +60,7 @@ class TestBatchedSamplerPacking:
)
ds_cfg = DictDefault(
{
- "field": "Text",
+ "field": "text",
}
)
completion_strategy = load(tokenizer, cfg, ds_cfg)
diff --git a/tests/test_prompt_tokenizers.py b/tests/test_prompt_tokenizers.py
index 65eee7ddb..3f16bc917 100644
--- a/tests/test_prompt_tokenizers.py
+++ b/tests/test_prompt_tokenizers.py
@@ -2,13 +2,8 @@
import json
import logging
-import unittest
from pathlib import Path
-import pytest
-from datasets import load_dataset
-from transformers import AddedToken, AutoTokenizer, LlamaTokenizer
-
from axolotl.prompt_strategies.alpaca_chat import NoSystemPrompter
from axolotl.prompt_strategies.alpaca_w_system import (
InstructionWSystemPromptTokenizingStrategy,
@@ -61,24 +56,13 @@ test_data = {
}
-class TestPromptTokenizationStrategies(unittest.TestCase):
+class TestPromptTokenizationStrategies:
"""
Test class for prompt tokenization strategies.
"""
@enable_hf_offline
- def setUp(self) -> None:
- # pylint: disable=duplicate-code
- self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
- self.tokenizer.add_special_tokens(
- {
- "bos_token": "",
- "eos_token": "",
- "unk_token": "",
- }
- )
-
- def test_no_sys_prompt(self):
+ def test_no_sys_prompt(self, tokenizer_huggyllama_w_special_tokens):
"""
tests the interface between the user and assistant parts
"""
@@ -86,7 +70,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
- self.tokenizer,
+ tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -99,7 +83,8 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx] == 3186
assert example["labels"][world_idx - 1] == -100
- def test_alpaca(self):
+ @enable_hf_offline
+ def test_alpaca(self, tokenizer_huggyllama_w_special_tokens):
"""
tests the interface between the user and assistant parts
"""
@@ -107,7 +92,7 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
- self.tokenizer,
+ tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -118,28 +103,17 @@ class TestPromptTokenizationStrategies(unittest.TestCase):
assert example["labels"][world_idx - 1] == -100
-class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
+class TestInstructionWSystemPromptTokenizingStrategy:
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
- def setUp(self) -> None:
- # pylint: disable=duplicate-code
- self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
- self.tokenizer.add_special_tokens(
- {
- "bos_token": "",
- "eos_token": "",
- "unk_token": "",
- }
- )
-
- def test_system_alpaca(self):
+ def test_system_alpaca(self, tokenizer_huggyllama_w_special_tokens):
prompter = SystemDataPrompter(PromptStyle.CHAT.value)
strat = InstructionWSystemPromptTokenizingStrategy(
prompter,
- self.tokenizer,
+ tokenizer_huggyllama_w_special_tokens,
False,
2048,
)
@@ -160,18 +134,13 @@ class InstructionWSystemPromptTokenizingStrategyTest(unittest.TestCase):
assert example["input_ids"][8] == 11889 # USER
-class Llama2ChatTokenizationTest(unittest.TestCase):
+class Llama2ChatTokenizationTest:
"""
Test class for prompt tokenization strategies with sys prompt from the dataset
"""
@enable_hf_offline
- def setUp(self) -> None:
- # pylint: disable=duplicate-code
- self.tokenizer = LlamaTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
- # woraround because official Meta repos are not open
-
- def test_llama2_chat_integration(self):
+ def test_llama2_chat_integration(self, tokenizer_llama2_7b):
with open(
Path(__file__).parent / "fixtures/conversation.json", encoding="utf-8"
) as fin:
@@ -186,16 +155,18 @@ class Llama2ChatTokenizationTest(unittest.TestCase):
prompter = Llama2ChatPrompter()
strat = LLama2ChatTokenizingStrategy(
prompter,
- self.tokenizer,
+ tokenizer_llama2_7b,
False,
4096,
)
example = strat.tokenize_prompt(conversation)
for fields in ["input_ids", "attention_mask", "labels"]:
- self.assertEqual(len(example[fields]), len(tokenized_conversation[fields]))
- self.assertEqual(example[fields], tokenized_conversation[fields])
+ # pytest assert equals
- def compare_with_transformers_integration(self):
+ assert len(example[fields]) == len(tokenized_conversation[fields])
+ assert example[fields] == tokenized_conversation[fields]
+
+ def compare_with_transformers_integration(self, tokenizer_llama2_7b):
# this needs transformers >= v4.31.0
from transformers.models.llama.tokenization_llama import B_SYS, E_SYS
from transformers.pipelines.conversational import Conversation
@@ -234,49 +205,27 @@ If a question does not make any sense, or is not factually coherent, explain why
generated_responses=answers,
)
# pylint: disable=W0212
- hf_tokens = self.tokenizer._build_conversation_input_ids(hf_conf)
+ hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
- self.assertEqual(
- hf_tokens, tokenized_conversation["input_ids"][: len(hf_tokens)]
- )
+ assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
-class OrpoTokenizationTest(unittest.TestCase):
+class OrpoTokenizationTest:
"""test case for the ORPO tokenization"""
@enable_hf_offline
- def setUp(self) -> None:
- # pylint: disable=duplicate-code
- tokenizer = LlamaTokenizer.from_pretrained(
- "casperhansen/mistral-7b-instruct-v0.1-awq"
- )
- tokenizer.add_special_tokens(
- {
- "eos_token": AddedToken(
- "<|im_end|>", rstrip=False, lstrip=False, normalized=False
- )
- }
- )
- tokenizer.add_tokens(
- [
- AddedToken(
- "<|im_start|>", rstrip=False, lstrip=False, normalized=False
- ),
- ]
- )
- self.tokenizer = tokenizer
- self.dataset = load_dataset(
- "argilla/ultrafeedback-binarized-preferences-cleaned", split="train"
- ).select([0])
-
- @pytest.mark.skip(reason="TODO: fix hf hub offline to work with HF rate limits")
- def test_orpo_integration(self):
+ def test_orpo_integration(
+ self,
+ tokenizer_mistral_7b_instruct_chatml,
+ dataset_argilla_ultrafeedback_binarized_preferences_cleaned,
+ ):
+ ds = dataset_argilla_ultrafeedback_binarized_preferences_cleaned.select([0])
strat = load(
- self.tokenizer,
+ tokenizer_mistral_7b_instruct_chatml,
DictDefault({"train_on_inputs": False}),
DictDefault({"chat_template": "chatml"}),
)
- res = strat.tokenize_prompt(self.dataset[0])
+ res = strat.tokenize_prompt(ds[0])
assert "rejected_input_ids" in res
assert "rejected_labels" in res
assert "input_ids" in res
@@ -295,7 +244,3 @@ class OrpoTokenizationTest(unittest.TestCase):
assert res["prompt_attention_mask"][0] == 1
assert res["prompt_attention_mask"][-1] == 0
-
-
-if __name__ == "__main__":
- unittest.main()