fix softmax class check
This commit is contained in:
@@ -17,4 +17,5 @@ class RalaPlugin(BasePlugin):
|
||||
return "axolotl.integrations.rala.args.RalaArgs"
|
||||
|
||||
def register(self):
|
||||
LOG.info("Registering RALA model with AutoConfig & AutoModel")
|
||||
register_rala_model()
|
||||
|
||||
@@ -333,7 +333,7 @@ class LlamaRalaDecoderLayer(nn.Module):
|
||||
super().__init__()
|
||||
self.hidden_size = config.hidden_size
|
||||
|
||||
if LlamaRalaConfig.is_layer_idx_softmax(
|
||||
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
|
||||
config.num_hidden_layers, layer_idx, config.softmax_every
|
||||
):
|
||||
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
|
||||
|
||||
Reference in New Issue
Block a user