keep some softmax layers
This commit is contained in:
@@ -1,5 +1,12 @@
|
|||||||
|
"""
|
||||||
|
Rala config class
|
||||||
|
"""
|
||||||
from transformers import LlamaConfig
|
from transformers import LlamaConfig
|
||||||
|
|
||||||
|
|
||||||
class LlamaRalaConfig(LlamaConfig):
|
class LlamaRalaConfig(LlamaConfig):
|
||||||
pass
|
"""
|
||||||
|
Configuration for LlamaRala model
|
||||||
|
"""
|
||||||
|
|
||||||
|
softmax_every: int = 6 # every 8th layer applies softmax
|
||||||
|
|||||||
@@ -333,6 +333,26 @@ class LlamaRalaDecoderLayer(nn.Module):
|
|||||||
config.hidden_size, eps=config.rms_norm_eps
|
config.hidden_size, eps=config.rms_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_layer_idx_softmax(
|
||||||
|
cls, num_hidden_layers: int, layer_idx: int, softmax_every: int
|
||||||
|
) -> bool:
|
||||||
|
inner_layers = num_hidden_layers - 2
|
||||||
|
if 1 + softmax_every * (inner_layers // softmax_every) == inner_layers:
|
||||||
|
softmax_start_idx = 1
|
||||||
|
elif 1 + softmax_every * (inner_layers // softmax_every) > inner_layers:
|
||||||
|
layer_group_size = 1 + softmax_every * ((inner_layers // softmax_every) - 1)
|
||||||
|
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
|
||||||
|
elif 1 + softmax_every * (inner_layers // softmax_every) < inner_layers:
|
||||||
|
layer_group_size = 1 + softmax_every * (inner_layers // softmax_every)
|
||||||
|
softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2
|
||||||
|
|
||||||
|
softmax_layers = set(range(softmax_start_idx, num_hidden_layers, softmax_every))
|
||||||
|
softmax_layers.add(0)
|
||||||
|
softmax_layers.add(num_hidden_layers - 1)
|
||||||
|
|
||||||
|
return layer_idx in softmax_layers
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ from transformers import PreTrainedModel
|
|||||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||||
|
|
||||||
from axolotl.integrations.rala import LlamaRALAAttention
|
from axolotl.integrations.rala import LlamaRALAAttention
|
||||||
|
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -46,18 +47,22 @@ def copy_attention_weights(
|
|||||||
|
|
||||||
|
|
||||||
def convert_to_rala(
|
def convert_to_rala(
|
||||||
model: PreTrainedModel,
|
model: PreTrainedModel, zero_init: bool = False, softmax_every_n: int = 6
|
||||||
zero_init: bool = False,
|
|
||||||
) -> PreTrainedModel:
|
) -> PreTrainedModel:
|
||||||
"""Convert a pre-trained model's attention layers to differential attention"""
|
"""Convert a pre-trained model's attention layers to differential attention"""
|
||||||
layer_idx = 0
|
layer_idx = 0
|
||||||
|
|
||||||
def convert_module(module):
|
def convert_module(module, softmax_every, num_hidden_layers):
|
||||||
nonlocal layer_idx
|
nonlocal layer_idx
|
||||||
|
|
||||||
# Iterate through module children, convert any attn layers to diff attn
|
# Iterate through module children, convert any attn layers to diff attn
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
|
if isinstance(child, tuple(ATTENTION_MAPPING.keys())):
|
||||||
|
decoder_layer_idx = child.layer_idx
|
||||||
|
if LlamaRalaDecoderLayer.is_layer_idx_softmax(
|
||||||
|
num_hidden_layers, decoder_layer_idx, softmax_every
|
||||||
|
):
|
||||||
|
continue
|
||||||
# Choose appropriate differential attention class
|
# Choose appropriate differential attention class
|
||||||
# pylint: disable=duplicate-code
|
# pylint: disable=duplicate-code
|
||||||
attention_class = ATTENTION_MAPPING[type(child)]
|
attention_class = ATTENTION_MAPPING[type(child)]
|
||||||
@@ -81,9 +86,10 @@ def convert_to_rala(
|
|||||||
setattr(module, name, new_attention)
|
setattr(module, name, new_attention)
|
||||||
layer_idx += 1
|
layer_idx += 1
|
||||||
elif len(list(child.children())) > 0:
|
elif len(list(child.children())) > 0:
|
||||||
convert_module(child)
|
convert_module(child, softmax_every, num_hidden_layers)
|
||||||
|
|
||||||
convert_module(model)
|
model.config.softmax_every = softmax_every_n
|
||||||
|
convert_module(model, softmax_every_n, model.config.num_hidden_layers)
|
||||||
logger.info(f"Converted {layer_idx} attention layers to RALA attention")
|
logger.info(f"Converted {layer_idx} attention layers to RALA attention")
|
||||||
|
|
||||||
model.config.architectures = [
|
model.config.architectures = [
|
||||||
|
|||||||
Reference in New Issue
Block a user