diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 7d9796345..a0e4d3081 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -68,7 +68,7 @@ You can skip certain CI checks by including specific keywords in your commit mes ### Code Style -axolotl uses [{codestyle}]({URLofCodestyle}) as its code style guide. Please ensure that your code follows these guidelines. +axolotl uses [Ruff](https://docs.astral.sh/ruff/) as its code style guide. Please ensure that your code follows these guidelines. Use the pre-commit linter to ensure that your code is formatted consistently. ```bash @@ -83,6 +83,6 @@ Write clear and concise commit messages that briefly describe the changes made i - [GitHub Help](https://help.github.com/) - [GitHub Pull Request Documentation](https://docs.github.com/en/github/collaborating-with-issues-and-pull-requests) -- [{codestyle}]({URLofCodestyle}) +- [Ruff](https://docs.astral.sh/ruff/) Thank you once again for your interest in contributing to axolotl. We look forward to collaborating with you and creating an even better project together! diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index b6f79c74c..f81ba0b2e 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -300,7 +300,7 @@ def load_cfg( try: device_props = torch.cuda.get_device_properties("cuda") gpu_version = "sm_" + str(device_props.major) + str(device_props.minor) - except: + except (RuntimeError, AssertionError): gpu_version = None prepare_plugins(cfg) diff --git a/src/axolotl/convert.py b/src/axolotl/convert.py index 9e09b37dc..f8d8b25f4 100644 --- a/src/axolotl/convert.py +++ b/src/axolotl/convert.py @@ -67,7 +67,7 @@ class JsonToJsonlConverter: self.json_parser = json_parser self.jsonl_serializer = jsonl_serializer - def convert(self, input_file_path, output_file_path): + def convert(self, input_file_path): content = self.file_reader.read(input_file_path) data = self.json_parser.parse(content) # data = [r for r in data if r["conversations"]] # vicuna cleaned has rows with empty conversations diff --git a/src/axolotl/utils/quantization.py b/src/axolotl/utils/quantization.py index 43af858b1..3a244d6d9 100644 --- a/src/axolotl/utils/quantization.py +++ b/src/axolotl/utils/quantization.py @@ -29,7 +29,7 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"): from torchao.prototype.mx_formats import NVFP4InferenceConfig quantization_config_to_str[NVFP4InferenceConfig] = "nvfp4" - except: + except (ImportError, RuntimeError): pass # int4 weight config imports will fail on machines with fbgemm-gpu installed @@ -38,7 +38,7 @@ if version.parse(torch.__version__) >= version.parse("2.8.0"): from torchao.quantization.quant_api import Int4WeightOnlyConfig quantization_config_to_str[Int4WeightOnlyConfig] = "int4" - except: + except (ImportError, RuntimeError): pass try: diff --git a/tests/test_convert.py b/tests/test_convert.py new file mode 100644 index 000000000..bfe0b603c --- /dev/null +++ b/tests/test_convert.py @@ -0,0 +1,91 @@ +"""Unit tests for src/axolotl/convert.py""" + +import json + +import pytest + +from axolotl.convert import ( + FileReader, + FileWriter, + JsonlSerializer, + JsonParser, + JsonToJsonlConverter, + StdoutWriter, +) + + +class TestJsonParser: + def test_parse_valid_json_array(self): + parser = JsonParser() + result = parser.parse('[{"key": "value"}]') + assert result == [{"key": "value"}] + + def test_parse_valid_json_object(self): + parser = JsonParser() + result = parser.parse('{"key": "value"}') + assert result == {"key": "value"} + + def test_parse_invalid_json_raises(self): + parser = JsonParser() + with pytest.raises(json.JSONDecodeError): + parser.parse("not valid json") + + +class TestJsonlSerializer: + def test_serialize_single_item(self): + serializer = JsonlSerializer() + result = serializer.serialize([{"a": 1}]) + assert result == '{"a": 1}' + + def test_serialize_multiple_items(self): + serializer = JsonlSerializer() + result = serializer.serialize([{"a": 1}, {"b": 2}]) + lines = result.split("\n") + assert len(lines) == 2 + assert json.loads(lines[0]) == {"a": 1} + assert json.loads(lines[1]) == {"b": 2} + + def test_serialize_empty_list(self): + serializer = JsonlSerializer() + result = serializer.serialize([]) + assert result == "" + + +class TestFileReaderWriter: + def test_read_write_roundtrip(self, tmp_path): + test_file = tmp_path / "test.txt" + content = '{"hello": "world"}' + writer = FileWriter(str(test_file)) + writer.write(content) + + reader = FileReader() + result = reader.read(str(test_file)) + assert result == content + + +class TestStdoutWriter: + def test_write_to_stdout(self, capsys): + writer = StdoutWriter() + writer.write("hello") + captured = capsys.readouterr() + assert captured.out == "hello\n" + + +class TestJsonToJsonlConverter: + def test_convert_json_to_jsonl(self, tmp_path): + input_data = [{"name": "Alice"}, {"name": "Bob"}] + input_file = tmp_path / "input.json" + output_file = tmp_path / "output.jsonl" + + input_file.write_text(json.dumps(input_data), encoding="utf-8") + + converter = JsonToJsonlConverter( + FileReader(), FileWriter(str(output_file)), JsonParser(), JsonlSerializer() + ) + converter.convert(str(input_file)) + + result = output_file.read_text(encoding="utf-8") + lines = result.split("\n") + assert len(lines) == 2 + assert json.loads(lines[0]) == {"name": "Alice"} + assert json.loads(lines[1]) == {"name": "Bob"}