From f9bdf1fb4454423169bdaa989098812bbb4f04f5 Mon Sep 17 00:00:00 2001 From: mhenrhcsen Date: Wed, 16 Jul 2025 21:23:25 +0200 Subject: [PATCH] checks --- src/axolotl/loaders/model.py | 74 +++++++++++++++++++++++++++--------- 1 file changed, 55 insertions(+), 19 deletions(-) diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 3d11601ba..fbf1117d8 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -216,15 +216,27 @@ class ModelLoader: def _resize_token_embeddings(self): """Resize token embeddings if needed.""" + # Skip if model doesn't have the necessary methods + if not hasattr(self.model, "get_input_embeddings"): + LOG.warning("Model does not have get_input_embeddings method, skipping token embedding resize") + return + + # Check if get_input_embeddings returns None + input_embeddings = self.model.get_input_embeddings() + if input_embeddings is None: + LOG.warning("Model's get_input_embeddings returned None, skipping token embedding resize") + return + embeddings_len = ( math.ceil(len(self.tokenizer) / 32) * 32 if self.cfg.resize_token_embeddings_to_32x else len(self.tokenizer) ) - if hasattr(self.model, "get_input_embeddings") and ( - self.model.get_input_embeddings().num_embeddings < embeddings_len + + if hasattr(input_embeddings, "num_embeddings") and ( + input_embeddings.num_embeddings < embeddings_len or ( - self.model.get_input_embeddings().num_embeddings > embeddings_len + input_embeddings.num_embeddings > embeddings_len and self.cfg.shrink_embeddings ) ): @@ -233,15 +245,24 @@ class ModelLoader: self.model_config.model_type != "llava" ): resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings - self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) + + if hasattr(self.model, "resize_token_embeddings"): + self.model.resize_token_embeddings(embeddings_len, **resize_kwargs) + else: + LOG.warning("Model does not have resize_token_embeddings method, skipping resize") else: if hasattr(self.model, "tie_weights"): self.model.tie_weights() def _adjust_model_config(self): + # Skip if model doesn't have config attribute + if not hasattr(self.model, "config"): + LOG.warning("Model does not have config attribute, skipping model config adjustments") + return + + # Handle max_position_embeddings if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "max_position_embeddings") + hasattr(self.model.config, "max_position_embeddings") and self.model.config.max_position_embeddings and self.cfg.sequence_len > self.model.config.max_position_embeddings ): @@ -251,17 +272,17 @@ class ModelLoader: ) self.model.config.max_position_embeddings = self.cfg.sequence_len + # Handle bos_token_id if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "bos_token_id") + hasattr(self.model.config, "bos_token_id") and self.model.config.bos_token_id and self.model.config.bos_token_id != self.tokenizer.bos_token_id ): self.model.config.bos_token_id = self.tokenizer.bos_token_id + # Handle eos_token_id if ( - hasattr(self.model, "config") - and hasattr(self.model.config, "eos_token_id") + hasattr(self.model.config, "eos_token_id") and self.model.config.eos_token_id and self.model.config.eos_token_id != self.tokenizer.eos_token_id ): @@ -293,9 +314,12 @@ class ModelLoader: if self.cfg.adapter in ["lora", "qlora"]: needs_fa2_dtype = True if self.cfg.gradient_checkpointing: - self.model.gradient_checkpointing_enable( - gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs - ) + if hasattr(self.model, "gradient_checkpointing_enable"): + self.model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs + ) + else: + LOG.warning("Model does not have gradient_checkpointing_enable method, skipping gradient checkpointing") self._prepare_model_for_quantization() @@ -372,11 +396,14 @@ class ModelLoader: self.model.is_parallelizable = True self.model.model_parallel = True - if not any( - param.requires_grad - for _, param in self.model.named_parameters(recurse=True) - ): - LOG.warning("There are no parameters that require gradient updates") + if hasattr(self.model, "named_parameters"): + if not any( + param.requires_grad + for _, param in self.model.named_parameters(recurse=True) + ): + LOG.warning("There are no parameters that require gradient updates") + else: + LOG.warning("Model does not have named_parameters attribute, skipping gradient check") if self.cfg.flash_optimum: from optimum.bettertransformer import BetterTransformer @@ -384,7 +411,10 @@ class ModelLoader: self.model = BetterTransformer.transform(self.model) if self.cfg.adapter is not None: - log_gpu_memory_usage(LOG, "after adapters", self.model.device) + if hasattr(self.model, "device"): + log_gpu_memory_usage(LOG, "after adapters", self.model.device) + else: + LOG.warning("Model does not have device attribute, skipping memory usage logging") for _ in range(3): gc.collect() @@ -792,6 +822,12 @@ class ModelLoader: dest = {"dtype": dist_dtype} if self.cfg.lora_on_cpu: dest["device"] = "cpu" + + # Check if the model has named_modules attribute + if not hasattr(self.model, "named_modules"): + LOG.warning("Model does not have named_modules attribute, skipping embedding dtype conversion") + return + for name, module in self.model.named_modules(): if "norm" in name: module.to(dist_dtype)