feat: support dot-notation CLI args for nested config options (#3419)

* feat: support dot-notation CLI args for nested config options

Add support for overriding nested config fields (like TRL config) via
CLI using dot-notation, e.g.:
  axolotl train grpo.yaml --trl.vllm-server-host=10.0.0.1 --trl.beta=0.1

Changes:
- args.py: Detect BaseModel subclass fields and generate dot-notation
  CLI options (--parent.child) that map to double-underscore kwargs
  (parent__child). Also fix _strip_optional_type for Python 3.10+
  union syntax (X | None).
- config.py: Handle double-underscore kwargs in load_cfg by setting
  nested dict values on the config.
- Add tests for nested option handling.

Fixes #2702

* Address CodeRabbit review: fix string parent bug, add type hints and docstring

Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>

* Add type coercion for CLI kwargs and fix pre-commit issues

- Add _coerce_value() for YAML-style type inference on string CLI args
- When existing config value has a type (int/float/bool), cast to match
- When no existing value, infer type from string (true/false, ints, floats, null)
- Apply coercion to both flat and nested (dot-notation) kwargs
- Fix unused pytest import (pre-commit/ruff)
- Update tests to pass string values (matching real CLI behavior)
- Add dedicated TestCoerceValue test class

Addresses maintainer feedback on type casting for nested kwargs.

---------

Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>
This commit is contained in:
Manas Vardhan
2026-02-23 07:10:06 -08:00
committed by GitHub
parent 3f30572d4a
commit 5ed455715e
3 changed files with 381 additions and 7 deletions

View File

@@ -5,7 +5,7 @@ import os
import tempfile
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Union
from typing import Any, Optional, Union
from urllib.parse import urlparse
import requests
@@ -32,6 +32,63 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
LOG = get_logger(__name__)
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
"""Coerce a string CLI value to its most likely Python type.
If an existing value is present in the config, its type is used to guide
casting. Otherwise, YAML-style inference is applied: booleans, ints,
floats, and None literals are recognised automatically.
Args:
value: The raw value (typically a string from the CLI).
existing: An optional existing config value whose type guides coercion.
Returns:
The value cast to the inferred or expected type.
"""
if not isinstance(value, str):
return value
# If the config already has a typed value, cast to match
if existing is not None:
if isinstance(existing, bool):
return value.lower() in ("true", "1", "yes")
if isinstance(existing, int):
try:
return int(value)
except (ValueError, TypeError):
return value
if isinstance(existing, float):
try:
return float(value)
except (ValueError, TypeError):
return value
# For other types (str, list, dict, etc.), return as-is
return value
# No existing value -- use YAML-style inference
lower = value.lower()
if lower in ("true", "yes"):
return True
if lower in ("false", "no"):
return False
if lower in ("null", "none", "~"):
return None
# Try int then float
try:
return int(value)
except ValueError:
pass
try:
return float(value)
except ValueError:
pass
return value
API_KEY_FIELDS = {"comet_api_key"}
TELEMETRY_MANAGER = TelemetryManager.get_instance()
@@ -208,13 +265,37 @@ def load_cfg(
# If there are any options passed in the cli, if it is something that seems valid
# from the yaml, then overwrite the value
cfg_keys = cfg.keys()
# Separate nested (dot-notation) kwargs from flat kwargs
nested_kwargs: dict[str, dict[str, Any]] = {}
flat_kwargs: dict[str, Any] = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
# Apply flat kwargs
for key, value in flat_kwargs.items():
# If not strict, allow writing to cfg even if it's not in the yml already
if key in cfg_keys or not cfg.strict:
if isinstance(cfg[key], bool):
cfg[key] = bool(value)
else:
cfg[key] = value
cfg[key] = _coerce_value(value, cfg.get(key))
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
for parent, children in nested_kwargs.items():
if parent not in cfg_keys and cfg.strict:
continue
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
LOG.warning(
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
)
cfg[parent] = {}
for child_key, child_value in children.items():
existing_child = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
try:
device_props = torch.cuda.get_device_properties("cuda")

View File

@@ -2,7 +2,7 @@
import dataclasses
from functools import wraps
from types import NoneType
from types import NoneType, UnionType
from typing import Any, Callable, Type, Union, get_args, get_origin
import click
@@ -20,7 +20,8 @@ def _strip_optional_type(field_type: type | str | None):
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
returns the input type unchanged.
"""
if get_origin(field_type) is Union and type(None) in get_args(field_type):
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
if is_union and type(None) in get_args(field_type):
field_type = next(
t for t in get_args(field_type) if not isinstance(t, NoneType)
)
@@ -87,10 +88,70 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
return decorator
def _is_pydantic_model(field_type: type) -> bool:
"""Check if a type is a Pydantic BaseModel subclass."""
try:
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
except TypeError:
return False
def _get_field_description(field) -> str | None:
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
if field.description:
return field.description
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
return field.json_schema_extra.get("description")
return None
def _add_nested_model_options(
function: Callable, parent_name: str, model_class: Type[BaseModel]
) -> Callable:
"""
Add Click options for all fields of a nested Pydantic model using dot-notation.
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
Args:
function: Click command function to add options to.
parent_name: Parent field name (e.g., "trl").
model_class: Nested Pydantic model class.
Returns:
Function with added Click options.
"""
for sub_name, sub_field in reversed(model_class.model_fields.items()):
sub_type = _strip_optional_type(sub_field.annotation)
# Use dot notation: --parent.sub_field
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
# The kwarg name uses double-underscore as separator
param_name = f"{parent_name}__{sub_name}"
description = _get_field_description(sub_field)
if sub_type is bool:
option_name = f"--{cli_name}/--no-{cli_name}"
function = click.option(
option_name, param_name, default=None, help=description
)(function)
else:
option_name = f"--{cli_name}"
click_type = {str: str, int: int, float: float}.get(sub_type)
function = click.option(
option_name, param_name, default=None, type=click_type, help=description
)(function)
return function
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
"""
Create Click options from the fields of a Pydantic model.
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
generated for each sub-field (e.g., ``--trl.beta=0.1``).
Args:
config_class: PyDantic model with fields to parse from the CLI
@@ -103,6 +164,11 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
for name, field in reversed(config_class.model_fields.items()):
field_type = _strip_optional_type(field.annotation)
# Handle nested Pydantic models with dot-notation options
if _is_pydantic_model(field_type):
function = _add_nested_model_options(function, name, field_type)
continue
if field_type is bool:
field_name = name.replace("_", "-")
option_name = f"--{field_name}/--no-{field_name}"

View File

@@ -0,0 +1,227 @@
"""Tests for nested config option handling via CLI dot-notation."""
import click
from click.testing import CliRunner
from pydantic import BaseModel, Field
from axolotl.cli.utils.args import add_options_from_config, filter_none_kwargs
class InnerConfig(BaseModel):
"""A nested config model for testing."""
beta: float | None = Field(
default=None,
description="Beta parameter.",
)
host: str | None = Field(
default=None,
description="Server host.",
)
use_feature: bool = Field(
default=False,
description="Whether to use the feature.",
)
class OuterConfig(BaseModel):
"""A top-level config model for testing."""
learning_rate: float | None = Field(
default=None,
description="Learning rate.",
)
inner: InnerConfig | None = Field(
default=None,
description="Inner config.",
)
name: str | None = Field(
default=None,
description="Model name.",
)
class TestAddOptionsFromConfigNested:
"""Test that add_options_from_config handles nested BaseModel fields."""
def setup_method(self):
self.runner = CliRunner()
def test_nested_dot_notation_options_are_registered(self):
"""Nested model fields should create --parent.child CLI options."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(cmd, ["--inner.beta=0.5", "--inner.host=localhost"])
assert result.exit_code == 0, result.output
assert "inner__beta=0.5" in result.output
assert "inner__host=localhost" in result.output
def test_nested_bool_option(self):
"""Nested bool fields should support --parent.field/--no-parent.field."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(cmd, ["--inner.use-feature"])
assert result.exit_code == 0, result.output
assert "inner__use_feature=True" in result.output
def test_flat_and_nested_options_together(self):
"""Flat and nested options should work together."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
for k, v in sorted(kwargs.items()):
click.echo(f"{k}={v}")
result = self.runner.invoke(
cmd, ["--learning-rate=0.001", "--inner.beta=0.1", "--name=test"]
)
assert result.exit_code == 0, result.output
assert "learning_rate=0.001" in result.output
assert "inner__beta=0.1" in result.output
assert "name=test" in result.output
def test_no_nested_options_passed(self):
"""When no nested options are passed, they should not appear in kwargs."""
@click.command()
@add_options_from_config(OuterConfig)
@filter_none_kwargs
def cmd(**kwargs):
click.echo(f"keys={sorted(kwargs.keys())}")
result = self.runner.invoke(cmd, ["--learning-rate=0.01"])
assert result.exit_code == 0, result.output
assert "inner__" not in result.output
class TestLoadCfgNestedKwargs:
"""Test that load_cfg correctly applies nested (double-underscore) kwargs."""
@staticmethod
def _apply_nested_kwargs(cfg, kwargs):
"""Helper that mirrors the nested kwargs handling from load_cfg,
including type coercion for string CLI values."""
from axolotl.cli.config import _coerce_value
nested_kwargs: dict = {}
flat_kwargs: dict = {}
for key, value in kwargs.items():
if "__" in key:
parent, child = key.split("__", 1)
nested_kwargs.setdefault(parent, {})[child] = value
else:
flat_kwargs[key] = value
cfg_keys = cfg.keys()
for key, value in flat_kwargs.items():
if key in cfg_keys:
cfg[key] = _coerce_value(value, cfg.get(key))
for parent, children in nested_kwargs.items():
if cfg[parent] is None:
cfg[parent] = {}
if not isinstance(cfg[parent], dict):
cfg[parent] = {}
for child_key, child_value in children.items():
existing = cfg[parent].get(child_key)
cfg[parent][child_key] = _coerce_value(child_value, existing)
return cfg
def test_nested_kwargs_applied_to_cfg(self, tmp_path):
"""Double-underscore kwargs should set nested config values."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": {"beta": 0.1}, "learning_rate": 0.01})
# CLI passes strings, so simulate that
kwargs = {
"trl__beta": "0.5",
"trl__host": "192.168.1.1",
"learning_rate": "0.02",
}
cfg = self._apply_nested_kwargs(cfg, kwargs)
assert cfg["learning_rate"] == 0.02
assert isinstance(cfg["learning_rate"], float)
assert cfg["trl"]["beta"] == 0.5
assert isinstance(cfg["trl"]["beta"], float)
assert cfg["trl"]["host"] == "192.168.1.1"
def test_nested_kwargs_creates_parent_if_none(self):
"""If the parent key is None, nested kwargs should create the dict."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": None, "learning_rate": 0.01})
cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"})
# No existing value, YAML-style inference: "0.5" -> 0.5
assert cfg["trl"]["beta"] == 0.5
assert isinstance(cfg["trl"]["beta"], float)
def test_nested_kwargs_overwrites_string_parent(self):
"""If the parent key is a string, it should be replaced with a dict."""
from axolotl.utils.dict import DictDefault
cfg = DictDefault({"trl": "some_string", "learning_rate": 0.01})
cfg = self._apply_nested_kwargs(cfg, {"trl__beta": "0.5"})
assert cfg["trl"]["beta"] == 0.5
class TestCoerceValue:
"""Test YAML-style type coercion for CLI string values."""
def test_coerce_with_existing_float(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("0.5", 0.1) == 0.5
assert isinstance(_coerce_value("0.5", 0.1), float)
def test_coerce_with_existing_int(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("42", 10) == 42
assert isinstance(_coerce_value("42", 10), int)
def test_coerce_with_existing_bool(self):
from axolotl.cli.config import _coerce_value
assert _coerce_value("true", False) is True
assert _coerce_value("false", True) is False
assert _coerce_value("1", False) is True
assert _coerce_value("0", True) is False
def test_coerce_yaml_inference_no_existing(self):
"""Without an existing value, use YAML-style inference."""
from axolotl.cli.config import _coerce_value
assert _coerce_value("true", None) is True
assert _coerce_value("false", None) is False
assert _coerce_value("42", None) == 42
assert isinstance(_coerce_value("42", None), int)
assert _coerce_value("3.14", None) == 3.14
assert isinstance(_coerce_value("3.14", None), float)
assert _coerce_value("null", None) is None
assert _coerce_value("hello", None) == "hello"
def test_coerce_non_string_passthrough(self):
"""Non-string values should pass through unchanged."""
from axolotl.cli.config import _coerce_value
assert _coerce_value(0.5, 0.1) == 0.5
assert _coerce_value(True, False) is True