feat: support dot-notation CLI args for nested config options (#3419)
* feat: support dot-notation CLI args for nested config options Add support for overriding nested config fields (like TRL config) via CLI using dot-notation, e.g.: axolotl train grpo.yaml --trl.vllm-server-host=10.0.0.1 --trl.beta=0.1 Changes: - args.py: Detect BaseModel subclass fields and generate dot-notation CLI options (--parent.child) that map to double-underscore kwargs (parent__child). Also fix _strip_optional_type for Python 3.10+ union syntax (X | None). - config.py: Handle double-underscore kwargs in load_cfg by setting nested dict values on the config. - Add tests for nested option handling. Fixes #2702 * Address CodeRabbit review: fix string parent bug, add type hints and docstring Signed-off-by: Manas Vardhan <manasvardhan@gmail.com> * Add type coercion for CLI kwargs and fix pre-commit issues - Add _coerce_value() for YAML-style type inference on string CLI args - When existing config value has a type (int/float/bool), cast to match - When no existing value, infer type from string (true/false, ints, floats, null) - Apply coercion to both flat and nested (dot-notation) kwargs - Fix unused pytest import (pre-commit/ruff) - Update tests to pass string values (matching real CLI behavior) - Add dedicated TestCoerceValue test class Addresses maintainer feedback on type casting for nested kwargs. --------- Signed-off-by: Manas Vardhan <manasvardhan@gmail.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from tempfile import NamedTemporaryFile
|
||||
from typing import Union
|
||||
from typing import Any, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -32,6 +32,63 @@ from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
LOG = get_logger(__name__)
|
||||
|
||||
|
||||
def _coerce_value(value: Any, existing: Optional[Any] = None) -> Any:
|
||||
"""Coerce a string CLI value to its most likely Python type.
|
||||
|
||||
If an existing value is present in the config, its type is used to guide
|
||||
casting. Otherwise, YAML-style inference is applied: booleans, ints,
|
||||
floats, and None literals are recognised automatically.
|
||||
|
||||
Args:
|
||||
value: The raw value (typically a string from the CLI).
|
||||
existing: An optional existing config value whose type guides coercion.
|
||||
|
||||
Returns:
|
||||
The value cast to the inferred or expected type.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
return value
|
||||
|
||||
# If the config already has a typed value, cast to match
|
||||
if existing is not None:
|
||||
if isinstance(existing, bool):
|
||||
return value.lower() in ("true", "1", "yes")
|
||||
if isinstance(existing, int):
|
||||
try:
|
||||
return int(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
if isinstance(existing, float):
|
||||
try:
|
||||
return float(value)
|
||||
except (ValueError, TypeError):
|
||||
return value
|
||||
# For other types (str, list, dict, etc.), return as-is
|
||||
return value
|
||||
|
||||
# No existing value -- use YAML-style inference
|
||||
lower = value.lower()
|
||||
if lower in ("true", "yes"):
|
||||
return True
|
||||
if lower in ("false", "no"):
|
||||
return False
|
||||
if lower in ("null", "none", "~"):
|
||||
return None
|
||||
|
||||
# Try int then float
|
||||
try:
|
||||
return int(value)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(value)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return value
|
||||
|
||||
|
||||
API_KEY_FIELDS = {"comet_api_key"}
|
||||
|
||||
TELEMETRY_MANAGER = TelemetryManager.get_instance()
|
||||
@@ -208,13 +265,37 @@ def load_cfg(
|
||||
# If there are any options passed in the cli, if it is something that seems valid
|
||||
# from the yaml, then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
|
||||
# Separate nested (dot-notation) kwargs from flat kwargs
|
||||
nested_kwargs: dict[str, dict[str, Any]] = {}
|
||||
flat_kwargs: dict[str, Any] = {}
|
||||
for key, value in kwargs.items():
|
||||
if "__" in key:
|
||||
parent, child = key.split("__", 1)
|
||||
nested_kwargs.setdefault(parent, {})[child] = value
|
||||
else:
|
||||
flat_kwargs[key] = value
|
||||
|
||||
# Apply flat kwargs
|
||||
for key, value in flat_kwargs.items():
|
||||
# If not strict, allow writing to cfg even if it's not in the yml already
|
||||
if key in cfg_keys or not cfg.strict:
|
||||
if isinstance(cfg[key], bool):
|
||||
cfg[key] = bool(value)
|
||||
else:
|
||||
cfg[key] = value
|
||||
cfg[key] = _coerce_value(value, cfg.get(key))
|
||||
|
||||
# Apply nested kwargs (e.g., trl__beta -> cfg.trl.beta)
|
||||
for parent, children in nested_kwargs.items():
|
||||
if parent not in cfg_keys and cfg.strict:
|
||||
continue
|
||||
if cfg[parent] is None:
|
||||
cfg[parent] = {}
|
||||
if not isinstance(cfg[parent], dict):
|
||||
LOG.warning(
|
||||
"Overwriting non-dict value for '%s' with nested CLI overrides", parent
|
||||
)
|
||||
cfg[parent] = {}
|
||||
for child_key, child_value in children.items():
|
||||
existing_child = cfg[parent].get(child_key)
|
||||
cfg[parent][child_key] = _coerce_value(child_value, existing_child)
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import dataclasses
|
||||
from functools import wraps
|
||||
from types import NoneType
|
||||
from types import NoneType, UnionType
|
||||
from typing import Any, Callable, Type, Union, get_args, get_origin
|
||||
|
||||
import click
|
||||
@@ -20,7 +20,8 @@ def _strip_optional_type(field_type: type | str | None):
|
||||
If the input type is `Union[T, None]` or `Optional[T]`, returns `T`. Otherwise
|
||||
returns the input type unchanged.
|
||||
"""
|
||||
if get_origin(field_type) is Union and type(None) in get_args(field_type):
|
||||
is_union = get_origin(field_type) is Union or isinstance(field_type, UnionType)
|
||||
if is_union and type(None) in get_args(field_type):
|
||||
field_type = next(
|
||||
t for t in get_args(field_type) if not isinstance(t, NoneType)
|
||||
)
|
||||
@@ -87,10 +88,70 @@ def add_options_from_dataclass(config_class: Type[Any]) -> Callable:
|
||||
return decorator
|
||||
|
||||
|
||||
def _is_pydantic_model(field_type: type) -> bool:
|
||||
"""Check if a type is a Pydantic BaseModel subclass."""
|
||||
try:
|
||||
return isinstance(field_type, type) and issubclass(field_type, BaseModel)
|
||||
except TypeError:
|
||||
return False
|
||||
|
||||
|
||||
def _get_field_description(field) -> str | None:
|
||||
"""Get description from a Pydantic field, checking both .description and json_schema_extra."""
|
||||
if field.description:
|
||||
return field.description
|
||||
if field.json_schema_extra and isinstance(field.json_schema_extra, dict):
|
||||
return field.json_schema_extra.get("description")
|
||||
return None
|
||||
|
||||
|
||||
def _add_nested_model_options(
|
||||
function: Callable, parent_name: str, model_class: Type[BaseModel]
|
||||
) -> Callable:
|
||||
"""
|
||||
Add Click options for all fields of a nested Pydantic model using dot-notation.
|
||||
|
||||
Note: Only single-level nesting is supported (e.g., ``--trl.beta``).
|
||||
Deeper nesting (e.g., ``--trl.scheduler.warmup``) is not handled.
|
||||
|
||||
Args:
|
||||
function: Click command function to add options to.
|
||||
parent_name: Parent field name (e.g., "trl").
|
||||
model_class: Nested Pydantic model class.
|
||||
|
||||
Returns:
|
||||
Function with added Click options.
|
||||
"""
|
||||
for sub_name, sub_field in reversed(model_class.model_fields.items()):
|
||||
sub_type = _strip_optional_type(sub_field.annotation)
|
||||
# Use dot notation: --parent.sub_field
|
||||
cli_name = f"{parent_name}.{sub_name}".replace("_", "-")
|
||||
# The kwarg name uses double-underscore as separator
|
||||
param_name = f"{parent_name}__{sub_name}"
|
||||
description = _get_field_description(sub_field)
|
||||
|
||||
if sub_type is bool:
|
||||
option_name = f"--{cli_name}/--no-{cli_name}"
|
||||
function = click.option(
|
||||
option_name, param_name, default=None, help=description
|
||||
)(function)
|
||||
else:
|
||||
option_name = f"--{cli_name}"
|
||||
click_type = {str: str, int: int, float: float}.get(sub_type)
|
||||
function = click.option(
|
||||
option_name, param_name, default=None, type=click_type, help=description
|
||||
)(function)
|
||||
|
||||
return function
|
||||
|
||||
|
||||
def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
"""
|
||||
Create Click options from the fields of a Pydantic model.
|
||||
|
||||
For fields whose type is itself a Pydantic BaseModel, dot-notation CLI options are
|
||||
generated for each sub-field (e.g., ``--trl.beta=0.1``).
|
||||
|
||||
Args:
|
||||
config_class: PyDantic model with fields to parse from the CLI
|
||||
|
||||
@@ -103,6 +164,11 @@ def add_options_from_config(config_class: Type[BaseModel]) -> Callable:
|
||||
for name, field in reversed(config_class.model_fields.items()):
|
||||
field_type = _strip_optional_type(field.annotation)
|
||||
|
||||
# Handle nested Pydantic models with dot-notation options
|
||||
if _is_pydantic_model(field_type):
|
||||
function = _add_nested_model_options(function, name, field_type)
|
||||
continue
|
||||
|
||||
if field_type is bool:
|
||||
field_name = name.replace("_", "-")
|
||||
option_name = f"--{field_name}/--no-{field_name}"
|
||||
|
||||
Reference in New Issue
Block a user