Add ruff, remove black, isort, flake8, pylint (#3092)
* black, isort, flake8 -> ruff * remove unused * add back needed import * fix
This commit is contained in:
@@ -30,7 +30,6 @@ def fixture_assistant_dataset():
|
||||
|
||||
@pytest.fixture(name="sharegpt_dataset")
|
||||
def fixture_sharegpt_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -47,7 +46,6 @@ def fixture_sharegpt_dataset():
|
||||
|
||||
@pytest.fixture(name="basic_dataset")
|
||||
def fixture_basic_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -65,7 +63,6 @@ def fixture_basic_dataset():
|
||||
|
||||
@pytest.fixture(name="toolcalling_dataset")
|
||||
def fixture_toolcalling_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -112,7 +109,7 @@ def fixture_toolcalling_dataset():
|
||||
@enable_hf_offline
|
||||
def fixture_llama3_tokenizer(
|
||||
download_llama3_8b_instruct_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Meta-Llama-3-8B-Instruct")
|
||||
|
||||
return tokenizer
|
||||
@@ -129,7 +126,7 @@ def fixture_smollm2_tokenizer():
|
||||
@enable_hf_offline
|
||||
def fixture_mistralv03_tokenizer(
|
||||
download_mlx_mistral_7b_model_fixture,
|
||||
): # pylint: disable=unused-argument,redefined-outer-name
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"mlx-community/Mistral-7B-Instruct-v0.3-4bit"
|
||||
)
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=duplicate-code
|
||||
import unittest
|
||||
|
||||
from axolotl.prompt_strategies.messages.chat import load
|
||||
@@ -53,9 +52,9 @@ class TestMessagesChatLlama3:
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -30,7 +30,6 @@ def fixture_alpaca_dataset():
|
||||
@pytest.fixture(name="tokenizer")
|
||||
@enable_hf_offline
|
||||
def fixture_tokenizer():
|
||||
# pylint: disable=all
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
"casperhansen/mistral-7b-instruct-v0.1-awq"
|
||||
)
|
||||
|
||||
@@ -18,9 +18,7 @@ def fixture_messages_w_tools():
|
||||
{"messages":[{"role":"user","content":"move to (0, 1)"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"move","arguments":{"x":0,"y":1}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"turn 270 degree"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"turn","arguments":{"theta": 270}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
{"messages":[{"role":"user","content":"jump high"},{"role":"assistant","content":"","tool_calls":[{"function":{"name":"invalid_prompt","arguments":{"message": "jump is not a valid action"}}}]}],"tools":[{"type":"function","function":{"name":"move","description":"Move to a given location measured in meters","parameters":{"type":"object","properties":{"x":{"type":"number","description":"The x coordinate of the location, negative values are to the left, positive values are to the right"},"y":{"type":"number","description":"The y coordinate of the location, negative values are backward, positive values are forward"}},"required":["x","y"]}}},{"type":"function","function":{"name":"turn","description":"Turn the robot to a given direction","parameters":{"type":"object","properties":{"theta":{"type":"integer","description":"The angle to turn to, in degrees, positive values are counter-clockwise, negative values are clockwise"}},"required":["theta"]}}},{"type":"function","function":{"name":"invalid_prompt","description":"call when the user's prompt is invalid","parameters":{"type":"object","properties":{"message":{"type":"string","description":"why the prompt is invalid"}},"required":["message"]}}}],"add_generation_prompt":false}
|
||||
""".strip().split(
|
||||
"\n"
|
||||
)
|
||||
""".strip().split("\n")
|
||||
rows = [json.loads(row) for row in jsons]
|
||||
return Dataset.from_list(rows)
|
||||
|
||||
@@ -28,7 +26,7 @@ def fixture_messages_w_tools():
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
|
||||
@@ -67,9 +67,9 @@ class TestAssistantChatTemplateLlama3:
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
def test_llama3(self, llama3_tokenizer, assistant_dataset):
|
||||
LOG.info("Testing llama-3 with assistant dataset")
|
||||
@@ -109,9 +109,9 @@ class TestAssistantChatTemplateLlama3:
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
def test_phi35(self, phi35_tokenizer, assistant_dataset):
|
||||
LOG.info("Testing phi-3.5 with assistant dataset")
|
||||
@@ -161,15 +161,15 @@ class TestAssistantChatTemplateLlama3:
|
||||
# fmt: on
|
||||
LOG.debug(f"Expected input_ids: {expected_input_ids}")
|
||||
LOG.debug(f"Actual input_ids: {input_ids}")
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
LOG.debug(f"Expected labels : {expected_labels}")
|
||||
LOG.debug(f"Actual labels : {labels}")
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Input IDs mismatch: {labels} != {expected_labels}"
|
||||
assert labels == expected_labels, (
|
||||
f"Input IDs mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
def test_llama3_with_training_data(self, llama3_tokenizer, assistant_dataset):
|
||||
LOG.info("Testing llama-3 with assistant dataset including training data")
|
||||
@@ -234,7 +234,7 @@ class TestSharegptChatTemplateLlama3:
|
||||
|
||||
def test_llama3_assistant(self, llama3_tokenizer, sharegpt_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 assistant prompts")
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
@@ -285,16 +285,16 @@ class TestSharegptChatTemplateLlama3:
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
def test_llama3_human(self, llama3_tokenizer, sharegpt_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 human prompts")
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
@@ -345,16 +345,16 @@ class TestSharegptChatTemplateLlama3:
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
def test_llama3_system_human(self, llama3_tokenizer, basic_dataset):
|
||||
LOG.info("Testing ShareGPT style datasets with llama-3 system/human prompts")
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
llama3_tokenizer,
|
||||
@@ -409,12 +409,12 @@ class TestSharegptChatTemplateLlama3:
|
||||
LOG.debug(f"Expected labels: {expected_labels}")
|
||||
LOG.debug(f"Actual labels: {labels}")
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
|
||||
class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
@@ -481,13 +481,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
def test_llama32vision_train_on_tools(
|
||||
self, llama3_tokenizer, toolcalling_dataset, llama3_2_vision_chat_template_jinja
|
||||
@@ -495,7 +495,6 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
LOG.info(
|
||||
"Testing assistant style datasets with tool_calling with llama-32 chat template, training on tools"
|
||||
)
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
strategy = ChatTemplateStrategy(
|
||||
ChatTemplatePrompter(
|
||||
@@ -549,13 +548,13 @@ class TestAssistantToolCallingChatTemplateLlama32Vision:
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
assert (
|
||||
labels == expected_labels
|
||||
), f"Labels mismatch: {labels} != {expected_labels}"
|
||||
assert labels == expected_labels, (
|
||||
f"Labels mismatch: {labels} != {expected_labels}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -2,8 +2,6 @@
|
||||
tests for chat_template prompt strategy
|
||||
"""
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
@@ -96,9 +94,9 @@ class TestChatTemplateConfigurations:
|
||||
and turn.get("from") in ["system", "context"]
|
||||
and ("mistral" in tokenizer.name_or_path.lower())
|
||||
):
|
||||
assert (
|
||||
start_idx == -1 and end_idx == -1
|
||||
), "Expected system message to be skipped"
|
||||
assert start_idx == -1 and end_idx == -1, (
|
||||
"Expected system message to be skipped"
|
||||
)
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -155,7 +153,9 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for input '{response}' to be ignored, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
LOG.debug("Full labels: %s", labels)
|
||||
LOG.debug("Full input_ids: %s", input_ids)
|
||||
@@ -215,11 +215,15 @@ class TestChatTemplateConfigurations:
|
||||
if is_assistant:
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
def test_roles_to_train_human_assistant_only(
|
||||
self,
|
||||
@@ -276,11 +280,15 @@ class TestChatTemplateConfigurations:
|
||||
if should_be_labelled:
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for assistant response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for human input '{response}' to be IGNORE_TOKEN_ID, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
def test_roles_to_train_all(
|
||||
self,
|
||||
@@ -327,13 +335,15 @@ class TestChatTemplateConfigurations:
|
||||
continue
|
||||
|
||||
decoded_response = tokenizer.decode(input_ids[start_idx:end_idx])
|
||||
assert (
|
||||
response in decoded_response
|
||||
), f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
|
||||
assert response in decoded_response, (
|
||||
f"Response {response} not found in index {start_idx}:{end_idx} decoded:{decoded_response}"
|
||||
)
|
||||
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
), (
|
||||
f"Expected labels for response '{response}' to be set, but got {labels[start_idx:end_idx]}"
|
||||
)
|
||||
|
||||
def test_empty_roles_to_train(
|
||||
self,
|
||||
@@ -371,9 +381,9 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
# Verify that no labels are set when roles_to_train is empty
|
||||
LOG.debug("Full labels: %s", labels)
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels
|
||||
), "Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
||||
assert all(label == IGNORE_TOKEN_ID for label in labels), (
|
||||
"Expected all labels to be IGNORE_TOKEN_ID when roles_to_train is empty"
|
||||
)
|
||||
|
||||
def test_train_on_eos_all(
|
||||
self,
|
||||
@@ -417,9 +427,9 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||
for eos_idx in eos_indices:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {eos_idx} to be labeled"
|
||||
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token at index {eos_idx} to be labeled"
|
||||
)
|
||||
|
||||
def test_train_on_eos_turn(
|
||||
self,
|
||||
@@ -477,9 +487,9 @@ class TestChatTemplateConfigurations:
|
||||
while eos_idx < len(input_ids) and input_ids[eos_idx] != eos_token_id:
|
||||
eos_idx += 1
|
||||
|
||||
assert eos_idx < len(
|
||||
input_ids
|
||||
), f"Could not find EOS token after '{response}'"
|
||||
assert eos_idx < len(input_ids), (
|
||||
f"Could not find EOS token after '{response}'"
|
||||
)
|
||||
|
||||
LOG.debug(
|
||||
f"Turn {i}: role={turn['from']}, content='{turn['value']}', start_idx={start_idx}, end_idx={end_idx}, eos_idx={eos_idx}"
|
||||
@@ -492,13 +502,13 @@ class TestChatTemplateConfigurations:
|
||||
# Verify EOS token labeling based on role
|
||||
is_assistant = turn["from"] == "assistant"
|
||||
if is_assistant:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token after assistant response '{response}' to be labeled"
|
||||
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token after assistant response '{response}' to be labeled"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token after non-assistant input '{response}' to not be labeled"
|
||||
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token after non-assistant input '{response}' to not be labeled"
|
||||
)
|
||||
|
||||
def test_train_on_eos_last(
|
||||
self,
|
||||
@@ -545,12 +555,12 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
# Check that only the last EOS token is labeled
|
||||
for idx in eos_indices[:-1]:
|
||||
assert (
|
||||
labels[idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {idx} to not be labeled"
|
||||
assert (
|
||||
labels[last_eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
||||
assert labels[idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token at index {idx} to not be labeled"
|
||||
)
|
||||
assert labels[last_eos_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected last EOS token at index {last_eos_idx} to be labeled"
|
||||
)
|
||||
|
||||
def test_train_on_eos_none(
|
||||
self,
|
||||
@@ -594,9 +604,9 @@ class TestChatTemplateConfigurations:
|
||||
|
||||
assert len(eos_indices) > 0, "Expected at least one EOS token in the input"
|
||||
for eos_idx in eos_indices:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOS token at index {eos_idx} to not be labeled"
|
||||
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOS token at index {eos_idx} to not be labeled"
|
||||
)
|
||||
|
||||
def test_drop_system_message(
|
||||
self,
|
||||
@@ -634,9 +644,9 @@ class TestChatTemplateConfigurations:
|
||||
# Check if system message is not present in input_ids
|
||||
system_message = "You are an AI assistant."
|
||||
decoded_message = tokenizer.decode(input_ids)
|
||||
assert (
|
||||
system_message not in decoded_message
|
||||
), "Expected system message to be dropped"
|
||||
assert system_message not in decoded_message, (
|
||||
"Expected system message to be dropped"
|
||||
)
|
||||
|
||||
def test_custom_roles(
|
||||
self,
|
||||
@@ -711,7 +721,9 @@ class TestChatTemplateConfigurations:
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels[start_idx:end_idx]
|
||||
), f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
|
||||
), (
|
||||
f"Expected labels for non-AI message '{response}' to be IGNORE_TOKEN_ID"
|
||||
)
|
||||
|
||||
def test_message_field_training(
|
||||
self,
|
||||
@@ -776,13 +788,13 @@ class TestChatTemplateConfigurations:
|
||||
def verify_labels(labels_span, should_train, context_message):
|
||||
"""Helper to verify if a span of labels matches expected training state"""
|
||||
if should_train:
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in labels_span
|
||||
), f"Expected all labels for {context_message} to be set, but got {labels_span}"
|
||||
assert all(label != IGNORE_TOKEN_ID for label in labels_span), (
|
||||
f"Expected all labels for {context_message} to be set, but got {labels_span}"
|
||||
)
|
||||
else:
|
||||
assert all(
|
||||
label == IGNORE_TOKEN_ID for label in labels_span
|
||||
), f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
|
||||
assert all(label == IGNORE_TOKEN_ID for label in labels_span), (
|
||||
f"Expected all labels for {context_message} to be {IGNORE_TOKEN_ID}, but got {labels_span}"
|
||||
)
|
||||
|
||||
# Process all turns and verify labeling
|
||||
for i, turn in enumerate(modified_dataset[0]["messages"]):
|
||||
@@ -861,9 +873,9 @@ class TestChatTemplateConfigurations:
|
||||
actual_labels = labels[
|
||||
start_idx : start_idx + len(token_offsets_masked)
|
||||
]
|
||||
assert (
|
||||
actual_labels == expected_labels
|
||||
), f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
||||
assert actual_labels == expected_labels, (
|
||||
f"Labels mismatch for turn: {turn['value']}\nExpected: {expected_labels}\nActual: {actual_labels}"
|
||||
)
|
||||
|
||||
# Verify each detail section
|
||||
for detail in adjusted_train_details:
|
||||
@@ -958,7 +970,7 @@ class TestChatTemplateConfigurations:
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset, # pylint: disable=unused-argument
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test that an error is raised when eot_tokens contains eos_token and train_on_eot/train_on_eos conflict"""
|
||||
@@ -1005,7 +1017,7 @@ class TestChatTemplateConfigurations:
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token,
|
||||
basic_dataset, # pylint: disable=unused-argument
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
"""Test that eot_tokens inherits from eos_token when not specified"""
|
||||
@@ -1032,12 +1044,12 @@ class TestChatTemplateConfigurations:
|
||||
)
|
||||
|
||||
# In backward compatibility mode, eot_tokens should be derived from eos_token
|
||||
assert strategy.eot_tokens == [
|
||||
tokenizer.eos_token
|
||||
], f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
|
||||
assert (
|
||||
strategy.train_on_eot == "turn"
|
||||
), f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
|
||||
assert strategy.eot_tokens == [tokenizer.eos_token], (
|
||||
f"Expected eot_tokens to inherit from eos_token, got {strategy.eot_tokens}"
|
||||
)
|
||||
assert strategy.train_on_eot == "turn", (
|
||||
f"Expected train_on_eot to inherit from train_on_eos, got {strategy.train_on_eot}"
|
||||
)
|
||||
|
||||
def test_token_not_in_template(
|
||||
self,
|
||||
@@ -1091,7 +1103,7 @@ class TestChatTemplateConfigurations:
|
||||
tokenizer,
|
||||
chat_template,
|
||||
chat_template_jinja,
|
||||
eos_token, # pylint: disable=unused-argument
|
||||
eos_token,
|
||||
basic_dataset,
|
||||
request,
|
||||
):
|
||||
@@ -1157,13 +1169,13 @@ class TestChatTemplateConfigurations:
|
||||
)
|
||||
|
||||
if is_after_assistant:
|
||||
assert (
|
||||
labels[eot_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
|
||||
assert labels[eot_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected EOT token after assistant turn at index {eot_idx} to be labeled"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
labels[eot_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
|
||||
assert labels[eot_idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOT token not after assistant turn at index {eot_idx} to not be labeled"
|
||||
)
|
||||
|
||||
def test_multiple_train_on_eot_settings(
|
||||
self,
|
||||
@@ -1224,9 +1236,9 @@ class TestChatTemplateConfigurations:
|
||||
i for i, token_id in enumerate(input_ids) if token_id == eos_token_id
|
||||
]
|
||||
|
||||
assert (
|
||||
len(eos_indices) > 0
|
||||
), "Expected at least one EOS/EOT token in the input"
|
||||
assert len(eos_indices) > 0, (
|
||||
"Expected at least one EOS/EOT token in the input"
|
||||
)
|
||||
|
||||
# Check labeling for each EOS/EOT token
|
||||
for idx, eos_idx in enumerate(eos_indices):
|
||||
@@ -1252,13 +1264,13 @@ class TestChatTemplateConfigurations:
|
||||
)
|
||||
|
||||
if expected_label:
|
||||
assert (
|
||||
labels[eos_idx] == IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
|
||||
assert labels[eos_idx] == IGNORE_TOKEN_ID, (
|
||||
f"Expected EOT token at index {eos_idx} to not be labeled with train_on_eot='{setting}'"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
labels[eos_idx] != IGNORE_TOKEN_ID
|
||||
), f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
assert labels[eos_idx] != IGNORE_TOKEN_ID, (
|
||||
f"Expected EOT token at index {eos_idx} to be labeled with train_on_eot='{setting}'"
|
||||
)
|
||||
|
||||
|
||||
class TestChatTemplateToolCalling:
|
||||
@@ -1378,29 +1390,27 @@ class TestChatTemplateToolCalling:
|
||||
decoded_conversation = tokenizer.decode(input_ids)
|
||||
|
||||
# Verify tool calling structure is present in the decoded conversation
|
||||
assert (
|
||||
'"type": "function",' in decoded_conversation
|
||||
), "Tool type function should be in conversation"
|
||||
assert (
|
||||
'"name": "multiples",' in decoded_conversation
|
||||
), "Tool function name should be in conversation"
|
||||
assert '"type": "function",' in decoded_conversation, (
|
||||
"Tool type function should be in conversation"
|
||||
)
|
||||
assert '"name": "multiples",' in decoded_conversation, (
|
||||
"Tool function name should be in conversation"
|
||||
)
|
||||
|
||||
assert (
|
||||
'<|python_start|><|python_end|>{"name": "multiples", "parameters": {"number": 5, "limit": 20}}<|eot|>'
|
||||
in decoded_conversation
|
||||
), "Assistant tool call should be in conversation"
|
||||
assert (
|
||||
"<|header_start|>ipython<|header_end|>" in decoded_conversation
|
||||
), "IPython header should be in conversation"
|
||||
assert (
|
||||
'"5,10,15"' in decoded_conversation
|
||||
), "Tool response should be in conversation"
|
||||
assert "<|header_start|>ipython<|header_end|>" in decoded_conversation, (
|
||||
"IPython header should be in conversation"
|
||||
)
|
||||
assert '"5,10,15"' in decoded_conversation, (
|
||||
"Tool response should be in conversation"
|
||||
)
|
||||
|
||||
# Get conversation turns to verify labeling
|
||||
turns = strategy.get_conversation_thread(tool_calling_dataset[0])
|
||||
tools = strategy._get_tools( # pylint: disable=protected-access
|
||||
tool_calling_dataset[0]
|
||||
)
|
||||
tools = strategy._get_tools(tool_calling_dataset[0])
|
||||
|
||||
# Check that assistant responses are properly labeled
|
||||
for i, turn in enumerate(tool_calling_dataset[0]["messages"]):
|
||||
@@ -1409,12 +1419,12 @@ class TestChatTemplateToolCalling:
|
||||
turns=turns, turn_idx=i, tools=tools
|
||||
)
|
||||
|
||||
assert (
|
||||
start_idx != -1 and end_idx != -1
|
||||
), f"Assistant turn {i} should be found"
|
||||
assert start_idx != -1 and end_idx != -1, (
|
||||
f"Assistant turn {i} should be found"
|
||||
)
|
||||
|
||||
# Verify that assistant responses have proper labels
|
||||
turn_labels = labels[start_idx:end_idx]
|
||||
assert all(
|
||||
label != IGNORE_TOKEN_ID for label in turn_labels
|
||||
), f"Assistant turn {i} should be unmasked"
|
||||
assert all(label != IGNORE_TOKEN_ID for label in turn_labels), (
|
||||
f"Assistant turn {i} should be unmasked"
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ def test_mistral_chat_template(
|
||||
request: pytest.FixtureRequest,
|
||||
):
|
||||
"""Test chat template with the Magistral/Devstral tokenizer"""
|
||||
# pylint: disable=duplicate-code
|
||||
|
||||
from axolotl.prompt_strategies.chat_template import MistralPrompter, MistralStrategy
|
||||
|
||||
tokenizer: HFMistralTokenizer = request.getfixturevalue(tokenizer_str)
|
||||
|
||||
@@ -59,7 +59,7 @@ def messages_w_reasoning_fixture():
|
||||
@pytest.fixture(name="qwen3_tokenizer")
|
||||
def qwen3_tokenizer_fixture(
|
||||
download_qwen3_half_billion_model,
|
||||
): # pylint: disable=unused-argument
|
||||
):
|
||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-0.6B")
|
||||
|
||||
return tokenizer
|
||||
@@ -71,7 +71,6 @@ class TestSplitThinking:
|
||||
"""
|
||||
|
||||
def test_splits_think(self, messages_w_reasoning, qwen3_tokenizer):
|
||||
# pylint: disable=duplicate-code
|
||||
strategy = load(
|
||||
qwen3_tokenizer,
|
||||
DictDefault(
|
||||
@@ -130,6 +129,6 @@ class TestSplitThinking:
|
||||
198, # \n
|
||||
]
|
||||
# fmt: on
|
||||
assert (
|
||||
input_ids == expected_input_ids
|
||||
), f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
assert input_ids == expected_input_ids, (
|
||||
f"Input IDs mismatch: {input_ids} != {expected_input_ids}"
|
||||
)
|
||||
|
||||
@@ -16,7 +16,6 @@ from tests.hf_offline_utils import enable_hf_offline
|
||||
|
||||
@pytest.fixture(name="assistant_dataset")
|
||||
def fixture_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -49,7 +48,6 @@ def fixture_assistant_dataset():
|
||||
|
||||
@pytest.fixture(name="custom_assistant_dataset")
|
||||
def fixture_custom_assistant_dataset():
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
@@ -102,7 +100,6 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
"""
|
||||
|
||||
def test_llama3_defaults(self, llama3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
@@ -127,7 +124,6 @@ class TestAssistantDPOChatTemplateLlama3:
|
||||
assert result["rejected"] == "party on<|eot_id|>"
|
||||
|
||||
def test_llama3_configured(self, llama3_tokenizer, custom_assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
@@ -168,7 +164,6 @@ class TestAssistantDPOChatTemplatePhi3:
|
||||
"""
|
||||
|
||||
def test_phi3_defaults(self, phi3_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
@@ -198,7 +193,6 @@ class TestAssistantDPOChatTemplateGemma:
|
||||
"""
|
||||
|
||||
def test_gemma_defaults(self, gemma_tokenizer, assistant_dataset):
|
||||
# pylint: disable=duplicate-code
|
||||
transform_fn, _ = default(
|
||||
DictDefault(
|
||||
{
|
||||
|
||||
@@ -20,7 +20,6 @@ class TestStepWiseSupervisedPromptTokenizingStrategy:
|
||||
|
||||
@pytest.fixture()
|
||||
def stepwise_supervised_dataset(self):
|
||||
# pylint: disable=duplicate-code
|
||||
return Dataset.from_list(
|
||||
[
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user