diff --git a/src/axolotl/loaders/model.py b/src/axolotl/loaders/model.py index 3ebf291ad..a530dc9d6 100644 --- a/src/axolotl/loaders/model.py +++ b/src/axolotl/loaders/model.py @@ -747,6 +747,16 @@ class ModelLoader: trust_remote_code=self.cfg.trust_remote_code or False, **self.model_kwargs, ) + + elif self.model_type == "GraniteSpeechConfig": + # Use the actual model class for Granite Speech + self.model = transformers.GraniteSpeechForCausalLM.from_pretrained( + self.base_model, + config=self.model_config, + trust_remote_code=self.cfg.trust_remote_code or False, + **self.model_kwargs, + ) + else: self.model = getattr(transformers, self.model_type).from_pretrained( self.base_model,