fix for protected model_ namespace w pydantic (#1345)

This commit is contained in:
Wing Lian
2024-02-28 15:07:49 -05:00
committed by GitHub
parent 3a5a2d2f34
commit 6b3b271925
5 changed files with 76 additions and 22 deletions

View File

@@ -124,7 +124,7 @@ def normalize_config(cfg):
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
)
# figure out if the model is falcon
@@ -140,7 +140,7 @@ def normalize_config(cfg):
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
)
cfg.is_mistral_derived_model = (
@@ -153,7 +153,7 @@ def normalize_config(cfg):
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower().split("/")[-1]
or (cfg.model_type and "mistral" in cfg.model_type.lower())
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
)
cfg.is_qwen_derived_model = (
@@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
if cfg.gptq and cfg.model_revision:
if cfg.gptq and cfg.revision_of_model:
raise ValueError(
"model_revision is not supported for GPTQ models. "
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
+ "point to its path, and remove revision_of_model from the config."
)
# if cfg.sample_packing and cfg.sdp_attention:

View File

@@ -47,6 +47,16 @@ class DeprecatedParameters(BaseModel):
return noisy_embedding_alpha
class RemappedParameters(BaseModel):
"""parameters that have been remapped to other names"""
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
default=None, alias="model_config"
)
type_of_model: Optional[str] = Field(default=None, alias="model_type")
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
class PretrainingDataset(BaseModel):
"""pretraining dataset configuration subset"""
@@ -234,12 +244,8 @@ class ModelInputConfig(BaseModel):
tokenizer_type: Optional[str] = Field(
default=None, metadata={"help": "transformers tokenizer class"}
)
model_type: Optional[str] = Field(default=None)
model_revision: Optional[str] = None
trust_remote_code: Optional[bool] = None
model_config_overrides: Optional[Dict[str, Any]] = None
@field_validator("trust_remote_code")
@classmethod
def hint_trust_remote_code(cls, trust_remote_code):
@@ -362,11 +368,17 @@ class AxolotlInputConfig(
HyperparametersConfig,
WandbConfig,
MLFlowConfig,
RemappedParameters,
DeprecatedParameters,
BaseModel,
):
"""wrapper of all config options"""
class Config:
"""Config for alias"""
populate_by_name = True
strict: Optional[bool] = Field(default=False)
resume_from_checkpoint: Optional[str] = None
auto_resume_from_checkpoints: Optional[bool] = None
@@ -550,11 +562,11 @@ class AxolotlInputConfig(
@model_validator(mode="before")
@classmethod
def check_gptq_w_revision(cls, data):
if data.get("gptq") and data.get("model_revision"):
if data.get("gptq") and data.get("revision_of_model"):
raise ValueError(
"model_revision is not supported for GPTQ models. "
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
+ "point to its path, and remove revision_of_model from the config."
)
return data

View File

@@ -86,8 +86,8 @@ def load_model_config(cfg):
model_config_name = cfg.tokenizer_config
trust_remote_code = cfg.trust_remote_code is True
config_kwargs = {}
if cfg.model_revision:
config_kwargs["revision"] = cfg.model_revision
if cfg.revision_of_model:
config_kwargs["revision"] = cfg.revision_of_model
try:
model_config = AutoConfig.from_pretrained(
@@ -104,8 +104,8 @@ def load_model_config(cfg):
)
raise err
if cfg.model_config_overrides:
for key, val in cfg.model_config_overrides.items():
if cfg.overrides_of_model_config:
for key, val in cfg.overrides_of_model_config.items():
setattr(model_config, key, val)
check_model_config(cfg, model_config)
@@ -272,7 +272,7 @@ def load_model(
Load a model for a given configuration and tokenizer.
"""
base_model = cfg.base_model
model_type = cfg.model_type
model_type = cfg.type_of_model
model_config = load_model_config(cfg)
# TODO refactor as a kwarg
@@ -426,8 +426,8 @@ def load_model(
if is_deepspeed_zero3_enabled():
del model_kwargs["device_map"]
if cfg.model_revision:
model_kwargs["revision"] = cfg.model_revision
if cfg.revision_of_model:
model_kwargs["revision"] = cfg.revision_of_model
if cfg.gptq:
if not hasattr(model_config, "quantization_config"):
LOG.warning("model config does not contain quantization_config information")