From 6b3b271925b2b0f0c98a33cebdc90788e31ffc29 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 28 Feb 2024 15:07:49 -0500 Subject: [PATCH] fix for protected model_ namespace w pydantic (#1345) --- README.md | 4 +- src/axolotl/utils/config/__init__.py | 12 +++--- .../config/models/input/v0_4_1/__init__.py | 26 ++++++++---- src/axolotl/utils/models.py | 14 +++---- tests/test_validation.py | 42 +++++++++++++++++++ 5 files changed, 76 insertions(+), 22 deletions(-) diff --git a/README.md b/README.md index 2d03968fd..4bd496423 100644 --- a/README.md +++ b/README.md @@ -546,7 +546,7 @@ base_model_ignore_patterns: # You can set that here, or leave this empty to default to base_model base_model_config: ./llama-7b-hf # You can specify to choose a specific model revision from huggingface hub -model_revision: +revision_of_model: # Optional tokenizer configuration path in case you want to use a different tokenizer # than the one defined in the base model tokenizer_config: @@ -573,7 +573,7 @@ is_qwen_derived_model: is_mistral_derived_model: # optional overrides to the base model configuration -model_config_overrides: +overrides_of_model_config: # RoPE Scaling https://github.com/huggingface/transformers/pull/24653 rope_scaling: type: # linear | dynamic diff --git a/src/axolotl/utils/config/__init__.py b/src/axolotl/utils/config/__init__.py index 99ce27321..9151f288a 100644 --- a/src/axolotl/utils/config/__init__.py +++ b/src/axolotl/utils/config/__init__.py @@ -124,7 +124,7 @@ def normalize_config(cfg): (hasattr(model_config, "model_type") and model_config.model_type == "llama") or cfg.is_llama_derived_model or "llama" in cfg.base_model.lower() - or (cfg.model_type and "llama" in cfg.model_type.lower()) + or (cfg.type_of_model and "llama" in cfg.type_of_model.lower()) ) # figure out if the model is falcon @@ -140,7 +140,7 @@ def normalize_config(cfg): ) or cfg.is_falcon_derived_model or "falcon" in cfg.base_model.lower() - or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower()) + or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower()) ) cfg.is_mistral_derived_model = ( @@ -153,7 +153,7 @@ def normalize_config(cfg): ) or cfg.is_mistral_derived_model or "mistral" in cfg.base_model.lower().split("/")[-1] - or (cfg.model_type and "mistral" in cfg.model_type.lower()) + or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower()) ) cfg.is_qwen_derived_model = ( @@ -379,11 +379,11 @@ def legacy_validate_config(cfg): "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." ) - if cfg.gptq and cfg.model_revision: + if cfg.gptq and cfg.revision_of_model: raise ValueError( - "model_revision is not supported for GPTQ models. " + "revision_of_model is not supported for GPTQ models. " + "Please download the model from HuggingFace Hub manually for correct branch, " - + "point to its path, and remove model_revision from the config." + + "point to its path, and remove revision_of_model from the config." ) # if cfg.sample_packing and cfg.sdp_attention: diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 262e8928a..6eaf34c54 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -47,6 +47,16 @@ class DeprecatedParameters(BaseModel): return noisy_embedding_alpha +class RemappedParameters(BaseModel): + """parameters that have been remapped to other names""" + + overrides_of_model_config: Optional[Dict[str, Any]] = Field( + default=None, alias="model_config" + ) + type_of_model: Optional[str] = Field(default=None, alias="model_type") + revision_of_model: Optional[str] = Field(default=None, alias="model_revision") + + class PretrainingDataset(BaseModel): """pretraining dataset configuration subset""" @@ -234,12 +244,8 @@ class ModelInputConfig(BaseModel): tokenizer_type: Optional[str] = Field( default=None, metadata={"help": "transformers tokenizer class"} ) - model_type: Optional[str] = Field(default=None) - model_revision: Optional[str] = None trust_remote_code: Optional[bool] = None - model_config_overrides: Optional[Dict[str, Any]] = None - @field_validator("trust_remote_code") @classmethod def hint_trust_remote_code(cls, trust_remote_code): @@ -362,11 +368,17 @@ class AxolotlInputConfig( HyperparametersConfig, WandbConfig, MLFlowConfig, + RemappedParameters, DeprecatedParameters, BaseModel, ): """wrapper of all config options""" + class Config: + """Config for alias""" + + populate_by_name = True + strict: Optional[bool] = Field(default=False) resume_from_checkpoint: Optional[str] = None auto_resume_from_checkpoints: Optional[bool] = None @@ -550,11 +562,11 @@ class AxolotlInputConfig( @model_validator(mode="before") @classmethod def check_gptq_w_revision(cls, data): - if data.get("gptq") and data.get("model_revision"): + if data.get("gptq") and data.get("revision_of_model"): raise ValueError( - "model_revision is not supported for GPTQ models. " + "revision_of_model is not supported for GPTQ models. " + "Please download the model from HuggingFace Hub manually for correct branch, " - + "point to its path, and remove model_revision from the config." + + "point to its path, and remove revision_of_model from the config." ) return data diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index c94908f3d..aa2e9539b 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -86,8 +86,8 @@ def load_model_config(cfg): model_config_name = cfg.tokenizer_config trust_remote_code = cfg.trust_remote_code is True config_kwargs = {} - if cfg.model_revision: - config_kwargs["revision"] = cfg.model_revision + if cfg.revision_of_model: + config_kwargs["revision"] = cfg.revision_of_model try: model_config = AutoConfig.from_pretrained( @@ -104,8 +104,8 @@ def load_model_config(cfg): ) raise err - if cfg.model_config_overrides: - for key, val in cfg.model_config_overrides.items(): + if cfg.overrides_of_model_config: + for key, val in cfg.overrides_of_model_config.items(): setattr(model_config, key, val) check_model_config(cfg, model_config) @@ -272,7 +272,7 @@ def load_model( Load a model for a given configuration and tokenizer. """ base_model = cfg.base_model - model_type = cfg.model_type + model_type = cfg.type_of_model model_config = load_model_config(cfg) # TODO refactor as a kwarg @@ -426,8 +426,8 @@ def load_model( if is_deepspeed_zero3_enabled(): del model_kwargs["device_map"] - if cfg.model_revision: - model_kwargs["revision"] = cfg.model_revision + if cfg.revision_of_model: + model_kwargs["revision"] = cfg.revision_of_model if cfg.gptq: if not hasattr(model_config, "quantization_config"): LOG.warning("model config does not contain quantization_config information") diff --git a/tests/test_validation.py b/tests/test_validation.py index 790a4b171..70dbc750e 100644 --- a/tests/test_validation.py +++ b/tests/test_validation.py @@ -3,6 +3,7 @@ import logging import os +import warnings from typing import Optional import pytest @@ -14,6 +15,8 @@ from axolotl.utils.dict import DictDefault from axolotl.utils.models import check_model_config from axolotl.utils.wandb_ import setup_wandb_env_vars +warnings.filterwarnings("error") + @pytest.fixture(name="minimal_cfg") def fixture_cfg(): @@ -190,6 +193,45 @@ class TestValidation(BaseValidation): assert new_cfg.learning_rate == 0.00005 + def test_model_config_remap(self, minimal_cfg): + cfg = ( + DictDefault( + { + "model_config": {"model_type": "mistral"}, + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg.overrides_of_model_config["model_type"] == "mistral" + + def test_model_type_remap(self, minimal_cfg): + cfg = ( + DictDefault( + { + "model_type": "AutoModelForCausalLM", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg.type_of_model == "AutoModelForCausalLM" + + def test_model_revision_remap(self, minimal_cfg): + cfg = ( + DictDefault( + { + "model_revision": "main", + } + ) + | minimal_cfg + ) + + new_cfg = validate_config(cfg) + assert new_cfg.revision_of_model == "main" + def test_qlora(self, minimal_cfg): base_cfg = ( DictDefault(