From 2b7b37413d833df77b2b1448e0f9ce325876e424 Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Wed, 8 Jan 2025 20:55:10 +0000 Subject: [PATCH] pytest fixes --- src/axolotl/cli/main.py | 17 +++---- src/axolotl/cli/utils.py | 21 +++++++++ tests/cli/conftest.py | 1 + tests/cli/test_cli_fetch.py | 1 + tests/cli/test_cli_inference.py | 1 + tests/cli/test_cli_interface.py | 1 + tests/cli/test_cli_merge_lora.py | 1 + .../test_cli_merge_sharded_fsdp_weights.py | 44 +------------------ tests/cli/test_cli_preprocess.py | 1 + tests/cli/test_cli_version.py | 1 + tests/cli/test_utils.py | 1 + 11 files changed, 37 insertions(+), 53 deletions(-) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index f82d623ce..ac55501a4 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -12,6 +12,7 @@ from axolotl.cli.utils import ( add_options_from_dataclass, build_command, fetch_from_github, + filter_none_kwargs, ) from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.utils import set_pytorch_cuda_alloc_conf @@ -28,6 +29,7 @@ def cli(): @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(PreprocessCliArgs) @add_options_from_config(AxolotlInputConfig) +@filter_none_kwargs def preprocess(config: str, **kwargs) -> None: """ Preprocess datasets before training. @@ -37,8 +39,6 @@ def preprocess(config: str, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - kwargs = {k: v for k, v in kwargs.items() if v is not None} - from axolotl.cli.preprocess import do_cli do_cli(config=config, **kwargs) @@ -53,6 +53,7 @@ def preprocess(config: str, **kwargs) -> None: ) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) +@filter_none_kwargs def train(config: str, accelerate: bool, **kwargs) -> None: """ Train or fine-tune a model. @@ -63,8 +64,6 @@ def train(config: str, accelerate: bool, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - kwargs = {k: v for k, v in kwargs.items() if v is not None} - # Enable expandable segments for cuda allocation to improve VRAM usage set_pytorch_cuda_alloc_conf() @@ -89,6 +88,7 @@ def train(config: str, accelerate: bool, **kwargs) -> None: ) @add_options_from_dataclass(EvaluateCliArgs) @add_options_from_config(AxolotlInputConfig) +@filter_none_kwargs def evaluate(config: str, accelerate: bool, **kwargs) -> None: """ Evaluate a model. @@ -99,8 +99,6 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - kwargs = {k: v for k, v in kwargs.items() if v is not None} - if accelerate: base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] if config: @@ -123,6 +121,7 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None: @click.option("--gradio", is_flag=True, help="Launch Gradio interface") @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) +@filter_none_kwargs def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: """ Run inference with a trained model. @@ -157,6 +156,7 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: ) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) +@filter_none_kwargs def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: """ Merge sharded FSDP model weights. @@ -167,8 +167,6 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - kwargs = {k: v for k, v in kwargs.items() if v is not None} - if accelerate: base_cmd = [ "accelerate", @@ -190,6 +188,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) +@filter_none_kwargs def merge_lora(config: str, **kwargs) -> None: """ Merge trained LoRA adapters into a base model. @@ -200,8 +199,6 @@ def merge_lora(config: str, **kwargs) -> None: kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` config options. """ - kwargs = {k: v for k, v in kwargs.items() if v is not None} - from axolotl.cli.merge_lora import do_cli do_cli(config=config, **kwargs) diff --git a/src/axolotl/cli/utils.py b/src/axolotl/cli/utils.py index ecd943f24..f04304759 100644 --- a/src/axolotl/cli/utils.py +++ b/src/axolotl/cli/utils.py @@ -5,6 +5,7 @@ import dataclasses import hashlib import json import logging +from functools import wraps from pathlib import Path from types import NoneType from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin @@ -16,6 +17,26 @@ from pydantic import BaseModel LOG = logging.getLogger(__name__) +def filter_none_kwargs(func): + """ + Wraps function to remove `None`-valued `kwargs`. + + Args: + func: Function to wrap. + + Returns: + Wrapped function. + """ + + @wraps(func) + def wrapper(*args, **kwargs): + """Filters out `None`-valued `kwargs`.""" + filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None} + return func(*args, **filtered_kwargs) + + return wrapper + + def add_options_from_dataclass(config_class: Type[Any]): """ Create Click options from the fields of a dataclass. diff --git a/tests/cli/conftest.py b/tests/cli/conftest.py index 78b090e19..d360e29d6 100644 --- a/tests/cli/conftest.py +++ b/tests/cli/conftest.py @@ -1,4 +1,5 @@ """Shared pytest fixtures for cli module.""" + import pytest from click.testing import CliRunner diff --git a/tests/cli/test_cli_fetch.py b/tests/cli/test_cli_fetch.py index 0df87b029..f06f06717 100644 --- a/tests/cli/test_cli_fetch.py +++ b/tests/cli/test_cli_fetch.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI fetch command.""" + from unittest.mock import patch from axolotl.cli.main import fetch diff --git a/tests/cli/test_cli_inference.py b/tests/cli/test_cli_inference.py index 7cb163d25..b8effa3d2 100644 --- a/tests/cli/test_cli_inference.py +++ b/tests/cli/test_cli_inference.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI inference command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_interface.py b/tests/cli/test_cli_interface.py index ed8335b76..8b5fec17f 100644 --- a/tests/cli/test_cli_interface.py +++ b/tests/cli/test_cli_interface.py @@ -1,4 +1,5 @@ """General pytest tests for axolotl.cli.main interface.""" + from axolotl.cli.main import build_command, cli diff --git a/tests/cli/test_cli_merge_lora.py b/tests/cli/test_cli_merge_lora.py index 165a64e98..aac016760 100644 --- a/tests/cli/test_cli_merge_lora.py +++ b/tests/cli/test_cli_merge_lora.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI merge_lora command.""" + from unittest.mock import patch from axolotl.cli.main import cli diff --git a/tests/cli/test_cli_merge_sharded_fsdp_weights.py b/tests/cli/test_cli_merge_sharded_fsdp_weights.py index cff0f3b77..18589a80d 100644 --- a/tests/cli/test_cli_merge_sharded_fsdp_weights.py +++ b/tests/cli/test_cli_merge_sharded_fsdp_weights.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" # pylint: disable=duplicate-code + from unittest.mock import patch from axolotl.cli.main import cli @@ -15,46 +16,3 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path): assert mock.called assert mock.call_args.kwargs["config"] == str(config_path) assert result.exit_code == 0 - - -def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path): - """Test merge_sharded_fsdp_weights command with model_dir option""" - model_dir = tmp_path / "model" - model_dir.mkdir() - - with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "merge-sharded-fsdp-weights", - str(config_path), - "--no-accelerate", - "--model-dir", - str(model_dir), - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["model_dir"] == str(model_dir) - assert result.exit_code == 0 - - -def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path): - """Test merge_sharded_fsdp_weights command with save_path option""" - with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock: - result = cli_runner.invoke( - cli, - [ - "merge-sharded-fsdp-weights", - str(config_path), - "--no-accelerate", - "--save-path", - "/path/to/save", - ], - ) - - assert mock.called - assert mock.call_args.kwargs["config"] == str(config_path) - assert mock.call_args.kwargs["save_path"] == "/path/to/save" - assert result.exit_code == 0 diff --git a/tests/cli/test_cli_preprocess.py b/tests/cli/test_cli_preprocess.py index 4719461aa..e2dd3a6c3 100644 --- a/tests/cli/test_cli_preprocess.py +++ b/tests/cli/test_cli_preprocess.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI preprocess command.""" + import shutil from pathlib import Path from unittest.mock import patch diff --git a/tests/cli/test_cli_version.py b/tests/cli/test_cli_version.py index 819780e94..533dd5c0e 100644 --- a/tests/cli/test_cli_version.py +++ b/tests/cli/test_cli_version.py @@ -1,4 +1,5 @@ """pytest tests for axolotl CLI --version""" + from axolotl.cli.main import cli diff --git a/tests/cli/test_utils.py b/tests/cli/test_utils.py index b88e4ac72..ecb0025e4 100644 --- a/tests/cli/test_utils.py +++ b/tests/cli/test_utils.py @@ -1,5 +1,6 @@ """pytest tests for axolotl CLI utils.""" # pylint: disable=redefined-outer-name + import json from unittest.mock import Mock, patch