Enhance model loading logic to include support for GraniteSpeechConfig, allowing for the use of the specific model class for Granite Speech.
This commit is contained in:
@@ -747,6 +747,16 @@ class ModelLoader:
|
|||||||
trust_remote_code=self.cfg.trust_remote_code or False,
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.model_type == "GraniteSpeechConfig":
|
||||||
|
# Use the actual model class for Granite Speech
|
||||||
|
self.model = transformers.GraniteSpeechForCausalLM.from_pretrained(
|
||||||
|
self.base_model,
|
||||||
|
config=self.model_config,
|
||||||
|
trust_remote_code=self.cfg.trust_remote_code or False,
|
||||||
|
**self.model_kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.model = getattr(transformers, self.model_type).from_pretrained(
|
self.model = getattr(transformers, self.model_type).from_pretrained(
|
||||||
self.base_model,
|
self.base_model,
|
||||||
|
|||||||
Reference in New Issue
Block a user