Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff * remove unused * add back needed import * fix
This commit is contained in:
@@ -1,7 +1,5 @@
|
||||
"""Tests for evaluate CLI command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
@@ -31,7 +29,6 @@ class TestEvaluateCommand(BaseCliTest):
|
||||
config_path = tmp_path / "config.yml"
|
||||
config_path.write_text(valid_test_config)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
|
||||
result = cli_runner.invoke(
|
||||
cli,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""pytest tests for axolotl CLI inference command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for train CLI command."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from axolotl.cli.main import cli
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""pytest tests for axolotl CLI utils."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
import json
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
@@ -25,7 +23,7 @@ MOCK_TREE_RESPONSE = {
|
||||
def mock_responses():
|
||||
"""Mock responses for API and file downloads"""
|
||||
|
||||
def mock_get(url, timeout=None): # pylint: disable=unused-argument
|
||||
def mock_get(url, timeout=None):
|
||||
response = Mock()
|
||||
if "api.github.com" in url:
|
||||
response.text = json.dumps(MOCK_TREE_RESPONSE)
|
||||
@@ -93,21 +91,21 @@ def assert_launcher_args_in_command(
|
||||
called_cmd = mock_subprocess_call.call_args.args[0]
|
||||
|
||||
# Verify launcher
|
||||
assert (
|
||||
called_cmd[0] == launcher
|
||||
), f"Expected launcher {launcher}, got {called_cmd[0]}"
|
||||
assert called_cmd[0] == launcher, (
|
||||
f"Expected launcher {launcher}, got {called_cmd[0]}"
|
||||
)
|
||||
|
||||
# Verify launcher args are present
|
||||
for arg in expected_launcher_args:
|
||||
assert (
|
||||
arg in called_cmd
|
||||
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
|
||||
assert arg in called_cmd, (
|
||||
f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
|
||||
)
|
||||
|
||||
# Verify module is present
|
||||
assert "-m" in called_cmd, "Expected -m flag for module execution"
|
||||
assert (
|
||||
command_module in called_cmd
|
||||
), f"Expected module {command_module} not found in command: {called_cmd}"
|
||||
assert command_module in called_cmd, (
|
||||
f"Expected module {command_module} not found in command: {called_cmd}"
|
||||
)
|
||||
|
||||
|
||||
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
|
||||
@@ -126,17 +124,17 @@ def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
|
||||
launch_idx = called_cmd.index("launch")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[launch_idx + 1 : m_idx]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
), f"Unexpected launcher args found: {launcher_section}"
|
||||
assert len(launcher_section) == 0, (
|
||||
f"Unexpected launcher args found: {launcher_section}"
|
||||
)
|
||||
elif launcher == "torchrun":
|
||||
# For torchrun, launcher args should be between 'torchrun' and '-m'
|
||||
torchrun_idx = called_cmd.index("torchrun")
|
||||
m_idx = called_cmd.index("-m")
|
||||
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
|
||||
assert (
|
||||
len(launcher_section) == 0
|
||||
), f"Unexpected launcher args found: {launcher_section}"
|
||||
assert len(launcher_section) == 0, (
|
||||
f"Unexpected launcher args found: {launcher_section}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -33,10 +33,9 @@ logging.getLogger("filelock").setLevel(logging.CRITICAL)
|
||||
|
||||
|
||||
def retry_on_request_exceptions(max_retries=3, delay=1):
|
||||
# pylint: disable=duplicate-code
|
||||
def decorator(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
|
||||
def wrapper(*args, **kwargs):
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
@@ -171,7 +170,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# ):
|
||||
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
|
||||
#
|
||||
#
|
||||
@@ -179,7 +178,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
|
||||
# @disable_hf_offline
|
||||
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
|
||||
# download_fozzie_alpaca_dpo_dataset,
|
||||
# ): # pylint: disable=unused-argument,redefined-outer-name
|
||||
# ):
|
||||
# return load_dataset(
|
||||
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
|
||||
# )
|
||||
@@ -359,7 +358,7 @@ def download_llama32_1b_model_fixture():
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama(
|
||||
download_huggyllama_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
tokenizer.pad_token = "</s>"
|
||||
|
||||
@@ -370,7 +369,7 @@ def tokenizer_huggyllama(
|
||||
@enable_hf_offline
|
||||
def tokenizer_huggyllama_w_special_tokens(
|
||||
tokenizer_huggyllama,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
tokenizer_huggyllama.add_special_tokens(
|
||||
{
|
||||
"bos_token": "<s>",
|
||||
@@ -386,7 +385,7 @@ def tokenizer_huggyllama_w_special_tokens(
|
||||
@enable_hf_offline
|
||||
def tokenizer_llama2_7b(
|
||||
download_llama2_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
|
||||
|
||||
return tokenizer
|
||||
@@ -396,7 +395,7 @@ def tokenizer_llama2_7b(
|
||||
@enable_hf_offline
|
||||
def tokenizer_mistral_7b_instruct(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
|
||||
|
||||
|
||||
@@ -442,9 +441,7 @@ def cleanup_monkeypatches():
|
||||
# original_fa2_forward = LlamaFlashAttention2.forward
|
||||
original_llama_attn_forward = LlamaAttention.forward
|
||||
original_llama_forward = LlamaForCausalLM.forward
|
||||
original_trainer_inner_training_loop = (
|
||||
Trainer._inner_training_loop # pylint: disable=protected-access
|
||||
)
|
||||
original_trainer_inner_training_loop = Trainer._inner_training_loop
|
||||
original_trainer_training_step = Trainer.training_step
|
||||
# monkey patches can happen inside the tests
|
||||
yield
|
||||
@@ -452,9 +449,7 @@ def cleanup_monkeypatches():
|
||||
# LlamaFlashAttention2.forward = original_fa2_forward
|
||||
LlamaAttention.forward = original_llama_attn_forward
|
||||
LlamaForCausalLM.forward = original_llama_forward
|
||||
Trainer._inner_training_loop = ( # pylint: disable=protected-access
|
||||
original_trainer_inner_training_loop
|
||||
)
|
||||
Trainer._inner_training_loop = original_trainer_inner_training_loop
|
||||
Trainer.training_step = original_trainer_training_step
|
||||
|
||||
# Reset other known monkeypatches
|
||||
@@ -490,7 +485,7 @@ def cleanup_monkeypatches():
|
||||
@pytest.fixture
|
||||
def dataset_winglian_tiny_shakespeare(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
|
||||
return datasets.load_from_disk(ds_path)
|
||||
|
||||
@@ -498,7 +493,7 @@ def dataset_winglian_tiny_shakespeare(
|
||||
@pytest.fixture
|
||||
def dataset_tatsu_lab_alpaca(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -506,7 +501,7 @@ def dataset_tatsu_lab_alpaca(
|
||||
@pytest.fixture
|
||||
def dataset_mhenrichsen_alpaca_2k_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -514,7 +509,7 @@ def dataset_mhenrichsen_alpaca_2k_test(
|
||||
@pytest.fixture
|
||||
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
|
||||
@@ -525,7 +520,7 @@ def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
|
||||
return datasets.load_from_disk(ds_path)["train"]
|
||||
|
||||
@@ -533,7 +528,7 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
|
||||
@pytest.fixture
|
||||
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
|
||||
download_ds_fixture_bundle: Path,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
ds_path = (
|
||||
download_ds_fixture_bundle
|
||||
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
|
||||
@@ -557,7 +552,7 @@ def fixture_min_base_cfg():
|
||||
)
|
||||
|
||||
|
||||
# # pylint: disable=redefined-outer-name,unused-argument
|
||||
#
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
|
||||
reason="Not running in CI cache preload",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
This module contains constants and configuration dictionaries used for
|
||||
datasets and other utilities in the Axolotl project, specifically for testing.
|
||||
"""
|
||||
|
||||
# Configuration for Alpaca Messages Dataset
|
||||
ALPACA_MESSAGES_CONFIG_OG = {
|
||||
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Unit tests for axolotl.core.builders"""
|
||||
|
||||
# pylint: disable=protected-access
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
@@ -330,7 +328,6 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
)
|
||||
|
||||
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
|
||||
|
||||
rewards_dir = tmp_path / "rewards_test"
|
||||
self._write_rewards_file(rewards_dir)
|
||||
|
||||
@@ -477,7 +474,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
|
||||
|
||||
assert trainer.optimizer_cls_and_kwargs is not None
|
||||
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
from axolotl.contribs.mit.muon import (
|
||||
Muon,
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
@@ -559,7 +556,7 @@ class TestHFCausalTrainerBuilder:
|
||||
|
||||
assert trainer.optimizer_cls_and_kwargs is not None
|
||||
|
||||
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
|
||||
from axolotl.contribs.mit.muon import (
|
||||
Muon,
|
||||
MuonOptimizerFactory,
|
||||
)
|
||||
@@ -599,6 +596,6 @@ class TestTrainerClsPlugin:
|
||||
except TypeError as e:
|
||||
# Error raised if trainer_cls is None
|
||||
assert "'tuple' object has no attribute 'config'" not in str(e)
|
||||
except Exception: # pylint: disable=broad-exception-caught
|
||||
except Exception:
|
||||
# Another error happens, so we passed trainer_cls to builder
|
||||
pass
|
||||
|
||||
@@ -12,8 +12,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from ..utils import check_model_output_exists
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def min_cfg(temp_dir):
|
||||
@@ -53,7 +51,6 @@ class TestCutCrossEntropyIntegration:
|
||||
e2e tests for cut_cross_entropy integration with Axolotl
|
||||
"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_llama_w_cce(self, min_cfg, temp_dir):
|
||||
cfg = DictDefault(min_cfg)
|
||||
cfg = validate_config(cfg)
|
||||
@@ -69,7 +66,6 @@ class TestCutCrossEntropyIntegration:
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
def test_qwen2_w_cce(self, temp_dir):
|
||||
cfg = DictDefault(
|
||||
{
|
||||
|
||||
@@ -18,7 +18,7 @@ class FP8IntegrationTestCase:
|
||||
@require_torch_2_7_0
|
||||
def test_fp8_single_gpu_smoke(self, temp_dir):
|
||||
"""Smoke test for single GPU FP8 + torch.compile training"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -53,7 +53,6 @@ class FP8IntegrationTestCase:
|
||||
}
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
dataset_meta = load_datasets(cfg=cfg)
|
||||
|
||||
@@ -28,85 +28,81 @@ class LogHooksPlugin(BasePlugin):
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
def post_trainer_create(self, cfg, trainer):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_trainer_create\n")
|
||||
|
||||
def pre_model_load(self, cfg): # pylint: disable=unused-argument
|
||||
def pre_model_load(self, cfg):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("pre_model_load\n")
|
||||
|
||||
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_model_build(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_model_build\n")
|
||||
|
||||
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def pre_lora_load(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("pre_lora_load\n")
|
||||
|
||||
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_lora_load(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_lora_load\n")
|
||||
|
||||
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_model_load(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_model_load\n")
|
||||
|
||||
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
|
||||
def create_optimizer(self, cfg, trainer):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("create_optimizer\n")
|
||||
|
||||
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
|
||||
def get_trainer_cls(self, cfg):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("get_trainer_cls\n")
|
||||
|
||||
def create_lr_scheduler(
|
||||
self, cfg, trainer, optimizer, num_training_steps
|
||||
): # pylint: disable=unused-argument
|
||||
def create_lr_scheduler(self, cfg, trainer, optimizer, num_training_steps):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("create_lr_scheduler\n")
|
||||
|
||||
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
|
||||
def add_callbacks_pre_trainer(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("add_callbacks_pre_trainer\n")
|
||||
return []
|
||||
|
||||
def add_callbacks_post_trainer(
|
||||
self, cfg, trainer
|
||||
): # pylint: disable=unused-argument
|
||||
def add_callbacks_post_trainer(self, cfg, trainer):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("add_callbacks_post_trainer\n")
|
||||
return []
|
||||
|
||||
def post_train(self, cfg, model): # pylint: disable=unused-argument
|
||||
def post_train(self, cfg, model):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
f.write("post_train\n")
|
||||
|
||||
def post_train_unload(self, cfg): # pylint: disable=unused-argument
|
||||
def post_train_unload(self, cfg):
|
||||
with open(
|
||||
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
|
||||
) as f:
|
||||
@@ -119,7 +115,6 @@ class TestPluginHooks:
|
||||
"""
|
||||
|
||||
def test_plugin_hooks(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -81,7 +81,7 @@ class TestKnowledgeDistillation:
|
||||
@require_torch_2_5_1
|
||||
def test_llama_kd(self, temp_dir, kd_min_cfg):
|
||||
cfg = DictDefault(kd_min_cfg)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
@@ -123,7 +123,7 @@ class TestKnowledgeDistillation:
|
||||
}
|
||||
| kd_min_cfg
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
# write cfg to yaml file
|
||||
Path(temp_dir).mkdir(parents=True, exist_ok=True)
|
||||
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
|
||||
|
||||
@@ -17,7 +17,6 @@ class LigerIntegrationTestCase:
|
||||
|
||||
@require_torch_2_4_1
|
||||
def test_llama_wo_flce(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -53,7 +52,7 @@ class LigerIntegrationTestCase:
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
@@ -64,7 +63,6 @@ class LigerIntegrationTestCase:
|
||||
|
||||
@require_torch_2_4_1
|
||||
def test_llama_w_flce(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -100,7 +98,7 @@ class LigerIntegrationTestCase:
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
prepare_plugins(cfg)
|
||||
normalize_config(cfg)
|
||||
|
||||
@@ -85,6 +85,6 @@ def test_geglu_inplace_preservation():
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
assert not torch.equal(grad_output, grad_copy), (
|
||||
"Grad output should be modified in-place"
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for LoRA custom autograd."""
|
||||
|
||||
# pylint: disable=invalid-name,redefined-outer-name
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
@@ -333,7 +331,7 @@ def test_lora_qkv(sample_tensors):
|
||||
X.requires_grad = True
|
||||
|
||||
# Test without LoRA adapters
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
Q1, K1, V1 = LoRA_QKV.apply(
|
||||
X,
|
||||
q_weight,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for quantization utility functions."""
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
|
||||
import torch
|
||||
from bitsandbytes.functional import QuantState
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Tests for SwiGLU activation function Triton kernels."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
@@ -74,6 +72,6 @@ def test_swiglu_inplace_preservation():
|
||||
|
||||
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
|
||||
assert not torch.equal(up, up_copy), "Up should be modified in-place"
|
||||
assert not torch.equal(
|
||||
grad_output, grad_copy
|
||||
), "Grad output should be modified in-place"
|
||||
assert not torch.equal(grad_output, grad_copy), (
|
||||
"Grad output should be modified in-place"
|
||||
)
|
||||
|
||||
@@ -31,7 +31,6 @@ class TestPackedFlex:
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_loss_llama(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -80,7 +80,7 @@ def start_vllm(
|
||||
cmd_env = env.copy()
|
||||
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
|
||||
# start `trl vllm-serve` command in the background and capture the process id
|
||||
process = subprocess.Popen( # pylint: disable=consider-using-with
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
env=cmd_env,
|
||||
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,
|
||||
|
||||
@@ -21,7 +21,6 @@ class TestMultiGPUEval:
|
||||
"""
|
||||
|
||||
def test_eval_sample_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -93,7 +92,6 @@ class TestMultiGPUEval:
|
||||
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.5, "Eval Loss is too high")
|
||||
|
||||
def test_eval(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Test module for FP8 mixed precision with FSDP2 multi-GPU functionality."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -28,9 +26,9 @@ def verify_fp8_training_success(temp_dir):
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert (
|
||||
len(checkpoint_files) > 0
|
||||
), "No checkpoint files found - training may have failed"
|
||||
assert len(checkpoint_files) > 0, (
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
@@ -42,9 +40,9 @@ def verify_fp8_training_success(temp_dir):
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(
|
||||
torch.tensor(final_loss)
|
||||
), f"Training loss is NaN: {final_loss}"
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
|
||||
|
||||
class TestFP8FSDP2:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Test module for FSDP1 multi-GPU functionality."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -29,9 +27,9 @@ def verify_training_success(temp_dir):
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert (
|
||||
len(checkpoint_files) > 0
|
||||
), "No checkpoint files found - training may have failed"
|
||||
assert len(checkpoint_files) > 0, (
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
@@ -43,9 +41,9 @@ def verify_training_success(temp_dir):
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(
|
||||
torch.tensor(final_loss)
|
||||
), f"Training loss is NaN: {final_loss}"
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
|
||||
|
||||
class TestFSDP1:
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Test module for FSDP2 multi-GPU functionality."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
@@ -29,9 +27,9 @@ def verify_training_success(temp_dir):
|
||||
assert len(model_files) > 0, "No model files found - training may have failed"
|
||||
|
||||
checkpoint_files = list(output_path.glob("checkpoint-*"))
|
||||
assert (
|
||||
len(checkpoint_files) > 0
|
||||
), "No checkpoint files found - training may have failed"
|
||||
assert len(checkpoint_files) > 0, (
|
||||
"No checkpoint files found - training may have failed"
|
||||
)
|
||||
|
||||
tb_log_path = most_recent_subdir(temp_dir + "/runs")
|
||||
if tb_log_path:
|
||||
@@ -43,9 +41,9 @@ def verify_training_success(temp_dir):
|
||||
train_loss_df = df[df.tag == "train/train_loss"]
|
||||
if len(train_loss_df) > 0:
|
||||
final_loss = train_loss_df.value.values[-1]
|
||||
assert not torch.isnan(
|
||||
torch.tensor(final_loss)
|
||||
), f"Training loss is NaN: {final_loss}"
|
||||
assert not torch.isnan(torch.tensor(final_loss)), (
|
||||
f"Training loss is NaN: {final_loss}"
|
||||
)
|
||||
|
||||
|
||||
class TestFSDP2:
|
||||
|
||||
@@ -29,7 +29,6 @@ class TestMultiGPUGemma3:
|
||||
"""
|
||||
|
||||
def test_lora_ddp_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-mirrors/gemma-3-4b-pt",
|
||||
|
||||
@@ -35,7 +35,6 @@ class TestMultiGPULlama:
|
||||
"""
|
||||
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -99,7 +98,6 @@ class TestMultiGPULlama:
|
||||
[1, 2],
|
||||
)
|
||||
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -162,7 +160,6 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
def test_dpo_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -242,7 +239,6 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
def test_dpo_qlora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -326,7 +322,6 @@ class TestMultiGPULlama:
|
||||
[1, 2],
|
||||
)
|
||||
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -402,7 +397,6 @@ class TestMultiGPULlama:
|
||||
],
|
||||
)
|
||||
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -484,7 +478,6 @@ class TestMultiGPULlama:
|
||||
def test_fsdp2_packed(
|
||||
self, temp_dir, attention_backend, fsdp_reshard_after_forward
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -556,7 +549,6 @@ class TestMultiGPULlama:
|
||||
)
|
||||
|
||||
def test_fsdp_qlora_prequant_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
|
||||
@@ -656,7 +648,6 @@ class TestMultiGPULlama:
|
||||
def test_ds_zero3_packed(
|
||||
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
"adapter": "qlora",
|
||||
@@ -732,7 +723,6 @@ class TestMultiGPULlama:
|
||||
[True, False],
|
||||
)
|
||||
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
"adapter": "qlora",
|
||||
@@ -809,7 +799,6 @@ class TestMultiGPULlama:
|
||||
[True, False],
|
||||
)
|
||||
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
|
||||
# pylint: disable=duplicate-code
|
||||
if qlora:
|
||||
adapter = {
|
||||
"adapter": "qlora",
|
||||
@@ -880,7 +869,6 @@ class TestMultiGPULlama:
|
||||
reason="fix untrained tokens brittle with lots of edge cases in latest transformers"
|
||||
)
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -26,7 +26,6 @@ class TestMultiGPURay:
|
||||
|
||||
@require_torch_lt_2_6_0
|
||||
def test_lora_ddp(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -90,7 +89,6 @@ class TestMultiGPURay:
|
||||
[1, 2],
|
||||
)
|
||||
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -150,7 +148,6 @@ class TestMultiGPURay:
|
||||
[1, 2],
|
||||
)
|
||||
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestTensorParallel:
|
||||
)
|
||||
@require_torch_2_7_0
|
||||
def test_fft_sft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""Integration tests for LoRA activation and attention kernels."""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
@@ -88,7 +86,7 @@ def test_attention_patching_integration(model_name, attention_cls):
|
||||
cfg = DictDefault({"base_model": model_name})
|
||||
|
||||
# Store the original implementation
|
||||
original_forward = getattr(attention_cls, "forward")
|
||||
original_forward = attention_cls.forward
|
||||
|
||||
# Apply patch
|
||||
patch_self_attn_lora(cfg)
|
||||
@@ -104,7 +102,7 @@ def test_attention_patching_integration(model_name, attention_cls):
|
||||
assert hasattr(attention_cls, "_original_forward")
|
||||
|
||||
# Clean up
|
||||
setattr(attention_cls, "forward", original_forward)
|
||||
attention_cls.forward = original_forward
|
||||
delattr(attention_cls, "_original_forward")
|
||||
|
||||
|
||||
@@ -379,9 +377,9 @@ def test_model_architecture(model_config):
|
||||
|
||||
# Verify correct activation function
|
||||
layer = patched_model.model.model.layers[0]
|
||||
assert (
|
||||
layer.mlp.forward.__func__ is model_config["expected_activation"]
|
||||
), f"Wrong activation for {model_config['name']}"
|
||||
assert layer.mlp.forward.__func__ is model_config["expected_activation"], (
|
||||
f"Wrong activation for {model_config['name']}"
|
||||
)
|
||||
|
||||
# Test forward pass
|
||||
inputs = get_test_inputs(model)
|
||||
@@ -390,12 +388,11 @@ def test_model_architecture(model_config):
|
||||
patched_output = patched_model(inputs).logits
|
||||
|
||||
# Check outputs match
|
||||
assert torch.allclose(
|
||||
original_output, patched_output, rtol=1e-4
|
||||
), f"Outputs don't match for {model_config['name']}"
|
||||
assert torch.allclose(original_output, patched_output, rtol=1e-4), (
|
||||
f"Outputs don't match for {model_config['name']}"
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
def test_kernel_training_integration(temp_dir):
|
||||
"""Test model loading with kernel patches enabled."""
|
||||
from axolotl.cli.utils import load_model_and_tokenizer
|
||||
@@ -563,15 +560,13 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
|
||||
model_loader = ModelLoader(cfg, tokenizer)
|
||||
|
||||
# Apply patch
|
||||
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
|
||||
model_loader.patch_manager._apply_self_attention_lora_patch()
|
||||
|
||||
# Verify patch was not applied
|
||||
assert attention_cls.forward == original_forward_method
|
||||
|
||||
# Apply apply_lora_kernel_patches
|
||||
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
|
||||
model
|
||||
)
|
||||
model_loader.patch_manager._apply_lora_kernel_patch(model)
|
||||
|
||||
# Verify patch was not applied
|
||||
layers = get_layers(model)
|
||||
|
||||
@@ -19,7 +19,6 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_sdp_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -67,7 +66,6 @@ class Test4dMultipackLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_torch_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -32,10 +32,9 @@ class TestActivationCheckpointing:
|
||||
def test_activation_checkpointing_offload(
|
||||
self,
|
||||
temp_dir,
|
||||
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
|
||||
fix_checkpoint_after_test,
|
||||
gradient_checkpointing,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -10,7 +10,6 @@ from axolotl.cli.config import load_cfg
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
class TestPluginArgs:
|
||||
"""
|
||||
test class for plugin args loaded from the config file
|
||||
|
||||
@@ -23,7 +23,6 @@ class TestFAXentropyLlama:
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestFalconPatched(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
@@ -71,7 +70,6 @@ class TestFalconPatched(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
|
||||
@@ -23,7 +23,6 @@ class TestFAFlattening:
|
||||
[1, 4],
|
||||
)
|
||||
def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -15,7 +15,6 @@ class TestFSDPPatchIntegration:
|
||||
apply_init_unsharded_param_patch,
|
||||
)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
original_init_sharded = FSDPParam._init_sharded_param
|
||||
original_init_unsharded = FSDPParam.init_unsharded_param
|
||||
|
||||
@@ -23,11 +22,9 @@ class TestFSDPPatchIntegration:
|
||||
apply_init_sharded_param_patch()
|
||||
apply_init_unsharded_param_patch()
|
||||
|
||||
assert (
|
||||
# pylint: disable=protected-access
|
||||
FSDPParam._init_sharded_param
|
||||
!= original_init_sharded
|
||||
), "_init_sharded_param was not patched"
|
||||
assert (
|
||||
FSDPParam.init_unsharded_param != original_init_unsharded
|
||||
), "init_unsharded_param was not patched"
|
||||
assert FSDPParam._init_sharded_param != original_init_sharded, (
|
||||
"_init_sharded_param was not patched"
|
||||
)
|
||||
assert FSDPParam.init_unsharded_param != original_init_unsharded, (
|
||||
"init_unsharded_param was not patched"
|
||||
)
|
||||
|
||||
@@ -23,7 +23,6 @@ class TestFusedLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_s2_attn(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -71,7 +70,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_s2_attn(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -73,7 +72,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
|
||||
@with_temp_dir
|
||||
def test_lora_gptq_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",
|
||||
|
||||
@@ -20,7 +20,6 @@ class TestMistral(unittest.TestCase):
|
||||
@require_torch_2_6_0
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
@@ -68,7 +67,6 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -64,7 +63,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
|
||||
@@ -89,5 +89,5 @@ class TestModelPatches(unittest.TestCase):
|
||||
|
||||
assert (
|
||||
"torch.jit"
|
||||
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
|
||||
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__
|
||||
)
|
||||
|
||||
@@ -15,7 +15,6 @@ class TestLlamaPeftEmbeddings:
|
||||
"""
|
||||
|
||||
def test_peft_embeddings_upcast(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
@@ -67,7 +66,6 @@ class TestPhiMultipack(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestResumeLlama:
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_resume_lora_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -12,7 +12,6 @@ from axolotl.utils.dict import DictDefault
|
||||
from ..utils import check_model_output_exists, check_tensorboard
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.mark.skip(
|
||||
reason="Unsloth integration will be broken going into latest transformers"
|
||||
)
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestPackedFlex(unittest.TestCase):
|
||||
@require_torch_2_6_0
|
||||
@with_temp_dir
|
||||
def test_loss_llama(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -20,7 +20,6 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_relora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -76,9 +75,9 @@ class TestReLoraLlama(unittest.TestCase):
|
||||
|
||||
train(cfg=cfg, dataset_meta=dataset_meta)
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
|
||||
assert (
|
||||
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
|
||||
).exists(), "Relora model checkpoint not found"
|
||||
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), (
|
||||
"Relora model checkpoint not found"
|
||||
)
|
||||
|
||||
check_tensorboard(
|
||||
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"
|
||||
|
||||
@@ -11,8 +11,6 @@ from axolotl.utils.dict import DictDefault
|
||||
|
||||
from .utils import check_model_output_exists
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
|
||||
class TestActivationOffloading:
|
||||
"""
|
||||
@@ -28,7 +26,6 @@ class TestActivationOffloading:
|
||||
temp_dir,
|
||||
adapter,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -25,7 +25,6 @@ class TestDeepseekV3:
|
||||
[True, False],
|
||||
)
|
||||
def test_lora_deepseekv3(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
|
||||
@@ -83,7 +82,6 @@ class TestDeepseekV3:
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_deepseekv3(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
|
||||
|
||||
@@ -21,7 +21,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -70,7 +69,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_nll_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -120,7 +118,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_dpo_use_weighting(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -171,7 +168,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
@pytest.mark.skip("kto_pair no longer supported in trl")
|
||||
@with_temp_dir
|
||||
def test_kto_pair_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -220,7 +216,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ipo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -269,7 +264,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_orpo_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -322,7 +316,6 @@ class TestDPOLlamaLora(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="Fix the implementation")
|
||||
@with_temp_dir
|
||||
def test_kto_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_train_w_embedding_lr_scale(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -65,7 +64,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_train_w_embedding_lr(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -13,7 +13,6 @@ class TestE2eEvaluate:
|
||||
"""Test cases for evaluate CLI"""
|
||||
|
||||
def test_evaluate(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestFalcon(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
@@ -74,7 +73,6 @@ class TestFalcon(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_lora_added_vocab(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
@@ -130,7 +128,6 @@ class TestFalcon(unittest.TestCase):
|
||||
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "illuin/tiny-random-FalconForCausalLM",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestGemma2:
|
||||
[True, False],
|
||||
)
|
||||
def test_lora_gemma2(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/gemma-2-33M",
|
||||
@@ -78,7 +77,6 @@ class TestGemma2:
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_gemma2(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/gemma-2-33M",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestGemma3Text:
|
||||
[True, False],
|
||||
)
|
||||
def test_lora_gemma3_text(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/gemma-3-34M",
|
||||
@@ -78,7 +77,6 @@ class TestGemma3Text:
|
||||
[True, False],
|
||||
)
|
||||
def test_fft_gemma3_text(self, temp_dir, sample_packing):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/gemma-3-34M",
|
||||
|
||||
@@ -11,11 +11,7 @@ class TestImports(unittest.TestCase):
|
||||
"""
|
||||
|
||||
def test_import_causal_trainer(self):
|
||||
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
|
||||
HFCausalTrainerBuilder,
|
||||
)
|
||||
pass
|
||||
|
||||
def test_import_rl_trainer(self):
|
||||
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
|
||||
HFRLTrainerBuilder,
|
||||
)
|
||||
pass
|
||||
|
||||
@@ -16,7 +16,6 @@ class TestLlama:
|
||||
"""
|
||||
|
||||
def test_fft_trust_remote_code(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -57,7 +56,6 @@ class TestLlama:
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -105,7 +103,6 @@ class TestLlama:
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_fix_untrained_tokens_already_trained(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -150,7 +147,6 @@ class TestLlama:
|
||||
check_model_output_exists(temp_dir, cfg)
|
||||
|
||||
def test_batch_flattening(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestPretrainLlama:
|
||||
],
|
||||
)
|
||||
def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestLlamaVision(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
|
||||
@@ -67,7 +66,6 @@ class TestLlamaVision(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
|
||||
|
||||
@@ -56,13 +56,11 @@ class TestLoadModelUtils:
|
||||
"context_parallel_size": 1,
|
||||
}
|
||||
)
|
||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer="",
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
self.model_loader = ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer="",
|
||||
inference=False,
|
||||
reference_model=True,
|
||||
)
|
||||
|
||||
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
|
||||
@@ -74,7 +72,7 @@ class TestLoadModelUtils:
|
||||
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
):
|
||||
self.cfg.output_dir = temp_dir
|
||||
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
|
||||
self.model_loader.tokenizer = load_tokenizer(self.cfg)
|
||||
self.model_loader.load()
|
||||
self.model_loader._convert_embedding_modules_dtype(
|
||||
embedding_modules, dist_dtype, before_kbit_train_or_finetune
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestLoraLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestMamba(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "state-spaces/mamba-130m",
|
||||
|
||||
@@ -21,7 +21,6 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
@@ -68,7 +67,6 @@ class TestMistral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
|
||||
|
||||
@@ -22,7 +22,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_w_fa2(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -78,7 +77,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_qlora_wo_fa2(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -134,7 +132,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_16bit_lora_w_fa2(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -193,7 +190,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_16bit_lora_wo_fa2(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
@@ -252,7 +248,6 @@ class TestMixtral(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "hf-internal-testing/Mixtral-tiny",
|
||||
|
||||
@@ -25,7 +25,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_optimi_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -71,7 +70,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
def test_adopt_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -117,7 +115,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
@require_torch_2_5_1
|
||||
def test_muon(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -164,7 +161,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
@require_torch_2_7_0
|
||||
def test_dion(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -206,7 +202,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_fft_schedule_free_adamw(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -234,7 +229,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
"save_first_step": False,
|
||||
}
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = validate_config(cfg)
|
||||
normalize_config(cfg)
|
||||
@@ -246,7 +240,6 @@ class TestCustomOptimizers(unittest.TestCase):
|
||||
@with_temp_dir
|
||||
@require_torch_2_6_0
|
||||
def test_came_pytorch(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "JackFram/llama-68m",
|
||||
|
||||
@@ -21,7 +21,6 @@ class TestPackedLlama(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_loss_packed(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestPhi(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_phi_ft(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
@@ -65,7 +64,6 @@ class TestPhi(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_phi_qlora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "microsoft/phi-1_5",
|
||||
|
||||
@@ -15,7 +15,7 @@ class TestPreprocess:
|
||||
|
||||
def test_w_deepspeed(self, temp_dir):
|
||||
"""make sure preproces doesn't choke when using deepspeed in the config"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "Qwen/Qwen2.5-0.5B",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_prm(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -18,7 +18,6 @@ class TestQATLlama:
|
||||
"""
|
||||
|
||||
def test_qat(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -68,7 +67,6 @@ class TestQATLlama:
|
||||
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
|
||||
|
||||
def test_qat_dpo(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -131,7 +131,7 @@ class TestQuantization:
|
||||
@require_torch_2_6_0
|
||||
def test_prepare_model_for_qat(
|
||||
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
prepare_model_for_qat(
|
||||
model, weight_dtype, group_size, activation_dtype, quantize_embedding
|
||||
)
|
||||
@@ -175,7 +175,7 @@ class TestQuantization:
|
||||
group_size,
|
||||
quantize_embedding,
|
||||
expected_exception,
|
||||
): # pylint: disable=redefined-outer-name
|
||||
):
|
||||
if expected_exception:
|
||||
with pytest.raises(expected_exception):
|
||||
quantize_model_for_ptq(
|
||||
@@ -198,11 +198,13 @@ class TestQuantization:
|
||||
if activation_dtype:
|
||||
assert isinstance(
|
||||
child.weight, LinearActivationQuantizedTensor
|
||||
), "Linear weight should be quantized with activation quantization"
|
||||
), (
|
||||
"Linear weight should be quantized with activation quantization"
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
child.weight, AffineQuantizedTensor
|
||||
), "Linear weight should be quantized without activation quantization"
|
||||
assert isinstance(child.weight, AffineQuantizedTensor), (
|
||||
"Linear weight should be quantized without activation quantization"
|
||||
)
|
||||
|
||||
|
||||
class TestQuantizationCallback:
|
||||
@@ -217,9 +219,7 @@ class TestQuantizationCallback:
|
||||
)
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_qat_callback_fake_quant_after_n_steps(
|
||||
self, model, trainer_state
|
||||
): # pylint: disable=redefined-outer-name
|
||||
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
@@ -269,9 +269,7 @@ class TestQuantizationCallback:
|
||||
assert model.lm_head.weight_fake_quantizer.enabled
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(
|
||||
self, model, trainer_state
|
||||
): # pylint: disable=redefined-outer-name
|
||||
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
|
||||
cfg = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
@@ -314,9 +312,7 @@ class TestConvertQATModelForPTQ:
|
||||
"""
|
||||
|
||||
@require_torch_2_6_0
|
||||
def test_convert_qat_model_for_ptq(
|
||||
self, model
|
||||
): # pylint: disable=redefined-outer-name
|
||||
def test_convert_qat_model_for_ptq(self, model):
|
||||
config = QATConfig(
|
||||
weight_dtype="int8",
|
||||
activation_dtype="int8",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestE2eQwen:
|
||||
|
||||
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
|
||||
def test_dpo(self, base_model, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": base_model,
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_rm_lora(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -20,7 +20,6 @@ class TestSaveFirstStepCallback(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_save_first_step(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -61,7 +60,6 @@ class TestSaveFirstStepCallback(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_no_save_first_step(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -19,7 +19,6 @@ class TestCustomSchedulers(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_rex_scheduler(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
helper utils for tests
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
@@ -107,12 +108,7 @@ def require_vllm(test_case):
|
||||
"""
|
||||
|
||||
def is_vllm_installed():
|
||||
try:
|
||||
import vllm # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
return importlib.util.find_spec("vllm") is not None
|
||||
|
||||
return unittest.skipUnless(
|
||||
is_vllm_installed(), "test requires vllm to be installed"
|
||||
@@ -125,12 +121,7 @@ def require_llmcompressor(test_case):
|
||||
"""
|
||||
|
||||
def is_llmcompressor_installed():
|
||||
try:
|
||||
import llmcompressor # pylint: disable=unused-import # noqa: F401
|
||||
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
return importlib.util.find_spec("llmcompressor") is not None
|
||||
|
||||
return unittest.skipUnless(
|
||||
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
|
||||
@@ -159,8 +150,8 @@ def check_tensorboard(
|
||||
tb_log_path = most_recent_subdir(temp_run_dir)
|
||||
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
|
||||
reader = SummaryReader(event_file)
|
||||
df = reader.scalars # pylint: disable=invalid-name
|
||||
df = df[(df.tag == tag)] # pylint: disable=invalid-name
|
||||
df = reader.scalars
|
||||
df = df[(df.tag == tag)]
|
||||
lt_val = (1 + rtol) * lt_val
|
||||
if "%s" in assertion_err:
|
||||
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]
|
||||
|
||||
@@ -20,7 +20,7 @@ def reload_modules(hf_hub_offline):
|
||||
importlib.reload(huggingface_hub.constants)
|
||||
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
|
||||
importlib.reload(datasets.config)
|
||||
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
|
||||
datasets.config.HF_HUB_OFFLINE = hf_hub_offline
|
||||
reset_sessions()
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from axolotl.utils.config import prepare_plugins, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
@pytest.fixture(name="minimal_liger_cfg")
|
||||
def fixture_cfg():
|
||||
return DictDefault(
|
||||
@@ -30,7 +29,6 @@ def fixture_cfg():
|
||||
)
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation:
|
||||
"""
|
||||
Test the validation module for liger
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# pylint: disable=too-many-lines
|
||||
"""Module for testing the validation module"""
|
||||
|
||||
import os
|
||||
@@ -49,7 +48,6 @@ class BaseValidation:
|
||||
self._caplog = caplog
|
||||
|
||||
|
||||
# pylint: disable=too-many-public-methods
|
||||
class TestValidation(BaseValidation):
|
||||
"""
|
||||
Test the validation module
|
||||
@@ -241,7 +239,7 @@ class TestValidation(BaseValidation):
|
||||
|
||||
def test_lr_as_float(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"learning_rate": "5e-5",
|
||||
}
|
||||
@@ -303,7 +301,7 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -315,7 +313,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"gptq": True,
|
||||
}
|
||||
@@ -327,7 +325,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": False,
|
||||
}
|
||||
@@ -339,7 +337,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
@@ -361,7 +359,7 @@ class TestValidation(BaseValidation):
|
||||
)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_8bit": True,
|
||||
}
|
||||
@@ -373,7 +371,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"gptq": True,
|
||||
}
|
||||
@@ -385,7 +383,7 @@ class TestValidation(BaseValidation):
|
||||
validate_config(cfg)
|
||||
|
||||
cfg = (
|
||||
DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
DictDefault(
|
||||
{
|
||||
"load_in_4bit": True,
|
||||
}
|
||||
|
||||
@@ -30,7 +30,6 @@ def fixture_assistant_dataset():
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset")
|
||||
def fixture_sharegpt_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -47,7 +46,6 @@ def fixture_sharegpt_dataset():
|
||||
|
||||
@pytest.fixture(name="basic_dataset")
|
||||
def fixture_basic_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -65,7 +63,6 @@ def fixture_basic_dataset():
|
||||
|
||||
@pytest.fixture(name="toolcalling_dataset")
|
||||
def fixture_toolcalling_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -112,7 +109,7 @@ def fixture_toolcalling_dataset():
|
||||
@enable_hf_offline
|
||||
def fixture_llama3_tokenizer(
|
||||
download_llama3_8b_instruct_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
|
||||
return tokenizer
|
||||
@@ -129,7 +126,7 @@ def fixture_smollm2_tokenizer():
|
||||
@enable_hf_offline
|
||||
def fixture_mistralv03_tokenizer(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
import unittest
|
||||
|
||||
from axolotl.prompt_strategies.messages.chat import load
|
||||
@@ -53,9 +52,9 @@ class TestMessagesChatLlama3:
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -30,7 +30,6 @@ def fixture_alpaca_dataset():
|
||||
@pytest.fixture(name="tokenizer")
|
||||
@enable_hf_offline
|
||||
def fixture_tokenizer():
|
||||
# pylint: disable=all
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||
)
|
||||
|
||||
@@ -18,9 +18,7 @@ def fixture_messages_w_tools():
|
||||
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
""".strip().split(
|
||||
"\n"
|
||||
)
|
||||
""".strip().split("\n")
|
||||
rows = [json.loads(row) for row in jsons]
|
||||
return Dataset.from_list(rows)
|
||||
|
||||
@@ -28,7 +26,7 @@ def fixture_messages_w_tools():
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -67,9 +67,9 @@ class TestAssistantChatTemplateLlama3:
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
||||
LOG.info("Testing llama-3 with assistant dataset")
|
||||
@@ -109,9 +109,9 @@ class TestAssistantChatTemplateLlama3:
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
def test_phi35(self, phi35_tokenizer, assistant_dataset):
|
||||
LOG.info("Testing phi-3.5 with assistant dataset")
|
||||
@@ -161,15 +161,15 @@ class TestAssistantChatTemplateLlama3:
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
LOG.debug(f"Expected labels : {expected_labels}")
|
||||
LOG.debug(f"Actual labels : {labels}")
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Input IDs mismatch: {labels} != {expected_labels}"
|
||||
assert labels == expected_labels, (
|
||||
f"Input IDs mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
|
||||
LOG.info("Testing llama-3 with assistant dataset including training data")
|
||||
@@ -234,7 +234,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
|
||||
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
@@ -285,16 +285,16 @@ class TestSharegptChatTemplateLlama3:
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
@@ -345,16 +345,16 @@ class TestSharegptChatTemplateLlama3:
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
@@ -409,12 +409,12 @@ class TestSharegptChatTemplateLlama3:
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
|
||||
class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
@@ -481,13 +481,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
def test_llama32vision_train_on_tools(
|
||||
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
|
||||
@@ -495,7 +495,6 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
LOG.info(
|
||||
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools"
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
@@ -549,13 +548,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
@@ -96,9 +94,9 @@ class TestChatTemplateConfigurations:
|
||||
and turn.get("from") in ["system", "context"]
|
||||
and ("mistral" in tokenizer.name_or_path.lower())
|
||||
):
|
||||
assert (
|
||||
start_idx == -1 and end_idx == -1
|
||||
), "Expected system message to be skipped"
|
||||
assert start_idx == -1 and end_idx == -1, (
|
||||
"Expected system message to be skipped"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -155,7 +153,9 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
LOG.debug("Full labels: %s", labels)
|
||||
LOG.debug("Full input_ids: %s", input_ids)
|
||||
@@ -215,11 +215,15 @@ class TestChatTemplateConfigurations:
|
||||
if is_assistant:
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
def test_roles_to_train_human_assistant_only(
|
||||
self,
|
||||
@@ -276,11 +280,15 @@ class TestChatTemplateConfigurations:
|
||||
if should_be_labelled:
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
def test_roles_to_train_all(
|
||||
self,
|
||||
@@ -327,13 +335,15 @@ class TestChatTemplateConfigurations:
|
||||
continue
|
||||
|
||||
decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
|
||||
assert (
|
||||
response in decoded_response
|
||||
), f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
|
||||
assert response in decoded_response, (
|
||||
f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
|
||||
)
|
||||
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
def test_empty_roles_to_train(
|
||||
self,
|
||||
@@ -371,9 +381,9 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
# Verify that no labels are set when roles_to_train is empty
|
||||
LOG.debug("Full labels: %s", labels)
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels
|
||||
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
||||
assert all(label == IGNORE_TOKEN_ID for label in labels), (
|
||||
"Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
||||
)
|
||||
|
||||
def test_train_on_eos_all(
|
||||
self,
|
||||
@@ -417,9 +427,9 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||
for eos_idx in eos_indices:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {eos_idx} to be labeled"
|
||||
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token at index {eos_idx} to be labeled"
|
||||
)
|
||||
|
||||
def test_train_on_eos_turn(
|
||||
self,
|
||||
@@ -477,9 +487,9 @@ class TestChatTemplateConfigurations:
|
||||
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||
eos_idx += 1
|
||||
|
||||
assert eos_idx < len(
|
||||
input_ids
|
||||
), f"Could not find EOS token after '{response}'"
|
||||
assert eos_idx < len(input_ids), (
|
||||
f"Could not find EOS token after '{response}'"
|
||||
)
|
||||
|
||||
LOG.debug(
|
||||
f"Turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}, eos_idx={eos_idx}"
|
||||
@@ -492,13 +502,13 @@ class TestChatTemplateConfigurations:
|
||||
# Verify EOS token labeling based on role
|
||||
is_assistant = turn["from"] == "assistant"
|
||||
if is_assistant:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token after assistant response '{response}' to be labeled"
|
||||
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token after assistant response '{response}' to be labeled"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token after non-assistant input '{response}' to not be labeled"
|
||||
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token after non-assistant input '{response}' to not be labeled"
|
||||
)
|
||||
|
||||
def test_train_on_eos_last(
|
||||
self,
|
||||
@@ -545,12 +555,12 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
# Check that only the last EOS token is labeled
|
||||
for idx in eos_indices[:-1]:
|
||||
assert (
|
||||
labels[idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {idx} to not be labeled"
|
||||
assert (
|
||||
labels[last_eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
||||
assert labels[idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token at index {idx} to not be labeled"
|
||||
)
|
||||
assert labels[last_eos_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
||||
)
|
||||
|
||||
def test_train_on_eos_none(
|
||||
self,
|
||||
@@ -594,9 +604,9 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||
for eos_idx in eos_indices:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {eos_idx} to not be labeled"
|
||||
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token at index {eos_idx} to not be labeled"
|
||||
)
|
||||
|
||||
def test_drop_system_message(
|
||||
self,
|
||||
@@ -634,9 +644,9 @@ class TestChatTemplateConfigurations:
|
||||
# Check if system message is not present in input_ids
|
||||
system_message = "You are an AI assistant."
|
||||
decoded_message = tokenizer.decode(input_ids)
|
||||
assert (
|
||||
system_message not in decoded_message
|
||||
), "Expected system message to be dropped"
|
||||
assert system_message not in decoded_message, (
|
||||
"Expected system message to be dropped"
|
||||
)
|
||||
|
||||
def test_custom_roles(
|
||||
self,
|
||||
@@ -711,7 +721,9 @@ class TestChatTemplateConfigurations:
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
|
||||
), (
|
||||
f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
|
||||
)
|
||||
|
||||
def test_message_field_training(
|
||||
self,
|
||||
@@ -776,13 +788,13 @@ class TestChatTemplateConfigurations:
|
||||
def verify_labels(labels_span, should_train, context_message):
|
||||
"""Helper to verify if a span of labels matches expected training state"""
|
||||
if should_train:
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels_span
|
||||
), f"Expected all labels for {context_message} to be set, but got {labels_span}"
|
||||
assert all(label != IGNORE_TOKEN_ID for label in labels_span), (
|
||||
f"Expected all labels for {context_message} to be set, but got {labels_span}"
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels_span
|
||||
), f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
|
||||
assert all(label == IGNORE_TOKEN_ID for label in labels_span), (
|
||||
f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
|
||||
)
|
||||
|
||||
# Process all turns and verify labeling
|
||||
for i, turn in enumerate(modified_dataset[0]["messages"]):
|
||||
@@ -861,9 +873,9 @@ class TestChatTemplateConfigurations:
|
||||
actual_labels = labels[
|
||||
start_idx : start_idx + len(token_offsets_masked)
|
||||
]
|
||||
assert (
|
||||
actual_labels == expected_labels
|
||||
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
||||
assert actual_labels == expected_labels, (
|
||||
f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
||||
)
|
||||
|
||||
# Verify each detail section
|
||||
for detail in adjusted_train_details:
|
||||
@@ -958,7 +970,7 @@ class TestChatTemplateConfigurations:
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset, # pylint: disable=unused-argument
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict"""
|
||||
@@ -1005,7 +1017,7 @@ class TestChatTemplateConfigurations:
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset, # pylint: disable=unused-argument
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test that eot_tokens inherits from eos_token when not specified"""
|
||||
@@ -1032,12 +1044,12 @@ class TestChatTemplateConfigurations:
|
||||
)
|
||||
|
||||
# In backward compatibility mode, eot_tokens should be derived from eos_token
|
||||
assert strategy.eot_tokens == [
|
||||
tokenizer.eos_token
|
||||
], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
|
||||
assert (
|
||||
strategy.train_on_eot == "turn"
|
||||
), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
|
||||
assert strategy.eot_tokens == [tokenizer.eos_token], (
|
||||
f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
|
||||
)
|
||||
assert strategy.train_on_eot == "turn", (
|
||||
f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
|
||||
)
|
||||
|
||||
def test_token_not_in_template(
|
||||
self,
|
||||
@@ -1091,7 +1103,7 @@ class TestChatTemplateConfigurations:
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token, # pylint: disable=unused-argument
|
||||
eos_token,
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
@@ -1157,13 +1169,13 @@ class TestChatTemplateConfigurations:
|
||||
)
|
||||
|
||||
if is_after_assistant:
|
||||
assert (
|
||||
labels[eot_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
|
||||
assert labels[eot_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
labels[eot_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
|
||||
assert labels[eot_idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
|
||||
)
|
||||
|
||||
def test_multiple_train_on_eot_settings(
|
||||
self,
|
||||
@@ -1224,9 +1236,9 @@ class TestChatTemplateConfigurations:
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||
]
|
||||
|
||||
assert (
|
||||
len(eos_indices) > 0
|
||||
), "Expected at least one EOS/EOT token in the input"
|
||||
assert len(eos_indices) > 0, (
|
||||
"Expected at least one EOS/EOT token in the input"
|
||||
)
|
||||
|
||||
# Check labeling for each EOS/EOT token
|
||||
for idx, eos_idx in enumerate(eos_indices):
|
||||
@@ -1252,13 +1264,13 @@ class TestChatTemplateConfigurations:
|
||||
)
|
||||
|
||||
if expected_label:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
|
||||
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
)
|
||||
|
||||
|
||||
class TestChatTemplateToolCalling:
|
||||
@@ -1378,29 +1390,27 @@ class TestChatTemplateToolCalling:
|
||||
decoded_conversation = tokenizer.decode(input_ids)
|
||||
|
||||
# Verify tool calling structure is present in the decoded conversation
|
||||
assert (
|
||||
'"type": "function",' in decoded_conversation
|
||||
), "Tool type function should be in conversation"
|
||||
assert (
|
||||
'"name": "multiples",' in decoded_conversation
|
||||
), "Tool function name should be in conversation"
|
||||
assert '"type": "function",' in decoded_conversation, (
|
||||
"Tool type function should be in conversation"
|
||||
)
|
||||
assert '"name": "multiples",' in decoded_conversation, (
|
||||
"Tool function name should be in conversation"
|
||||
)
|
||||
|
||||
assert (
|
||||
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
|
||||
in decoded_conversation
|
||||
), "Assistant tool call should be in conversation"
|
||||
assert (
|
||||
"<|header_start|>ipython<|header_end|>" in decoded_conversation
|
||||
), "IPython header should be in conversation"
|
||||
assert (
|
||||
'"5,10,15"' in decoded_conversation
|
||||
), "Tool response should be in conversation"
|
||||
assert "<|header_start|>ipython<|header_end|>" in decoded_conversation, (
|
||||
"IPython header should be in conversation"
|
||||
)
|
||||
assert '"5,10,15"' in decoded_conversation, (
|
||||
"Tool response should be in conversation"
|
||||
)
|
||||
|
||||
# Get conversation turns to verify labeling
|
||||
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
|
||||
tools = strategy._get_tools( # pylint: disable=protected-access
|
||||
tool_calling_dataset[0]
|
||||
)
|
||||
tools = strategy._get_tools(tool_calling_dataset[0])
|
||||
|
||||
# Check that assistant responses are properly labeled
|
||||
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
|
||||
@@ -1409,12 +1419,12 @@ class TestChatTemplateToolCalling:
|
||||
turns=turns, turn_idx=i, tools=tools
|
||||
)
|
||||
|
||||
assert (
|
||||
start_idx != -1 and end_idx != -1
|
||||
), f"Assistant turn {i} should be found"
|
||||
assert start_idx != -1 and end_idx != -1, (
|
||||
f"Assistant turn {i} should be found"
|
||||
)
|
||||
|
||||
# Verify that assistant responses have proper labels
|
||||
turn_labels = labels[start_idx:end_idx]
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in turn_labels
|
||||
), f"Assistant turn {i} should be unmasked"
|
||||
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||
f"Assistant turn {i} should be unmasked"
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_mistral_chat_template(
|
||||
request: pytest.FixtureRequest,
|
||||
):
|
||||
"""Test chat template with the Magistral/Devstral tokenizer"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
|
||||
|
||||
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)
|
||||
|
||||
@@ -59,7 +59,7 @@ def messages_w_reasoning_fixture():
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
@@ -71,7 +71,6 @@ class TestSplitThinking:
|
||||
"""
|
||||
|
||||
def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = load(
|
||||
qwen3_tokenizer,
|
||||
DictDefault(
|
||||
@@ -130,6 +129,6 @@ class TestSplitThinking:
|
||||
198, # \n
|
||||
]
|
||||
# fmt: on
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
@@ -16,7 +16,6 @@ from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@pytest.fixture(name="assistant_dataset")
|
||||
def fixture_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -49,7 +48,6 @@ def fixture_assistant_dataset():
|
||||
|
||||
@pytest.fixture(name="custom_assistant_dataset")
|
||||
def fixture_custom_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -102,7 +100,6 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
"""
|
||||
|
||||
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
@@ -127,7 +124,6 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
assert result["rejected"] == "party on<|eot_id|>"
|
||||
|
||||
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
@@ -168,7 +164,6 @@ class TestAssistantDPOChatTemplatePhi3:
|
||||
"""
|
||||
|
||||
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
@@ -198,7 +193,6 @@ class TestAssistantDPOChatTemplateGemma:
|
||||
"""
|
||||
|
||||
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
|
||||
@@ -20,7 +20,6 @@ class TestStepWiseSupervisedPromptTokenizingStrategy:
|
||||
|
||||
@pytest.fixture()
|
||||
def stepwise_supervised_dataset(self):
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
|
||||
@@ -22,7 +22,7 @@ def chunked_fixtures():
|
||||
return lm_head, hidden_state, labels, vocab_size
|
||||
|
||||
|
||||
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
|
||||
def test_chunked_forward(chunked_fixtures):
|
||||
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
|
||||
lm_loss = get_causal_lm_loss()
|
||||
|
||||
|
||||
@@ -374,7 +374,6 @@ class TestDatasetPreparation:
|
||||
}
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
) as mock_load_dataset:
|
||||
|
||||
@@ -21,26 +21,26 @@ class DictDefaultTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
assert (
|
||||
cfg.key_a.key_b == "value_a"
|
||||
), "DictDefault should return value for existing nested keys"
|
||||
assert cfg.key_a.key_b == "value_a", (
|
||||
"DictDefault should return value for existing nested keys"
|
||||
)
|
||||
|
||||
assert (
|
||||
cfg.key_c == "value_c"
|
||||
), "DictDefault should return value for existing keys"
|
||||
assert cfg.key_c == "value_c", (
|
||||
"DictDefault should return value for existing keys"
|
||||
)
|
||||
|
||||
assert (
|
||||
cfg.key_d[0] == "value_d"
|
||||
), "DictDefault should return value for existing keys in list"
|
||||
assert cfg.key_d[0] == "value_d", (
|
||||
"DictDefault should return value for existing keys in list"
|
||||
)
|
||||
|
||||
assert (
|
||||
"value_e" in cfg.key_d
|
||||
), "DictDefault should support in operator for existing keys in list"
|
||||
assert "value_e" in cfg.key_d, (
|
||||
"DictDefault should support in operator for existing keys in list"
|
||||
)
|
||||
|
||||
def test_dict_or_operator(self):
|
||||
cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
|
||||
|
||||
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
|
||||
cfg = cfg | DictDefault(
|
||||
{
|
||||
"key_a": {"key_b": "value_a"},
|
||||
"key_c": "value_c",
|
||||
@@ -49,9 +49,9 @@ class DictDefaultTest(unittest.TestCase):
|
||||
}
|
||||
)
|
||||
|
||||
assert (
|
||||
cfg.key_a.key_b == "value_b"
|
||||
), "DictDefault should support OR operator for existing nested keys"
|
||||
assert cfg.key_a.key_b == "value_b", (
|
||||
"DictDefault should support OR operator for existing nested keys"
|
||||
)
|
||||
|
||||
assert cfg.key_c == "value_c", "DictDefault should not delete existing key"
|
||||
|
||||
@@ -60,9 +60,9 @@ class DictDefaultTest(unittest.TestCase):
|
||||
"value_e",
|
||||
], "DictDefault should not overwrite existing keys in list"
|
||||
|
||||
assert (
|
||||
cfg.key_f == "value_g"
|
||||
), "DictDefault should support OR operator for existing key"
|
||||
assert cfg.key_f == "value_g", (
|
||||
"DictDefault should support OR operator for existing key"
|
||||
)
|
||||
|
||||
def test_dict_missingkey(self):
|
||||
cfg = DictDefault({})
|
||||
@@ -72,9 +72,9 @@ class DictDefaultTest(unittest.TestCase):
|
||||
def test_dict_or(self):
|
||||
cfg = DictDefault({}) | DictDefault({})
|
||||
|
||||
assert (
|
||||
cfg.random_key is None
|
||||
), "DictDefault should return None for missing keys after | operation"
|
||||
assert cfg.random_key is None, (
|
||||
"DictDefault should return None for missing keys after | operation"
|
||||
)
|
||||
|
||||
def test_dict_nested_missingparentkey(self):
|
||||
"""
|
||||
|
||||
@@ -41,9 +41,9 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
|
||||
assert actual_rows == expected_rows, f"Mismatch in {dataset_name} dataset"
|
||||
|
||||
# Verify size consistency
|
||||
assert len(actual_rows) == len(
|
||||
actual_dataset
|
||||
), f"Size mismatch in {dataset_name} dataset after deduplication"
|
||||
assert len(actual_rows) == len(actual_dataset), (
|
||||
f"Size mismatch in {dataset_name} dataset after deduplication"
|
||||
)
|
||||
|
||||
|
||||
class TestDeduplicateIndividualFunctions(unittest.TestCase):
|
||||
@@ -224,7 +224,6 @@ class TestDeduplicateRLDataset:
|
||||
):
|
||||
"""Verify that loading with deduplication removes duplicates."""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
@@ -251,7 +250,6 @@ class TestDeduplicateRLDataset:
|
||||
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
|
||||
tokenizer_huggyllama,
|
||||
):
|
||||
# pylint: disable=duplicate-code
|
||||
with (
|
||||
patch(
|
||||
"axolotl.utils.data.rl.load_dataset_with_config"
|
||||
@@ -271,9 +269,9 @@ class TestDeduplicateRLDataset:
|
||||
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
|
||||
|
||||
# Verify that the dataset retains duplicates
|
||||
assert (
|
||||
len(train_dataset) == 1800 * 2
|
||||
), "Dataset deduplication occurred when it should not have"
|
||||
assert len(train_dataset) == 1800 * 2, (
|
||||
"Dataset deduplication occurred when it should not have"
|
||||
)
|
||||
|
||||
|
||||
class TestDeduplicateNonRL(unittest.TestCase):
|
||||
|
||||
@@ -17,7 +17,7 @@ class TestModelsUtils:
|
||||
|
||||
def setup_method(self) -> None:
|
||||
# load config
|
||||
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
|
||||
self.cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
@@ -30,20 +30,16 @@ class TestModelsUtils:
|
||||
"device_map": "auto",
|
||||
}
|
||||
)
|
||||
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
|
||||
spec=PreTrainedTokenizerBase
|
||||
)
|
||||
self.inference = False # pylint: disable=attribute-defined-outside-init
|
||||
self.reference_model = True # pylint: disable=attribute-defined-outside-init
|
||||
self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
|
||||
self.inference = False
|
||||
self.reference_model = True
|
||||
|
||||
# init ModelLoader
|
||||
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
|
||||
ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
inference=self.inference,
|
||||
reference_model=self.reference_model,
|
||||
)
|
||||
self.model_loader = ModelLoader(
|
||||
cfg=self.cfg,
|
||||
tokenizer=self.tokenizer,
|
||||
inference=self.inference,
|
||||
reference_model=self.reference_model,
|
||||
)
|
||||
|
||||
def test_set_device_map_config(self):
|
||||
@@ -51,7 +47,7 @@ class TestModelsUtils:
|
||||
device_map = self.cfg.device_map
|
||||
if is_torch_mps_available():
|
||||
device_map = "mps"
|
||||
# pylint: disable=protected-access
|
||||
|
||||
self.model_loader._set_device_map_config()
|
||||
if is_deepspeed_zero3_enabled():
|
||||
assert "device_map" not in self.model_loader.model_kwargs
|
||||
@@ -78,7 +74,6 @@ class TestModelsUtils:
|
||||
self.cfg.gptq = gptq
|
||||
self.cfg.adapter = adapter
|
||||
|
||||
# pylint: disable=protected-access
|
||||
self.model_loader._set_quantization_config()
|
||||
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
|
||||
assert not (
|
||||
@@ -194,7 +189,7 @@ class TestModelsUtils:
|
||||
is_fsdp,
|
||||
expected,
|
||||
):
|
||||
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
|
||||
res = _get_parallel_config_kwargs(
|
||||
world_size,
|
||||
tensor_parallel_size,
|
||||
context_parallel_size,
|
||||
|
||||
@@ -6,7 +6,6 @@ from axolotl.loaders import ModelLoader, load_tokenizer
|
||||
from axolotl.utils.config import normalize_config, validate_config
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
minimal_config = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
|
||||
@@ -93,7 +93,7 @@ class TestBatchedSamplerPacking:
|
||||
loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_sampler=batch_sampler,
|
||||
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
|
||||
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq(
|
||||
tokenizer=tokenizer,
|
||||
padding=True,
|
||||
pad_to_multiple_of=max_seq_length,
|
||||
|
||||
@@ -26,7 +26,6 @@ class TestPacking(unittest.TestCase):
|
||||
|
||||
@enable_hf_offline
|
||||
def setUp(self) -> None:
|
||||
# pylint: disable=duplicate-code
|
||||
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
|
||||
self.tokenizer.add_special_tokens(
|
||||
{
|
||||
@@ -75,7 +74,6 @@ class TestPacking(unittest.TestCase):
|
||||
|
||||
@with_temp_dir
|
||||
def test_lora_packing(self, temp_dir):
|
||||
# pylint: disable=duplicate-code
|
||||
cfg = DictDefault(
|
||||
{
|
||||
"base_model": "HuggingFaceTB/SmolLM2-135M",
|
||||
@@ -127,9 +125,7 @@ class TestPacking(unittest.TestCase):
|
||||
_,
|
||||
) = setup_model_and_trainer(cfg, dataset_meta)
|
||||
|
||||
sampler = trainer._get_eval_sampler( # pylint: disable=protected-access
|
||||
trainer.eval_dataset
|
||||
)
|
||||
sampler = trainer._get_eval_sampler(trainer.eval_dataset)
|
||||
assert "MultipackBatchSampler" in sampler.__class__.__name__
|
||||
assert (
|
||||
"V2BatchSamplerDataCollatorForSeq2Seq"
|
||||
@@ -140,9 +136,7 @@ class TestPacking(unittest.TestCase):
|
||||
batch = next(dataloader_iter)
|
||||
assert batch["input_ids"].shape == (1, 8192)
|
||||
|
||||
sampler = trainer._get_train_sampler( # pylint: disable=protected-access
|
||||
trainer.train_dataset
|
||||
)
|
||||
sampler = trainer._get_train_sampler(trainer.train_dataset)
|
||||
assert "MultipackBatchSampler" in sampler.__class__.__name__
|
||||
assert (
|
||||
"V2BatchSamplerDataCollatorForSeq2Seq"
|
||||
|
||||
@@ -76,7 +76,6 @@ class TestPretrainingPacking:
|
||||
cfg.pretraining_dataset[0]["type"] or "pretrain",
|
||||
)
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
original_bsz = cfg.micro_batch_size
|
||||
train_dataset = wrap_pretraining_dataset(
|
||||
dataset,
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
"""unit tests for perplexity eval callback"""
|
||||
|
||||
# pylint: disable=redefined-outer-name
|
||||
|
||||
from pytest import fixture
|
||||
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
|
||||
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
||||
|
||||
@@ -64,7 +64,7 @@ class TestPromptTokenizationStrategies:
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
prompter = NoSystemPrompter()
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
tokenizer_huggyllama_w_special_tokens,
|
||||
@@ -85,7 +85,7 @@ class TestPromptTokenizationStrategies:
|
||||
"""
|
||||
tests the interface between the user and assistant parts
|
||||
"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
prompter = AlpacaPrompter()
|
||||
strat = AlpacaPromptTokenizingStrategy(
|
||||
prompter,
|
||||
@@ -171,7 +171,7 @@ class Llama2ChatTokenizationTest:
|
||||
# from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT
|
||||
# broken as of 23/7/20
|
||||
# see https://github.com/huggingface/transformers/pull/24935
|
||||
# pylint: disable=C0103
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """\
|
||||
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
|
||||
|
||||
@@ -201,7 +201,7 @@ If a question does not make any sense, or is not factually coherent, explain why
|
||||
+ user_input[1:-1],
|
||||
generated_responses=answers,
|
||||
)
|
||||
# pylint: disable=W0212
|
||||
|
||||
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
|
||||
|
||||
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user