fix for protected model_ namespace w pydantic (#1345)

This commit is contained in:
Wing Lian
2024-02-28 15:07:49 -05:00
committed by GitHub
parent 3a5a2d2f34
commit 6b3b271925
5 changed files with 76 additions and 22 deletions

View File

@@ -546,7 +546,7 @@ base_model_ignore_patterns:
# You can set that here, or leave this empty to default to base_model # You can set that here, or leave this empty to default to base_model
base_model_config: ./llama-7b-hf base_model_config: ./llama-7b-hf
# You can specify to choose a specific model revision from huggingface hub # You can specify to choose a specific model revision from huggingface hub
model_revision: revision_of_model:
# Optional tokenizer configuration path in case you want to use a different tokenizer # Optional tokenizer configuration path in case you want to use a different tokenizer
# than the one defined in the base model # than the one defined in the base model
tokenizer_config: tokenizer_config:
@@ -573,7 +573,7 @@ is_qwen_derived_model:
is_mistral_derived_model: is_mistral_derived_model:
# optional overrides to the base model configuration # optional overrides to the base model configuration
model_config_overrides: overrides_of_model_config:
# RoPE Scaling https://github.com/huggingface/transformers/pull/24653 # RoPE Scaling https://github.com/huggingface/transformers/pull/24653
rope_scaling: rope_scaling:
type: # linear | dynamic type: # linear | dynamic

View File

@@ -124,7 +124,7 @@ def normalize_config(cfg):
(hasattr(model_config, "model_type") and model_config.model_type == "llama") (hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower() or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower()) or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
) )
# figure out if the model is falcon # figure out if the model is falcon
@@ -140,7 +140,7 @@ def normalize_config(cfg):
) )
or cfg.is_falcon_derived_model or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower() or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower()) or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
) )
cfg.is_mistral_derived_model = ( cfg.is_mistral_derived_model = (
@@ -153,7 +153,7 @@ def normalize_config(cfg):
) )
or cfg.is_mistral_derived_model or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower().split("/")[-1] or "mistral" in cfg.base_model.lower().split("/")[-1]
or (cfg.model_type and "mistral" in cfg.model_type.lower()) or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
) )
cfg.is_qwen_derived_model = ( cfg.is_qwen_derived_model = (
@@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch." "hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
) )
if cfg.gptq and cfg.model_revision: if cfg.gptq and cfg.revision_of_model:
raise ValueError( raise ValueError(
"model_revision is not supported for GPTQ models. " "revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, " + "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config." + "point to its path, and remove revision_of_model from the config."
) )
# if cfg.sample_packing and cfg.sdp_attention: # if cfg.sample_packing and cfg.sdp_attention:

View File

