fix for protected model_ namespace w pydantic (#1345)
This commit is contained in:
@@ -124,7 +124,7 @@ def normalize_config(cfg):
|
||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||
or cfg.is_llama_derived_model
|
||||
or "llama" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
# figure out if the model is falcon
|
||||
@@ -140,7 +140,7 @@ def normalize_config(cfg):
|
||||
)
|
||||
or cfg.is_falcon_derived_model
|
||||
or "falcon" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
cfg.is_mistral_derived_model = (
|
||||
@@ -153,7 +153,7 @@ def normalize_config(cfg):
|
||||
)
|
||||
or cfg.is_mistral_derived_model
|
||||
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
cfg.is_qwen_derived_model = (
|
||||
@@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
|
||||
@@ -47,6 +47,16 @@ class DeprecatedParameters(BaseModel):
|
||||
return noisy_embedding_alpha
|
||||
|
||||
|
||||
class RemappedParameters(BaseModel):
|
||||
"""parameters that have been remapped to other names"""
|
||||
|
||||
overrides_of_model_config: Optional[Dict[str, Any]] = Field(
|
||||
default=None, alias="model_config"
|
||||
)
|
||||
type_of_model: Optional[str] = Field(default=None, alias="model_type")
|
||||
revision_of_model: Optional[str] = Field(default=None, alias="model_revision")
|
||||
|
||||
|
||||
class PretrainingDataset(BaseModel):
|
||||
"""pretraining dataset configuration subset"""
|
||||
|
||||
@@ -234,12 +244,8 @@ class ModelInputConfig(BaseModel):
|
||||
tokenizer_type: Optional[str] = Field(
|
||||
default=None, metadata={"help": "transformers tokenizer class"}
|
||||
)
|
||||
model_type: Optional[str] = Field(default=None)
|
||||
model_revision: Optional[str] = None
|
||||
trust_remote_code: Optional[bool] = None
|
||||
|
||||
model_config_overrides: Optional[Dict[str, Any]] = None
|
||||
|
||||
@field_validator("trust_remote_code")
|
||||
@classmethod
|
||||
def hint_trust_remote_code(cls, trust_remote_code):
|
||||
@@ -362,11 +368,17 @@ class AxolotlInputConfig(
|
||||
HyperparametersConfig,
|
||||
WandbConfig,
|
||||
MLFlowConfig,
|
||||
RemappedParameters,
|
||||
DeprecatedParameters,
|
||||
BaseModel,
|
||||
):
|
||||
"""wrapper of all config options"""
|
||||
|
||||
class Config:
|
||||
"""Config for alias"""
|
||||
|
||||
populate_by_name = True
|
||||
|
||||
strict: Optional[bool] = Field(default=False)
|
||||
resume_from_checkpoint: Optional[str] = None
|
||||
auto_resume_from_checkpoints: Optional[bool] = None
|
||||
@@ -550,11 +562,11 @@ class AxolotlInputConfig(
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_gptq_w_revision(cls, data):
|
||||
if data.get("gptq") and data.get("model_revision"):
|
||||
if data.get("gptq") and data.get("revision_of_model"):
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
return data
|
||||
|
||||
|
||||
@@ -86,8 +86,8 @@ def load_model_config(cfg):
|
||||
model_config_name = cfg.tokenizer_config
|
||||
trust_remote_code = cfg.trust_remote_code is True
|
||||
config_kwargs = {}
|
||||
if cfg.model_revision:
|
||||
config_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.revision_of_model:
|
||||
config_kwargs["revision"] = cfg.revision_of_model
|
||||
|
||||
try:
|
||||
model_config = AutoConfig.from_pretrained(
|
||||
@@ -104,8 +104,8 @@ def load_model_config(cfg):
|
||||
)
|
||||
raise err
|
||||
|
||||
if cfg.model_config_overrides:
|
||||
for key, val in cfg.model_config_overrides.items():
|
||||
if cfg.overrides_of_model_config:
|
||||
for key, val in cfg.overrides_of_model_config.items():
|
||||
setattr(model_config, key, val)
|
||||
|
||||
check_model_config(cfg, model_config)
|
||||
@@ -272,7 +272,7 @@ def load_model(
|
||||
Load a model for a given configuration and tokenizer.
|
||||
"""
|
||||
base_model = cfg.base_model
|
||||
model_type = cfg.model_type
|
||||
model_type = cfg.type_of_model
|
||||
model_config = load_model_config(cfg)
|
||||
|
||||
# TODO refactor as a kwarg
|
||||
@@ -426,8 +426,8 @@ def load_model(
|
||||
if is_deepspeed_zero3_enabled():
|
||||
del model_kwargs["device_map"]
|
||||
|
||||
if cfg.model_revision:
|
||||
model_kwargs["revision"] = cfg.model_revision
|
||||
if cfg.revision_of_model:
|
||||
model_kwargs["revision"] = cfg.revision_of_model
|
||||
if cfg.gptq:
|
||||
if not hasattr(model_config, "quantization_config"):
|
||||
LOG.warning("model config does not contain quantization_config information")
|
||||
|
||||
Reference in New Issue
Block a user