fix for protected model_ namespace w pydantic (#1345)
This commit is contained in:
@@ -124,7 +124,7 @@ def normalize_config(cfg):
|
||||
(hasattr(model_config, "model_type") and model_config.model_type == "llama")
|
||||
or cfg.is_llama_derived_model
|
||||
or "llama" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "llama" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "llama" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
# figure out if the model is falcon
|
||||
@@ -140,7 +140,7 @@ def normalize_config(cfg):
|
||||
)
|
||||
or cfg.is_falcon_derived_model
|
||||
or "falcon" in cfg.base_model.lower()
|
||||
or (cfg.model_type and "rwforcausallm" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "rwforcausallm" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
cfg.is_mistral_derived_model = (
|
||||
@@ -153,7 +153,7 @@ def normalize_config(cfg):
|
||||
)
|
||||
or cfg.is_mistral_derived_model
|
||||
or "mistral" in cfg.base_model.lower().split("/")[-1]
|
||||
or (cfg.model_type and "mistral" in cfg.model_type.lower())
|
||||
or (cfg.type_of_model and "mistral" in cfg.type_of_model.lower())
|
||||
)
|
||||
|
||||
cfg.is_qwen_derived_model = (
|
||||
@@ -379,11 +379,11 @@ def legacy_validate_config(cfg):
|
||||
"hub_model_id is set without any models being saved. To save a model, set either save_steps or saves_per_epoch."
|
||||
)
|
||||
|
||||
if cfg.gptq and cfg.model_revision:
|
||||
if cfg.gptq and cfg.revision_of_model:
|
||||
raise ValueError(
|
||||
"model_revision is not supported for GPTQ models. "
|
||||
"revision_of_model is not supported for GPTQ models. "
|
||||
+ "Please download the model from HuggingFace Hub manually for correct branch, "
|
||||
+ "point to its path, and remove model_revision from the config."
|
||||
+ "point to its path, and remove revision_of_model from the config."
|
||||
)
|
||||
|
||||
# if cfg.sample_packing and cfg.sdp_attention:
|
||||
|
||||
Reference in New Issue
Block a user