@@ -47,6 +47,16 @@ class DeprecatedParameters(BaseModel):
return noisy_embedding_alpha return noisy_embedding_alpha
class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names"""
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
default=None, alias="model_config"
)
type_of_model: Optional[str] = Field(default=None, alias="model_type")
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
class PretrainingDataset(BaseModel): class PretrainingDataset(BaseModel):
"""pretraining dataset configuration subset""" """pretraining dataset configuration subset"""
@@ -234,12 +244,8 @@ class ModelInputConfig(BaseModel):
tokenizer_type: Optional[str] = Field( tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"} default=None, metadata={"help": "transformers tokenizer class"}
) )
model_type: Optional[str] = Field(default=None)
model_revision: Optional[str] = None
trust_remote_code: Optional[bool] = None trust_remote_code: Optional[bool] = None
model_config_overrides: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code") @field_validator("trust_remote_code")
@classmethod @classmethod
def hint_trust_remote_code(cls, trust_remote_code): def hint_trust_remote_code(cls, trust_remote_code):
@@ -362,11 +368,17 @@ class AxolotlInputConfig(
HyperparametersConfig, HyperparametersConfig,
WandbConfig, WandbConfig,
MLFlowConfig, MLFlowConfig,
RemappedParameters,
DeprecatedParameters, DeprecatedParameters,
BaseModel, BaseModel,
): ):
"""wrapper of all config options""" """wrapper of all config options"""
class Config:
"""Config for alias"""
populate_by_name = True
strict: Optional[bool] = Field(default=False) strict: Optional[bool] = Field(default=False)
resume_from_checkpoint: Optional[str] = None resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None auto_resume_from_checkpoints: Optional[bool] = None
@@ -550,11 +562,11 @@ class AxolotlInputConfig(
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def check_gptq_w_revision(cls, data): def check_gptq_w_revision(cls, data):
if data.get("gptq") and data.get("model_revision"): if data.get("gptq") and data.get("revision_of_model"):
raise ValueError( raise ValueError(
"model_revision is not supported for GPTQ models. " "revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, " + "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config." + "point to its path, and remove revision_of_model from the config."
) )
return data return data

View File

@@ -86,8 +86,8 @@ def load_model_config(cfg):
model_config_name = cfg.tokenizer_config model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {} config_kwargs = {}
if cfg.model_revision: if cfg.revision_of_model:
config_kwargs["revision"] = cfg.model_revision config_kwargs["revision"] = cfg.revision_of_model
try: try:
model_config = AutoConfig.from_pretrained( model_config = AutoConfig.from_pretrained(
@@ -104,8 +104,8 @@ def load_model_config(cfg):
) )
raise err raise err
if cfg.model_config_overrides: if cfg.overrides_of_model_config:
for key, val in cfg.model_config_overrides.items(): for key, val in cfg.overrides_of_model_config.items():
setattr(model_config, key, val) setattr(model_config, key, val)
check_model_config(cfg, model_config) check_model_config(cfg, model_config)
@@ -272,7 +272,7 @@ def load_model(
Load a model for a given configuration and tokenizer. Load a model for a given configuration and tokenizer.
""" """
base_model = cfg.base_model base_model = cfg.base_model
model_type = cfg.model_type model_type = cfg.type_of_model
model_config = load_model_config(cfg) model_config = load_model_config(cfg)
# TODO refactor as a kwarg # TODO refactor as a kwarg
@@ -426,8 +426,8 @@ def load_model(
if is_deepspeed_zero3_enabled(): if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"] del model_kwargs["device_map"]
if cfg.model_revision: if cfg.revision_of_model:
model_kwargs["revision"] = cfg.model_revision model_kwargs["revision"] = cfg.revision_of_model
if cfg.gptq: if cfg.gptq:
if not hasattr(model_config, "quantization_config"): if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information") LOG.warning("model config does not contain quantization_config information")

View File

@@ -3,6 +3,7 @@
import logging import logging
import os import os
import warnings
from typing import Optional from typing import Optional
import pytest import pytest
@@ -14,6 +15,8 @@ from axolotl.utils.dict import DictDefault
from axolotl.utils.models import check_model_config from axolotl.utils.models import check_model_config
from axolotl.utils.wandb_ import setup_wandb_env_vars from axolotl.utils.wandb_ import setup_wandb_env_vars
warnings.filterwarnings("error")
@pytest.fixture(name="minimal_cfg") @pytest.fixture(name="minimal_cfg")
def fixture_cfg(): def fixture_cfg():
@@ -190,6 +193,45 @@ class TestValidation(BaseValidation):
assert new_cfg.learning_rate == 0.00005 assert new_cfg.learning_rate == 0.00005
def test_model_config_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_config": {"model_type": "mistral"},
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.overrides_of_model_config["model_type"] == "mistral"
def test_model_type_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_type": "AutoModelForCausalLM",
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.type_of_model == "AutoModelForCausalLM"
def test_model_revision_remap(self, minimal_cfg):
cfg = (
DictDefault(
{
"model_revision": "main",
}
)
| minimal_cfg
)
new_cfg = validate_config(cfg)
assert new_cfg.revision_of_model == "main"
def test_qlora(self, minimal_cfg): def test_qlora(self, minimal_cfg):
base_cfg = ( base_cfg = (
DictDefault( DictDefault(