Add validation for TP with models with tied embeddings (#2999)

* add validation for tp + tied embeddings models

* fix logic and messaging

* add additional guard for null tp size
This commit is contained in:
Wing Lian
2025-08-01 13:58:16 -04:00
committed by GitHub
parent 01a6bd1a0e
commit 7c3b428f23

View File

@@ -131,6 +131,17 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`."
)
if (
cfg.tensor_parallel_size
and cfg.tensor_parallel_size > 1
and hasattr(model_config, "tie_word_embeddings")
and model_config.tie_word_embeddings
):
raise ValueError(
"Tensor parallelism is incompatible with models configured with `tie_word_embeddings` enabled. "
"Please use a model without `tie_word_embeddings`, or disable tensor parallelism."
)
def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
"""Loads and configures a model configuration from HuggingFace or local sources.