From f6623c34cc1171bed125fc70a1422f8c77c48ed9 Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Mon, 12 May 2025 22:53:30 +0200 Subject: [PATCH] Linting fix --- tests/test_trainer_utils.py | 58 +++++++++++++++++-------------------- 1 file changed, 27 insertions(+), 31 deletions(-) diff --git a/tests/test_trainer_utils.py b/tests/test_trainer_utils.py index f912304b9..f03f62f3e 100644 --- a/tests/test_trainer_utils.py +++ b/tests/test_trainer_utils.py @@ -1,9 +1,6 @@ import unittest from functools import partial -import pytest - -# Assuming the function is in axolotl.utils.trainer from axolotl.utils.trainer import truncate_or_drop_long_seq @@ -58,7 +55,7 @@ class TestTruncateOrDropLongSeq(unittest.TestCase): # Let's refine this test - the function *itself* returns the sample if too short when truncating. sample_short = {"input_ids": [1, 2], "labels": [1, 2]} result_short = handler(sample_short) - self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged + self.assertEqual(result_short["input_ids"], [1, 2]) # Unchanged # Too long original_long = list(range(self.sequence_len + 5)) @@ -69,18 +66,19 @@ class TestTruncateOrDropLongSeq(unittest.TestCase): self.assertEqual(len(result_long["labels"]), self.sequence_len) self.assertEqual(result_long["labels"], list(range(self.sequence_len))) - # Just right - sample_ok = {"input_ids": list(range(self.min_sequence_len)), "labels": list(range(self.min_sequence_len))} + sample_ok = { + "input_ids": list(range(self.min_sequence_len)), + "labels": list(range(self.min_sequence_len)), + } result_ok = handler(sample_ok) self.assertEqual(len(result_ok["input_ids"]), self.min_sequence_len) - self.assertEqual(result_ok, sample_ok) # Should be unchanged + self.assertEqual(result_ok, sample_ok) # Should be unchanged # Empty sample_empty = {"input_ids": [], "labels": []} result_empty = handler(sample_empty) - self.assertEqual(result_empty, sample_empty) # Unchanged - + self.assertEqual(result_empty, sample_empty) # Unchanged def test_drop_mode_batched(self): """Test drop mode with batched examples.""" @@ -94,15 +92,14 @@ class TestTruncateOrDropLongSeq(unittest.TestCase): "input_ids": [ [1, 2], # Too short list(range(self.sequence_len + 1)), # Too long - list(range(self.sequence_len)), # OK (len = 10) - list(range(self.min_sequence_len)), # OK (len = 3) - [], # Empty + list(range(self.sequence_len)), # OK (len = 10) + list(range(self.min_sequence_len)), # OK (len = 3) + [], # Empty ] } expected = [False, False, True, True, False] self.assertEqual(handler(sample), expected) - def test_truncate_mode_batched(self): """Test truncate mode with batched examples.""" handler = partial( @@ -113,13 +110,13 @@ class TestTruncateOrDropLongSeq(unittest.TestCase): ) sample = { "input_ids": [ - [1, 2], # Too short - list(range(self.sequence_len + 5)), # Too long - list(range(self.sequence_len)), # OK - list(range(self.min_sequence_len)), # OK - [], # Empty + [1, 2], # Too short + list(range(self.sequence_len + 5)), # Too long + list(range(self.sequence_len)), # OK + list(range(self.min_sequence_len)), # OK + [], # Empty ], - "labels": [ # Add labels to test truncation + "labels": [ # Add labels to test truncation [1, 2], list(range(self.sequence_len + 5)), list(range(self.sequence_len)), @@ -132,24 +129,23 @@ class TestTruncateOrDropLongSeq(unittest.TestCase): # Expected results after truncation (too short and empty remain unchanged by this function) expected_input_ids = [ - [1, 2], # Unchanged (too short) - list(range(self.sequence_len)), # Truncated - list(range(self.sequence_len)), # Unchanged (OK) - list(range(self.min_sequence_len)), # Unchanged (OK) - [], # Unchanged (Empty) + [1, 2], # Unchanged (too short) + list(range(self.sequence_len)), # Truncated + list(range(self.sequence_len)), # Unchanged (OK) + list(range(self.min_sequence_len)), # Unchanged (OK) + [], # Unchanged (Empty) ] expected_labels = [ - [1, 2], # Unchanged (too short) - list(range(self.sequence_len)), # Truncated - list(range(self.sequence_len)), # Unchanged (OK) - list(range(self.min_sequence_len)), # Unchanged (OK) - [], # Unchanged (Empty) + [1, 2], # Unchanged (too short) + list(range(self.sequence_len)), # Truncated + list(range(self.sequence_len)), # Unchanged (OK) + list(range(self.min_sequence_len)), # Unchanged (OK) + [], # Unchanged (Empty) ] - self.assertEqual(result["input_ids"], expected_input_ids) self.assertEqual(result["labels"], expected_labels) if __name__ == "__main__": - unittest.main() \ No newline at end of file + unittest.main()