pytest fixes
This commit is contained in:
@@ -12,6 +12,7 @@ from axolotl.cli.utils import (
|
|||||||
add_options_from_dataclass,
|
add_options_from_dataclass,
|
||||||
build_command,
|
build_command,
|
||||||
fetch_from_github,
|
fetch_from_github,
|
||||||
|
filter_none_kwargs,
|
||||||
)
|
)
|
||||||
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
from axolotl.common.cli import EvaluateCliArgs, PreprocessCliArgs, TrainerCliArgs
|
||||||
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
from axolotl.utils import set_pytorch_cuda_alloc_conf
|
||||||
@@ -28,6 +29,7 @@ def cli():
|
|||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@add_options_from_dataclass(PreprocessCliArgs)
|
@add_options_from_dataclass(PreprocessCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
@filter_none_kwargs
|
||||||
def preprocess(config: str, **kwargs) -> None:
|
def preprocess(config: str, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Preprocess datasets before training.
|
Preprocess datasets before training.
|
||||||
@@ -37,8 +39,6 @@ def preprocess(config: str, **kwargs) -> None:
|
|||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
from axolotl.cli.preprocess import do_cli
|
from axolotl.cli.preprocess import do_cli
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
@@ -53,6 +53,7 @@ def preprocess(config: str, **kwargs) -> None:
|
|||||||
)
|
)
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
@filter_none_kwargs
|
||||||
def train(config: str, accelerate: bool, **kwargs) -> None:
|
def train(config: str, accelerate: bool, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Train or fine-tune a model.
|
Train or fine-tune a model.
|
||||||
@@ -63,8 +64,6 @@ def train(config: str, accelerate: bool, **kwargs) -> None:
|
|||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
# Enable expandable segments for cuda allocation to improve VRAM usage
|
# Enable expandable segments for cuda allocation to improve VRAM usage
|
||||||
set_pytorch_cuda_alloc_conf()
|
set_pytorch_cuda_alloc_conf()
|
||||||
|
|
||||||
@@ -89,6 +88,7 @@ def train(config: str, accelerate: bool, **kwargs) -> None:
|
|||||||
)
|
)
|
||||||
@add_options_from_dataclass(EvaluateCliArgs)
|
@add_options_from_dataclass(EvaluateCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
@filter_none_kwargs
|
||||||
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Evaluate a model.
|
Evaluate a model.
|
||||||
@@ -99,8 +99,6 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
|||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
if accelerate:
|
if accelerate:
|
||||||
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
base_cmd = ["accelerate", "launch", "-m", "axolotl.cli.evaluate"]
|
||||||
if config:
|
if config:
|
||||||
@@ -123,6 +121,7 @@ def evaluate(config: str, accelerate: bool, **kwargs) -> None:
|
|||||||
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
@click.option("--gradio", is_flag=True, help="Launch Gradio interface")
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
@filter_none_kwargs
|
||||||
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Run inference with a trained model.
|
Run inference with a trained model.
|
||||||
@@ -157,6 +156,7 @@ def inference(config: str, accelerate: bool, gradio: bool, **kwargs) -> None:
|
|||||||
)
|
)
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
@filter_none_kwargs
|
||||||
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Merge sharded FSDP model weights.
|
Merge sharded FSDP model weights.
|
||||||
@@ -167,8 +167,6 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
|||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
if accelerate:
|
if accelerate:
|
||||||
base_cmd = [
|
base_cmd = [
|
||||||
"accelerate",
|
"accelerate",
|
||||||
@@ -190,6 +188,7 @@ def merge_sharded_fsdp_weights(config: str, accelerate: bool, **kwargs) -> None:
|
|||||||
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
@click.argument("config", type=click.Path(exists=True, path_type=str))
|
||||||
@add_options_from_dataclass(TrainerCliArgs)
|
@add_options_from_dataclass(TrainerCliArgs)
|
||||||
@add_options_from_config(AxolotlInputConfig)
|
@add_options_from_config(AxolotlInputConfig)
|
||||||
|
@filter_none_kwargs
|
||||||
def merge_lora(config: str, **kwargs) -> None:
|
def merge_lora(config: str, **kwargs) -> None:
|
||||||
"""
|
"""
|
||||||
Merge trained LoRA adapters into a base model.
|
Merge trained LoRA adapters into a base model.
|
||||||
@@ -200,8 +199,6 @@ def merge_lora(config: str, **kwargs) -> None:
|
|||||||
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
kwargs: Additional keyword arguments which correspond to CLI args or `axolotl`
|
||||||
config options.
|
config options.
|
||||||
"""
|
"""
|
||||||
kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
|
||||||
|
|
||||||
from axolotl.cli.merge_lora import do_cli
|
from axolotl.cli.merge_lora import do_cli
|
||||||
|
|
||||||
do_cli(config=config, **kwargs)
|
do_cli(config=config, **kwargs)
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import dataclasses
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from functools import wraps
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import NoneType
|
from types import NoneType
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
|
from typing import Any, Dict, List, Optional, Tuple, Type, Union, get_args, get_origin
|
||||||
@@ -16,6 +17,26 @@ from pydantic import BaseModel
|
|||||||
LOG = logging.getLogger(__name__)
|
LOG = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def filter_none_kwargs(func):
|
||||||
|
"""
|
||||||
|
Wraps function to remove `None`-valued `kwargs`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
func: Function to wrap.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Wrapped function.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
"""Filters out `None`-valued `kwargs`."""
|
||||||
|
filtered_kwargs = {k: v for k, v in kwargs.items() if v is not None}
|
||||||
|
return func(*args, **filtered_kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def add_options_from_dataclass(config_class: Type[Any]):
|
def add_options_from_dataclass(config_class: Type[Any]):
|
||||||
"""
|
"""
|
||||||
Create Click options from the fields of a dataclass.
|
Create Click options from the fields of a dataclass.
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Shared pytest fixtures for cli module."""
|
"""Shared pytest fixtures for cli module."""
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI fetch command."""
|
"""pytest tests for axolotl CLI fetch command."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import fetch
|
from axolotl.cli.main import fetch
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI inference command."""
|
"""pytest tests for axolotl CLI inference command."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""General pytest tests for axolotl.cli.main interface."""
|
"""General pytest tests for axolotl.cli.main interface."""
|
||||||
|
|
||||||
from axolotl.cli.main import build_command, cli
|
from axolotl.cli.main import build_command, cli
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI merge_lora command."""
|
"""pytest tests for axolotl CLI merge_lora command."""
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
"""pytest tests for axolotl CLI merge_sharded_fsdp_weights command."""
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
|
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
@@ -15,46 +16,3 @@ def test_merge_sharded_fsdp_weights_no_accelerate(cli_runner, config_path):
|
|||||||
assert mock.called
|
assert mock.called
|
||||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
assert mock.call_args.kwargs["config"] == str(config_path)
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
|
|
||||||
|
|
||||||
def test_merge_sharded_fsdp_weights_with_model_dir(cli_runner, config_path, tmp_path):
|
|
||||||
"""Test merge_sharded_fsdp_weights command with model_dir option"""
|
|
||||||
model_dir = tmp_path / "model"
|
|
||||||
model_dir.mkdir()
|
|
||||||
|
|
||||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
|
||||||
result = cli_runner.invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"merge-sharded-fsdp-weights",
|
|
||||||
str(config_path),
|
|
||||||
"--no-accelerate",
|
|
||||||
"--model-dir",
|
|
||||||
str(model_dir),
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mock.called
|
|
||||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
|
||||||
assert mock.call_args.kwargs["model_dir"] == str(model_dir)
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_merge_sharded_fsdp_weights_with_save_path(cli_runner, config_path):
|
|
||||||
"""Test merge_sharded_fsdp_weights command with save_path option"""
|
|
||||||
with patch("axolotl.cli.merge_sharded_fsdp_weights.do_cli") as mock:
|
|
||||||
result = cli_runner.invoke(
|
|
||||||
cli,
|
|
||||||
[
|
|
||||||
"merge-sharded-fsdp-weights",
|
|
||||||
str(config_path),
|
|
||||||
"--no-accelerate",
|
|
||||||
"--save-path",
|
|
||||||
"/path/to/save",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert mock.called
|
|
||||||
assert mock.call_args.kwargs["config"] == str(config_path)
|
|
||||||
assert mock.call_args.kwargs["save_path"] == "/path/to/save"
|
|
||||||
assert result.exit_code == 0
|
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI preprocess command."""
|
"""pytest tests for axolotl CLI preprocess command."""
|
||||||
|
|
||||||
import shutil
|
import shutil
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""pytest tests for axolotl CLI --version"""
|
"""pytest tests for axolotl CLI --version"""
|
||||||
|
|
||||||
from axolotl.cli.main import cli
|
from axolotl.cli.main import cli
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
"""pytest tests for axolotl CLI utils."""
|
"""pytest tests for axolotl CLI utils."""
|
||||||
# pylint: disable=redefined-outer-name
|
# pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user