fix for protected model_ namespace w pydantic (#1345)
This commit is contained in:
@@ -546,7 +546,7 @@ base_model_ignore_patterns:
|
||||
# You can set that here, or leave this empty to default to base_model
|
||||
base_model_config: ./llama-7b-hf
|
||||
# You can specify to choose a specific model revision from huggingface hub
|
||||
model_revision:
|
||||
revision_of_model:
|
||||
# Optional tokenizer configuration path in case you want to use a different tokenizer
|
||||
# than the one defined in the base model
|
||||
tokenizer_config:
|
||||
@@ -573,7 +573,7 @@ is_qwen_derived_model:
|
||||
is_mistral_derived_model:
|
||||
|
||||
# optional overrides to the base model configuration
|
||||
model_config_overrides:
|
||||
overrides_of_model_config:
|
||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||
rope_scaling:
|
||||
type: # linear | dynamic
|
||||
|
||||
@@ -124,7 +124,7 @@ def normalize_config(cfg):
|
||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||
or cfg.is_llama_derived_model
|
||||
or "llama" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
# figure out if the model is falcon
|
||||
@@ -140,7 +140,7 @@ def normalize_config(cfg):
|
||||
)
|
||||
or cfg.is_falcon_derived_model
|
||||
or "falcon" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
cfg.is_mistral_derived_model = (
|
||||
@@ -153,7 +153,7 @@ def normalize_config(cfg):
|
||||
)
|
||||
or cfg.is_mistral_derived_model
|
||||
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
cfg.is_qwen_derived_model = (
|
||||
@@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
|
||||
@@ -47,6 +47,16 @@ class DeprecatedParameters(BaseModel):
|
||||
return noisy_embedding_alpha
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""parameters that have been remapped to other names"""
|
||||
|
||||
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="model_config"
|
||||
)
|
||||
type_of_model: Optional[str] = Field(default=None, alias="model_type")
|
||||
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
|
||||
|
||||
|
||||
class PretrainingDataset(BaseModel):
|
||||
"""pretraining dataset configuration subset"""
|
||||
|
||||
@@ -234,12 +244,8 @@ class ModelInputConfig(BaseModel):
|
||||
tokenizer_type: Optional[str] = Field(
|
||||
default=None, metadata={"help": "transformers tokenizer class"}
|
||||
)
|
||||
model_type: Optional[str] = Field(default=None)
|
||||
model_revision: Optional[str] = None
|
||||
trust_remote_code: Optional[bool] = None
|
||||
|
||||
model_config_overrides: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
def hint_trust_remote_code(cls, trust_remote_code):
|
||||
@@ -362,11 +368,17 @@ class AxolotlInputConfig(
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
RemappedParameters,
|
||||
DeprecatedParameters,
|
||||
BaseModel,
|
||||
):
|
||||
"""wrapper of all config options"""
|
||||
|
||||
class Config:
|
||||
"""Config for alias"""
|
||||
|
||||
populate_by_name = True
|
||||
|
||||
strict: Optional[bool] = Field(default=False)
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
@@ -550,11 +562,11 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gptq_w_revision(cls, data):
|
||||
if data.get("gptq") and data.get("model_revision"):
|
||||
if data.get("gptq") and data.get("revision_of_model"):
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -86,8 +86,8 @@ def load_model_config(cfg):
|
||||
model_config_name = cfg.tokenizer_config
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
config_kwargs = {}
|
||||
if cfg.model_revision:
|
||||
config_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.revision_of_model:
|
||||
config_kwargs["revision"] = cfg.revision_of_model
|
||||
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
@@ -104,8 +104,8 @@ def load_model_config(cfg):
|
||||
)
|
||||
raise err
|
||||
|
||||
if cfg.model_config_overrides:
|
||||
for key, val in cfg.model_config_overrides.items():
|
||||
if cfg.overrides_of_model_config:
|
||||
for key, val in cfg.overrides_of_model_config.items():
|
||||
setattr(model_config, key, val)
|
||||
|
||||
check_model_config(cfg, model_config)
|
||||
@@ -272,7 +272,7 @@ def load_model(
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
model_type = cfg.model_type
|
||||
model_type = cfg.type_of_model
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
@@ -426,8 +426,8 @@ def load_model(
|
||||
if is_deepspeed_zero3_enabled():
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
if cfg.model_revision:
|
||||
model_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.revision_of_model:
|
||||
model_kwargs["revision"] = cfg.revision_of_model
|
||||
if cfg.gptq:
|
||||
if not hasattr(model_config, "quantization_config"):
|
||||
LOG.warning("model config does not contain quantization_config information")
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import Optional
|
||||
|
||||
import pytest
|
||||
@@ -14,6 +15,8 @@ from axolotl.utils.dict import DictDefault
|
||||
from axolotl.utils.models import check_model_config
|
||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||
|
||||
warnings.filterwarnings("error")
|
||||
|
||||
|
||||
@pytest.fixture(name="minimal_cfg")
|
||||
def fixture_cfg():
|
||||
@@ -190,6 +193,45 @@ class TestValidation(BaseValidation):
|
||||
|
||||
assert new_cfg.learning_rate == 0.00005
|
||||
|
||||
def test_model_config_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"model_config": {"model_type": "mistral"},
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.overrides_of_model_config["model_type"] == "mistral"
|
||||
|
||||
def test_model_type_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"model_type": "AutoModelForCausalLM",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.type_of_model == "AutoModelForCausalLM"
|
||||
|
||||
def test_model_revision_remap(self, minimal_cfg):
|
||||
cfg = (
|
||||
DictDefault(
|
||||
{
|
||||
"model_revision": "main",
|
||||
}
|
||||
)
|
||||
| minimal_cfg
|
||||
)
|
||||
|
||||
new_cfg = validate_config(cfg)
|
||||
assert new_cfg.revision_of_model == "main"
|
||||
|
||||
def test_qlora(self, minimal_cfg):
|
||||
base_cfg = (
|
||||
DictDefault(
|
||||
|
||||
Reference in New Issue
Block a user