Add validation for TP with models with tied embeddings (#2999)
* add validation for tp + tied embeddings models * fix logic and messaging * add additional guard for null tp size
This commit is contained in:
@@ -131,6 +131,17 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig):
|
||||
f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`."
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.tensor_parallel_size
|
||||
and cfg.tensor_parallel_size > 1
|
||||
and hasattr(model_config, "tie_word_embeddings")
|
||||
and model_config.tie_word_embeddings
|
||||
):
|
||||
raise ValueError(
|
||||
"Tensor parallelism is incompatible with models configured with `tie_word_embeddings` enabled. "
|
||||
"Please use a model without `tie_word_embeddings`, or disable tensor parallelism."
|
||||
)
|
||||
|
||||
|
||||
def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict:
|
||||
"""Loads and configures a model configuration from HuggingFace or local sources.
|
||||
|
||||
Reference in New Issue
Block a user