fix for protected model_ namespace w pydantic (#1345)
This commit is contained in:
@@ -546,7 +546,7 @@ base_model_ignore_patterns:
|
|||||||
# You can set that here, or leave this empty to default to base_model
|
# You can set that here, or leave this empty to default to base_model
|
||||||
base_model_config: ./llama-7b-hf
|
base_model_config: ./llama-7b-hf
|
||||||
# You can specify to choose a specific model revision from huggingface hub
|
# You can specify to choose a specific model revision from huggingface hub
|
||||||
model_revision:
|
revision_of_model:
|
||||||
# Optional tokenizer configuration path in case you want to use a different tokenizer
|
# Optional tokenizer configuration path in case you want to use a different tokenizer
|
||||||
# than the one defined in the base model
|
# than the one defined in the base model
|
||||||
tokenizer_config:
|
tokenizer_config:
|
||||||
@@ -573,7 +573,7 @@ is_qwen_derived_model:
|
|||||||
is_mistral_derived_model:
|
is_mistral_derived_model:
|
||||||
|
|
||||||
# optional overrides to the base model configuration
|
# optional overrides to the base model configuration
|
||||||
model_config_overrides:
|
overrides_of_model_config:
|
||||||
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653
|
||||||
rope_scaling:
|
rope_scaling:
|
||||||
type: # linear | dynamic
|
type: # linear | dynamic
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ def normalize_config(cfg):
|
|||||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||||
or cfg.is_llama_derived_model
|
or cfg.is_llama_derived_model
|
||||||
or "llama" in cfg.base_model.lower()
|
or "llama" in cfg.base_model.lower()
|
||||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
# figure out if the model is falcon
|
# figure out if the model is falcon
|
||||||
@@ -140,7 +140,7 @@ def normalize_config(cfg):
|
|||||||
)
|
)
|
||||||
or cfg.is_falcon_derived_model
|
or cfg.is_falcon_derived_model
|
||||||
or "falcon" in cfg.base_model.lower()
|
or "falcon" in cfg.base_model.lower()
|
||||||
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
|
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg.is_mistral_derived_model = (
|
cfg.is_mistral_derived_model = (
|
||||||
@@ -153,7 +153,7 @@ def normalize_config(cfg):
|
|||||||
)
|
)
|
||||||
or cfg.is_mistral_derived_model
|
or cfg.is_mistral_derived_model
|
||||||
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
||||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
|
||||||
)
|
)
|
||||||
|
|
||||||
cfg.is_qwen_derived_model = (
|
cfg.is_qwen_derived_model = (
|
||||||
@@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
|
|||||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||||
)
|
)
|
||||||
|
|
||||||
if cfg.gptq and cfg.model_revision:
|
if cfg.gptq and cfg.revision_of_model:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"model_revision is not supported for GPTQ models. "
|
"revision_of_model is not supported for GPTQ models. "
|
||||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||||
+ "point to its path, and remove model_revision from the config."
|
+ "point to its path, and remove revision_of_model from the config."
|
||||||
)
|
)
|
||||||
|
|
||||||
# if cfg.sample_packing and cfg.sdp_attention:
|
# if cfg.sample_packing and cfg.sdp_attention:
|
||||||
|
|||||||
@@ -47,6 +47,16 @@ class DeprecatedParameters(BaseModel):
|
|||||||
return noisy_embedding_alpha
|
return noisy_embedding_alpha
|
||||||
|
|
||||||
|
|
||||||
|
class RemappedParameters(BaseModel):
|
||||||
|
"""parameters that have been remapped to other names"""
|
||||||
|
|
||||||
|
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
|
||||||
|
default=None, alias="model_config"
|
||||||
|
)
|
||||||
|
type_of_model: Optional[str] = Field(default=None, alias="model_type")
|
||||||
|
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
|
||||||
|
|
||||||
|
|
||||||
class PretrainingDataset(BaseModel):
|
class PretrainingDataset(BaseModel):
|
||||||
"""pretraining dataset configuration subset"""
|
"""pretraining dataset configuration subset"""
|
||||||
|
|
||||||
@@ -234,12 +244,8 @@ class ModelInputConfig(BaseModel):
|
|||||||
tokenizer_type: Optional[str] = Field(
|
tokenizer_type: Optional[str] = Field(
|
||||||
default=None, metadata={"help": "transformers tokenizer class"}
|
default=None, metadata={"help": "transformers tokenizer class"}
|
||||||
)
|
)
|
||||||
model_type: Optional[str] = Field(default=None)
|
|
||||||
model_revision: Optional[str] = None
|
|
||||||
trust_remote_code: Optional[bool] = None
|
trust_remote_code: Optional[bool] = None
|
||||||
|
|
||||||
model_config_overrides: Optional[Dict[str, Any]] = None
|
|
||||||
|
|
||||||
@field_validator("trust_remote_code")
|
@field_validator("trust_remote_code")
|
||||||
@classmethod
|
@classmethod
|
||||||
def hint_trust_remote_code(cls, trust_remote_code):
|
def hint_trust_remote_code(cls, trust_remote_code):
|
||||||
@@ -362,11 +368,17 @@ class AxolotlInputConfig(
|
|||||||
HyperparametersConfig,
|
HyperparametersConfig,
|
||||||
WandbConfig,
|
WandbConfig,
|
||||||
MLFlowConfig,
|
MLFlowConfig,
|
||||||
|
RemappedParameters,
|
||||||
DeprecatedParameters,
|
DeprecatedParameters,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
):
|
):
|
||||||
"""wrapper of all config options"""
|
"""wrapper of all config options"""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Config for alias"""
|
||||||
|
|
||||||
|
populate_by_name = True
|
||||||
|
|
||||||
strict: Optional[bool] = Field(default=False)
|
strict: Optional[bool] = Field(default=False)
|
||||||
resume_from_checkpoint: Optional[str] = None
|
resume_from_checkpoint: Optional[str] = None
|
||||||
auto_resume_from_checkpoints: Optional[bool] = None
|
auto_resume_from_checkpoints: Optional[bool] = None
|
||||||
@@ -550,11 +562,11 @@ class AxolotlInputConfig(
|
|||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def check_gptq_w_revision(cls, data):
|
def check_gptq_w_revision(cls, data):
|
||||||
if data.get("gptq") and data.get("model_revision"):
|
if data.get("gptq") and data.get("revision_of_model"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"model_revision is not supported for GPTQ models. "
|
"revision_of_model is not supported for GPTQ models. "
|
||||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||||
+ "point to its path, and remove model_revision from the config."
|
+ "point to its path, and remove revision_of_model from the config."
|
||||||
)
|
)
|
||||||
return data
|
return data
|
||||||
|
|
||||||
|
|||||||
@@ -86,8 +86,8 @@ def load_model_config(cfg):
|
|||||||
model_config_name = cfg.tokenizer_config
|
model_config_name = cfg.tokenizer_config
|
||||||
trust_remote_code = cfg.trust_remote_code is True
|
trust_remote_code = cfg.trust_remote_code is True
|
||||||
config_kwargs = {}
|
config_kwargs = {}
|
||||||
if cfg.model_revision:
|
if cfg.revision_of_model:
|
||||||
config_kwargs["revision"] = cfg.model_revision
|
config_kwargs["revision"] = cfg.revision_of_model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_config = AutoConfig.from_pretrained(
|
model_config = AutoConfig.from_pretrained(
|
||||||
@@ -104,8 +104,8 @@ def load_model_config(cfg):
|
|||||||
)
|
)
|
||||||
raise err
|
raise err
|
||||||
|
|
||||||
if cfg.model_config_overrides:
|
if cfg.overrides_of_model_config:
|
||||||
for key, val in cfg.model_config_overrides.items():
|
for key, val in cfg.overrides_of_model_config.items():
|
||||||
setattr(model_config, key, val)
|
setattr(model_config, key, val)
|
||||||
|
|
||||||
check_model_config(cfg, model_config)
|
check_model_config(cfg, model_config)
|
||||||
@@ -272,7 +272,7 @@ def load_model(
|
|||||||
Load a model for a given configuration and tokenizer.
|
Load a model for a given configuration and tokenizer.
|
||||||
"""
|
"""
|
||||||
base_model = cfg.base_model
|
base_model = cfg.base_model
|
||||||
model_type = cfg.model_type
|
model_type = cfg.type_of_model
|
||||||
model_config = load_model_config(cfg)
|
model_config = load_model_config(cfg)
|
||||||
|
|
||||||
# TODO refactor as a kwarg
|
# TODO refactor as a kwarg
|
||||||
@@ -426,8 +426,8 @@ def load_model(
|
|||||||
if is_deepspeed_zero3_enabled():
|
if is_deepspeed_zero3_enabled():
|
||||||
del model_kwargs["device_map"]
|
del model_kwargs["device_map"]
|
||||||
|
|
||||||
if cfg.model_revision:
|
if cfg.revision_of_model:
|
||||||
model_kwargs["revision"] = cfg.model_revision
|
model_kwargs["revision"] = cfg.revision_of_model
|
||||||
if cfg.gptq:
|
if cfg.gptq:
|
||||||
if not hasattr(model_config, "quantization_config"):
|
if not hasattr(model_config, "quantization_config"):
|
||||||
LOG.warning("model config does not contain quantization_config information")
|
LOG.warning("model config does not contain quantization_config information")
|
||||||
|
|||||||
@@ -3,6 +3,7 @@
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import warnings
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -14,6 +15,8 @@ from axolotl.utils.dict import DictDefault
|
|||||||
from axolotl.utils.models import check_model_config
|
from axolotl.utils.models import check_model_config
|
||||||
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
from axolotl.utils.wandb_ import setup_wandb_env_vars
|
||||||
|
|
||||||
|
warnings.filterwarnings("error")
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(name="minimal_cfg")
|
@pytest.fixture(name="minimal_cfg")
|
||||||
def fixture_cfg():
|
def fixture_cfg():
|
||||||
@@ -190,6 +193,45 @@ class TestValidation(BaseValidation):
|
|||||||
|
|
||||||
assert new_cfg.learning_rate == 0.00005
|
assert new_cfg.learning_rate == 0.00005
|
||||||
|
|
||||||
|
def test_model_config_remap(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"model_config": {"model_type": "mistral"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.overrides_of_model_config["model_type"] == "mistral"
|
||||||
|
|
||||||
|
def test_model_type_remap(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"model_type": "AutoModelForCausalLM",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.type_of_model == "AutoModelForCausalLM"
|
||||||
|
|
||||||
|
def test_model_revision_remap(self, minimal_cfg):
|
||||||
|
cfg = (
|
||||||
|
DictDefault(
|
||||||
|
{
|
||||||
|
"model_revision": "main",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
| minimal_cfg
|
||||||
|
)
|
||||||
|
|
||||||
|
new_cfg = validate_config(cfg)
|
||||||
|
assert new_cfg.revision_of_model == "main"
|
||||||
|
|
||||||
def test_qlora(self, minimal_cfg):
|
def test_qlora(self, minimal_cfg):
|
||||||
base_cfg = (
|
base_cfg = (
|
||||||
DictDefault(
|
DictDefault(
|
||||||
|
|||||||
Reference in New Issue
Block a user