keep some softmax layers
This commit is contained in:
@@ -1,5 +1,12 @@
|
||||
"""
|
||||
Rala config class
|
||||
"""
|
||||
from transformers import LlamaConfig
|
||||
|
||||
|
||||
class LlamaRalaConfig(LlamaConfig):
|
||||
pass
|
||||
"""
|
||||
Configuration for LlamaRala model
|
||||
"""
|
||||
|
||||
softmax_every: int = 6 # every 8th layer applies softmax
|
||||
|
||||
@@ -333,6 +333,26 @@ class LlamaRalaDecoderLayer(nn.Module):
|
||||
config.hidden_size, eps=config.rms_norm_eps
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_layer_idx_softmax(
|
||||
cls, num_hidden_layers: int, layer_idx: int, softmax_every: int
|
||||
) -> bool:
|
||||
inner_layers = num_hidden_layers - 2
|
||||
if 1 + softmax_every * (inner_layers // softmax_every) == inner_layers:
|
||||
softmax_start_idx = 1
|
||||
elif 1 + softmax_every * (inner_layers // softmax_every) > inner_layers:
|
||||
layer_group_size = 1 + softmax_every * ((inner_layers // softmax_every) - 1)
|
||||
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
|
||||
elif 1 + softmax_every * (inner_layers // softmax_every) < inner_layers:
|
||||
layer_group_size = 1 + softmax_every * (inner_layers // softmax_every)
|
||||
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
|
||||
|
||||
softmax_layers = set(range(softmax_start_idx, num_hidden_layers, softmax_every))
|
||||
softmax_layers.add(0)
|
||||
softmax_layers.add(num_hidden_layers - 1)
|
||||
|
||||
return layer_idx in softmax_layers
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
|
||||
@@ -8,6 +8,7 @@ from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from axolotl.integrations.rala import LlamaRALAAttention
|
||||
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -46,18 +47,22 @@ def copy_attention_weights(
|
||||
|
||||
|
||||
def convert_to_rala(
|
||||
model: PreTrainedModel,
|
||||
zero_init: bool = False,
|
||||
model: PreTrainedModel, zero_init: bool = False, softmax_every_n: int = 6
|
||||
) -> PreTrainedModel:
|
||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||
layer_idx = 0
|
||||
|
||||
def convert_module(module):
|
||||
def convert_module(module, softmax_every, num_hidden_layers):
|
||||
nonlocal layer_idx
|
||||
|
||||
# Iterate through module children, convert any attn layers to diff attn
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
|
||||
decoder_layer_idx = child.layer_idx
|
||||
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
|
||||
num_hidden_layers, decoder_layer_idx, softmax_every
|
||||
):
|
||||
continue
|
||||
# Choose appropriate differential attention class
|
||||
# pylint: disable=duplicate-code
|
||||
attention_class = ATTENTION_MAPPING[type(child)]
|
||||
@@ -81,9 +86,10 @@ def convert_to_rala(
|
||||
setattr(module, name, new_attention)
|
||||
layer_idx += 1
|
||||
elif len(list(child.children())) > 0:
|
||||
convert_module(child)
|
||||
convert_module(child, softmax_every, num_hidden_layers)
|
||||
|
||||
convert_module(model)
|
||||
model.config.softmax_every = softmax_every_n
|
||||
convert_module(model, softmax_every_n, model.config.num_hidden_layers)
|
||||
logger.info(f"Converted {layer_idx} attention layers to RALA attention")
|
||||
|
||||
model.config.architectures = [
|
||||
|
||||
Reference in New Issue
Block a user