diff --git a/src/axolotl/integrations/rala/auto/llama/configuration_rala.py b/src/axolotl/integrations/rala/auto/llama/configuration_rala.py index 3b06f264a..378627c6d 100644 --- a/src/axolotl/integrations/rala/auto/llama/configuration_rala.py +++ b/src/axolotl/integrations/rala/auto/llama/configuration_rala.py @@ -1,5 +1,12 @@ +""" +Rala config class +""" from transformers import LlamaConfig class LlamaRalaConfig(LlamaConfig): - pass + """ + Configuration for LlamaRala model + """ + + softmax_every: int = 6 # every 8th layer applies softmax diff --git a/src/axolotl/integrations/rala/auto/llama/modeling_rala.py b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py index 759ef37f8..9ed7f2940 100644 --- a/src/axolotl/integrations/rala/auto/llama/modeling_rala.py +++ b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py @@ -333,6 +333,26 @@ class LlamaRalaDecoderLayer(nn.Module): config.hidden_size, eps=config.rms_norm_eps ) + @classmethod + def is_layer_idx_softmax( + cls, num_hidden_layers: int, layer_idx: int, softmax_every: int + ) -> bool: + inner_layers = num_hidden_layers - 2 + if 1 + softmax_every * (inner_layers // softmax_every) == inner_layers: + softmax_start_idx = 1 + elif 1 + softmax_every * (inner_layers // softmax_every) > inner_layers: + layer_group_size = 1 + softmax_every * ((inner_layers // softmax_every) - 1) + softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2 + elif 1 + softmax_every * (inner_layers // softmax_every) < inner_layers: + layer_group_size = 1 + softmax_every * (inner_layers // softmax_every) + softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2 + + softmax_layers = set(range(softmax_start_idx, num_hidden_layers, softmax_every)) + softmax_layers.add(0) + softmax_layers.add(num_hidden_layers - 1) + + return layer_idx in softmax_layers + def forward( self, hidden_states: torch.Tensor, diff --git a/src/axolotl/integrations/rala/convert.py b/src/axolotl/integrations/rala/convert.py index 7dbf84a33..5523ba9d3 100644 --- a/src/axolotl/integrations/rala/convert.py +++ b/src/axolotl/integrations/rala/convert.py @@ -8,6 +8,7 @@ from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LlamaAttention from axolotl.integrations.rala import LlamaRALAAttention +from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer logger = logging.getLogger(__name__) @@ -46,18 +47,22 @@ def copy_attention_weights( def convert_to_rala( - model: PreTrainedModel, - zero_init: bool = False, + model: PreTrainedModel, zero_init: bool = False, softmax_every_n: int = 6 ) -> PreTrainedModel: """Convert a pre-trained model's attention layers to differential attention""" layer_idx = 0 - def convert_module(module): + def convert_module(module, softmax_every, num_hidden_layers): nonlocal layer_idx # Iterate through module children, convert any attn layers to diff attn for name, child in module.named_children(): if isinstance(child, tuple(ATTENTION_MAPPING.keys())): + decoder_layer_idx = child.layer_idx + if LlamaRalaDecoderLayer.is_layer_idx_softmax( + num_hidden_layers, decoder_layer_idx, softmax_every + ): + continue # Choose appropriate differential attention class # pylint: disable=duplicate-code attention_class = ATTENTION_MAPPING[type(child)] @@ -81,9 +86,10 @@ def convert_to_rala( setattr(module, name, new_attention) layer_idx += 1 elif len(list(child.children())) > 0: - convert_module(child) + convert_module(child, softmax_every, num_hidden_layers) - convert_module(model) + model.config.softmax_every = softmax_every_n + convert_module(model, softmax_every_n, model.config.num_hidden_layers) logger.info(f"Converted {layer_idx} attention layers to RALA attention") model.config.architectures = [