checks
This commit is contained in:
@@ -216,15 +216,27 @@ class ModelLoader:
|
||||
|
||||
def _resize_token_embeddings(self):
|
||||
"""Resize token embeddings if needed."""
|
||||
# Skip if model doesn't have the necessary methods
|
||||
if not hasattr(self.model, "get_input_embeddings"):
|
||||
LOG.warning("Model does not have get_input_embeddings method, skipping token embedding resize")
|
||||
return
|
||||
|
||||
# Check if get_input_embeddings returns None
|
||||
input_embeddings = self.model.get_input_embeddings()
|
||||
if input_embeddings is None:
|
||||
LOG.warning("Model's get_input_embeddings returned None, skipping token embedding resize")
|
||||
return
|
||||
|
||||
embeddings_len = (
|
||||
math.ceil(len(self.tokenizer) / 32) * 32
|
||||
if self.cfg.resize_token_embeddings_to_32x
|
||||
else len(self.tokenizer)
|
||||
)
|
||||
if hasattr(self.model, "get_input_embeddings") and (
|
||||
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
||||
|
||||
if hasattr(input_embeddings, "num_embeddings") and (
|
||||
input_embeddings.num_embeddings < embeddings_len
|
||||
or (
|
||||
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
||||
input_embeddings.num_embeddings > embeddings_len
|
||||
and self.cfg.shrink_embeddings
|
||||
)
|
||||
):
|
||||
@@ -233,15 +245,24 @@ class ModelLoader:
|
||||
self.model_config.model_type != "llava"
|
||||
):
|
||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||
|
||||
if hasattr(self.model, "resize_token_embeddings"):
|
||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||
else:
|
||||
LOG.warning("Model does not have resize_token_embeddings method, skipping resize")
|
||||
else:
|
||||
if hasattr(self.model, "tie_weights"):
|
||||
self.model.tie_weights()
|
||||
|
||||
def _adjust_model_config(self):
|
||||
# Skip if model doesn't have config attribute
|
||||
if not hasattr(self.model, "config"):
|
||||
LOG.warning("Model does not have config attribute, skipping model config adjustments")
|
||||
return
|
||||
|
||||
# Handle max_position_embeddings
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "max_position_embeddings")
|
||||
hasattr(self.model.config, "max_position_embeddings")
|
||||
and self.model.config.max_position_embeddings
|
||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||
):
|
||||
@@ -251,17 +272,17 @@ class ModelLoader:
|
||||
)
|
||||
self.model.config.max_position_embeddings = self.cfg.sequence_len
|
||||
|
||||
# Handle bos_token_id
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "bos_token_id")
|
||||
hasattr(self.model.config, "bos_token_id")
|
||||
and self.model.config.bos_token_id
|
||||
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
|
||||
):
|
||||
self.model.config.bos_token_id = self.tokenizer.bos_token_id
|
||||
|
||||
# Handle eos_token_id
|
||||
if (
|
||||
hasattr(self.model, "config")
|
||||
and hasattr(self.model.config, "eos_token_id")
|
||||
hasattr(self.model.config, "eos_token_id")
|
||||
and self.model.config.eos_token_id
|
||||
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
|
||||
):
|
||||
@@ -293,9 +314,12 @@ class ModelLoader:
|
||||
if self.cfg.adapter in ["lora", "qlora"]:
|
||||
needs_fa2_dtype = True
|
||||
if self.cfg.gradient_checkpointing:
|
||||
self.model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
if hasattr(self.model, "gradient_checkpointing_enable"):
|
||||
self.model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
|
||||
)
|
||||
else:
|
||||
LOG.warning("Model does not have gradient_checkpointing_enable method, skipping gradient checkpointing")
|
||||
|
||||
self._prepare_model_for_quantization()
|
||||
|
||||
@@ -372,11 +396,14 @@ class ModelLoader:
|
||||
self.model.is_parallelizable = True
|
||||
self.model.model_parallel = True
|
||||
|
||||
if not any(
|
||||
param.requires_grad
|
||||
for _, param in self.model.named_parameters(recurse=True)
|
||||
):
|
||||
LOG.warning("There are no parameters that require gradient updates")
|
||||
if hasattr(self.model, "named_parameters"):
|
||||
if not any(
|
||||
param.requires_grad
|
||||
for _, param in self.model.named_parameters(recurse=True)
|
||||
):
|
||||
LOG.warning("There are no parameters that require gradient updates")
|
||||
else:
|
||||
LOG.warning("Model does not have named_parameters attribute, skipping gradient check")
|
||||
|
||||
if self.cfg.flash_optimum:
|
||||
from optimum.bettertransformer import BetterTransformer
|
||||
@@ -384,7 +411,10 @@ class ModelLoader:
|
||||
self.model = BetterTransformer.transform(self.model)
|
||||
|
||||
if self.cfg.adapter is not None:
|
||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||
if hasattr(self.model, "device"):
|
||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||
else:
|
||||
LOG.warning("Model does not have device attribute, skipping memory usage logging")
|
||||
|
||||
for _ in range(3):
|
||||
gc.collect()
|
||||
@@ -792,6 +822,12 @@ class ModelLoader:
|
||||
dest = {"dtype": dist_dtype}
|
||||
if self.cfg.lora_on_cpu:
|
||||
dest["device"] = "cpu"
|
||||
|
||||
# Check if the model has named_modules attribute
|
||||
if not hasattr(self.model, "named_modules"):
|
||||
LOG.warning("Model does not have named_modules attribute, skipping embedding dtype conversion")
|
||||
return
|
||||
|
||||
for name, module in self.model.named_modules():
|
||||
if "norm" in name:
|
||||
module.to(dist_dtype)
|
||||
|
||||
Reference in New Issue
Block a user