Add validation for TP with models with tied embeddings (#2999)
* add validation for tp + tied embeddings models * fix logic and messaging * add additional guard for null tp size
This commit is contained in:
@@ -131,6 +131,17 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
|||||||
f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`."
|
f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
cfg.tensor_parallel_size
|
||||||
|
and cfg.tensor_parallel_size > 1
|
||||||
|
and hasattr(model_config, "tie_word_embeddings")
|
||||||
|
and model_config.tie_word_embeddings
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Tensor parallelism is incompatible with models configured with `tie_word_embeddings` enabled. "
|
||||||
|
"Please use a model without `tie_word_embeddings`, or disable tensor parallelism."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
||||||
"""Loads and configures a model configuration from HuggingFace or local sources.
|
"""Loads and configures a model configuration from HuggingFace or local sources.
|
||||||
|
|||||||
Reference in New Issue
Block a user