Add ruff, remove black, isort, flake8, pylint (#3092)

* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
This commit is contained in:
Dan Saunders
2025-08-23 23:37:33 -04:00
committed by GitHub
parent eea7a006e1
commit 79ddaebe9a
286 changed files with 10979 additions and 11435 deletions

View File

@@ -1,7 +1,5 @@
"""Tests for evaluate CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli
@@ -31,7 +29,6 @@ class TestEvaluateCommand(BaseCliTest):
config_path = tmp_path / "config.yml"
config_path.write_text(valid_test_config)
# pylint: disable=duplicate-code
with patch("axolotl.cli.evaluate.do_evaluate") as mock_evaluate:
result = cli_runner.invoke(
cli,

View File

@@ -1,7 +1,5 @@
"""pytest tests for axolotl CLI inference command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,7 +1,5 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code
from unittest.mock import patch
from axolotl.cli.main import cli

View File

@@ -1,7 +1,5 @@
"""Tests for train CLI command."""
# pylint: disable=duplicate-code
from unittest.mock import MagicMock, patch
from axolotl.cli.main import cli

View File

@@ -1,7 +1,5 @@
"""pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name
import json
from unittest.mock import Mock, patch
@@ -25,7 +23,7 @@ MOCK_TREE_RESPONSE = {
def mock_responses():
"""Mock responses for API and file downloads"""
def mock_get(url, timeout=None): # pylint: disable=unused-argument
def mock_get(url, timeout=None):
response = Mock()
if "api.github.com" in url:
response.text = json.dumps(MOCK_TREE_RESPONSE)
@@ -93,21 +91,21 @@ def assert_launcher_args_in_command(
called_cmd = mock_subprocess_call.call_args.args[0]
# Verify launcher
assert (
called_cmd[0] == launcher
), f"Expected launcher {launcher}, got {called_cmd[0]}"
assert called_cmd[0] == launcher, (
f"Expected launcher {launcher}, got {called_cmd[0]}"
)
# Verify launcher args are present
for arg in expected_launcher_args:
assert (
arg in called_cmd
), f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
assert arg in called_cmd, (
f"Expected launcher arg '{arg}' not found in command: {called_cmd}"
)
# Verify module is present
assert "-m" in called_cmd, "Expected -m flag for module execution"
assert (
command_module in called_cmd
), f"Expected module {command_module} not found in command: {called_cmd}"
assert command_module in called_cmd, (
f"Expected module {command_module} not found in command: {called_cmd}"
)
def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
@@ -126,17 +124,17 @@ def assert_no_launcher_args_contamination(mock_subprocess_call, launcher: str):
launch_idx = called_cmd.index("launch")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[launch_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
assert len(launcher_section) == 0, (
f"Unexpected launcher args found: {launcher_section}"
)
elif launcher == "torchrun":
# For torchrun, launcher args should be between 'torchrun' and '-m'
torchrun_idx = called_cmd.index("torchrun")
m_idx = called_cmd.index("-m")
launcher_section = called_cmd[torchrun_idx + 1 : m_idx]
assert (
len(launcher_section) == 0
), f"Unexpected launcher args found: {launcher_section}"
assert len(launcher_section) == 0, (
f"Unexpected launcher args found: {launcher_section}"
)
@pytest.fixture

View File

@@ -33,10 +33,9 @@ logging.getLogger("filelock").setLevel(logging.CRITICAL)
def retry_on_request_exceptions(max_retries=3, delay=1):
# pylint: disable=duplicate-code
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs): # pylint: disable=inconsistent-return-statements
def wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
@@ -171,7 +170,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
# @disable_hf_offline
# def dataset_fozzie_alpaca_dpo_dataset(
# download_fozzie_alpaca_dpo_dataset,
# ): # pylint: disable=unused-argument,redefined-outer-name
# ):
# return load_dataset("fozziethebeat/alpaca_messages_2k_dpo_test", split="train")
#
#
@@ -179,7 +178,7 @@ def download_argilla_distilabel_intel_orca_dpo_dataset():
# @disable_hf_offline
# def dataset_fozzie_alpaca_dpo_dataset_rev_ea82cff(
# download_fozzie_alpaca_dpo_dataset,
# ): # pylint: disable=unused-argument,redefined-outer-name
# ):
# return load_dataset(
# "fozziethebeat/alpaca_messages_2k_dpo_test", split="train", revision="ea82cff"
# )
@@ -359,7 +358,7 @@ def download_llama32_1b_model_fixture():
@enable_hf_offline
def tokenizer_huggyllama(
download_huggyllama_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
):
tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
tokenizer.pad_token = "</s>"
@@ -370,7 +369,7 @@ def tokenizer_huggyllama(
@enable_hf_offline
def tokenizer_huggyllama_w_special_tokens(
tokenizer_huggyllama,
): # pylint: disable=redefined-outer-name
):
tokenizer_huggyllama.add_special_tokens(
{
"bos_token": "<s>",
@@ -386,7 +385,7 @@ def tokenizer_huggyllama_w_special_tokens(
@enable_hf_offline
def tokenizer_llama2_7b(
download_llama2_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
):
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf")
return tokenizer
@@ -396,7 +395,7 @@ def tokenizer_llama2_7b(
@enable_hf_offline
def tokenizer_mistral_7b_instruct(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
):
return AutoTokenizer.from_pretrained("casperhansen/mistral-7b-instruct-v0.1-awq")
@@ -442,9 +441,7 @@ def cleanup_monkeypatches():
# original_fa2_forward = LlamaFlashAttention2.forward
original_llama_attn_forward = LlamaAttention.forward
original_llama_forward = LlamaForCausalLM.forward
original_trainer_inner_training_loop = (
Trainer._inner_training_loop # pylint: disable=protected-access
)
original_trainer_inner_training_loop = Trainer._inner_training_loop
original_trainer_training_step = Trainer.training_step
# monkey patches can happen inside the tests
yield
@@ -452,9 +449,7 @@ def cleanup_monkeypatches():
# LlamaFlashAttention2.forward = original_fa2_forward
LlamaAttention.forward = original_llama_attn_forward
LlamaForCausalLM.forward = original_llama_forward
Trainer._inner_training_loop = ( # pylint: disable=protected-access
original_trainer_inner_training_loop
)
Trainer._inner_training_loop = original_trainer_inner_training_loop
Trainer.training_step = original_trainer_training_step
# Reset other known monkeypatches
@@ -490,7 +485,7 @@ def cleanup_monkeypatches():
@pytest.fixture
def dataset_winglian_tiny_shakespeare(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
):
ds_path = download_ds_fixture_bundle / "winglian__tiny-shakespeare"
return datasets.load_from_disk(ds_path)
@@ -498,7 +493,7 @@ def dataset_winglian_tiny_shakespeare(
@pytest.fixture
def dataset_tatsu_lab_alpaca(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
):
ds_path = download_ds_fixture_bundle / "tatsu-lab__alpaca"
return datasets.load_from_disk(ds_path)["train"]
@@ -506,7 +501,7 @@ def dataset_tatsu_lab_alpaca(
@pytest.fixture
def dataset_mhenrichsen_alpaca_2k_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
):
ds_path = download_ds_fixture_bundle / "mhenrichsen__alpaca_2k_test"
return datasets.load_from_disk(ds_path)["train"]
@@ -514,7 +509,7 @@ def dataset_mhenrichsen_alpaca_2k_test(
@pytest.fixture
def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
):
ds_path = (
download_ds_fixture_bundle
/ "argilla__ultrafeedback-binarized-preferences-cleaned"
@@ -525,7 +520,7 @@ def dataset_argilla_ultrafeedback_binarized_preferences_cleaned(
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
):
ds_path = download_ds_fixture_bundle / "fozziethebeat__alpaca_messages_2k_dpo_test"
return datasets.load_from_disk(ds_path)["train"]
@@ -533,7 +528,7 @@ def dataset_fozziethebeat_alpaca_messages_2k_dpo_test(
@pytest.fixture
def dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff(
download_ds_fixture_bundle: Path,
): # pylint: disable=redefined-outer-name
):
ds_path = (
download_ds_fixture_bundle
/ "fozziethebeat__alpaca_messages_2k_dpo_test__rev_ea82cff"
@@ -557,7 +552,7 @@ def fixture_min_base_cfg():
)
# # pylint: disable=redefined-outer-name,unused-argument
#
@pytest.mark.skipif(
os.environ.get("AXOLOTL_IS_CI_CACHE_PRELOAD", "-1") != "1",
reason="Not running in CI cache preload",

View File

@@ -3,6 +3,7 @@
This module contains constants and configuration dictionaries used for
datasets and other utilities in the Axolotl project, specifically for testing.
"""
# Configuration for Alpaca Messages Dataset
ALPACA_MESSAGES_CONFIG_OG = {
"path": "fozziethebeat/alpaca_messages_2k_dpo_test",

View File

@@ -1,7 +1,5 @@
"""Unit tests for axolotl.core.builders"""
# pylint: disable=protected-access
import sys
from pathlib import Path
from unittest.mock import patch
@@ -330,7 +328,6 @@ def rand_reward_func(prompts, completions) -> list[float]:
)
def test_grpo_training_arguments(self, grpo_cfg, model, tokenizer, tmp_path):
rewards_dir = tmp_path / "rewards_test"
self._write_rewards_file(rewards_dir)
@@ -477,7 +474,7 @@ def rand_reward_func(prompts, completions) -> list[float]:
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
from axolotl.contribs.mit.muon import (
Muon,
MuonOptimizerFactory,
)
@@ -559,7 +556,7 @@ class TestHFCausalTrainerBuilder:
assert trainer.optimizer_cls_and_kwargs is not None
from axolotl.contribs.mit.muon import ( # pylint: disable=no-name-in-module
from axolotl.contribs.mit.muon import (
Muon,
MuonOptimizerFactory,
)
@@ -599,6 +596,6 @@ class TestTrainerClsPlugin:
except TypeError as e:
# Error raised if trainer_cls is None
assert "'tuple' object has no attribute 'config'" not in str(e)
except Exception: # pylint: disable=broad-exception-caught
except Exception:
# Another error happens, so we passed trainer_cls to builder
pass

View File

@@ -12,8 +12,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists
# pylint: disable=duplicate-code
@pytest.fixture()
def min_cfg(temp_dir):
@@ -53,7 +51,6 @@ class TestCutCrossEntropyIntegration:
e2e tests for cut_cross_entropy integration with Axolotl
"""
# pylint: disable=redefined-outer-name
def test_llama_w_cce(self, min_cfg, temp_dir):
cfg = DictDefault(min_cfg)
cfg = validate_config(cfg)
@@ -69,7 +66,6 @@ class TestCutCrossEntropyIntegration:
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(temp_dir, cfg)
# pylint: disable=redefined-outer-name
def test_qwen2_w_cce(self, temp_dir):
cfg = DictDefault(
{

View File

@@ -18,7 +18,7 @@ class FP8IntegrationTestCase:
@require_torch_2_7_0
def test_fp8_single_gpu_smoke(self, temp_dir):
"""Smoke test for single GPU FP8 + torch.compile training"""
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -53,7 +53,6 @@ class FP8IntegrationTestCase:
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
normalize_config(cfg)
dataset_meta = load_datasets(cfg=cfg)

View File

@@ -28,85 +28,81 @@ class LogHooksPlugin(BasePlugin):
except FileNotFoundError:
pass
def post_trainer_create(self, cfg, trainer): # pylint: disable=unused-argument
def post_trainer_create(self, cfg, trainer):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_trainer_create\n")
def pre_model_load(self, cfg): # pylint: disable=unused-argument
def pre_model_load(self, cfg):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("pre_model_load\n")
def post_model_build(self, cfg, model): # pylint: disable=unused-argument
def post_model_build(self, cfg, model):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_model_build\n")
def pre_lora_load(self, cfg, model): # pylint: disable=unused-argument
def pre_lora_load(self, cfg, model):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("pre_lora_load\n")
def post_lora_load(self, cfg, model): # pylint: disable=unused-argument
def post_lora_load(self, cfg, model):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_lora_load\n")
def post_model_load(self, cfg, model): # pylint: disable=unused-argument
def post_model_load(self, cfg, model):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_model_load\n")
def create_optimizer(self, cfg, trainer): # pylint: disable=unused-argument
def create_optimizer(self, cfg, trainer):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("create_optimizer\n")
def get_trainer_cls(self, cfg): # pylint: disable=unused-argument
def get_trainer_cls(self, cfg):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("get_trainer_cls\n")
def create_lr_scheduler(
self, cfg, trainer, optimizer, num_training_steps
): # pylint: disable=unused-argument
def create_lr_scheduler(self, cfg, trainer, optimizer, num_training_steps):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("create_lr_scheduler\n")
def add_callbacks_pre_trainer(self, cfg, model): # pylint: disable=unused-argument
def add_callbacks_pre_trainer(self, cfg, model):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("add_callbacks_pre_trainer\n")
return []
def add_callbacks_post_trainer(
self, cfg, trainer
): # pylint: disable=unused-argument
def add_callbacks_post_trainer(self, cfg, trainer):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("add_callbacks_post_trainer\n")
return []
def post_train(self, cfg, model): # pylint: disable=unused-argument
def post_train(self, cfg, model):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
f.write("post_train\n")
def post_train_unload(self, cfg): # pylint: disable=unused-argument
def post_train_unload(self, cfg):
with open(
self.base_dir.joinpath("plugin_hooks.log"), "a", encoding="utf-8"
) as f:
@@ -119,7 +115,6 @@ class TestPluginHooks:
"""
def test_plugin_hooks(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -81,7 +81,7 @@ class TestKnowledgeDistillation:
@require_torch_2_5_1
def test_llama_kd(self, temp_dir, kd_min_cfg):
cfg = DictDefault(kd_min_cfg)
# pylint: disable=duplicate-code
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:
@@ -123,7 +123,7 @@ class TestKnowledgeDistillation:
}
| kd_min_cfg
)
# pylint: disable=duplicate-code
# write cfg to yaml file
Path(temp_dir).mkdir(parents=True, exist_ok=True)
with open(Path(temp_dir) / "config.yaml", "w", encoding="utf-8") as fout:

View File

@@ -17,7 +17,6 @@ class LigerIntegrationTestCase:
@require_torch_2_4_1
def test_llama_wo_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -53,7 +52,7 @@ class LigerIntegrationTestCase:
"save_first_step": False,
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)
@@ -64,7 +63,6 @@ class LigerIntegrationTestCase:
@require_torch_2_4_1
def test_llama_w_flce(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -100,7 +98,7 @@ class LigerIntegrationTestCase:
"save_first_step": False,
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
prepare_plugins(cfg)
normalize_config(cfg)

View File

@@ -85,6 +85,6 @@ def test_geglu_inplace_preservation():
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"
assert not torch.equal(grad_output, grad_copy), (
"Grad output should be modified in-place"
)

View File

@@ -1,7 +1,5 @@
"""Tests for LoRA custom autograd."""
# pylint: disable=invalid-name,redefined-outer-name
import pytest
import torch
from bitsandbytes.functional import QuantState
@@ -333,7 +331,7 @@ def test_lora_qkv(sample_tensors):
X.requires_grad = True
# Test without LoRA adapters
# pylint: disable=duplicate-code
Q1, K1, V1 = LoRA_QKV.apply(
X,
q_weight,

View File

@@ -1,7 +1,5 @@
"""Tests for quantization utility functions."""
# pylint: disable=invalid-name
import torch
from bitsandbytes.functional import QuantState

View File

@@ -1,7 +1,5 @@
"""Tests for SwiGLU activation function Triton kernels."""
# pylint: disable=duplicate-code
import torch
import torch.nn.functional as F
@@ -74,6 +72,6 @@ def test_swiglu_inplace_preservation():
assert not torch.equal(gate, gate_copy), "Gate should be modified in-place"
assert not torch.equal(up, up_copy), "Up should be modified in-place"
assert not torch.equal(
grad_output, grad_copy
), "Grad output should be modified in-place"
assert not torch.equal(grad_output, grad_copy), (
"Grad output should be modified in-place"
)

View File

@@ -31,7 +31,6 @@ class TestPackedFlex:
@require_torch_2_6_0
def test_loss_llama(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -80,7 +80,7 @@ def start_vllm(
cmd_env = env.copy()
cmd_env.update({"VLLM_LOGGING_CONFIG_PATH": vllm_logging_json})
# start `trl vllm-serve` command in the background and capture the process id
process = subprocess.Popen( # pylint: disable=consider-using-with
process = subprocess.Popen(
cmd,
env=cmd_env,
stdout=subprocess.DEVNULL if quiet else subprocess.PIPE,

View File

@@ -21,7 +21,6 @@ class TestMultiGPUEval:
"""
def test_eval_sample_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -93,7 +92,6 @@ class TestMultiGPUEval:
check_tensorboard(temp_dir + "/runs", "eval/loss", 2.5, "Eval Loss is too high")
def test_eval(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -1,7 +1,5 @@
"""Test module for FP8 mixed precision with FSDP2 multi-GPU functionality."""
# pylint: disable=duplicate-code
import os
from pathlib import Path
@@ -28,9 +26,9 @@ def verify_fp8_training_success(temp_dir):
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert (
len(checkpoint_files) > 0
), "No checkpoint files found - training may have failed"
assert len(checkpoint_files) > 0, (
"No checkpoint files found - training may have failed"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
@@ -42,9 +40,9 @@ def verify_fp8_training_success(temp_dir):
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(
torch.tensor(final_loss)
), f"Training loss is NaN: {final_loss}"
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
class TestFP8FSDP2:

View File

@@ -1,7 +1,5 @@
"""Test module for FSDP1 multi-GPU functionality."""
# pylint: disable=duplicate-code
import os
from pathlib import Path
@@ -29,9 +27,9 @@ def verify_training_success(temp_dir):
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert (
len(checkpoint_files) > 0
), "No checkpoint files found - training may have failed"
assert len(checkpoint_files) > 0, (
"No checkpoint files found - training may have failed"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
@@ -43,9 +41,9 @@ def verify_training_success(temp_dir):
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(
torch.tensor(final_loss)
), f"Training loss is NaN: {final_loss}"
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
class TestFSDP1:

View File

@@ -1,7 +1,5 @@
"""Test module for FSDP2 multi-GPU functionality."""
# pylint: disable=duplicate-code
import os
from pathlib import Path
@@ -29,9 +27,9 @@ def verify_training_success(temp_dir):
assert len(model_files) > 0, "No model files found - training may have failed"
checkpoint_files = list(output_path.glob("checkpoint-*"))
assert (
len(checkpoint_files) > 0
), "No checkpoint files found - training may have failed"
assert len(checkpoint_files) > 0, (
"No checkpoint files found - training may have failed"
)
tb_log_path = most_recent_subdir(temp_dir + "/runs")
if tb_log_path:
@@ -43,9 +41,9 @@ def verify_training_success(temp_dir):
train_loss_df = df[df.tag == "train/train_loss"]
if len(train_loss_df) > 0:
final_loss = train_loss_df.value.values[-1]
assert not torch.isnan(
torch.tensor(final_loss)
), f"Training loss is NaN: {final_loss}"
assert not torch.isnan(torch.tensor(final_loss)), (
f"Training loss is NaN: {final_loss}"
)
class TestFSDP2:

View File

@@ -29,7 +29,6 @@ class TestMultiGPUGemma3:
"""
def test_lora_ddp_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-mirrors/gemma-3-4b-pt",

View File

@@ -35,7 +35,6 @@ class TestMultiGPULlama:
"""
def test_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -99,7 +98,6 @@ class TestMultiGPULlama:
[1, 2],
)
def test_lora_ddp_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -162,7 +160,6 @@ class TestMultiGPULlama:
)
def test_dpo_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -242,7 +239,6 @@ class TestMultiGPULlama:
)
def test_dpo_qlora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -326,7 +322,6 @@ class TestMultiGPULlama:
[1, 2],
)
def test_fsdp(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -402,7 +397,6 @@ class TestMultiGPULlama:
],
)
def test_fsdp_packed(self, temp_dir, fsdp_state_dict_type):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -484,7 +478,6 @@ class TestMultiGPULlama:
def test_fsdp2_packed(
self, temp_dir, attention_backend, fsdp_reshard_after_forward
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -556,7 +549,6 @@ class TestMultiGPULlama:
)
def test_fsdp_qlora_prequant_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/SmolLM2-135M-bnb-nf4-bf16",
@@ -656,7 +648,6 @@ class TestMultiGPULlama:
def test_ds_zero3_packed(
self, temp_dir, gradient_accumulation_steps, deepspeed, qlora
):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
@@ -732,7 +723,6 @@ class TestMultiGPULlama:
[True, False],
)
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
@@ -809,7 +799,6 @@ class TestMultiGPULlama:
[True, False],
)
def test_ds_zero1_packed(self, temp_dir, gradient_accumulation_steps, qlora):
# pylint: disable=duplicate-code
if qlora:
adapter = {
"adapter": "qlora",
@@ -880,7 +869,6 @@ class TestMultiGPULlama:
reason="fix untrained tokens brittle with lots of edge cases in latest transformers"
)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -26,7 +26,6 @@ class TestMultiGPURay:
@require_torch_lt_2_6_0
def test_lora_ddp(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -90,7 +89,6 @@ class TestMultiGPURay:
[1, 2],
)
def test_ds_zero2_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -150,7 +148,6 @@ class TestMultiGPURay:
[1, 2],
)
def test_sft_fsdp2_packed(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -19,7 +19,6 @@ class TestTensorParallel:
)
@require_torch_2_7_0
def test_fft_sft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",

View File

@@ -1,7 +1,5 @@
"""Integration tests for LoRA activation and attention kernels."""
# pylint: disable=redefined-outer-name
from pathlib import Path
import pytest
@@ -88,7 +86,7 @@ def test_attention_patching_integration(model_name, attention_cls):
cfg = DictDefault({"base_model": model_name})
# Store the original implementation
original_forward = getattr(attention_cls, "forward")
original_forward = attention_cls.forward
# Apply patch
patch_self_attn_lora(cfg)
@@ -104,7 +102,7 @@ def test_attention_patching_integration(model_name, attention_cls):
assert hasattr(attention_cls, "_original_forward")
# Clean up
setattr(attention_cls, "forward", original_forward)
attention_cls.forward = original_forward
delattr(attention_cls, "_original_forward")
@@ -379,9 +377,9 @@ def test_model_architecture(model_config):
# Verify correct activation function
layer = patched_model.model.model.layers[0]
assert (
layer.mlp.forward.__func__ is model_config["expected_activation"]
), f"Wrong activation for {model_config['name']}"
assert layer.mlp.forward.__func__ is model_config["expected_activation"], (
f"Wrong activation for {model_config['name']}"
)
# Test forward pass
inputs = get_test_inputs(model)
@@ -390,12 +388,11 @@ def test_model_architecture(model_config):
patched_output = patched_model(inputs).logits
# Check outputs match
assert torch.allclose(
original_output, patched_output, rtol=1e-4
), f"Outputs don't match for {model_config['name']}"
assert torch.allclose(original_output, patched_output, rtol=1e-4), (
f"Outputs don't match for {model_config['name']}"
)
# pylint: disable=duplicate-code
def test_kernel_training_integration(temp_dir):
"""Test model loading with kernel patches enabled."""
from axolotl.cli.utils import load_model_and_tokenizer
@@ -563,15 +560,13 @@ def test_kernel_training_integration_dropout_non_zero(temp_dir):
model_loader = ModelLoader(cfg, tokenizer)
# Apply patch
model_loader.patch_manager._apply_self_attention_lora_patch() # pylint: disable=protected-access
model_loader.patch_manager._apply_self_attention_lora_patch()
# Verify patch was not applied
assert attention_cls.forward == original_forward_method
# Apply apply_lora_kernel_patches
model_loader.patch_manager._apply_lora_kernel_patch( # pylint: disable=protected-access
model
)
model_loader.patch_manager._apply_lora_kernel_patch(model)
# Verify patch was not applied
layers = get_layers(model)

View File

@@ -19,7 +19,6 @@ class Test4dMultipackLlama(unittest.TestCase):
@with_temp_dir
def test_sdp_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -67,7 +66,6 @@ class Test4dMultipackLlama(unittest.TestCase):
@with_temp_dir
def test_torch_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -32,10 +32,9 @@ class TestActivationCheckpointing:
def test_activation_checkpointing_offload(
self,
temp_dir,
fix_checkpoint_after_test, # pylint: disable=unused-argument,redefined-outer-name
fix_checkpoint_after_test,
gradient_checkpointing,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -10,7 +10,6 @@ from axolotl.cli.config import load_cfg
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
class TestPluginArgs:
"""
test class for plugin args loaded from the config file

View File

@@ -23,7 +23,6 @@ class TestFAXentropyLlama:
[1, 4],
)
def test_lora_packing_fa_cross_entropy(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -22,7 +22,6 @@ class TestFalconPatched(unittest.TestCase):
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
@@ -71,7 +70,6 @@ class TestFalconPatched(unittest.TestCase):
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",

View File

@@ -23,7 +23,6 @@ class TestFAFlattening:
[1, 4],
)
def test_lora_packing_flattening(self, temp_dir, gradient_accumulation_steps):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -15,7 +15,6 @@ class TestFSDPPatchIntegration:
apply_init_unsharded_param_patch,
)
# pylint: disable=protected-access
original_init_sharded = FSDPParam._init_sharded_param
original_init_unsharded = FSDPParam.init_unsharded_param
@@ -23,11 +22,9 @@ class TestFSDPPatchIntegration:
apply_init_sharded_param_patch()
apply_init_unsharded_param_patch()
assert (
# pylint: disable=protected-access
FSDPParam._init_sharded_param
!= original_init_sharded
), "_init_sharded_param was not patched"
assert (
FSDPParam.init_unsharded_param != original_init_unsharded
), "init_unsharded_param was not patched"
assert FSDPParam._init_sharded_param != original_init_sharded, (
"_init_sharded_param was not patched"
)
assert FSDPParam.init_unsharded_param != original_init_unsharded, (
"init_unsharded_param was not patched"
)

View File

@@ -23,7 +23,6 @@ class TestFusedLlama(unittest.TestCase):
@with_temp_dir
def test_fft_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -22,7 +22,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
@with_temp_dir
def test_lora_s2_attn(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -71,7 +70,6 @@ class TestLlamaShiftedSparseAttention(unittest.TestCase):
@with_temp_dir
def test_fft_s2_attn(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -22,7 +22,6 @@ class TestLoraLlama(unittest.TestCase):
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -73,7 +72,6 @@ class TestLoraLlama(unittest.TestCase):
@pytest.mark.skipif(not is_auto_gptq_available(), reason="auto-gptq not available")
@with_temp_dir
def test_lora_gptq_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "lilmeaty/SmolLM2-135M-Instruct-GPTQ",

View File

@@ -20,7 +20,6 @@ class TestMistral(unittest.TestCase):
@require_torch_2_6_0
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
@@ -68,7 +67,6 @@ class TestMistral(unittest.TestCase):
@with_temp_dir
def test_ft_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",

View File

@@ -19,7 +19,6 @@ class TestMixtral(unittest.TestCase):
@with_temp_dir
def test_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
@@ -64,7 +63,6 @@ class TestMixtral(unittest.TestCase):
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",

View File

@@ -89,5 +89,5 @@ class TestModelPatches(unittest.TestCase):
assert (
"torch.jit"
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__ # pylint: disable=protected-access
in transformers.modeling_flash_attention_utils._get_unpad_data.__module__
)

View File

@@ -15,7 +15,6 @@ class TestLlamaPeftEmbeddings:
"""
def test_peft_embeddings_upcast(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -19,7 +19,6 @@ class TestPhiMultipack(unittest.TestCase):
@with_temp_dir
def test_ft_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
@@ -67,7 +66,6 @@ class TestPhiMultipack(unittest.TestCase):
@with_temp_dir
def test_qlora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",

View File

@@ -22,7 +22,6 @@ class TestResumeLlama:
@require_torch_2_6_0
def test_resume_lora_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -12,7 +12,6 @@ from axolotl.utils.dict import DictDefault
from ..utils import check_model_output_exists, check_tensorboard
# pylint: disable=duplicate-code
@pytest.mark.skip(
reason="Unsloth integration will be broken going into latest transformers"
)

View File

@@ -22,7 +22,6 @@ class TestPackedFlex(unittest.TestCase):
@require_torch_2_6_0
@with_temp_dir
def test_loss_llama(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -20,7 +20,6 @@ class TestReLoraLlama(unittest.TestCase):
@with_temp_dir
def test_relora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -76,9 +75,9 @@ class TestReLoraLlama(unittest.TestCase):
train(cfg=cfg, dataset_meta=dataset_meta)
check_model_output_exists(Path(temp_dir) / "checkpoint-100/adapter", cfg)
assert (
Path(temp_dir) / "checkpoint-100/relora/model.safetensors"
).exists(), "Relora model checkpoint not found"
assert (Path(temp_dir) / "checkpoint-100/relora/model.safetensors").exists(), (
"Relora model checkpoint not found"
)
check_tensorboard(
temp_dir + "/runs", "train/grad_norm", 0.2, "grad_norm is too high"

View File

@@ -11,8 +11,6 @@ from axolotl.utils.dict import DictDefault
from .utils import check_model_output_exists
# pylint: disable=duplicate-code
class TestActivationOffloading:
"""
@@ -28,7 +26,6 @@ class TestActivationOffloading:
temp_dir,
adapter,
):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -25,7 +25,6 @@ class TestDeepseekV3:
[True, False],
)
def test_lora_deepseekv3(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",
@@ -83,7 +82,6 @@ class TestDeepseekV3:
[True, False],
)
def test_fft_deepseekv3(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/DeepSeek-V3-11M",

View File

@@ -21,7 +21,6 @@ class TestDPOLlamaLora(unittest.TestCase):
@with_temp_dir
def test_dpo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -70,7 +69,6 @@ class TestDPOLlamaLora(unittest.TestCase):
@with_temp_dir
def test_dpo_nll_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -120,7 +118,6 @@ class TestDPOLlamaLora(unittest.TestCase):
@with_temp_dir
def test_dpo_use_weighting(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -171,7 +168,6 @@ class TestDPOLlamaLora(unittest.TestCase):
@pytest.mark.skip("kto_pair no longer supported in trl")
@with_temp_dir
def test_kto_pair_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -220,7 +216,6 @@ class TestDPOLlamaLora(unittest.TestCase):
@with_temp_dir
def test_ipo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -269,7 +264,6 @@ class TestDPOLlamaLora(unittest.TestCase):
@with_temp_dir
def test_orpo_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -322,7 +316,6 @@ class TestDPOLlamaLora(unittest.TestCase):
@pytest.mark.skip(reason="Fix the implementation")
@with_temp_dir
def test_kto_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -19,7 +19,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
@with_temp_dir
def test_train_w_embedding_lr_scale(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -65,7 +64,6 @@ class TestEmbeddingsLrScale(unittest.TestCase):
@with_temp_dir
def test_train_w_embedding_lr(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -13,7 +13,6 @@ class TestE2eEvaluate:
"""Test cases for evaluate CLI"""
def test_evaluate(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -22,7 +22,6 @@ class TestFalcon(unittest.TestCase):
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
@@ -74,7 +73,6 @@ class TestFalcon(unittest.TestCase):
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_lora_added_vocab(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",
@@ -130,7 +128,6 @@ class TestFalcon(unittest.TestCase):
@pytest.mark.skip(reason="no tiny models for testing with safetensors")
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "illuin/tiny-random-FalconForCausalLM",

View File

@@ -22,7 +22,6 @@ class TestGemma2:
[True, False],
)
def test_lora_gemma2(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-2-33M",
@@ -78,7 +77,6 @@ class TestGemma2:
[True, False],
)
def test_fft_gemma2(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-2-33M",

View File

@@ -22,7 +22,6 @@ class TestGemma3Text:
[True, False],
)
def test_lora_gemma3_text(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-3-34M",
@@ -78,7 +77,6 @@ class TestGemma3Text:
[True, False],
)
def test_fft_gemma3_text(self, temp_dir, sample_packing):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/gemma-3-34M",

View File

@@ -11,11 +11,7 @@ class TestImports(unittest.TestCase):
"""
def test_import_causal_trainer(self):
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
HFCausalTrainerBuilder,
)
pass
def test_import_rl_trainer(self):
from axolotl.core.builders import ( # pylint: disable=unused-import # noqa: F401
HFRLTrainerBuilder,
)
pass

View File

@@ -16,7 +16,6 @@ class TestLlama:
"""
def test_fft_trust_remote_code(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -57,7 +56,6 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -105,7 +103,6 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg)
def test_fix_untrained_tokens_already_trained(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -150,7 +147,6 @@ class TestLlama:
check_model_output_exists(temp_dir, cfg)
def test_batch_flattening(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -22,7 +22,6 @@ class TestPretrainLlama:
],
)
def test_pretrain(self, temp_dir, sample_packing, pretrain_multipack_attn):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -19,7 +19,6 @@ class TestLlamaVision(unittest.TestCase):
@with_temp_dir
def test_lora_llama_vision_text_only_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",
@@ -67,7 +66,6 @@ class TestLlamaVision(unittest.TestCase):
@with_temp_dir
def test_lora_llama_vision_multimodal_dataset(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "axolotl-ai-co/Llama-3.2-39M-Vision",

View File

@@ -56,13 +56,11 @@ class TestLoadModelUtils:
"context_parallel_size": 1,
}
)
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer="",
inference=False,
reference_model=True,
)
self.model_loader = ModelLoader(
cfg=self.cfg,
tokenizer="",
inference=False,
reference_model=True,
)
@pytest.mark.parametrize("embedding_modules", ["embed_tokens", "lm_head"])
@@ -74,7 +72,7 @@ class TestLoadModelUtils:
self, temp_dir, embedding_modules, dist_dtype, before_kbit_train_or_finetune
):
self.cfg.output_dir = temp_dir
self.model_loader.tokenizer = load_tokenizer(self.cfg) # pylint: disable=all
self.model_loader.tokenizer = load_tokenizer(self.cfg)
self.model_loader.load()
self.model_loader._convert_embedding_modules_dtype(
embedding_modules, dist_dtype, before_kbit_train_or_finetune

View File

@@ -19,7 +19,6 @@ class TestLoraLlama(unittest.TestCase):
@with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -22,7 +22,6 @@ class TestMamba(unittest.TestCase):
@with_temp_dir
def test_fft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "state-spaces/mamba-130m",

View File

@@ -21,7 +21,6 @@ class TestMistral(unittest.TestCase):
@with_temp_dir
def test_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",
@@ -68,7 +67,6 @@ class TestMistral(unittest.TestCase):
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "trl-internal-testing/tiny-MistralForCausalLM-0.2",

View File

@@ -22,7 +22,6 @@ class TestMixtral(unittest.TestCase):
@with_temp_dir
def test_qlora_w_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
@@ -78,7 +77,6 @@ class TestMixtral(unittest.TestCase):
@with_temp_dir
def test_qlora_wo_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
@@ -134,7 +132,6 @@ class TestMixtral(unittest.TestCase):
@with_temp_dir
def test_16bit_lora_w_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
@@ -193,7 +190,6 @@ class TestMixtral(unittest.TestCase):
@with_temp_dir
def test_16bit_lora_wo_fa2(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",
@@ -252,7 +248,6 @@ class TestMixtral(unittest.TestCase):
@with_temp_dir
def test_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "hf-internal-testing/Mixtral-tiny",

View File

@@ -25,7 +25,6 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
def test_optimi_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -71,7 +70,6 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
@require_torch_2_5_1
def test_adopt_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -117,7 +115,6 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
@require_torch_2_5_1
def test_muon(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -164,7 +161,6 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
@require_torch_2_7_0
def test_dion(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -206,7 +202,6 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
def test_fft_schedule_free_adamw(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -234,7 +229,6 @@ class TestCustomOptimizers(unittest.TestCase):
"save_first_step": False,
}
)
# pylint: disable=duplicate-code
cfg = validate_config(cfg)
normalize_config(cfg)
@@ -246,7 +240,6 @@ class TestCustomOptimizers(unittest.TestCase):
@with_temp_dir
@require_torch_2_6_0
def test_came_pytorch(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "JackFram/llama-68m",

View File

@@ -21,7 +21,6 @@ class TestPackedLlama(unittest.TestCase):
@with_temp_dir
def test_loss_packed(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -19,7 +19,6 @@ class TestPhi(unittest.TestCase):
@with_temp_dir
def test_phi_ft(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",
@@ -65,7 +64,6 @@ class TestPhi(unittest.TestCase):
@with_temp_dir
def test_phi_qlora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "microsoft/phi-1_5",

View File

@@ -15,7 +15,7 @@ class TestPreprocess:
def test_w_deepspeed(self, temp_dir):
"""make sure preproces doesn't choke when using deepspeed in the config"""
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "Qwen/Qwen2.5-0.5B",

View File

@@ -19,7 +19,6 @@ class TestProcessRewardSmolLM2(unittest.TestCase):
@with_temp_dir
def test_prm(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -18,7 +18,6 @@ class TestQATLlama:
"""
def test_qat(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -68,7 +67,6 @@ class TestQATLlama:
check_model_output_exists(Path(temp_dir) / "checkpoint-5", cfg)
def test_qat_dpo(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -131,7 +131,7 @@ class TestQuantization:
@require_torch_2_6_0
def test_prepare_model_for_qat(
self, model, weight_dtype, activation_dtype, group_size, quantize_embedding
): # pylint: disable=redefined-outer-name
):
prepare_model_for_qat(
model, weight_dtype, group_size, activation_dtype, quantize_embedding
)
@@ -175,7 +175,7 @@ class TestQuantization:
group_size,
quantize_embedding,
expected_exception,
): # pylint: disable=redefined-outer-name
):
if expected_exception:
with pytest.raises(expected_exception):
quantize_model_for_ptq(
@@ -198,11 +198,13 @@ class TestQuantization:
if activation_dtype:
assert isinstance(
child.weight, LinearActivationQuantizedTensor
), "Linear weight should be quantized with activation quantization"
), (
"Linear weight should be quantized with activation quantization"
)
else:
assert isinstance(
child.weight, AffineQuantizedTensor
), "Linear weight should be quantized without activation quantization"
assert isinstance(child.weight, AffineQuantizedTensor), (
"Linear weight should be quantized without activation quantization"
)
class TestQuantizationCallback:
@@ -217,9 +219,7 @@ class TestQuantizationCallback:
)
@require_torch_2_6_0
def test_qat_callback_fake_quant_after_n_steps(
self, model, trainer_state
): # pylint: disable=redefined-outer-name
def test_qat_callback_fake_quant_after_n_steps(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
@@ -269,9 +269,7 @@ class TestQuantizationCallback:
assert model.lm_head.weight_fake_quantizer.enabled
@require_torch_2_6_0
def test_qat_callback_fake_quant_after_n_steps_is_none(
self, model, trainer_state
): # pylint: disable=redefined-outer-name
def test_qat_callback_fake_quant_after_n_steps_is_none(self, model, trainer_state):
cfg = QATConfig(
weight_dtype="int8",
activation_dtype="int8",
@@ -314,9 +312,7 @@ class TestConvertQATModelForPTQ:
"""
@require_torch_2_6_0
def test_convert_qat_model_for_ptq(
self, model
): # pylint: disable=redefined-outer-name
def test_convert_qat_model_for_ptq(self, model):
config = QATConfig(
weight_dtype="int8",
activation_dtype="int8",

View File

@@ -19,7 +19,6 @@ class TestE2eQwen:
@pytest.mark.parametrize("base_model", ["Qwen/Qwen2-0.5B", "Qwen/Qwen2.5-0.5B"])
def test_dpo(self, base_model, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": base_model,

View File

@@ -19,7 +19,6 @@ class TestRewardModelLoraSmolLM2(unittest.TestCase):
@with_temp_dir
def test_rm_lora(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -20,7 +20,6 @@ class TestSaveFirstStepCallback(unittest.TestCase):
@with_temp_dir
def test_save_first_step(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -61,7 +60,6 @@ class TestSaveFirstStepCallback(unittest.TestCase):
@with_temp_dir
def test_no_save_first_step(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -19,7 +19,6 @@ class TestCustomSchedulers(unittest.TestCase):
@with_temp_dir
def test_rex_scheduler(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -2,6 +2,7 @@
helper utils for tests
"""
import importlib.util
import os
import shutil
import tempfile
@@ -107,12 +108,7 @@ def require_vllm(test_case):
"""
def is_vllm_installed():
try:
import vllm # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return importlib.util.find_spec("vllm") is not None
return unittest.skipUnless(
is_vllm_installed(), "test requires vllm to be installed"
@@ -125,12 +121,7 @@ def require_llmcompressor(test_case):
"""
def is_llmcompressor_installed():
try:
import llmcompressor # pylint: disable=unused-import # noqa: F401
return True
except ImportError:
return False
return importlib.util.find_spec("llmcompressor") is not None
return unittest.skipUnless(
is_llmcompressor_installed(), "test requires llmcompressor to be installed"
@@ -159,8 +150,8 @@ def check_tensorboard(
tb_log_path = most_recent_subdir(temp_run_dir)
event_file = os.path.join(tb_log_path, sorted(os.listdir(tb_log_path))[0])
reader = SummaryReader(event_file)
df = reader.scalars # pylint: disable=invalid-name
df = df[(df.tag == tag)] # pylint: disable=invalid-name
df = reader.scalars
df = df[(df.tag == tag)]
lt_val = (1 + rtol) * lt_val
if "%s" in assertion_err:
assert df.value.values[-1] < lt_val, assertion_err % df.value.values[-1]

View File

@@ -20,7 +20,7 @@ def reload_modules(hf_hub_offline):
importlib.reload(huggingface_hub.constants)
huggingface_hub.constants.HF_HUB_OFFLINE = hf_hub_offline
importlib.reload(datasets.config)
setattr(datasets.config, "HF_HUB_OFFLINE", hf_hub_offline)
datasets.config.HF_HUB_OFFLINE = hf_hub_offline
reset_sessions()

View File

@@ -10,7 +10,6 @@ from axolotl.utils.config import prepare_plugins, validate_config
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
@pytest.fixture(name="minimal_liger_cfg")
def fixture_cfg():
return DictDefault(
@@ -30,7 +29,6 @@ def fixture_cfg():
)
# pylint: disable=too-many-public-methods
class TestValidation:
"""
Test the validation module for liger

View File

@@ -1,4 +1,3 @@
# pylint: disable=too-many-lines
"""Module for testing the validation module"""
import os
@@ -49,7 +48,6 @@ class BaseValidation:
self._caplog = caplog
# pylint: disable=too-many-public-methods
class TestValidation(BaseValidation):
"""
Test the validation module
@@ -241,7 +239,7 @@ class TestValidation(BaseValidation):
def test_lr_as_float(self, minimal_cfg):
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
DictDefault(
{
"learning_rate": "5e-5",
}
@@ -303,7 +301,7 @@ class TestValidation(BaseValidation):
)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
DictDefault(
{
"load_in_8bit": True,
}
@@ -315,7 +313,7 @@ class TestValidation(BaseValidation):
validate_config(cfg)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
DictDefault(
{
"gptq": True,
}
@@ -327,7 +325,7 @@ class TestValidation(BaseValidation):
validate_config(cfg)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
DictDefault(
{
"load_in_4bit": False,
}
@@ -339,7 +337,7 @@ class TestValidation(BaseValidation):
validate_config(cfg)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
DictDefault(
{
"load_in_4bit": True,
}
@@ -361,7 +359,7 @@ class TestValidation(BaseValidation):
)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
DictDefault(
{
"load_in_8bit": True,
}
@@ -373,7 +371,7 @@ class TestValidation(BaseValidation):
validate_config(cfg)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
DictDefault(
{
"gptq": True,
}
@@ -385,7 +383,7 @@ class TestValidation(BaseValidation):
validate_config(cfg)
cfg = (
DictDefault( # pylint: disable=unsupported-binary-operation
DictDefault(
{
"load_in_4bit": True,
}

View File

@@ -30,7 +30,6 @@ def fixture_assistant_dataset():
@pytest.fixture(name="sharegpt_dataset")
def fixture_sharegpt_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -47,7 +46,6 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="basic_dataset")
def fixture_basic_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -65,7 +63,6 @@ def fixture_basic_dataset():
@pytest.fixture(name="toolcalling_dataset")
def fixture_toolcalling_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -112,7 +109,7 @@ def fixture_toolcalling_dataset():
@enable_hf_offline
def fixture_llama3_tokenizer(
download_llama3_8b_instruct_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
):
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer
@@ -129,7 +126,7 @@ def fixture_smollm2_tokenizer():
@enable_hf_offline
def fixture_mistralv03_tokenizer(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
):
tokenizer = AutoTokenizer.from_pretrained(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
)

View File

@@ -2,7 +2,6 @@
tests for chat_template prompt strategy
"""
# pylint: disable=duplicate-code
import unittest
from axolotl.prompt_strategies.messages.chat import load
@@ -53,9 +52,9 @@ class TestMessagesChatLlama3:
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
if __name__ == "__main__":

View File

@@ -30,7 +30,6 @@ def fixture_alpaca_dataset():
@pytest.fixture(name="tokenizer")
@enable_hf_offline
def fixture_tokenizer():
# pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)

View File

@@ -18,9 +18,7 @@ def fixture_messages_w_tools():
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
""".strip().split(
"\n"
)
""".strip().split("\n")
rows = [json.loads(row) for row in jsons]
return Dataset.from_list(rows)
@@ -28,7 +26,7 @@ def fixture_messages_w_tools():
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer

View File

@@ -67,9 +67,9 @@ class TestAssistantChatTemplateLlama3:
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
def test_llama3(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset")
@@ -109,9 +109,9 @@ class TestAssistantChatTemplateLlama3:
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
def test_phi35(self, phi35_tokenizer, assistant_dataset):
LOG.info("Testing phi-3.5 with assistant dataset")
@@ -161,15 +161,15 @@ class TestAssistantChatTemplateLlama3:
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
LOG.debug(f"Expected labels : {expected_labels}")
LOG.debug(f"Actual labels : {labels}")
assert (
labels == expected_labels
), f"Input IDs mismatch: {labels} != {expected_labels}"
assert labels == expected_labels, (
f"Input IDs mismatch: {labels} != {expected_labels}"
)
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset including training data")
@@ -234,7 +234,7 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
@@ -285,16 +285,16 @@ class TestSharegptChatTemplateLlama3:
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
@@ -345,16 +345,16 @@ class TestSharegptChatTemplateLlama3:
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
@@ -409,12 +409,12 @@ class TestSharegptChatTemplateLlama3:
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
class TestAssistantToolCallingChatTemplateLlama32Vision:
@@ -481,13 +481,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
def test_llama32vision_train_on_tools(
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
@@ -495,7 +495,6 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
LOG.info(
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools"
)
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
@@ -549,13 +548,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
if __name__ == "__main__":

View File

@@ -2,8 +2,6 @@
tests for chat_template prompt strategy
"""
# pylint: disable=too-many-lines
from copy import deepcopy
import pytest
@@ -96,9 +94,9 @@ class TestChatTemplateConfigurations:
and turn.get("from") in ["system", "context"]
and ("mistral" in tokenizer.name_or_path.lower())
):
assert (
start_idx == -1 and end_idx == -1
), "Expected system message to be skipped"
assert start_idx == -1 and end_idx == -1, (
"Expected system message to be skipped"
)
return True
return False
@@ -155,7 +153,9 @@ class TestChatTemplateConfigurations:
assert all(
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
)
LOG.debug("Full labels: %s", labels)
LOG.debug("Full input_ids: %s", input_ids)
@@ -215,11 +215,15 @@ class TestChatTemplateConfigurations:
if is_assistant:
assert all(
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
)
def test_roles_to_train_human_assistant_only(
self,
@@ -276,11 +280,15 @@ class TestChatTemplateConfigurations:
if should_be_labelled:
assert all(
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
)
def test_roles_to_train_all(
self,
@@ -327,13 +335,15 @@ class TestChatTemplateConfigurations:
continue
decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
assert (
response in decoded_response
), f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
assert response in decoded_response, (
f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
)
assert all(
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
)
def test_empty_roles_to_train(
self,
@@ -371,9 +381,9 @@ class TestChatTemplateConfigurations:
# Verify that no labels are set when roles_to_train is empty
LOG.debug("Full labels: %s", labels)
assert all(
label == IGNORE_TOKEN_ID for label in labels
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
assert all(label == IGNORE_TOKEN_ID for label in labels), (
"Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
)
def test_train_on_eos_all(
self,
@@ -417,9 +427,9 @@ class TestChatTemplateConfigurations:
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to be labeled"
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
f"Expected EOS token at index {eos_idx} to be labeled"
)
def test_train_on_eos_turn(
self,
@@ -477,9 +487,9 @@ class TestChatTemplateConfigurations:
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert eos_idx < len(
input_ids
), f"Could not find EOS token after '{response}'"
assert eos_idx < len(input_ids), (
f"Could not find EOS token after '{response}'"
)
LOG.debug(
f"Turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}, eos_idx={eos_idx}"
@@ -492,13 +502,13 @@ class TestChatTemplateConfigurations:
# Verify EOS token labeling based on role
is_assistant = turn["from"] == "assistant"
if is_assistant:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token after assistant response '{response}' to be labeled"
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
f"Expected EOS token after assistant response '{response}' to be labeled"
)
else:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token after non-assistant input '{response}' to not be labeled"
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
f"Expected EOS token after non-assistant input '{response}' to not be labeled"
)
def test_train_on_eos_last(
self,
@@ -545,12 +555,12 @@ class TestChatTemplateConfigurations:
# Check that only the last EOS token is labeled
for idx in eos_indices[:-1]:
assert (
labels[idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {idx} to not be labeled"
assert (
labels[last_eos_idx] != IGNORE_TOKEN_ID
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
assert labels[idx] == IGNORE_TOKEN_ID, (
f"Expected EOS token at index {idx} to not be labeled"
)
assert labels[last_eos_idx] != IGNORE_TOKEN_ID, (
f"Expected last EOS token at index {last_eos_idx} to be labeled"
)
def test_train_on_eos_none(
self,
@@ -594,9 +604,9 @@ class TestChatTemplateConfigurations:
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to not be labeled"
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
f"Expected EOS token at index {eos_idx} to not be labeled"
)
def test_drop_system_message(
self,
@@ -634,9 +644,9 @@ class TestChatTemplateConfigurations:
# Check if system message is not present in input_ids
system_message = "You are an AI assistant."
decoded_message = tokenizer.decode(input_ids)
assert (
system_message not in decoded_message
), "Expected system message to be dropped"
assert system_message not in decoded_message, (
"Expected system message to be dropped"
)
def test_custom_roles(
self,
@@ -711,7 +721,9 @@ class TestChatTemplateConfigurations:
else:
assert all(
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
), (
f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
)
def test_message_field_training(
self,
@@ -776,13 +788,13 @@ class TestChatTemplateConfigurations:
def verify_labels(labels_span, should_train, context_message):
"""Helper to verify if a span of labels matches expected training state"""
if should_train:
assert all(
label != IGNORE_TOKEN_ID for label in labels_span
), f"Expected all labels for {context_message} to be set, but got {labels_span}"
assert all(label != IGNORE_TOKEN_ID for label in labels_span), (
f"Expected all labels for {context_message} to be set, but got {labels_span}"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in labels_span
), f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
assert all(label == IGNORE_TOKEN_ID for label in labels_span), (
f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
)
# Process all turns and verify labeling
for i, turn in enumerate(modified_dataset[0]["messages"]):
@@ -861,9 +873,9 @@ class TestChatTemplateConfigurations:
actual_labels = labels[
start_idx : start_idx + len(token_offsets_masked)
]
assert (
actual_labels == expected_labels
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
assert actual_labels == expected_labels, (
f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
)
# Verify each detail section
for detail in adjusted_train_details:
@@ -958,7 +970,7 @@ class TestChatTemplateConfigurations:
chat_template,
chat_template_jinja,
eos_token,
basic_dataset, # pylint: disable=unused-argument
basic_dataset,
request,
):
"""Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict"""
@@ -1005,7 +1017,7 @@ class TestChatTemplateConfigurations:
chat_template,
chat_template_jinja,
eos_token,
basic_dataset, # pylint: disable=unused-argument
basic_dataset,
request,
):
"""Test that eot_tokens inherits from eos_token when not specified"""
@@ -1032,12 +1044,12 @@ class TestChatTemplateConfigurations:
)
# In backward compatibility mode, eot_tokens should be derived from eos_token
assert strategy.eot_tokens == [
tokenizer.eos_token
], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
assert (
strategy.train_on_eot == "turn"
), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
assert strategy.eot_tokens == [tokenizer.eos_token], (
f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
)
assert strategy.train_on_eot == "turn", (
f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
)
def test_token_not_in_template(
self,
@@ -1091,7 +1103,7 @@ class TestChatTemplateConfigurations:
tokenizer,
chat_template,
chat_template_jinja,
eos_token, # pylint: disable=unused-argument
eos_token,
basic_dataset,
request,
):
@@ -1157,13 +1169,13 @@ class TestChatTemplateConfigurations:
)
if is_after_assistant:
assert (
labels[eot_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
assert labels[eot_idx] != IGNORE_TOKEN_ID, (
f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
)
else:
assert (
labels[eot_idx] == IGNORE_TOKEN_ID
), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
assert labels[eot_idx] == IGNORE_TOKEN_ID, (
f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
)
def test_multiple_train_on_eot_settings(
self,
@@ -1224,9 +1236,9 @@ class TestChatTemplateConfigurations:
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert (
len(eos_indices) > 0
), "Expected at least one EOS/EOT token in the input"
assert len(eos_indices) > 0, (
"Expected at least one EOS/EOT token in the input"
)
# Check labeling for each EOS/EOT token
for idx, eos_idx in enumerate(eos_indices):
@@ -1252,13 +1264,13 @@ class TestChatTemplateConfigurations:
)
if expected_label:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
)
else:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
)
class TestChatTemplateToolCalling:
@@ -1378,29 +1390,27 @@ class TestChatTemplateToolCalling:
decoded_conversation = tokenizer.decode(input_ids)
# Verify tool calling structure is present in the decoded conversation
assert (
'"type": "function",' in decoded_conversation
), "Tool type function should be in conversation"
assert (
'"name": "multiples",' in decoded_conversation
), "Tool function name should be in conversation"
assert '"type": "function",' in decoded_conversation, (
"Tool type function should be in conversation"
)
assert '"name": "multiples",' in decoded_conversation, (
"Tool function name should be in conversation"
)
assert (
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
in decoded_conversation
), "Assistant tool call should be in conversation"
assert (
"<|header_start|>ipython<|header_end|>" in decoded_conversation
), "IPython header should be in conversation"
assert (
'"5,10,15"' in decoded_conversation
), "Tool response should be in conversation"
assert "<|header_start|>ipython<|header_end|>" in decoded_conversation, (
"IPython header should be in conversation"
)
assert '"5,10,15"' in decoded_conversation, (
"Tool response should be in conversation"
)
# Get conversation turns to verify labeling
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
tools = strategy._get_tools( # pylint: disable=protected-access
tool_calling_dataset[0]
)
tools = strategy._get_tools(tool_calling_dataset[0])
# Check that assistant responses are properly labeled
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
@@ -1409,12 +1419,12 @@ class TestChatTemplateToolCalling:
turns=turns, turn_idx=i, tools=tools
)
assert (
start_idx != -1 and end_idx != -1
), f"Assistant turn {i} should be found"
assert start_idx != -1 and end_idx != -1, (
f"Assistant turn {i} should be found"
)
# Verify that assistant responses have proper labels
turn_labels = labels[start_idx:end_idx]
assert all(
label != IGNORE_TOKEN_ID for label in turn_labels
), f"Assistant turn {i} should be unmasked"
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Assistant turn {i} should be unmasked"
)

View File

@@ -28,7 +28,7 @@ def test_mistral_chat_template(
request: pytest.FixtureRequest,
):
"""Test chat template with the Magistral/Devstral tokenizer"""
# pylint: disable=duplicate-code
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)

View File

@@ -59,7 +59,7 @@ def messages_w_reasoning_fixture():
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@@ -71,7 +71,6 @@ class TestSplitThinking:
"""
def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):
# pylint: disable=duplicate-code
strategy = load(
qwen3_tokenizer,
DictDefault(
@@ -130,6 +129,6 @@ class TestSplitThinking:
198, # \n
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)

View File

@@ -16,7 +16,6 @@ from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -49,7 +48,6 @@ def fixture_assistant_dataset():
@pytest.fixture(name="custom_assistant_dataset")
def fixture_custom_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -102,7 +100,6 @@ class TestAssistantDPOChatTemplateLlama3:
"""
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn, _ = default(
DictDefault(
{
@@ -127,7 +124,6 @@ class TestAssistantDPOChatTemplateLlama3:
assert result["rejected"] == "party on<|eot_id|>"
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
# pylint: disable=duplicate-code
transform_fn, _ = default(
DictDefault(
{
@@ -168,7 +164,6 @@ class TestAssistantDPOChatTemplatePhi3:
"""
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn, _ = default(
DictDefault(
{
@@ -198,7 +193,6 @@ class TestAssistantDPOChatTemplateGemma:
"""
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn, _ = default(
DictDefault(
{

View File

@@ -20,7 +20,6 @@ class TestStepWiseSupervisedPromptTokenizingStrategy:
@pytest.fixture()
def stepwise_supervised_dataset(self):
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{

View File

@@ -22,7 +22,7 @@ def chunked_fixtures():
return lm_head, hidden_state, labels, vocab_size
def test_chunked_forward(chunked_fixtures): # pylint: disable=redefined-outer-name
def test_chunked_forward(chunked_fixtures):
lm_head, hidden_state, labels, vocab_size = chunked_fixtures
lm_loss = get_causal_lm_loss()

View File

@@ -374,7 +374,6 @@ class TestDatasetPreparation:
}
)
# pylint: disable=duplicate-code
with patch(
"axolotl.utils.data.rl.load_dataset_with_config"
) as mock_load_dataset:

View File

@@ -21,26 +21,26 @@ class DictDefaultTest(unittest.TestCase):
}
)
assert (
cfg.key_a.key_b == "value_a"
), "DictDefault should return value for existing nested keys"
assert cfg.key_a.key_b == "value_a", (
"DictDefault should return value for existing nested keys"
)
assert (
cfg.key_c == "value_c"
), "DictDefault should return value for existing keys"
assert cfg.key_c == "value_c", (
"DictDefault should return value for existing keys"
)
assert (
cfg.key_d[0] == "value_d"
), "DictDefault should return value for existing keys in list"
assert cfg.key_d[0] == "value_d", (
"DictDefault should return value for existing keys in list"
)
assert (
"value_e" in cfg.key_d
), "DictDefault should support in operator for existing keys in list"
assert "value_e" in cfg.key_d, (
"DictDefault should support in operator for existing keys in list"
)
def test_dict_or_operator(self):
cfg = DictDefault({"key_a": {"key_b": "value_b"}, "key_f": "value_g"})
cfg = cfg | DictDefault( # pylint: disable=unsupported-binary-operation
cfg = cfg | DictDefault(
{
"key_a": {"key_b": "value_a"},
"key_c": "value_c",
@@ -49,9 +49,9 @@ class DictDefaultTest(unittest.TestCase):
}
)
assert (
cfg.key_a.key_b == "value_b"
), "DictDefault should support OR operator for existing nested keys"
assert cfg.key_a.key_b == "value_b", (
"DictDefault should support OR operator for existing nested keys"
)
assert cfg.key_c == "value_c", "DictDefault should not delete existing key"
@@ -60,9 +60,9 @@ class DictDefaultTest(unittest.TestCase):
"value_e",
], "DictDefault should not overwrite existing keys in list"
assert (
cfg.key_f == "value_g"
), "DictDefault should support OR operator for existing key"
assert cfg.key_f == "value_g", (
"DictDefault should support OR operator for existing key"
)
def test_dict_missingkey(self):
cfg = DictDefault({})
@@ -72,9 +72,9 @@ class DictDefaultTest(unittest.TestCase):
def test_dict_or(self):
cfg = DictDefault({}) | DictDefault({})
assert (
cfg.random_key is None
), "DictDefault should return None for missing keys after | operation"
assert cfg.random_key is None, (
"DictDefault should return None for missing keys after | operation"
)
def test_dict_nested_missingparentkey(self):
"""

View File

@@ -41,9 +41,9 @@ def verify_deduplication(actual_dataset, expected_dataset, dataset_name):
assert actual_rows == expected_rows, f"Mismatch in {dataset_name} dataset"
# Verify size consistency
assert len(actual_rows) == len(
actual_dataset
), f"Size mismatch in {dataset_name} dataset after deduplication"
assert len(actual_rows) == len(actual_dataset), (
f"Size mismatch in {dataset_name} dataset after deduplication"
)
class TestDeduplicateIndividualFunctions(unittest.TestCase):
@@ -224,7 +224,6 @@ class TestDeduplicateRLDataset:
):
"""Verify that loading with deduplication removes duplicates."""
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.rl.load_dataset_with_config"
@@ -251,7 +250,6 @@ class TestDeduplicateRLDataset:
dataset_fozziethebeat_alpaca_messages_2k_dpo_test_rev_ea82cff,
tokenizer_huggyllama,
):
# pylint: disable=duplicate-code
with (
patch(
"axolotl.utils.data.rl.load_dataset_with_config"
@@ -271,9 +269,9 @@ class TestDeduplicateRLDataset:
train_dataset, _ = prepare_preference_datasets(cfg, tokenizer)
# Verify that the dataset retains duplicates
assert (
len(train_dataset) == 1800 * 2
), "Dataset deduplication occurred when it should not have"
assert len(train_dataset) == 1800 * 2, (
"Dataset deduplication occurred when it should not have"
)
class TestDeduplicateNonRL(unittest.TestCase):

View File

@@ -17,7 +17,7 @@ class TestModelsUtils:
def setup_method(self) -> None:
# load config
self.cfg = DictDefault( # pylint: disable=attribute-defined-outside-init
self.cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
"model_type": "AutoModelForCausalLM",
@@ -30,20 +30,16 @@ class TestModelsUtils:
"device_map": "auto",
}
)
self.tokenizer = MagicMock( # pylint: disable=attribute-defined-outside-init
spec=PreTrainedTokenizerBase
)
self.inference = False # pylint: disable=attribute-defined-outside-init
self.reference_model = True # pylint: disable=attribute-defined-outside-init
self.tokenizer = MagicMock(spec=PreTrainedTokenizerBase)
self.inference = False
self.reference_model = True
# init ModelLoader
self.model_loader = ( # pylint: disable=attribute-defined-outside-init
ModelLoader(
cfg=self.cfg,
tokenizer=self.tokenizer,
inference=self.inference,
reference_model=self.reference_model,
)
self.model_loader = ModelLoader(
cfg=self.cfg,
tokenizer=self.tokenizer,
inference=self.inference,
reference_model=self.reference_model,
)
def test_set_device_map_config(self):
@@ -51,7 +47,7 @@ class TestModelsUtils:
device_map = self.cfg.device_map
if is_torch_mps_available():
device_map = "mps"
# pylint: disable=protected-access
self.model_loader._set_device_map_config()
if is_deepspeed_zero3_enabled():
assert "device_map" not in self.model_loader.model_kwargs
@@ -78,7 +74,6 @@ class TestModelsUtils:
self.cfg.gptq = gptq
self.cfg.adapter = adapter
# pylint: disable=protected-access
self.model_loader._set_quantization_config()
if "quantization_config" in self.model_loader.model_kwargs or self.cfg.gptq:
assert not (
@@ -194,7 +189,7 @@ class TestModelsUtils:
is_fsdp,
expected,
):
res = _get_parallel_config_kwargs( # pylint: disable=protected-access
res = _get_parallel_config_kwargs(
world_size,
tensor_parallel_size,
context_parallel_size,

View File

@@ -6,7 +6,6 @@ from axolotl.loaders import ModelLoader, load_tokenizer
from axolotl.utils.config import normalize_config, validate_config
from axolotl.utils.dict import DictDefault
# pylint: disable=duplicate-code
minimal_config = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",

View File

@@ -93,7 +93,7 @@ class TestBatchedSamplerPacking:
loader = DataLoader(
train_dataset,
batch_sampler=batch_sampler,
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq( # pylint: disable=unexpected-keyword-arg
collate_fn=V2BatchSamplerDataCollatorForSeq2Seq(
tokenizer=tokenizer,
padding=True,
pad_to_multiple_of=max_seq_length,

View File

@@ -26,7 +26,6 @@ class TestPacking(unittest.TestCase):
@enable_hf_offline
def setUp(self) -> None:
# pylint: disable=duplicate-code
self.tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")
self.tokenizer.add_special_tokens(
{
@@ -75,7 +74,6 @@ class TestPacking(unittest.TestCase):
@with_temp_dir
def test_lora_packing(self, temp_dir):
# pylint: disable=duplicate-code
cfg = DictDefault(
{
"base_model": "HuggingFaceTB/SmolLM2-135M",
@@ -127,9 +125,7 @@ class TestPacking(unittest.TestCase):
_,
) = setup_model_and_trainer(cfg, dataset_meta)
sampler = trainer._get_eval_sampler( # pylint: disable=protected-access
trainer.eval_dataset
)
sampler = trainer._get_eval_sampler(trainer.eval_dataset)
assert "MultipackBatchSampler" in sampler.__class__.__name__
assert (
"V2BatchSamplerDataCollatorForSeq2Seq"
@@ -140,9 +136,7 @@ class TestPacking(unittest.TestCase):
batch = next(dataloader_iter)
assert batch["input_ids"].shape == (1, 8192)
sampler = trainer._get_train_sampler( # pylint: disable=protected-access
trainer.train_dataset
)
sampler = trainer._get_train_sampler(trainer.train_dataset)
assert "MultipackBatchSampler" in sampler.__class__.__name__
assert (
"V2BatchSamplerDataCollatorForSeq2Seq"

View File

@@ -76,7 +76,6 @@ class TestPretrainingPacking:
cfg.pretraining_dataset[0]["type"] or "pretrain",
)
# pylint: disable=duplicate-code
original_bsz = cfg.micro_batch_size
train_dataset = wrap_pretraining_dataset(
dataset,

View File

@@ -1,7 +1,5 @@
"""unit tests for perplexity eval callback"""
# pylint: disable=redefined-outer-name
from pytest import fixture
from transformers.models.auto.modeling_auto import AutoModelForCausalLM
from transformers.models.auto.tokenization_auto import AutoTokenizer

View File

@@ -64,7 +64,7 @@ class TestPromptTokenizationStrategies:
tests the interface between the user and assistant parts
"""
prompter = NoSystemPrompter()
# pylint: disable=duplicate-code
strat = AlpacaPromptTokenizingStrategy(
prompter,
tokenizer_huggyllama_w_special_tokens,
@@ -85,7 +85,7 @@ class TestPromptTokenizationStrategies:
"""
tests the interface between the user and assistant parts
"""
# pylint: disable=duplicate-code
prompter = AlpacaPrompter()
strat = AlpacaPromptTokenizingStrategy(
prompter,
@@ -171,7 +171,7 @@ class Llama2ChatTokenizationTest:
# from transformers.models.llama.tokenization_llama import DEFAULT_SYSTEM_PROMPT
# broken as of 23/7/20
# see https://github.com/huggingface/transformers/pull/24935
# pylint: disable=C0103
DEFAULT_SYSTEM_PROMPT = """\
You are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.
@@ -201,7 +201,7 @@ If a question does not make any sense, or is not factually coherent, explain why
+ user_input[1:-1],
generated_responses=answers,
)
# pylint: disable=W0212
hf_tokens = tokenizer_llama2_7b._build_conversation_input_ids(hf_conf)
assert hf_tokens == tokenized_conversation["input_ids"][: len(hf_tokens)]

Some files were not shown because too many files have changed in this diff Show More