fix softmax class check

This commit is contained in:
Wing Lian
2025-01-15 23:23:13 -05:00
parent 36b71f34d7
commit 8c4f89745a
2 changed files with 2 additions and 1 deletions

View File

@@ -17,4 +17,5 @@ class RalaPlugin(BasePlugin):
return "axolotl.integrations.rala.args.RalaArgs"
def register(self):
LOG.info("Registering RALA model with AutoConfig & AutoModel")
register_rala_model()

View File

@@ -333,7 +333,7 @@ class LlamaRalaDecoderLayer(nn.Module):
super().__init__()
self.hidden_size = config.hidden_size
if LlamaRalaConfig.is_layer_idx_softmax(
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
config.num_hidden_layers, layer_idx, config.softmax_every
):
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](