fix for protected model_ namespace w pydantic (#1345)

This commit is contained in:
Wing Lian
2024-02-28 15:07:49 -05:00
committed by GitHub
parent 3a5a2d2f34
commit 6b3b271925
5 changed files with 76 additions and 22 deletions

View File

@@ -124,7 +124,7 @@ def normalize_config(cfg):
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
or cfg.is_llama_derived_model
or "llama" in cfg.base_model.lower()
or (cfg.model_type and "llama" in cfg.model_type.lower())
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
)
# figure out if the model is falcon
@@ -140,7 +140,7 @@ def normalize_config(cfg):
)
or cfg.is_falcon_derived_model
or "falcon" in cfg.base_model.lower()
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
)
cfg.is_mistral_derived_model = (
@@ -153,7 +153,7 @@ def normalize_config(cfg):
)
or cfg.is_mistral_derived_model
or "mistral" in cfg.base_model.lower().split("/")[-1]
or (cfg.model_type and "mistral" in cfg.model_type.lower())
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
)
cfg.is_qwen_derived_model = (
@@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
)
if cfg.gptq and cfg.model_revision:
if cfg.gptq and cfg.revision_of_model:
raise ValueError(
"model_revision is not supported for GPTQ models. "
"revision_of_model is not supported for GPTQ models. "
+ "Please download the model from HuggingFace Hub manually for correct branch, "
+ "point to its path, and remove model_revision from the config."
+ "point to its path, and remove revision_of_model from the config."
)
# if cfg.sample_packing and cfg.sdp_attention: