pytest fixes

This commit is contained in:
Dan Saunders
2025-01-08 20:55:10 +00:00
parent 6e72baf287
commit 2b7b37413d
11 changed files with 37 additions and 53 deletions

View File

@@ -12,6 +12,7 @@ from axolotl.cli.utils import (
add_options_from_dataclass, add_options_from_dataclass,
build_command, build_command,
fetch_from_github, fetch_from_github,
filter_none_kwargs,
) )
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
from axolotl.utils import set_pytorch_cuda_alloc_conf from axolotl.utils import set_pytorch_cuda_alloc_conf
@@ -28,6 +29,7 @@ def cli():
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(PreprocessCliArgs) @add_options_from_dataclass(PreprocessCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def preprocess(config: str, **kwargs) -> None: def preprocess(config: str, **kwargs) -> None:
""" """
Preprocess datasets before training. Preprocess datasets before training.
@@ -37,8 +39,6 @@ def preprocess(config: str, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options. config options.
""" """
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.preprocess import do_cli from axolotl.cli.preprocess import do_cli
do_cli(config=config, **kwargs) do_cli(config=config, **kwargs)
@@ -53,6 +53,7 @@ def preprocess(config: str, **kwargs) -> None:
) )
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def train(config: str, accelerate: bool, **kwargs) -> None: def train(config: str, accelerate: bool, **kwargs) -> None:
""" """
Train or fine-tune a model. Train or fine-tune a model.
@@ -63,8 +64,6 @@ def train(config: str, accelerate: bool, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options. config options.
""" """
kwargs = {k: v for k, v in kwargs.items() if v is not None}
# Enable expandable segments for cuda allocation to improve VRAM usage # Enable expandable segments for cuda allocation to improve VRAM usage
set_pytorch_cuda_alloc_conf() set_pytorch_cuda_alloc_conf()
@@ -89,6 +88,7 @@ def train(config: str, accelerate: bool, **kwargs) -> None:
) )
@add_options_from_dataclass(EvaluateCliArgs) @add_options_from_dataclass(EvaluateCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def evaluate(config: str, accelerate: bool, **kwargs) -> None: def evaluate(config: str, accelerate: bool, **kwargs) -> None:
""" """
Evaluate a model. Evaluate a model.
@@ -99,8 +99,6 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options. config options.
""" """
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate: if accelerate:
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"] base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
if config: if config:
@@ -123,6 +121,7 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
@click.option("--gradio", is_flag=True, help="Launch Gradio interface") @click.option("--gradio", is_flag=True, help="Launch Gradio interface")
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None: def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
""" """
Run inference with a trained model. Run inference with a trained model.
@@ -157,6 +156,7 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
) )
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None: def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
""" """
Merge sharded FSDP model weights. Merge sharded FSDP model weights.
@@ -167,8 +167,6 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options. config options.
""" """
kwargs = {k: v for k, v in kwargs.items() if v is not None}
if accelerate: if accelerate:
base_cmd = [ base_cmd = [
"accelerate", "accelerate",
@@ -190,6 +188,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
@click.argument("config", type=click.Path(exists=True, path_type=str)) @click.argument("config", type=click.Path(exists=True, path_type=str))
@add_options_from_dataclass(TrainerCliArgs) @add_options_from_dataclass(TrainerCliArgs)
@add_options_from_config(AxolotlInputConfig) @add_options_from_config(AxolotlInputConfig)
@filter_none_kwargs
def merge_lora(config: str, **kwargs) -> None: def merge_lora(config: str, **kwargs) -> None:
""" """
Merge trained LoRA adapters into a base model. Merge trained LoRA adapters into a base model.
@@ -200,8 +199,6 @@ def merge_lora(config: str, **kwargs) -> None:
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl` kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
config options. config options.
""" """
kwargs = {k: v for k, v in kwargs.items() if v is not None}
from axolotl.cli.merge_lora import do_cli from axolotl.cli.merge_lora import do_cli
do_cli(config=config, **kwargs) do_cli(config=config, **kwargs)

View File

@@ -5,6 +5,7 @@ import dataclasses
import hashlib import hashlib
import json import json
import logging import logging
from functools import wraps
from pathlib import Path from pathlib import Path
from types import NoneType from types import NoneType
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
@@ -16,6 +17,26 @@ from pydantic import BaseModel
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
def filter_none_kwargs(func):
"""
Wraps function to remove `None`-valued `kwargs`.
Args:
func: Function to wrap.
Returns:
Wrapped function.
"""
@wraps(func)
def wrapper(*args, **kwargs):
"""Filters out `None`-valued `kwargs`."""
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
return func(*args, **filtered_kwargs)
return wrapper
def add_options_from_dataclass(config_class: Type[Any]): def add_options_from_dataclass(config_class: Type[Any]):
""" """
Create Click options from the fields of a dataclass. Create Click options from the fields of a dataclass.

View File

@@ -1,4 +1,5 @@
"""Shared pytest fixtures for cli module.""" """Shared pytest fixtures for cli module."""
import pytest import pytest
from click.testing import CliRunner from click.testing import CliRunner

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI fetch command.""" """pytest tests for axolotl CLI fetch command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import fetch from axolotl.cli.main import fetch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI inference command.""" """pytest tests for axolotl CLI inference command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,4 +1,5 @@
"""General pytest tests for axolotl.cli.main interface.""" """General pytest tests for axolotl.cli.main interface."""
from axolotl.cli.main import build_command, cli from axolotl.cli.main import build_command, cli

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI merge_lora command.""" """pytest tests for axolotl CLI merge_lora command."""
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command.""" """pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
from unittest.mock import patch from unittest.mock import patch
from axolotl.cli.main import cli from axolotl.cli.main import cli
@@ -15,46 +16,3 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
assert mock.called assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path) assert mock.call_args.kwargs["config"] == str(config_path)
assert result.exit_code == 0 assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
"""Test merge_sharded_fsdp_weights command with model_dir option"""
model_dir = tmp_path / "model"
model_dir.mkdir()
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--model-dir",
str(model_dir),
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
assert result.exit_code == 0
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
"""Test merge_sharded_fsdp_weights command with save_path option"""
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
result = cli_runner.invoke(
cli,
[
"merge-sharded-fsdp-weights",
str(config_path),
"--no-accelerate",
"--save-path",
"/path/to/save",
],
)
assert mock.called
assert mock.call_args.kwargs["config"] == str(config_path)
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
assert result.exit_code == 0

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI preprocess command.""" """pytest tests for axolotl CLI preprocess command."""
import shutil import shutil
from pathlib import Path from pathlib import Path
from unittest.mock import patch from unittest.mock import patch

View File

@@ -1,4 +1,5 @@
"""pytest tests for axolotl CLI --version""" """pytest tests for axolotl CLI --version"""
from axolotl.cli.main import cli from axolotl.cli.main import cli

View File

@@ -1,5 +1,6 @@
"""pytest tests for axolotl CLI utils.""" """pytest tests for axolotl CLI utils."""
# pylint: disable=redefined-outer-name # pylint: disable=redefined-outer-name
import json import json
from unittest.mock import Mock, patch from unittest.mock import Mock, patch