Add ruff, remove black, isort, flake8, pylint (#3092)

* black, isort, flake8 -> ruff

* remove unused

* add back needed import

* fix
This commit is contained in:
Dan Saunders
2025-08-23 23:37:33 -04:00
committed by GitHub
parent eea7a006e1
commit 79ddaebe9a
286 changed files with 10979 additions and 11435 deletions

View File

@@ -30,7 +30,6 @@ def fixture_assistant_dataset():
@pytest.fixture(name="sharegpt_dataset")
def fixture_sharegpt_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -47,7 +46,6 @@ def fixture_sharegpt_dataset():
@pytest.fixture(name="basic_dataset")
def fixture_basic_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -65,7 +63,6 @@ def fixture_basic_dataset():
@pytest.fixture(name="toolcalling_dataset")
def fixture_toolcalling_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -112,7 +109,7 @@ def fixture_toolcalling_dataset():
@enable_hf_offline
def fixture_llama3_tokenizer(
download_llama3_8b_instruct_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
):
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
return tokenizer
@@ -129,7 +126,7 @@ def fixture_smollm2_tokenizer():
@enable_hf_offline
def fixture_mistralv03_tokenizer(
download_mlx_mistral_7b_model_fixture,
): # pylint: disable=unused-argument,redefined-outer-name
):
tokenizer = AutoTokenizer.from_pretrained(
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
)

View File

