From d28fee7609f7becc74aa13425f3862ca0f990059 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 15 Jan 2025 23:14:47 -0500 Subject: [PATCH] use autoconfig w rala --- src/axolotl/cli/integrations/convert_rala.py | 1 + src/axolotl/integrations/rala/__init__.py | 11 +++---- .../rala/auto/llama/configuration_rala.py | 1 + .../rala/auto/llama/modeling_rala.py | 33 +++++++++++++++++-- src/axolotl/integrations/rala/convert.py | 6 ++-- 5 files changed, 40 insertions(+), 12 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_rala.py b/src/axolotl/cli/integrations/convert_rala.py index b2d7fa1d3..47eed7917 100644 --- a/src/axolotl/cli/integrations/convert_rala.py +++ b/src/axolotl/cli/integrations/convert_rala.py @@ -82,6 +82,7 @@ def convert_rala(cfg, cli_args, config_path): zero_init=cli_args.zero_init, ) model.to(cfg.device, dtype=cfg.torch_dtype) + model.config.model_type = "llama-rala" except Exception as exc: LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) raise diff --git a/src/axolotl/integrations/rala/__init__.py b/src/axolotl/integrations/rala/__init__.py index fddc3e0bf..235d0cc22 100644 --- a/src/axolotl/integrations/rala/__init__.py +++ b/src/axolotl/integrations/rala/__init__.py @@ -2,10 +2,8 @@ import logging -# from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES - from axolotl.integrations.base import BasePlugin -from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRALAAttention +from axolotl.integrations.rala.auto.llama.modeling_rala import register_rala_model LOG = logging.getLogger(__name__) @@ -18,7 +16,6 @@ class RalaPlugin(BasePlugin): def get_input_args(self): return "axolotl.integrations.rala.args.RalaArgs" - def set_attn_config(self, cfg, model_kwargs, model_config): - # if cfg.rala_attention: - # model_kwargs["attn_implementation"] = "rala" - ... + def pre_model_load(self, cfg): + if cfg.rala_attention: + register_rala_model() diff --git a/src/axolotl/integrations/rala/auto/llama/configuration_rala.py b/src/axolotl/integrations/rala/auto/llama/configuration_rala.py index 04d485a15..afb0c29e5 100644 --- a/src/axolotl/integrations/rala/auto/llama/configuration_rala.py +++ b/src/axolotl/integrations/rala/auto/llama/configuration_rala.py @@ -9,4 +9,5 @@ class LlamaRalaConfig(LlamaConfig): Configuration for LlamaRala model """ + model_type = "llama-rala" softmax_every: int = 6 # every N-th layer applies softmax diff --git a/src/axolotl/integrations/rala/auto/llama/modeling_rala.py b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py index b951f47c4..60964a343 100644 --- a/src/axolotl/integrations/rala/auto/llama/modeling_rala.py +++ b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py @@ -19,11 +19,18 @@ from typing import List, Optional, Tuple, Union, Unpack import torch import torch.nn.functional as F from torch import nn -from transformers import Cache, GenerationMixin, LlamaModel +from transformers import ( + AutoConfig, + AutoModel, + AutoModelForCausalLM, + Cache, + GenerationMixin, + LlamaModel, +) from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.models.llama.modeling_llama import ( + LLAMA_ATTENTION_CLASSES, KwargsForCausalLM, - LlamaAttention, LlamaDynamicNTKScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding, LlamaMLP, @@ -329,7 +336,10 @@ class LlamaRalaDecoderLayer(nn.Module): if LlamaRalaConfig.is_layer_idx_softmax( config.num_hidden_layers, layer_idx, config.softmax_every ): - self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation]( + config=config, layer_idx=layer_idx + ) + # self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) else: self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx) @@ -594,3 +604,20 @@ class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin): hidden_states=outputs.hidden_states, attentions=outputs.attentions, ) + + +def register_rala_model() -> None: + """ + Register differential attention components with the transformers library. + This function registers the differential attention configurations and model classes + with the Auto* classes from `transformers`, making them available through the + standard model loading pipeline. + """ + # Register configs + AutoConfig.register("llama-rala", LlamaRalaConfig) + + # Register models + AutoModel.register(LlamaRalaConfig, LlamaRalaModel) + AutoModelForCausalLM.register(LlamaRalaConfig, LlamaRalaForCausalLM) + + LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention diff --git a/src/axolotl/integrations/rala/convert.py b/src/axolotl/integrations/rala/convert.py index 2f081d8e8..da57074bf 100644 --- a/src/axolotl/integrations/rala/convert.py +++ b/src/axolotl/integrations/rala/convert.py @@ -7,8 +7,10 @@ from torch import nn from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LlamaAttention -from axolotl.integrations.rala import LlamaRALAAttention -from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer +from axolotl.integrations.rala.auto.llama.modeling_rala import ( + LlamaRALAAttention, + LlamaRalaDecoderLayer, +) logger = logging.getLogger(__name__)