Merge branch 'main' into 775-option-to-drop-vs-truncate-on-rows-longer-than-context-length
This commit is contained in:
@@ -1,13 +1,12 @@
|
||||
"""
|
||||
unit tests for axolotl.core.trainer_builder
|
||||
"""
|
||||
"""Unit tests for axolotl.core.trainer_builder"""
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.core.trainer_builder import HFRLTrainerBuilder
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
from axolotl.utils.schemas.enums import RLType
|
||||
|
||||
|
||||
@pytest.fixture(name="cfg")
|
||||
@@ -49,7 +48,7 @@ def fixture_tokenizer(cfg):
|
||||
|
||||
@pytest.fixture(name="model")
|
||||
def fixture_model(cfg, tokenizer):
|
||||
return load_model(cfg, tokenizer)
|
||||
return ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
|
||||
class TestHFRLTrainerBuilder:
|
||||
@@ -65,3 +64,27 @@ class TestHFRLTrainerBuilder:
|
||||
assert training_arguments.adam_epsilon == 0.00001
|
||||
assert training_arguments.dataloader_num_workers == 1
|
||||
assert training_arguments.dataloader_pin_memory is True
|
||||
|
||||
|
||||
class TestTrainerClsPlugin:
|
||||
"""
|
||||
TestCase class for trainer builder with plugin
|
||||
"""
|
||||
|
||||
def test_trainer_cls_is_not_none_with_plugin(self, cfg, model, tokenizer):
|
||||
"""
|
||||
Test that the trainer cls is not none with plugin
|
||||
|
||||
Fixes #2693
|
||||
"""
|
||||
cfg.plugins = ["axolotl.integrations.liger.LigerPlugin"]
|
||||
cfg.rl = RLType.KTO
|
||||
|
||||
# Expected AttributeError as we don't pass regular model configs to RL trainer builder
|
||||
# If it throws `TypeError: None is not a callable object`, trainer_cls could be None
|
||||
with pytest.raises(
|
||||
AttributeError, match=r".*'tuple' object has no attribute 'config'.*"
|
||||
):
|
||||
builder = HFRLTrainerBuilder(cfg, model, tokenizer)
|
||||
|
||||
builder.build(100)
|
||||
|
||||
@@ -166,7 +166,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"""
|
||||
)
|
||||
|
||||
@pytest.mark.skip(reason="flaky test")
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
@@ -231,8 +230,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC",
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||
# "VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process = start_vllm(
|
||||
cfg.base_model,
|
||||
@@ -266,7 +263,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
finally:
|
||||
recursive_kill(vllm_process)
|
||||
|
||||
@pytest.mark.skip(reason="flaky test")
|
||||
@pytest.mark.parametrize(
|
||||
"num_gpus",
|
||||
[1, 2],
|
||||
@@ -325,8 +321,6 @@ def oai_gsm8k_transform(cfg, *args, **kwargs):
|
||||
"NCCL_P2P_LEVEL": "LOC", # nccl can be brittle, assume P2P isn't reliable
|
||||
**current_env,
|
||||
"CUDA_VISIBLE_DEVICES": "1",
|
||||
"VLLM_DISABLE_COMPILE_CACHE": "1",
|
||||
# "VLLM_USE_V1": "0",
|
||||
}
|
||||
vllm_process = start_vllm(
|
||||
cfg.base_model,
|
||||
|
||||
@@ -6,9 +6,9 @@ import unittest
|
||||
|
||||
import transformers
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
from ..utils import with_temp_dir
|
||||
|
||||
@@ -50,7 +50,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
ModelLoader(cfg, tokenizer, inference=False).load()
|
||||
|
||||
@with_temp_dir
|
||||
def test_mistral_multipack(self, temp_dir):
|
||||
@@ -83,7 +83,7 @@ class TestModelPatches(unittest.TestCase):
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer, inference=False)
|
||||
ModelLoader(cfg, tokenizer, inference=False).load()
|
||||
|
||||
assert (
|
||||
"torch.jit"
|
||||
|
||||
@@ -10,7 +10,7 @@ import pytest
|
||||
import torch
|
||||
from accelerate.state import PartialState
|
||||
|
||||
from axolotl.monkeypatch.attention.ring_attn import (
|
||||
from axolotl.monkeypatch.ring_attn import (
|
||||
get_ring_attn_group,
|
||||
register_ring_attn,
|
||||
set_ring_attn_group,
|
||||
@@ -84,16 +84,16 @@ class TestRingAttention:
|
||||
def test_get_ring_attn_group_no_registration(
|
||||
self, mock_world_size, mock_rank, partial_state
|
||||
):
|
||||
"""Test that get_ring_attn_group returns None when no group has been registered."""
|
||||
"""Test that get_ring_attn_group raises RuntimeError when no group has been registered."""
|
||||
# Setup mocks
|
||||
mock_world_size.return_value = 4
|
||||
mock_rank.return_value = 0
|
||||
|
||||
# Get the group without registration
|
||||
group = get_ring_attn_group()
|
||||
|
||||
# Verify that None was returned
|
||||
assert group is None
|
||||
# Verify that RuntimeError is raised when no group is registered
|
||||
with pytest.raises(
|
||||
RuntimeError, match="register_ring_attn\\(\\) not yet called"
|
||||
):
|
||||
get_ring_attn_group()
|
||||
|
||||
@patch("torch.distributed.new_group")
|
||||
@patch("torch.distributed.get_rank")
|
||||
@@ -313,18 +313,21 @@ class TestApplySequenceParallelism:
|
||||
|
||||
# Mock the process group
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.attention.ring_attn.get_ring_attn_group",
|
||||
"axolotl.monkeypatch.ring_attn.get_ring_attn_group",
|
||||
MagicMock,
|
||||
)
|
||||
|
||||
# Mock update_ring_attn_params
|
||||
monkeypatch.setattr(
|
||||
"axolotl.monkeypatch.attention.ring_attn.update_ring_attn_params",
|
||||
"axolotl.monkeypatch.ring_attn.update_ring_attn_params",
|
||||
lambda **kwargs: None,
|
||||
)
|
||||
|
||||
def test_world_size_one(self, sequence_parallel_batch):
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_world_size_one(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test that function returns original batch when world size is 1."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
result, _, _ = apply_sequence_parallelism(
|
||||
batch=sequence_parallel_batch,
|
||||
local_rank=0,
|
||||
@@ -336,8 +339,11 @@ class TestApplySequenceParallelism:
|
||||
# Should return the original batch unchanged
|
||||
assert result == sequence_parallel_batch
|
||||
|
||||
def test_batch_ring_rank0(self, sequence_parallel_batch):
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank0(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 0 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
|
||||
@@ -359,8 +365,11 @@ class TestApplySequenceParallelism:
|
||||
result["position_ids"], batch["position_ids"][:, : seq_len // 2]
|
||||
)
|
||||
|
||||
def test_batch_ring_rank1(self, sequence_parallel_batch):
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_batch_ring_rank1(self, mock_get_ring_attn_group, sequence_parallel_batch):
|
||||
"""Test BATCH_RING sharding for rank 1 in a 2-process group."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
seq_len = batch["input_ids"].size(1)
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
@@ -419,8 +428,13 @@ class TestApplySequenceParallelism:
|
||||
# assert torch.equal(result_rank0["input_ids"], rank0_expected)
|
||||
# assert torch.equal(result_rank1["input_ids"], rank1_expected)
|
||||
|
||||
def test_partial_application(self, sequence_parallel_batch):
|
||||
@patch("axolotl.monkeypatch.ring_attn.patch.get_ring_attn_group")
|
||||
def test_partial_application(
|
||||
self, mock_get_ring_attn_group, sequence_parallel_batch
|
||||
):
|
||||
"""Test that we can create a partially applied version of the function."""
|
||||
mock_get_ring_attn_group.return_value = 0
|
||||
|
||||
batch = sequence_parallel_batch
|
||||
original_input_ids = batch["input_ids"].clone()
|
||||
|
||||
|
||||
@@ -6,8 +6,8 @@ import tempfile
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import ModelLoader, load_model, load_tokenizer
|
||||
|
||||
|
||||
@pytest.fixture(name="temp_dir")
|
||||
@@ -58,6 +58,8 @@ class TestLoadModelUtils:
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer="",
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -71,13 +73,8 @@ class TestLoadModelUtils:
|
||||
):
|
||||
self.cfg.output_dir = temp_dir
|
||||
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
|
||||
self.model_loader.model, _ = load_model(
|
||||
self.cfg,
|
||||
self.model_loader.tokenizer,
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
self.model_loader.convert_embedding_modules_dtype(
|
||||
self.model_loader.load()
|
||||
self.model_loader._convert_embedding_modules_dtype(
|
||||
embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
)
|
||||
for name, module in self.model_loader.model.named_modules():
|
||||
|
||||
@@ -9,11 +9,11 @@ from typing import Optional
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from axolotl.loaders.utils import check_model_config
|
||||
from axolotl.utils import is_comet_available
|
||||
from axolotl.utils.config import validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.mlflow_ import setup_mlflow_env_vars
|
||||
from axolotl.utils.models import check_model_config
|
||||
from axolotl.utils.schemas.config import AxolotlConfigWCapabilities
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
@@ -1215,6 +1215,20 @@ class TestValidation(BaseValidation):
|
||||
cfg, capabilities=capabilities, env_capabilities=env_capabilities
|
||||
)
|
||||
|
||||
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self, minimal_cfg):
|
||||
test_cfg = DictDefault(
|
||||
{
|
||||
"s2_attention": True,
|
||||
"sample_packing": True,
|
||||
}
|
||||
| minimal_cfg
|
||||
)
|
||||
with pytest.raises(
|
||||
ValidationError,
|
||||
match=r".*shifted-sparse attention does not currently support sample packing*",
|
||||
):
|
||||
validate_config(test_cfg)
|
||||
|
||||
|
||||
class TestTorchCompileValidation(BaseValidation):
|
||||
"""
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""
|
||||
Test suite for functions in the axolotl.utils.data.utils module, focusing on the deduplicate_and_log_datasets function.
|
||||
"""Test suite for functions in the `axolotl.utils.data.utils` module, focusing on the
|
||||
`deduplicate_and_log_datasets` function.
|
||||
|
||||
Additionally, this test suite includes tests for functions that indirectly call deduplicate_and_log_datasets during the execution of the preprocess command.
|
||||
Additionally, this test suite includes tests for functions that indirectly call
|
||||
`deduplicate_and_log_datasets` during the execution of the preprocess command.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
@@ -11,20 +12,19 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from axolotl.loaders import load_processor, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.data import prepare_dataset
|
||||
from axolotl.utils.data.rl import load_prepare_preference_datasets
|
||||
from axolotl.utils.data.utils import deduplicate_and_log_datasets
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_processor, load_tokenizer
|
||||
|
||||
from tests.constants import ALPACA_MESSAGES_CONFIG_REVISION
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
||||
"""
|
||||
Validates deduplication results and size consistency.
|
||||
"""Validates deduplication results and size consistency.
|
||||
|
||||
Parameters:
|
||||
- actual_dataset: Deduplicated dataset.
|
||||
@@ -49,9 +49,7 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
||||
|
||||
|
||||
class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
"""
|
||||
test class for deduplication function in data utils
|
||||
"""
|
||||
"""Test class for deduplication function in data utils"""
|
||||
|
||||
def setUp(self):
|
||||
# Sample data with duplicates
|
||||
@@ -248,7 +246,7 @@ class TestDeduplicateRLDataset:
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
@@ -272,7 +270,7 @@ class TestDeduplicateRLDataset:
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch("axolotl.utils.data.rl.load_dataset_w_config") as mock_load_dataset,
|
||||
patch("axolotl.utils.models.load_tokenizer") as mock_load_tokenizer,
|
||||
patch("axolotl.loaders.load_tokenizer") as mock_load_tokenizer,
|
||||
):
|
||||
# Set up the mock to return different values on successive calls
|
||||
mock_load_dataset.side_effect = [
|
||||
@@ -411,7 +409,7 @@ class TestDeduplicateNonRL(unittest.TestCase):
|
||||
|
||||
|
||||
class TestWrongCollisions(unittest.TestCase):
|
||||
"""Creating mock datasets for testing wrong collisions"""
|
||||
"""Creating mock datasets for testing wrong collisions."""
|
||||
|
||||
def setUp(self):
|
||||
self.train_data = {"text": ["sample 5", "sample 6"], "label": [1, 2]}
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
"""Module for testing models utils file."""
|
||||
"""Module for `axolotl.loaders`."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from transformers import BitsAndBytesConfig, PreTrainedTokenizerBase
|
||||
from transformers.integrations.deepspeed import is_deepspeed_zero3_enabled
|
||||
from transformers.utils.import_utils import is_torch_mps_available
|
||||
|
||||
from axolotl.loaders import ModelLoader
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import ModelLoader, load_model
|
||||
|
||||
|
||||
class TestModelsUtils:
|
||||
"""Testing module for models utils."""
|
||||
"""Testing module for `axolotl.loaders`."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
# load config
|
||||
@@ -50,7 +50,8 @@ class TestModelsUtils:
|
||||
device_map = self.cfg.device_map
|
||||
if is_torch_mps_available():
|
||||
device_map = "mps"
|
||||
self.model_loader.set_device_map_config()
|
||||
# pylint: disable=protected-access
|
||||
self.model_loader._set_device_map_config()
|
||||
if is_deepspeed_zero3_enabled():
|
||||
assert "device_map" not in self.model_loader.model_kwargs
|
||||
else:
|
||||
@@ -59,29 +60,6 @@ class TestModelsUtils:
|
||||
# check torch_dtype
|
||||
assert self.cfg.torch_dtype == self.model_loader.model_kwargs["torch_dtype"]
|
||||
|
||||
def test_cfg_throws_error_with_s2_attention_and_sample_packing(self):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"s2_attention": True,
|
||||
"sample_packing": True,
|
||||
"base_model": "",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
}
|
||||
)
|
||||
|
||||
# Mock out call to HF hub
|
||||
with patch(
|
||||
"axolotl.utils.models.load_model_config"
|
||||
) as mocked_load_model_config:
|
||||
mocked_load_model_config.return_value = {}
|
||||
with pytest.raises(ValueError) as exc:
|
||||
# Should error before hitting tokenizer, so we pass in an empty str
|
||||
load_model(cfg, tokenizer="") # type: ignore
|
||||
assert (
|
||||
"shifted-sparse attention does not currently support sample packing"
|
||||
in str(exc.value)
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("adapter", ["lora", "qlora", None])
|
||||
@pytest.mark.parametrize("load_in_8bit", [True, False])
|
||||
@pytest.mark.parametrize("load_in_4bit", [True, False])
|
||||
@@ -99,7 +77,8 @@ class TestModelsUtils:
|
||||
self.cfg.gptq = gptq
|
||||
self.cfg.adapter = adapter
|
||||
|
||||
self.model_loader.set_quantization_config()
|
||||
# pylint: disable=protected-access
|
||||
self.model_loader._set_quantization_config()
|
||||
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
|
||||
assert not (
|
||||
hasattr(self.model_loader.model_kwargs, "load_in_8bit")
|
||||
@@ -2,9 +2,9 @@
|
||||
tests for loading loras
|
||||
"""
|
||||
|
||||
from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_model, load_tokenizer
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
minimal_config = DictDefault(
|
||||
@@ -46,7 +46,7 @@ class TestLoRALoad:
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer)
|
||||
ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
def test_load_lora_weights_empty_dropout(self):
|
||||
cfg = DictDefault(
|
||||
@@ -67,4 +67,4 @@ class TestLoRALoad:
|
||||
normalize_config(cfg)
|
||||
assert cfg.lora_dropout == 0.0
|
||||
tokenizer = load_tokenizer(cfg)
|
||||
load_model(cfg, tokenizer)
|
||||
ModelLoader(cfg, tokenizer).load()
|
||||
|
||||
@@ -6,8 +6,8 @@ import unittest
|
||||
|
||||
import pytest
|
||||
|
||||
from axolotl.loaders import load_tokenizer
|
||||
from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import load_tokenizer
|
||||
|
||||
from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
|
||||
Reference in New Issue
Block a user