feat: support dot-notation CLI args for nested config options (#3419)
* feat: support dot-notation CLI args for nested config options Add support for overriding nested config fields (like TRL config) via CLI using dot-notation, e.g.: axolotl train grpo.yaml --trl.vllm-server-host=10.0.0.1 --trl.beta=0.1 Changes: - args.py: Detect BaseModel subclass fields and generate dot-notation CLI options (--parent.child) that map to double-underscore kwargs (parent__child). Also fix _strip_optional_type for Python 3.10+ union syntax (X | None). - config.py: Handle double-underscore kwargs in load_cfg by setting nested dict values on the config. - Add tests for nested option handling. Fixes #2702 * Address CodeRabbit review: fix string parent bug, add type hints and docstring Signed-off-by: Manas Vardhan <manasvardhan@gmail.com> * Add type coercion for CLI kwargs and fix pre-commit issues - Add _coerce_value() for YAML-style type inference on string CLI args - When existing config value has a type (int/float/bool), cast to match - When no existing value, infer type from string (true/false, ints, floats, null) - Apply coercion to both flat and nested (dot-notation) kwargs - Fix unused pytest import (pre-commit/ruff) - Update tests to pass string values (matching real CLI behavior) - Add dedicated TestCoerceValue test class Addresses maintainer feedback on type casting for nested kwargs. --------- Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Union
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -32,6 +32,63 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
|
||||
"""Coerce a string CLI value to its most likely Python type.
|
||||
|
||||
If an existing value is present in the config, its type is used to guide
|
||||
casting. Otherwise, YAML-style inference is applied: booleans, ints,
|
||||
floats, and None literals are recognised automatically.
|
||||
|
||||
Args:
|
||||
value: The raw value (typically a string from the CLI).
|
||||
existing: An optional existing config value whose type guides coercion.
|
||||
|
||||
Returns:
|
||||
The value cast to the inferred or expected type.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
# If the config already has a typed value, cast to match
|
||||
if existing is not None:
|
||||
if isinstance(existing, bool):
|
||||
return value.lower() in ("true", "1", "yes")
|
||||
if isinstance(existing, int):
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
if isinstance(existing, float):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
# For other types (str, list, dict, etc.), return as-is
|
||||
return value
|
||||
|
||||
# No existing value -- use YAML-style inference
|
||||
lower = value.lower()
|
||||
if lower in ("true", "yes"):
|
||||
return True
|
||||
if lower in ("false", "no"):
|
||||
return False
|
||||
if lower in ("null", "none", "~"):
|
||||
return None
|
||||
|
||||
# Try int then float
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
API_KEY_FIELDS = {"comet_api_key"}
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
@@ -208,13 +265,37 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
|
||||
# Separate nested (dot-notation) kwargs from flat kwargs
|
||||
nested_kwargs: dict[str, dict[str, Any]] = {}
|
||||
flat_kwargs: dict[str, Any] = {}
|
||||
for key, value in kwargs.items():
|
||||
if "__" in key:
|
||||
parent, child = key.split("__", 1)
|
||||
nested_kwargs.setdefault(parent, {})[child] = value
|
||||
else:
|
||||
flat_kwargs[key] = value
|
||||
|
||||
# Apply flat kwargs
|
||||
for key, value in flat_kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[key] = value
|
||||
cfg[key] = _coerce_value(value, cfg.get(key))
|
||||
|
||||
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
|
||||
for parent, children in nested_kwargs.items():
|
||||
if parent not in cfg_keys and cfg.strict:
|
||||
continue
|
||||
if cfg[parent] is None:
|
||||
cfg[parent] = {}
|
||||
if not isinstance(cfg[parent], dict):
|
||||
LOG.warning(
|
||||
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
|
||||
)
|
||||
cfg[parent] = {}
|
||||
for child_key, child_value in children.items():
|
||||
existing_child = cfg[parent].get(child_key)
|
||||
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType
|
||||
from types import NoneType, UnionType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
@@ -20,7 +20,8 @@ def _strip_optional_type(field_type: type | str | None):
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
|
||||
if is_union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
@@ -87,10 +88,70 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def _is_pydantic_model(field_type: type) -> bool:
|
||||
"""Check if a type is a Pydantic BaseModel subclass."""
|
||||
try:
|
||||
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_field_description(field) -> str | None:
|
||||
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
|
||||
if field.description:
|
||||
return field.description
|
||||
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
|
||||
return field.json_schema_extra.get("description")
|
||||
return None
|
||||
|
||||
|
||||
def _add_nested_model_options(
|
||||
function: Callable, parent_name: str, model_class: Type[BaseModel]
|
||||
) -> Callable:
|
||||
"""
|
||||
Add Click options for all fields of a nested Pydantic model using dot-notation.
|
||||
|
||||
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
|
||||
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
|
||||
|
||||
Args:
|
||||
function: Click command function to add options to.
|
||||
parent_name: Parent field name (e.g., "trl").
|
||||
model_class: Nested Pydantic model class.
|
||||
|
||||
Returns:
|
||||
Function with added Click options.
|
||||
"""
|
||||
for sub_name, sub_field in reversed(model_class.model_fields.items()):
|
||||
sub_type = _strip_optional_type(sub_field.annotation)
|
||||
# Use dot notation: --parent.sub_field
|
||||
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
|
||||
# The kwarg name uses double-underscore as separator
|
||||
param_name = f"{parent_name}__{sub_name}"
|
||||
description = _get_field_description(sub_field)
|
||||
|
||||
if sub_type is bool:
|
||||
option_name = f"--{cli_name}/--no-{cli_name}"
|
||||
function = click.option(
|
||||
option_name, param_name, default=None, help=description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{cli_name}"
|
||||
click_type = {str: str, int: int, float: float}.get(sub_type)
|
||||
function = click.option(
|
||||
option_name, param_name, default=None, type=click_type, help=description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
|
||||
generated for each sub-field (e.g., ``--trl.beta=0.1``).
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
@@ -103,6 +164,11 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
# Handle nested Pydantic models with dot-notation options
|
||||
if _is_pydantic_model(field_type):
|
||||
function = _add_nested_model_options(function, name, field_type)
|
||||
continue
|
||||
|
||||
if field_type is bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
|
||||
227
tests/cli/test_nested_options.py
Normal file
227
tests/cli/test_nested_options.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Tests for nested config option handling via CLI dot-notation."""
|
||||
|
||||
import click
|
||||
from click.testing import CliRunner
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from axolotl.cli.utils.args import add_options_from_config, filter_none_kwargs
|
||||
|
||||
|
||||
class InnerConfig(BaseModel):
|
||||
"""A nested config model for testing."""
|
||||
|
||||
beta: float | None = Field(
|
||||
default=None,
|
||||
description="Beta parameter.",
|
||||
)
|
||||
host: str | None = Field(
|
||||
default=None,
|
||||
description="Server host.",
|
||||
)
|
||||
use_feature: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use the feature.",
|
||||
)
|
||||
|
||||
|
||||
class OuterConfig(BaseModel):
|
||||
"""A top-level config model for testing."""
|
||||
|
||||
learning_rate: float | None = Field(
|
||||
default=None,
|
||||
description="Learning rate.",
|
||||
)
|
||||
inner: InnerConfig | None = Field(
|
||||
default=None,
|
||||
description="Inner config.",
|
||||
)
|
||||
name: str | None = Field(
|
||||
default=None,
|
||||
description="Model name.",
|
||||
)
|
||||
|
||||
|
||||
class TestAddOptionsFromConfigNested:
|
||||
"""Test that add_options_from_config handles nested BaseModel fields."""
|
||||
|
||||
def setup_method(self):
|
||||
self.runner = CliRunner()
|
||||
|
||||
def test_nested_dot_notation_options_are_registered(self):
|
||||
"""Nested model fields should create --parent.child CLI options."""
|
||||
|
||||
@click.command()
|
||||
@add_options_from_config(OuterConfig)
|
||||
@filter_none_kwargs
|
||||
def cmd(**kwargs):
|
||||
for k, v in sorted(kwargs.items()):
|
||||
click.echo(f"{k}={v}")
|
||||
|
||||
result = self.runner.invoke(cmd, ["--inner.beta=0.5", "--inner.host=localhost"])
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "inner__beta=0.5" in result.output
|
||||
assert "inner__host=localhost" in result.output
|
||||
|
||||
def test_nested_bool_option(self):
|
||||
"""Nested bool fields should support --parent.field/--no-parent.field."""
|
||||
|
||||
@click.command()
|
||||
@add_options_from_config(OuterConfig)
|
||||
@filter_none_kwargs
|
||||
def cmd(**kwargs):
|
||||
for k, v in sorted(kwargs.items()):
|
||||
click.echo(f"{k}={v}")
|
||||
|
||||
result = self.runner.invoke(cmd, ["--inner.use-feature"])
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "inner__use_feature=True" in result.output
|
||||
|
||||
def test_flat_and_nested_options_together(self):
|
||||
"""Flat and nested options should work together."""
|
||||
|
||||
@click.command()
|
||||
@add_options_from_config(OuterConfig)
|
||||
@filter_none_kwargs
|
||||
def cmd(**kwargs):
|
||||
for k, v in sorted(kwargs.items()):
|
||||
click.echo(f"{k}={v}")
|
||||
|
||||
result = self.runner.invoke(
|
||||
cmd, ["--learning-rate=0.001", "--inner.beta=0.1", "--name=test"]
|
||||
)
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "learning_rate=0.001" in result.output
|
||||
assert "inner__beta=0.1" in result.output
|
||||
assert "name=test" in result.output
|
||||
|
||||
def test_no_nested_options_passed(self):
|
||||
"""When no nested options are passed, they should not appear in kwargs."""
|
||||
|
||||
@click.command()
|
||||
@add_options_from_config(OuterConfig)
|
||||
@filter_none_kwargs
|
||||
def cmd(**kwargs):
|
||||
click.echo(f"keys={sorted(kwargs.keys())}")
|
||||
|
||||
result = self.runner.invoke(cmd, ["--learning-rate=0.01"])
|
||||
assert result.exit_code == 0, result.output
|
||||
assert "inner__" not in result.output
|
||||
|
||||
|
||||
class TestLoadCfgNestedKwargs:
|
||||
"""Test that load_cfg correctly applies nested (double-underscore) kwargs."""
|
||||
|
||||
@staticmethod
|
||||
def _apply_nested_kwargs(cfg, kwargs):
|
||||
"""Helper that mirrors the nested kwargs handling from load_cfg,
|
||||
including type coercion for string CLI values."""
|
||||
from axolotl.cli.config import _coerce_value
|
||||
|
||||
nested_kwargs: dict = {}
|
||||
flat_kwargs: dict = {}
|
||||
for key, value in kwargs.items():
|
||||
if "__" in key:
|
||||
parent, child = key.split("__", 1)
|
||||
nested_kwargs.setdefault(parent, {})[child] = value
|
||||
else:
|
||||
flat_kwargs[key] = value
|
||||
|
||||
cfg_keys = cfg.keys()
|
||||
for key, value in flat_kwargs.items():
|
||||
if key in cfg_keys:
|
||||
cfg[key] = _coerce_value(value, cfg.get(key))
|
||||
|
||||
for parent, children in nested_kwargs.items():
|
||||
if cfg[parent] is None:
|
||||
cfg[parent] = {}
|
||||
if not isinstance(cfg[parent], dict):
|
||||
cfg[parent] = {}
|
||||
for child_key, child_value in children.items():
|
||||
existing = cfg[parent].get(child_key)
|
||||
cfg[parent][child_key] = _coerce_value(child_value, existing)
|
||||
|
||||
return cfg
|
||||
|
||||
def test_nested_kwargs_applied_to_cfg(self, tmp_path):
|
||||
"""Double-underscore kwargs should set nested config values."""
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
cfg = DictDefault({"trl": {"beta": 0.1}, "learning_rate": 0.01})
|
||||
# CLI passes strings, so simulate that
|
||||
kwargs = {
|
||||
"trl__beta": "0.5",
|
||||
"trl__host": "192.168.1.1",
|
||||
"learning_rate": "0.02",
|
||||
}
|
||||
|
||||
cfg = self._apply_nested_kwargs(cfg, kwargs)
|
||||
|
||||
assert cfg["learning_rate"] == 0.02
|
||||
assert isinstance(cfg["learning_rate"], float)
|
||||
assert cfg["trl"]["beta"] == 0.5
|
||||
assert isinstance(cfg["trl"]["beta"], float)
|
||||
assert cfg["trl"]["host"] == "192.168.1.1"
|
||||
|
||||
def test_nested_kwargs_creates_parent_if_none(self):
|
||||
"""If the parent key is None, nested kwargs should create the dict."""
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
cfg = DictDefault({"trl": None, "learning_rate": 0.01})
|
||||
cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"})
|
||||
|
||||
# No existing value, YAML-style inference: "0.5" -> 0.5
|
||||
assert cfg["trl"]["beta"] == 0.5
|
||||
assert isinstance(cfg["trl"]["beta"], float)
|
||||
|
||||
def test_nested_kwargs_overwrites_string_parent(self):
|
||||
"""If the parent key is a string, it should be replaced with a dict."""
|
||||
from axolotl.utils.dict import DictDefault
|
||||
|
||||
cfg = DictDefault({"trl": "some_string", "learning_rate": 0.01})
|
||||
cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"})
|
||||
|
||||
assert cfg["trl"]["beta"] == 0.5
|
||||
|
||||
|
||||
class TestCoerceValue:
|
||||
"""Test YAML-style type coercion for CLI string values."""
|
||||
|
||||
def test_coerce_with_existing_float(self):
|
||||
from axolotl.cli.config import _coerce_value
|
||||
|
||||
assert _coerce_value("0.5", 0.1) == 0.5
|
||||
assert isinstance(_coerce_value("0.5", 0.1), float)
|
||||
|
||||
def test_coerce_with_existing_int(self):
|
||||
from axolotl.cli.config import _coerce_value
|
||||
|
||||
assert _coerce_value("42", 10) == 42
|
||||
assert isinstance(_coerce_value("42", 10), int)
|
||||
|
||||
def test_coerce_with_existing_bool(self):
|
||||
from axolotl.cli.config import _coerce_value
|
||||
|
||||
assert _coerce_value("true", False) is True
|
||||
assert _coerce_value("false", True) is False
|
||||
assert _coerce_value("1", False) is True
|
||||
assert _coerce_value("0", True) is False
|
||||
|
||||
def test_coerce_yaml_inference_no_existing(self):
|
||||
"""Without an existing value, use YAML-style inference."""
|
||||
from axolotl.cli.config import _coerce_value
|
||||
|
||||
assert _coerce_value("true", None) is True
|
||||
assert _coerce_value("false", None) is False
|
||||
assert _coerce_value("42", None) == 42
|
||||
assert isinstance(_coerce_value("42", None), int)
|
||||
assert _coerce_value("3.14", None) == 3.14
|
||||
assert isinstance(_coerce_value("3.14", None), float)
|
||||
assert _coerce_value("null", None) is None
|
||||
assert _coerce_value("hello", None) == "hello"
|
||||
|
||||
def test_coerce_non_string_passthrough(self):
|
||||
"""Non-string values should pass through unchanged."""
|
||||
from axolotl.cli.config import _coerce_value
|
||||
|
||||
assert _coerce_value(0.5, 0.1) == 0.5
|
||||
assert _coerce_value(True, False) is True
|
||||
Reference in New Issue
Block a user