diff --git a/src/axolotl/cli/config.py b/src/axolotl/cli/config.py index 986167f02..b6f79c74c 100644 --- a/src/axolotl/cli/config.py +++ b/src/axolotl/cli/config.py @@ -5,7 +5,7 @@ import os import tempfile from pathlib import Path from tempfile import NamedTemporaryFile -from typing import Union +from typing import Any, Optional, Union from urllib.parse import urlparse import requests @@ -32,6 +32,63 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars LOG = get_logger(__name__) + +def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any: + """Coerce a string CLI value to its most likely Python type. + + If an existing value is present in the config, its type is used to guide + casting. Otherwise, YAML-style inference is applied: booleans, ints, + floats, and None literals are recognised automatically. + + Args: + value: The raw value (typically a string from the CLI). + existing: An optional existing config value whose type guides coercion. + + Returns: + The value cast to the inferred or expected type. + """ + if not isinstance(value, str): + return value + + # If the config already has a typed value, cast to match + if existing is not None: + if isinstance(existing, bool): + return value.lower() in ("true", "1", "yes") + if isinstance(existing, int): + try: + return int(value) + except (ValueError, TypeError): + return value + if isinstance(existing, float): + try: + return float(value) + except (ValueError, TypeError): + return value + # For other types (str, list, dict, etc.), return as-is + return value + + # No existing value -- use YAML-style inference + lower = value.lower() + if lower in ("true", "yes"): + return True + if lower in ("false", "no"): + return False + if lower in ("null", "none", "~"): + return None + + # Try int then float + try: + return int(value) + except ValueError: + pass + try: + return float(value) + except ValueError: + pass + + return value + + API_KEY_FIELDS = {"comet_api_key"} TELEMETRY_MANAGER = TelemetryManager.get_instance() @@ -208,13 +265,37 @@ def load_cfg( # If there are any options passed in the cli, if it is something that seems valid # from the yaml, then overwrite the value cfg_keys = cfg.keys() + + # Separate nested (dot-notation) kwargs from flat kwargs + nested_kwargs: dict[str, dict[str, Any]] = {} + flat_kwargs: dict[str, Any] = {} for key, value in kwargs.items(): + if "__" in key: + parent, child = key.split("__", 1) + nested_kwargs.setdefault(parent, {})[child] = value + else: + flat_kwargs[key] = value + + # Apply flat kwargs + for key, value in flat_kwargs.items(): # If not strict, allow writing to cfg even if it's not in the yml already if key in cfg_keys or not cfg.strict: - if isinstance(cfg[key], bool): - cfg[key] = bool(value) - else: - cfg[key] = value + cfg[key] = _coerce_value(value, cfg.get(key)) + + # Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta) + for parent, children in nested_kwargs.items(): + if parent not in cfg_keys and cfg.strict: + continue + if cfg[parent] is None: + cfg[parent] = {} + if not isinstance(cfg[parent], dict): + LOG.warning( + "Overwriting non-dict value for '%s' with nested CLI overrides", parent + ) + cfg[parent] = {} + for child_key, child_value in children.items(): + existing_child = cfg[parent].get(child_key) + cfg[parent][child_key] = _coerce_value(child_value, existing_child) try: device_props = torch.cuda.get_device_properties("cuda") diff --git a/src/axolotl/cli/utils/args.py b/src/axolotl/cli/utils/args.py index 0aec737b8..d50a9163a 100644 --- a/src/axolotl/cli/utils/args.py +++ b/src/axolotl/cli/utils/args.py @@ -2,7 +2,7 @@ import dataclasses from functools import wraps -from types import NoneType +from types import NoneType, UnionType from typing import Any, Callable, Type, Union, get_args, get_origin import click @@ -20,7 +20,8 @@ def _strip_optional_type(field_type: type | str | None): If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise returns the input type unchanged. """ - if get_origin(field_type) is Union and type(None) in get_args(field_type): + is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType) + if is_union and type(None) in get_args(field_type): field_type = next( t for t in get_args(field_type) if not isinstance(t, NoneType) ) @@ -87,10 +88,70 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable: return decorator +def _is_pydantic_model(field_type: type) -> bool: + """Check if a type is a Pydantic BaseModel subclass.""" + try: + return isinstance(field_type, type) and issubclass(field_type, BaseModel) + except TypeError: + return False + + +def _get_field_description(field) -> str | None: + """Get description from a Pydantic field, checking both .description and json_schema_extra.""" + if field.description: + return field.description + if field.json_schema_extra and isinstance(field.json_schema_extra, dict): + return field.json_schema_extra.get("description") + return None + + +def _add_nested_model_options( + function: Callable, parent_name: str, model_class: Type[BaseModel] +) -> Callable: + """ + Add Click options for all fields of a nested Pydantic model using dot-notation. + + Note: Only single-level nesting is supported (e.g., ``--trl.beta``). + Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled. + + Args: + function: Click command function to add options to. + parent_name: Parent field name (e.g., "trl"). + model_class: Nested Pydantic model class. + + Returns: + Function with added Click options. + """ + for sub_name, sub_field in reversed(model_class.model_fields.items()): + sub_type = _strip_optional_type(sub_field.annotation) + # Use dot notation: --parent.sub_field + cli_name = f"{parent_name}.{sub_name}".replace("_", "-") + # The kwarg name uses double-underscore as separator + param_name = f"{parent_name}__{sub_name}" + description = _get_field_description(sub_field) + + if sub_type is bool: + option_name = f"--{cli_name}/--no-{cli_name}" + function = click.option( + option_name, param_name, default=None, help=description + )(function) + else: + option_name = f"--{cli_name}" + click_type = {str: str, int: int, float: float}.get(sub_type) + function = click.option( + option_name, param_name, default=None, type=click_type, help=description + )(function) + + return function + + def add_options_from_config(config_class: Type[BaseModel]) -> Callable: """ Create Click options from the fields of a Pydantic model. + For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are + generated for each sub-field (e.g., ``--trl.beta=0.1``). + Args: config_class: PyDantic model with fields to parse from the CLI @@ -103,6 +164,11 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable: for name, field in reversed(config_class.model_fields.items()): field_type = _strip_optional_type(field.annotation) + # Handle nested Pydantic models with dot-notation options + if _is_pydantic_model(field_type): + function = _add_nested_model_options(function, name, field_type) + continue + if field_type is bool: field_name = name.replace("_", "-") option_name = f"--{field_name}/--no-{field_name}" diff --git a/tests/cli/test_nested_options.py b/tests/cli/test_nested_options.py new file mode 100644 index 000000000..221de951e --- /dev/null +++ b/tests/cli/test_nested_options.py @@ -0,0 +1,227 @@ +"""Tests for nested config option handling via CLI dot-notation.""" + +import click +from click.testing import CliRunner +from pydantic import BaseModel, Field + +from axolotl.cli.utils.args import add_options_from_config, filter_none_kwargs + + +class InnerConfig(BaseModel): + """A nested config model for testing.""" + + beta: float | None = Field( + default=None, + description="Beta parameter.", + ) + host: str | None = Field( + default=None, + description="Server host.", + ) + use_feature: bool = Field( + default=False, + description="Whether to use the feature.", + ) + + +class OuterConfig(BaseModel): + """A top-level config model for testing.""" + + learning_rate: float | None = Field( + default=None, + description="Learning rate.", + ) + inner: InnerConfig | None = Field( + default=None, + description="Inner config.", + ) + name: str | None = Field( + default=None, + description="Model name.", + ) + + +class TestAddOptionsFromConfigNested: + """Test that add_options_from_config handles nested BaseModel fields.""" + + def setup_method(self): + self.runner = CliRunner() + + def test_nested_dot_notation_options_are_registered(self): + """Nested model fields should create --parent.child CLI options.""" + + @click.command() + @add_options_from_config(OuterConfig) + @filter_none_kwargs + def cmd(**kwargs): + for k, v in sorted(kwargs.items()): + click.echo(f"{k}={v}") + + result = self.runner.invoke(cmd, ["--inner.beta=0.5", "--inner.host=localhost"]) + assert result.exit_code == 0, result.output + assert "inner__beta=0.5" in result.output + assert "inner__host=localhost" in result.output + + def test_nested_bool_option(self): + """Nested bool fields should support --parent.field/--no-parent.field.""" + + @click.command() + @add_options_from_config(OuterConfig) + @filter_none_kwargs + def cmd(**kwargs): + for k, v in sorted(kwargs.items()): + click.echo(f"{k}={v}") + + result = self.runner.invoke(cmd, ["--inner.use-feature"]) + assert result.exit_code == 0, result.output + assert "inner__use_feature=True" in result.output + + def test_flat_and_nested_options_together(self): + """Flat and nested options should work together.""" + + @click.command() + @add_options_from_config(OuterConfig) + @filter_none_kwargs + def cmd(**kwargs): + for k, v in sorted(kwargs.items()): + click.echo(f"{k}={v}") + + result = self.runner.invoke( + cmd, ["--learning-rate=0.001", "--inner.beta=0.1", "--name=test"] + ) + assert result.exit_code == 0, result.output + assert "learning_rate=0.001" in result.output + assert "inner__beta=0.1" in result.output + assert "name=test" in result.output + + def test_no_nested_options_passed(self): + """When no nested options are passed, they should not appear in kwargs.""" + + @click.command() + @add_options_from_config(OuterConfig) + @filter_none_kwargs + def cmd(**kwargs): + click.echo(f"keys={sorted(kwargs.keys())}") + + result = self.runner.invoke(cmd, ["--learning-rate=0.01"]) + assert result.exit_code == 0, result.output + assert "inner__" not in result.output + + +class TestLoadCfgNestedKwargs: + """Test that load_cfg correctly applies nested (double-underscore) kwargs.""" + + @staticmethod + def _apply_nested_kwargs(cfg, kwargs): + """Helper that mirrors the nested kwargs handling from load_cfg, + including type coercion for string CLI values.""" + from axolotl.cli.config import _coerce_value + + nested_kwargs: dict = {} + flat_kwargs: dict = {} + for key, value in kwargs.items(): + if "__" in key: + parent, child = key.split("__", 1) + nested_kwargs.setdefault(parent, {})[child] = value + else: + flat_kwargs[key] = value + + cfg_keys = cfg.keys() + for key, value in flat_kwargs.items(): + if key in cfg_keys: + cfg[key] = _coerce_value(value, cfg.get(key)) + + for parent, children in nested_kwargs.items(): + if cfg[parent] is None: + cfg[parent] = {} + if not isinstance(cfg[parent], dict): + cfg[parent] = {} + for child_key, child_value in children.items(): + existing = cfg[parent].get(child_key) + cfg[parent][child_key] = _coerce_value(child_value, existing) + + return cfg + + def test_nested_kwargs_applied_to_cfg(self, tmp_path): + """Double-underscore kwargs should set nested config values.""" + from axolotl.utils.dict import DictDefault + + cfg = DictDefault({"trl": {"beta": 0.1}, "learning_rate": 0.01}) + # CLI passes strings, so simulate that + kwargs = { + "trl__beta": "0.5", + "trl__host": "192.168.1.1", + "learning_rate": "0.02", + } + + cfg = self._apply_nested_kwargs(cfg, kwargs) + + assert cfg["learning_rate"] == 0.02 + assert isinstance(cfg["learning_rate"], float) + assert cfg["trl"]["beta"] == 0.5 + assert isinstance(cfg["trl"]["beta"], float) + assert cfg["trl"]["host"] == "192.168.1.1" + + def test_nested_kwargs_creates_parent_if_none(self): + """If the parent key is None, nested kwargs should create the dict.""" + from axolotl.utils.dict import DictDefault + + cfg = DictDefault({"trl": None, "learning_rate": 0.01}) + cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"}) + + # No existing value, YAML-style inference: "0.5" -> 0.5 + assert cfg["trl"]["beta"] == 0.5 + assert isinstance(cfg["trl"]["beta"], float) + + def test_nested_kwargs_overwrites_string_parent(self): + """If the parent key is a string, it should be replaced with a dict.""" + from axolotl.utils.dict import DictDefault + + cfg = DictDefault({"trl": "some_string", "learning_rate": 0.01}) + cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"}) + + assert cfg["trl"]["beta"] == 0.5 + + +class TestCoerceValue: + """Test YAML-style type coercion for CLI string values.""" + + def test_coerce_with_existing_float(self): + from axolotl.cli.config import _coerce_value + + assert _coerce_value("0.5", 0.1) == 0.5 + assert isinstance(_coerce_value("0.5", 0.1), float) + + def test_coerce_with_existing_int(self): + from axolotl.cli.config import _coerce_value + + assert _coerce_value("42", 10) == 42 + assert isinstance(_coerce_value("42", 10), int) + + def test_coerce_with_existing_bool(self): + from axolotl.cli.config import _coerce_value + + assert _coerce_value("true", False) is True + assert _coerce_value("false", True) is False + assert _coerce_value("1", False) is True + assert _coerce_value("0", True) is False + + def test_coerce_yaml_inference_no_existing(self): + """Without an existing value, use YAML-style inference.""" + from axolotl.cli.config import _coerce_value + + assert _coerce_value("true", None) is True + assert _coerce_value("false", None) is False + assert _coerce_value("42", None) == 42 + assert isinstance(_coerce_value("42", None), int) + assert _coerce_value("3.14", None) == 3.14 + assert isinstance(_coerce_value("3.14", None), float) + assert _coerce_value("null", None) is None + assert _coerce_value("hello", None) == "hello" + + def test_coerce_non_string_passthrough(self): + """Non-string values should pass through unchanged.""" + from axolotl.cli.config import _coerce_value + + assert _coerce_value(0.5, 0.1) == 0.5 + assert _coerce_value(True, False) is True