use autoconfig w rala

This commit is contained in:
Wing Lian
2025-01-15 23:14:47 -05:00
parent c196776996
commit d28fee7609
5 changed files with 40 additions and 12 deletions

View File

@@ -82,6 +82,7 @@ def convert_rala(cfg, cli_args, config_path):
zero_init=cli_args.zero_init,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
model.config.model_type = "llama-rala"
except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise

View File

@@ -2,10 +2,8 @@
import logging
# from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRALAAttention
from axolotl.integrations.rala.auto.llama.modeling_rala import register_rala_model
LOG = logging.getLogger(__name__)
@@ -18,7 +16,6 @@ class RalaPlugin(BasePlugin):
def get_input_args(self):
return "axolotl.integrations.rala.args.RalaArgs"
def set_attn_config(self, cfg, model_kwargs, model_config):
# if cfg.rala_attention:
# model_kwargs["attn_implementation"] = "rala"
...
def pre_model_load(self, cfg):
if cfg.rala_attention:
register_rala_model()

View File

@@ -9,4 +9,5 @@ class LlamaRalaConfig(LlamaConfig):
Configuration for LlamaRala model
"""
model_type = "llama-rala"
softmax_every: int = 6 # every N-th layer applies softmax

View File

@@ -19,11 +19,18 @@ from typing import List, Optional, Tuple, Union, Unpack
import torch
import torch.nn.functional as F
from torch import nn
from transformers import Cache, GenerationMixin, LlamaModel
from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
Cache,
GenerationMixin,
LlamaModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import (
LLAMA_ATTENTION_CLASSES,
KwargsForCausalLM,
LlamaAttention,
LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding,
LlamaMLP,
@@ -329,7 +336,10 @@ class LlamaRalaDecoderLayer(nn.Module):
if LlamaRalaConfig.is_layer_idx_softmax(
config.num_hidden_layers, layer_idx, config.softmax_every
):
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx
)
# self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
else:
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
@@ -594,3 +604,20 @@ class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
def register_rala_model() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-rala", LlamaRalaConfig)
# Register models
AutoModel.register(LlamaRalaConfig, LlamaRalaModel)
AutoModelForCausalLM.register(LlamaRalaConfig, LlamaRalaForCausalLM)
LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention

View File

@@ -7,8 +7,10 @@ from torch import nn
from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.integrations.rala import LlamaRALAAttention
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer
from axolotl.integrations.rala.auto.llama.modeling_rala import (
LlamaRALAAttention,
LlamaRalaDecoderLayer,
)
logger = logging.getLogger(__name__)