diff --git a/src/axolotl/integrations/rala/__init__.py b/src/axolotl/integrations/rala/__init__.py index ccdcc42b5..6c0bea951 100644 --- a/src/axolotl/integrations/rala/__init__.py +++ b/src/axolotl/integrations/rala/__init__.py @@ -17,4 +17,5 @@ class RalaPlugin(BasePlugin): return "axolotl.integrations.rala.args.RalaArgs" def register(self): + LOG.info("Registering RALA model with AutoConfig & AutoModel") register_rala_model() diff --git a/src/axolotl/integrations/rala/auto/llama/modeling_rala.py b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py index 60964a343..c27026658 100644 --- a/src/axolotl/integrations/rala/auto/llama/modeling_rala.py +++ b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py @@ -333,7 +333,7 @@ class LlamaRalaDecoderLayer(nn.Module): super().__init__() self.hidden_size = config.hidden_size - if LlamaRalaConfig.is_layer_idx_softmax( + if LlamaRalaDecoderLayer.is_layer_idx_softmax( config.num_hidden_layers, layer_idx, config.softmax_every ): self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](