Enhance model loading logic to include support for GraniteSpeechConfig, allowing for the use of the specific model class for Granite Speech.

This commit is contained in:
mhenrhcsen
2025-07-17 19:45:23 +02:00
parent 738adb2258
commit ea234afa8a

View File

@@ -747,6 +747,16 @@ class ModelLoader:
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
elif self.model_type == "GraniteSpeechConfig":
# Use the actual model class for Granite Speech
self.model = transformers.GraniteSpeechForCausalLM.from_pretrained(
self.base_model,
config=self.model_config,
trust_remote_code=self.cfg.trust_remote_code or False,
**self.model_kwargs,
)
else:
self.model = getattr(transformers, self.model_type).from_pretrained(
self.base_model,