keep some softmax layers

This commit is contained in:
Wing Lian
2025-01-15 15:08:32 -05:00
parent 12aade921a
commit 6d3f4b9ab5
3 changed files with 39 additions and 6 deletions

View File

@@ -1,5 +1,12 @@
"""
Rala config class
"""
from transformers import LlamaConfig
class LlamaRalaConfig(LlamaConfig):
pass
"""
Configuration for LlamaRala model
"""
softmax_every: int = 6 # every 8th layer applies softmax

View File

@@ -333,6 +333,26 @@ class LlamaRalaDecoderLayer(nn.Module):
config.hidden_size, eps=config.rms_norm_eps
)
@classmethod
def is_layer_idx_softmax(
cls, num_hidden_layers: int, layer_idx: int, softmax_every: int
) -> bool:
inner_layers = num_hidden_layers - 2
if 1 + softmax_every * (inner_layers // softmax_every) == inner_layers:
softmax_start_idx = 1
elif 1 + softmax_every * (inner_layers // softmax_every) > inner_layers:
layer_group_size = 1 + softmax_every * ((inner_layers // softmax_every) - 1)
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
elif 1 + softmax_every * (inner_layers // softmax_every) < inner_layers:
layer_group_size = 1 + softmax_every * (inner_layers // softmax_every)
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
softmax_layers = set(range(softmax_start_idx, num_hidden_layers, softmax_every))
softmax_layers.add(0)
softmax_layers.add(num_hidden_layers - 1)
return layer_idx in softmax_layers
def forward(
self,
hidden_states: torch.Tensor,

View File

@@ -8,6 +8,7 @@ from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.integrations.rala import LlamaRALAAttention
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer
logger = logging.getLogger(__name__)
@@ -46,18 +47,22 @@ def copy_attention_weights(
def convert_to_rala(
model: PreTrainedModel,
zero_init: bool = False,
model: PreTrainedModel, zero_init: bool = False, softmax_every_n: int = 6
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0
def convert_module(module):
def convert_module(module, softmax_every, num_hidden_layers):
nonlocal layer_idx
# Iterate through module children, convert any attn layers to diff attn
for name, child in module.named_children():
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
decoder_layer_idx = child.layer_idx
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
num_hidden_layers, decoder_layer_idx, softmax_every
):
continue
# Choose appropriate differential attention class
# pylint: disable=duplicate-code
attention_class = ATTENTION_MAPPING[type(child)]
@@ -81,9 +86,10 @@ def convert_to_rala(
setattr(module, name, new_attention)
layer_idx += 1
elif len(list(child.children())) > 0:
convert_module(child)
convert_module(child, softmax_every, num_hidden_layers)
convert_module(model)
model.config.softmax_every = softmax_every_n
convert_module(model, softmax_every_n, model.config.num_hidden_layers)
logger.info(f"Converted {layer_idx} attention layers to RALA attention")
model.config.architectures = [