fix softmax class check
This commit is contained in:
@@ -17,4 +17,5 @@ class RalaPlugin(BasePlugin):
|
|||||||
return "axolotl.integrations.rala.args.RalaArgs"
|
return "axolotl.integrations.rala.args.RalaArgs"
|
||||||
|
|
||||||
def register(self):
|
def register(self):
|
||||||
|
LOG.info("Registering RALA model with AutoConfig & AutoModel")
|
||||||
register_rala_model()
|
register_rala_model()
|
||||||
|
|||||||
@@ -333,7 +333,7 @@ class LlamaRalaDecoderLayer(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
||||||
if LlamaRalaConfig.is_layer_idx_softmax(
|
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
|
||||||
config.num_hidden_layers, layer_idx, config.softmax_every
|
config.num_hidden_layers, layer_idx, config.softmax_every
|
||||||
):
|
):
|
||||||
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
|
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
|
||||||
|
|||||||
Reference in New Issue
Block a user