Pydantic 2.x cfg (#1239)
* WIP conversion to use pydantic for config validation * wip, more fields, add capabilities * wip * update pydantic validation to match existing tests * tweak requirements * setup deprecated paams pydantic model * more validations * wrap up rest of the validations * flesh out the rest of the options from the readme into pydantic * fix model validators as class methods remember to return in validator missing return add missing relora attributes fix test for DictDefault change fix sys template for mistral from fastchat change in PR 2872 fix test for batch size warning * more missing attributes for cfg * updates from PR feedback * fix validation for datasets and pretrain datasets * fix test for lora check
This commit is contained in:
@@ -24,11 +24,13 @@ from art import text2art
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import LocalTokenNotFoundError
|
||||
from transformers import GenerationConfig, TextIteratorStreamer, TextStreamer
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
|
||||
from axolotl.common.cli import TrainerCliArgs, load_model_and_tokenizer
|
||||
from axolotl.logging_config import configure_logging
|
||||
from axolotl.train import TrainDatasetMeta
|
||||
from axolotl.utils.config import (
|
||||
GPUCapabilities,
|
||||
normalize_cfg_datasets,
|
||||
normalize_config,
|
||||
validate_config,
|
||||
@@ -328,7 +330,6 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
# load the config from the yaml file
|
||||
with open(config, encoding="utf-8") as file:
|
||||
cfg: DictDefault = DictDefault(yaml.safe_load(file))
|
||||
cfg.axolotl_config_path = config
|
||||
# if there are any options passed in the cli, if it is something that seems valid from the yaml,
|
||||
# then overwrite the value
|
||||
cfg_keys = cfg.keys()
|
||||
@@ -341,7 +342,21 @@ def load_cfg(config: Union[str, Path] = Path("examples/"), **kwargs):
|
||||
else:
|
||||
cfg[k] = kwargs[k]
|
||||
|
||||
validate_config(cfg)
|
||||
cfg.axolotl_config_path = config
|
||||
|
||||
try:
|
||||
device_props = torch.cuda.get_device_properties("cuda")
|
||||
gpu_version = "sm_" + str(device_props.major) + str(device_props.minor)
|
||||
except: # pylint: disable=bare-except # noqa: E722
|
||||
gpu_version = None
|
||||
|
||||
capabilities = GPUCapabilities(
|
||||
bf16=is_torch_bf16_gpu_available(),
|
||||
n_gpu=os.environ.get("WORLD_SIZE", 1),
|
||||
compute_capability=gpu_version,
|
||||
)
|
||||
|
||||
cfg = validate_config(cfg, capabilities=capabilities)
|
||||
|
||||
prepare_optim_env(cfg)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user