@@ -2,7 +2,6 @@
tests for chat_template prompt strategy
"""
# pylint: disable=duplicate-code
import unittest
from axolotl.prompt_strategies.messages.chat import load
@@ -53,9 +52,9 @@ class TestMessagesChatLlama3:
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
if __name__ == "__main__":

View File

@@ -30,7 +30,6 @@ def fixture_alpaca_dataset():
@pytest.fixture(name="tokenizer")
@enable_hf_offline
def fixture_tokenizer():
# pylint: disable=all
tokenizer = AutoTokenizer.from_pretrained(
"casperhansen/mistral-7b-instruct-v0.1-awq"
)

View File

@@ -18,9 +18,7 @@ def fixture_messages_w_tools():
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
""".strip().split(
"\n"
)
""".strip().split("\n")
rows = [json.loads(row) for row in jsons]
return Dataset.from_list(rows)
@@ -28,7 +26,7 @@ def fixture_messages_w_tools():
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer

View File

@@ -67,9 +67,9 @@ class TestAssistantChatTemplateLlama3:
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
def test_llama3(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset")
@@ -109,9 +109,9 @@ class TestAssistantChatTemplateLlama3:
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
def test_phi35(self, phi35_tokenizer, assistant_dataset):
LOG.info("Testing phi-3.5 with assistant dataset")
@@ -161,15 +161,15 @@ class TestAssistantChatTemplateLlama3:
# fmt: on
LOG.debug(f"Expected input_ids: {expected_input_ids}")
LOG.debug(f"Actual input_ids: {input_ids}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
LOG.debug(f"Expected labels : {expected_labels}")
LOG.debug(f"Actual labels : {labels}")
assert (
labels == expected_labels
), f"Input IDs mismatch: {labels} != {expected_labels}"
assert labels == expected_labels, (
f"Input IDs mismatch: {labels} != {expected_labels}"
)
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
LOG.info("Testing llama-3 with assistant dataset including training data")
@@ -234,7 +234,7 @@ class TestSharegptChatTemplateLlama3:
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
@@ -285,16 +285,16 @@ class TestSharegptChatTemplateLlama3:
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
@@ -345,16 +345,16 @@ class TestSharegptChatTemplateLlama3:
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
llama3_tokenizer,
@@ -409,12 +409,12 @@ class TestSharegptChatTemplateLlama3:
LOG.debug(f"Expected labels: {expected_labels}")
LOG.debug(f"Actual labels: {labels}")
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
class TestAssistantToolCallingChatTemplateLlama32Vision:
@@ -481,13 +481,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
def test_llama32vision_train_on_tools(
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
@@ -495,7 +495,6 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
LOG.info(
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools"
)
# pylint: disable=duplicate-code
strategy = ChatTemplateStrategy(
ChatTemplatePrompter(
@@ -549,13 +548,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)
assert (
labels == expected_labels
), f"Labels mismatch: {labels} != {expected_labels}"
assert labels == expected_labels, (
f"Labels mismatch: {labels} != {expected_labels}"
)
if __name__ == "__main__":

View File

@@ -2,8 +2,6 @@
tests for chat_template prompt strategy
"""
# pylint: disable=too-many-lines
from copy import deepcopy
import pytest
@@ -96,9 +94,9 @@ class TestChatTemplateConfigurations:
and turn.get("from") in ["system", "context"]
and ("mistral" in tokenizer.name_or_path.lower())
):
assert (
start_idx == -1 and end_idx == -1
), "Expected system message to be skipped"
assert start_idx == -1 and end_idx == -1, (
"Expected system message to be skipped"
)
return True
return False
@@ -155,7 +153,9 @@ class TestChatTemplateConfigurations:
assert all(
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
)
LOG.debug("Full labels: %s", labels)
LOG.debug("Full input_ids: %s", input_ids)
@@ -215,11 +215,15 @@ class TestChatTemplateConfigurations:
if is_assistant:
assert all(
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
)
def test_roles_to_train_human_assistant_only(
self,
@@ -276,11 +280,15 @@ class TestChatTemplateConfigurations:
if should_be_labelled:
assert all(
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
)
def test_roles_to_train_all(
self,
@@ -327,13 +335,15 @@ class TestChatTemplateConfigurations:
continue
decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
assert (
response in decoded_response
), f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
assert response in decoded_response, (
f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
)
assert all(
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
), (
f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
)
def test_empty_roles_to_train(
self,
@@ -371,9 +381,9 @@ class TestChatTemplateConfigurations:
# Verify that no labels are set when roles_to_train is empty
LOG.debug("Full labels: %s", labels)
assert all(
label == IGNORE_TOKEN_ID for label in labels
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
assert all(label == IGNORE_TOKEN_ID for label in labels), (
"Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
)
def test_train_on_eos_all(
self,
@@ -417,9 +427,9 @@ class TestChatTemplateConfigurations:
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to be labeled"
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
f"Expected EOS token at index {eos_idx} to be labeled"
)
def test_train_on_eos_turn(
self,
@@ -477,9 +487,9 @@ class TestChatTemplateConfigurations:
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
eos_idx += 1
assert eos_idx < len(
input_ids
), f"Could not find EOS token after '{response}'"
assert eos_idx < len(input_ids), (
f"Could not find EOS token after '{response}'"
)
LOG.debug(
f"Turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}, eos_idx={eos_idx}"
@@ -492,13 +502,13 @@ class TestChatTemplateConfigurations:
# Verify EOS token labeling based on role
is_assistant = turn["from"] == "assistant"
if is_assistant:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOS token after assistant response '{response}' to be labeled"
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
f"Expected EOS token after assistant response '{response}' to be labeled"
)
else:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token after non-assistant input '{response}' to not be labeled"
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
f"Expected EOS token after non-assistant input '{response}' to not be labeled"
)
def test_train_on_eos_last(
self,
@@ -545,12 +555,12 @@ class TestChatTemplateConfigurations:
# Check that only the last EOS token is labeled
for idx in eos_indices[:-1]:
assert (
labels[idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {idx} to not be labeled"
assert (
labels[last_eos_idx] != IGNORE_TOKEN_ID
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
assert labels[idx] == IGNORE_TOKEN_ID, (
f"Expected EOS token at index {idx} to not be labeled"
)
assert labels[last_eos_idx] != IGNORE_TOKEN_ID, (
f"Expected last EOS token at index {last_eos_idx} to be labeled"
)
def test_train_on_eos_none(
self,
@@ -594,9 +604,9 @@ class TestChatTemplateConfigurations:
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
for eos_idx in eos_indices:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOS token at index {eos_idx} to not be labeled"
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
f"Expected EOS token at index {eos_idx} to not be labeled"
)
def test_drop_system_message(
self,
@@ -634,9 +644,9 @@ class TestChatTemplateConfigurations:
# Check if system message is not present in input_ids
system_message = "You are an AI assistant."
decoded_message = tokenizer.decode(input_ids)
assert (
system_message not in decoded_message
), "Expected system message to be dropped"
assert system_message not in decoded_message, (
"Expected system message to be dropped"
)
def test_custom_roles(
self,
@@ -711,7 +721,9 @@ class TestChatTemplateConfigurations:
else:
assert all(
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
), f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
), (
f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
)
def test_message_field_training(
self,
@@ -776,13 +788,13 @@ class TestChatTemplateConfigurations:
def verify_labels(labels_span, should_train, context_message):
"""Helper to verify if a span of labels matches expected training state"""
if should_train:
assert all(
label != IGNORE_TOKEN_ID for label in labels_span
), f"Expected all labels for {context_message} to be set, but got {labels_span}"
assert all(label != IGNORE_TOKEN_ID for label in labels_span), (
f"Expected all labels for {context_message} to be set, but got {labels_span}"
)
else:
assert all(
label == IGNORE_TOKEN_ID for label in labels_span
), f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
assert all(label == IGNORE_TOKEN_ID for label in labels_span), (
f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
)
# Process all turns and verify labeling
for i, turn in enumerate(modified_dataset[0]["messages"]):
@@ -861,9 +873,9 @@ class TestChatTemplateConfigurations:
actual_labels = labels[
start_idx : start_idx + len(token_offsets_masked)
]
assert (
actual_labels == expected_labels
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
assert actual_labels == expected_labels, (
f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
)
# Verify each detail section
for detail in adjusted_train_details:
@@ -958,7 +970,7 @@ class TestChatTemplateConfigurations:
chat_template,
chat_template_jinja,
eos_token,
basic_dataset, # pylint: disable=unused-argument
basic_dataset,
request,
):
"""Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict"""
@@ -1005,7 +1017,7 @@ class TestChatTemplateConfigurations:
chat_template,
chat_template_jinja,
eos_token,
basic_dataset, # pylint: disable=unused-argument
basic_dataset,
request,
):
"""Test that eot_tokens inherits from eos_token when not specified"""
@@ -1032,12 +1044,12 @@ class TestChatTemplateConfigurations:
)
# In backward compatibility mode, eot_tokens should be derived from eos_token
assert strategy.eot_tokens == [
tokenizer.eos_token
], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
assert (
strategy.train_on_eot == "turn"
), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
assert strategy.eot_tokens == [tokenizer.eos_token], (
f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
)
assert strategy.train_on_eot == "turn", (
f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
)
def test_token_not_in_template(
self,
@@ -1091,7 +1103,7 @@ class TestChatTemplateConfigurations:
tokenizer,
chat_template,
chat_template_jinja,
eos_token, # pylint: disable=unused-argument
eos_token,
basic_dataset,
request,
):
@@ -1157,13 +1169,13 @@ class TestChatTemplateConfigurations:
)
if is_after_assistant:
assert (
labels[eot_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
assert labels[eot_idx] != IGNORE_TOKEN_ID, (
f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
)
else:
assert (
labels[eot_idx] == IGNORE_TOKEN_ID
), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
assert labels[eot_idx] == IGNORE_TOKEN_ID, (
f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
)
def test_multiple_train_on_eot_settings(
self,
@@ -1224,9 +1236,9 @@ class TestChatTemplateConfigurations:
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
]
assert (
len(eos_indices) > 0
), "Expected at least one EOS/EOT token in the input"
assert len(eos_indices) > 0, (
"Expected at least one EOS/EOT token in the input"
)
# Check labeling for each EOS/EOT token
for idx, eos_idx in enumerate(eos_indices):
@@ -1252,13 +1264,13 @@ class TestChatTemplateConfigurations:
)
if expected_label:
assert (
labels[eos_idx] == IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
)
else:
assert (
labels[eos_idx] != IGNORE_TOKEN_ID
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
)
class TestChatTemplateToolCalling:
@@ -1378,29 +1390,27 @@ class TestChatTemplateToolCalling:
decoded_conversation = tokenizer.decode(input_ids)
# Verify tool calling structure is present in the decoded conversation
assert (
'"type": "function",' in decoded_conversation
), "Tool type function should be in conversation"
assert (
'"name": "multiples",' in decoded_conversation
), "Tool function name should be in conversation"
assert '"type": "function",' in decoded_conversation, (
"Tool type function should be in conversation"
)
assert '"name": "multiples",' in decoded_conversation, (
"Tool function name should be in conversation"
)
assert (
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
in decoded_conversation
), "Assistant tool call should be in conversation"
assert (
"<|header_start|>ipython<|header_end|>" in decoded_conversation
), "IPython header should be in conversation"
assert (
'"5,10,15"' in decoded_conversation
), "Tool response should be in conversation"
assert "<|header_start|>ipython<|header_end|>" in decoded_conversation, (
"IPython header should be in conversation"
)
assert '"5,10,15"' in decoded_conversation, (
"Tool response should be in conversation"
)
# Get conversation turns to verify labeling
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
tools = strategy._get_tools( # pylint: disable=protected-access
tool_calling_dataset[0]
)
tools = strategy._get_tools(tool_calling_dataset[0])
# Check that assistant responses are properly labeled
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
@@ -1409,12 +1419,12 @@ class TestChatTemplateToolCalling:
turns=turns, turn_idx=i, tools=tools
)
assert (
start_idx != -1 and end_idx != -1
), f"Assistant turn {i} should be found"
assert start_idx != -1 and end_idx != -1, (
f"Assistant turn {i} should be found"
)
# Verify that assistant responses have proper labels
turn_labels = labels[start_idx:end_idx]
assert all(
label != IGNORE_TOKEN_ID for label in turn_labels
), f"Assistant turn {i} should be unmasked"
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
f"Assistant turn {i} should be unmasked"
)

View File

@@ -28,7 +28,7 @@ def test_mistral_chat_template(
request: pytest.FixtureRequest,
):
"""Test chat template with the Magistral/Devstral tokenizer"""
# pylint: disable=duplicate-code
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)

View File

@@ -59,7 +59,7 @@ def messages_w_reasoning_fixture():
@pytest.fixture(name="qwen3_tokenizer")
def qwen3_tokenizer_fixture(
download_qwen3_half_billion_model,
): # pylint: disable=unused-argument
):
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
return tokenizer
@@ -71,7 +71,6 @@ class TestSplitThinking:
"""
def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):
# pylint: disable=duplicate-code
strategy = load(
qwen3_tokenizer,
DictDefault(
@@ -130,6 +129,6 @@ class TestSplitThinking:
198, # \n
]
# fmt: on
assert (
input_ids == expected_input_ids
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
assert input_ids == expected_input_ids, (
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
)

View File

@@ -16,7 +16,6 @@ from tests.hf_offline_utils import enable_hf_offline
@pytest.fixture(name="assistant_dataset")
def fixture_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -49,7 +48,6 @@ def fixture_assistant_dataset():
@pytest.fixture(name="custom_assistant_dataset")
def fixture_custom_assistant_dataset():
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{
@@ -102,7 +100,6 @@ class TestAssistantDPOChatTemplateLlama3:
"""
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn, _ = default(
DictDefault(
{
@@ -127,7 +124,6 @@ class TestAssistantDPOChatTemplateLlama3:
assert result["rejected"] == "party on<|eot_id|>"
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
# pylint: disable=duplicate-code
transform_fn, _ = default(
DictDefault(
{
@@ -168,7 +164,6 @@ class TestAssistantDPOChatTemplatePhi3:
"""
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn, _ = default(
DictDefault(
{
@@ -198,7 +193,6 @@ class TestAssistantDPOChatTemplateGemma:
"""
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
# pylint: disable=duplicate-code
transform_fn, _ = default(
DictDefault(
{

View File

@@ -20,7 +20,6 @@ class TestStepWiseSupervisedPromptTokenizingStrategy:
@pytest.fixture()
def stepwise_supervised_dataset(self):
# pylint: disable=duplicate-code
return Dataset.from_list(
[
{