From 7c3b428f2344da7888dfe3cc87e2ee9cab8d49f6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 1 Aug 2025 13:58:16 -0400 Subject: [PATCH] Add validation for TP with models with tied embeddings (#2999) * add validation for tp + tied embeddings models * fix logic and messaging * add additional guard for null tp size --- src/axolotl/loaders/utils.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/axolotl/loaders/utils.py b/src/axolotl/loaders/utils.py index 4b93d14ac..240e00da7 100644 --- a/src/axolotl/loaders/utils.py +++ b/src/axolotl/loaders/utils.py @@ -131,6 +131,17 @@ def check_model_config(cfg: DictDefault, model_config: PretrainedConfig): f"Please include [{lora_modules_to_save_joined}] in `lora_modules_to_save`." ) + if ( + cfg.tensor_parallel_size + and cfg.tensor_parallel_size > 1 + and hasattr(model_config, "tie_word_embeddings") + and model_config.tie_word_embeddings + ): + raise ValueError( + "Tensor parallelism is incompatible with models configured with `tie_word_embeddings` enabled. " + "Please use a model without `tie_word_embeddings`, or disable tensor parallelism." + ) + def load_model_config(cfg: DictDefault) -> PretrainedConfig | addict.Dict: """Loads and configures a model configuration from HuggingFace or local sources.