Compare commits
8 Commits
coderabbit
...
fix/granit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
380921ee56 | ||
|
|
6e71819560 | ||
|
|
ea234afa8a | ||
|
|
738adb2258 | ||
|
|
f40e8caa28 | ||
|
|
f9bdf1fb44 | ||
|
|
2f670a5988 | ||
|
|
84ad69afad |
@@ -163,6 +163,15 @@ class ModelLoader:
|
|||||||
# Build the model
|
# Build the model
|
||||||
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
PLUGIN_MANAGER.pre_model_load(self.cfg)
|
||||||
skip_move_to_device = self._build_model()
|
skip_move_to_device = self._build_model()
|
||||||
|
|
||||||
|
# Check if the model is a GraniteConfig object
|
||||||
|
if hasattr(self, 'model') and self.model.__class__.__name__ == "GraniteConfig":
|
||||||
|
LOG.error("The model loaded is a GraniteConfig object, not a proper model.")
|
||||||
|
LOG.error("This is likely because the model type 'GraniteConfig' is not supported.")
|
||||||
|
LOG.error("Please use a different model type or ensure the model is properly configured.")
|
||||||
|
LOG.error("Setting trust_remote_code=True might help if the model requires custom code.")
|
||||||
|
raise ValueError("Model loaded is a GraniteConfig object, not a proper model. Use a supported model type or set trust_remote_code=True.")
|
||||||
|
|
||||||
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
PLUGIN_MANAGER.post_model_build(self.cfg, self.model)
|
||||||
|
|
||||||
# Post-build model configuration
|
# Post-build model configuration
|
||||||
@@ -216,15 +225,27 @@ class ModelLoader:
|
|||||||
|
|
||||||
def _resize_token_embeddings(self):
|
def _resize_token_embeddings(self):
|
||||||
"""Resize token embeddings if needed."""
|
"""Resize token embeddings if needed."""
|
||||||
|
# Skip if model doesn't have the necessary methods
|
||||||
|
if not hasattr(self.model, "get_input_embeddings"):
|
||||||
|
LOG.warning("Model does not have get_input_embeddings method, skipping token embedding resize")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Check if get_input_embeddings returns None
|
||||||
|
input_embeddings = self.model.get_input_embeddings()
|
||||||
|
if input_embeddings is None:
|
||||||
|
LOG.warning("Model's get_input_embeddings returned None, skipping token embedding resize")
|
||||||
|
return
|
||||||
|
|
||||||
embeddings_len = (
|
embeddings_len = (
|
||||||
math.ceil(len(self.tokenizer) / 32) * 32
|
math.ceil(len(self.tokenizer) / 32) * 32
|
||||||
if self.cfg.resize_token_embeddings_to_32x
|
if self.cfg.resize_token_embeddings_to_32x
|
||||||
else len(self.tokenizer)
|
else len(self.tokenizer)
|
||||||
)
|
)
|
||||||
if hasattr(self.model, "get_input_embeddings") and (
|
|
||||||
self.model.get_input_embeddings().num_embeddings < embeddings_len
|
if hasattr(input_embeddings, "num_embeddings") and (
|
||||||
|
input_embeddings.num_embeddings < embeddings_len
|
||||||
or (
|
or (
|
||||||
self.model.get_input_embeddings().num_embeddings > embeddings_len
|
input_embeddings.num_embeddings > embeddings_len
|
||||||
and self.cfg.shrink_embeddings
|
and self.cfg.shrink_embeddings
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
@@ -233,14 +254,24 @@ class ModelLoader:
|
|||||||
self.model_config.model_type != "llava"
|
self.model_config.model_type != "llava"
|
||||||
):
|
):
|
||||||
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
resize_kwargs["mean_resizing"] = self.cfg.mean_resizing_embeddings
|
||||||
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
|
||||||
|
if hasattr(self.model, "resize_token_embeddings"):
|
||||||
|
self.model.resize_token_embeddings(embeddings_len, **resize_kwargs)
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have resize_token_embeddings method, skipping resize")
|
||||||
else:
|
else:
|
||||||
self.model.tie_weights()
|
if hasattr(self.model, "tie_weights"):
|
||||||
|
self.model.tie_weights()
|
||||||
|
|
||||||
def _adjust_model_config(self):
|
def _adjust_model_config(self):
|
||||||
|
# Skip if model doesn't have config attribute
|
||||||
|
if not hasattr(self.model, "config"):
|
||||||
|
LOG.warning("Model does not have config attribute, skipping model config adjustments")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Handle max_position_embeddings
|
||||||
if (
|
if (
|
||||||
hasattr(self.model, "config")
|
hasattr(self.model.config, "max_position_embeddings")
|
||||||
and hasattr(self.model.config, "max_position_embeddings")
|
|
||||||
and self.model.config.max_position_embeddings
|
and self.model.config.max_position_embeddings
|
||||||
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
and self.cfg.sequence_len > self.model.config.max_position_embeddings
|
||||||
):
|
):
|
||||||
@@ -250,17 +281,17 @@ class ModelLoader:
|
|||||||
)
|
)
|
||||||
self.model.config.max_position_embeddings = self.cfg.sequence_len
|
self.model.config.max_position_embeddings = self.cfg.sequence_len
|
||||||
|
|
||||||
|
# Handle bos_token_id
|
||||||
if (
|
if (
|
||||||
hasattr(self.model, "config")
|
hasattr(self.model.config, "bos_token_id")
|
||||||
and hasattr(self.model.config, "bos_token_id")
|
|
||||||
and self.model.config.bos_token_id
|
and self.model.config.bos_token_id
|
||||||
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
|
and self.model.config.bos_token_id != self.tokenizer.bos_token_id
|
||||||
):
|
):
|
||||||
self.model.config.bos_token_id = self.tokenizer.bos_token_id
|
self.model.config.bos_token_id = self.tokenizer.bos_token_id
|
||||||
|
|
||||||
|
# Handle eos_token_id
|
||||||
if (
|
if (
|
||||||
hasattr(self.model, "config")
|
hasattr(self.model.config, "eos_token_id")
|
||||||
and hasattr(self.model.config, "eos_token_id")
|
|
||||||
and self.model.config.eos_token_id
|
and self.model.config.eos_token_id
|
||||||
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
|
and self.model.config.eos_token_id != self.tokenizer.eos_token_id
|
||||||
):
|
):
|
||||||
@@ -292,9 +323,12 @@ class ModelLoader:
|
|||||||
if self.cfg.adapter in ["lora", "qlora"]:
|
if self.cfg.adapter in ["lora", "qlora"]:
|
||||||
needs_fa2_dtype = True
|
needs_fa2_dtype = True
|
||||||
if self.cfg.gradient_checkpointing:
|
if self.cfg.gradient_checkpointing:
|
||||||
self.model.gradient_checkpointing_enable(
|
if hasattr(self.model, "gradient_checkpointing_enable"):
|
||||||
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
|
self.model.gradient_checkpointing_enable(
|
||||||
)
|
gradient_checkpointing_kwargs=self.cfg.gradient_checkpointing_kwargs
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have gradient_checkpointing_enable method, skipping gradient checkpointing")
|
||||||
|
|
||||||
self._prepare_model_for_quantization()
|
self._prepare_model_for_quantization()
|
||||||
|
|
||||||
@@ -371,11 +405,14 @@ class ModelLoader:
|
|||||||
self.model.is_parallelizable = True
|
self.model.is_parallelizable = True
|
||||||
self.model.model_parallel = True
|
self.model.model_parallel = True
|
||||||
|
|
||||||
if not any(
|
if hasattr(self.model, "named_parameters"):
|
||||||
param.requires_grad
|
if not any(
|
||||||
for _, param in self.model.named_parameters(recurse=True)
|
param.requires_grad
|
||||||
):
|
for _, param in self.model.named_parameters(recurse=True)
|
||||||
LOG.warning("There are no parameters that require gradient updates")
|
):
|
||||||
|
LOG.warning("There are no parameters that require gradient updates")
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have named_parameters attribute, skipping gradient check")
|
||||||
|
|
||||||
if self.cfg.flash_optimum:
|
if self.cfg.flash_optimum:
|
||||||
from optimum.bettertransformer import BetterTransformer
|
from optimum.bettertransformer import BetterTransformer
|
||||||
@@ -383,7 +420,10 @@ class ModelLoader:
|
|||||||
self.model = BetterTransformer.transform(self.model)
|
self.model = BetterTransformer.transform(self.model)
|
||||||
|
|
||||||
if self.cfg.adapter is not None:
|
if self.cfg.adapter is not None:
|
||||||
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
if hasattr(self.model, "device"):
|
||||||
|
log_gpu_memory_usage(LOG, "after adapters", self.model.device)
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have device attribute, skipping memory usage logging")
|
||||||
|
|
||||||
for _ in range(3):
|
for _ in range(3):
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@@ -700,6 +740,10 @@ class ModelLoader:
|
|||||||
and self.model_type != "AutoModelForCausalLM"
|
and self.model_type != "AutoModelForCausalLM"
|
||||||
and not self.cfg.trust_remote_code
|
and not self.cfg.trust_remote_code
|
||||||
):
|
):
|
||||||
|
if self.model_type == "GraniteSpeechConfig" and not hasattr(self.model_config, 'vocab_size'):
|
||||||
|
# Set vocab_size from tokenizer or use a reasonable default
|
||||||
|
self.model_config.vocab_size = getattr(self.model_config, 'vocab_size', 50257)
|
||||||
|
|
||||||
if self.cfg.gptq:
|
if self.cfg.gptq:
|
||||||
self.model = self.auto_model_loader.from_pretrained(
|
self.model = self.auto_model_loader.from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
@@ -707,7 +751,21 @@ class ModelLoader:
|
|||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.model_type == "GraniteSpeechConfig":
|
||||||
|
# Use the actual model class for Granite Speech
|
||||||
|
self.model = transformers.GraniteSpeechForCausalLM.from_pretrained(
|
||||||
|
self.base_model,
|
||||||
|
config=self.model_config,
|
||||||
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
|
**self.model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
if not hasattr(self.model_config, 'vocab_size'):
|
||||||
|
LOG.warning("Model config does not have vocab_size attribute, setting to 50257")
|
||||||
|
self.model_config.vocab_size = 50257
|
||||||
|
|
||||||
self.model = getattr(transformers, self.model_type).from_pretrained(
|
self.model = getattr(transformers, self.model_type).from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
config=self.model_config,
|
config=self.model_config,
|
||||||
@@ -791,13 +849,19 @@ class ModelLoader:
|
|||||||
dest = {"dtype": dist_dtype}
|
dest = {"dtype": dist_dtype}
|
||||||
if self.cfg.lora_on_cpu:
|
if self.cfg.lora_on_cpu:
|
||||||
dest["device"] = "cpu"
|
dest["device"] = "cpu"
|
||||||
|
|
||||||
|
# Check if the model has named_modules attribute
|
||||||
|
if not hasattr(self.model, "named_modules"):
|
||||||
|
LOG.warning("Model does not have named_modules attribute, skipping embedding dtype conversion")
|
||||||
|
return
|
||||||
|
|
||||||
for name, module in self.model.named_modules():
|
for name, module in self.model.named_modules():
|
||||||
if "norm" in name:
|
if "norm" in name:
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
if before_kbit_train_or_finetune:
|
if before_kbit_train_or_finetune:
|
||||||
if name.endswith(".gate"):
|
if name.endswith(".gate"):
|
||||||
module.to(dist_dtype)
|
module.to(dist_dtype)
|
||||||
if self.model_config.model_type == "btlm":
|
if self.model_config.model_type == "btlm" and "lm_head" in name:
|
||||||
# don't upcast lm_head for btlm
|
# don't upcast lm_head for btlm
|
||||||
continue
|
continue
|
||||||
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
if any(m in name for m in embedding_modules) and hasattr(module, "weight"):
|
||||||
|
|||||||
@@ -80,7 +80,15 @@ def setup_model_and_tokenizer(
|
|||||||
|
|
||||||
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
model_loader = ModelLoader(cfg, tokenizer, processor=processor)
|
||||||
model, peft_config = model_loader.load()
|
model, peft_config = model_loader.load()
|
||||||
if model.generation_config is not None:
|
|
||||||
|
# Check if model is actually a GraniteConfig object
|
||||||
|
if model.__class__.__name__ == "GraniteConfig":
|
||||||
|
LOG.error("The model loaded is a GraniteConfig object, not a proper model.")
|
||||||
|
LOG.error("This is likely because the model type 'GraniteConfig' is not supported.")
|
||||||
|
LOG.error("Please use a different model type or ensure the model is properly configured.")
|
||||||
|
raise ValueError("Model loaded is a GraniteConfig object, not a proper model. Use a supported model type.")
|
||||||
|
|
||||||
|
if hasattr(model, "generation_config") and model.generation_config is not None:
|
||||||
model.generation_config.do_sample = True
|
model.generation_config.do_sample = True
|
||||||
|
|
||||||
# Apply freezing if specified
|
# Apply freezing if specified
|
||||||
@@ -90,7 +98,10 @@ def setup_model_and_tokenizer(
|
|||||||
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
any(embed in param for embed in ["lm_head", "embed_tokens"])
|
||||||
for param in cfg.unfrozen_parameters
|
for param in cfg.unfrozen_parameters
|
||||||
):
|
):
|
||||||
model.enable_input_require_grads()
|
if hasattr(model, "enable_input_require_grads"):
|
||||||
|
model.enable_input_require_grads()
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have enable_input_require_grads method, skipping")
|
||||||
|
|
||||||
return model, tokenizer, peft_config, processor
|
return model, tokenizer, peft_config, processor
|
||||||
|
|
||||||
@@ -246,9 +257,12 @@ def save_trained_model(
|
|||||||
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
LOG.info(f"Training completed! Saving trained model to {cfg.output_dir}.")
|
||||||
|
|
||||||
# Post training module hooks
|
# Post training module hooks
|
||||||
for name, module in model.named_modules():
|
if hasattr(model, "named_modules"):
|
||||||
if hasattr(module, "_post_training"):
|
for name, module in model.named_modules():
|
||||||
module._post_training(model, name) # pylint: disable=protected-access
|
if hasattr(module, "_post_training"):
|
||||||
|
module._post_training(model, name) # pylint: disable=protected-access
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have named_modules attribute, skipping post training hooks")
|
||||||
|
|
||||||
# handle QAT
|
# handle QAT
|
||||||
if cfg.qat:
|
if cfg.qat:
|
||||||
@@ -308,11 +322,17 @@ def save_trained_model(
|
|||||||
model = BetterTransformer.reverse(model)
|
model = BetterTransformer.reverse(model)
|
||||||
|
|
||||||
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
if cfg.rl and cfg.adapter and not cfg.rl_adapter_ref_model:
|
||||||
trainer.model.save_pretrained(
|
if hasattr(trainer.model, "save_pretrained"):
|
||||||
cfg.output_dir, safe_serialization=safe_serialization
|
trainer.model.save_pretrained(
|
||||||
)
|
cfg.output_dir, safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning("Trainer model does not have save_pretrained method, skipping save")
|
||||||
|
|
||||||
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
if hasattr(model, "save_pretrained"):
|
||||||
|
model.save_pretrained(cfg.output_dir, safe_serialization=safe_serialization)
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have save_pretrained method, skipping save")
|
||||||
|
|
||||||
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
if hasattr(cfg, "llmcompressor") and cfg.llmcompressor:
|
||||||
# TODO: add integration support so this can be implemented completely within the plugin
|
# TODO: add integration support so this can be implemented completely within the plugin
|
||||||
@@ -398,7 +418,10 @@ def save_initial_configs(
|
|||||||
tokenizer.save_pretrained(str(output_dir))
|
tokenizer.save_pretrained(str(output_dir))
|
||||||
if hasattr(model, "config"):
|
if hasattr(model, "config"):
|
||||||
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving model config to {cfg.output_dir}...")
|
||||||
model.config.save_pretrained(str(output_dir))
|
if hasattr(model.config, "save_pretrained"):
|
||||||
|
model.config.save_pretrained(str(output_dir))
|
||||||
|
else:
|
||||||
|
LOG.warning("Model config does not have save_pretrained method, skipping config save")
|
||||||
|
|
||||||
if processor:
|
if processor:
|
||||||
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
LOG.info(f"Pre-saving processor to {cfg.output_dir}...")
|
||||||
@@ -461,9 +484,12 @@ def handle_untrained_tokens_fix(
|
|||||||
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
fix_untrained_tokens(model, tokenizer, train_dataset, **fix_kwargs)
|
||||||
|
|
||||||
if cfg.local_rank == 0:
|
if cfg.local_rank == 0:
|
||||||
model.save_pretrained(
|
if hasattr(model, "save_pretrained"):
|
||||||
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
model.save_pretrained(
|
||||||
)
|
str(Path(cfg.output_dir)), safe_serialization=safe_serialization
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
LOG.warning("Model does not have save_pretrained method, skipping save")
|
||||||
|
|
||||||
|
|
||||||
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
def setup_model_and_trainer(cfg: DictDefault, dataset_meta: TrainDatasetMeta) -> tuple[
|
||||||
|
|||||||
Reference in New Issue
Block a user