use autoconfig w rala
This commit is contained in:
@@ -82,6 +82,7 @@ def convert_rala(cfg, cli_args, config_path):
|
||||
zero_init=cli_args.zero_init,
|
||||
)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
model.config.model_type = "llama-rala"
|
||||
except Exception as exc:
|
||||
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
|
||||
raise
|
||||
|
||||
@@ -2,10 +2,8 @@
|
||||
|
||||
import logging
|
||||
|
||||
# from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||
|
||||
from axolotl.integrations.base import BasePlugin
|
||||
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRALAAttention
|
||||
from axolotl.integrations.rala.auto.llama.modeling_rala import register_rala_model
|
||||
|
||||
LOG = logging.getLogger(__name__)
|
||||
|
||||
@@ -18,7 +16,6 @@ class RalaPlugin(BasePlugin):
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.rala.args.RalaArgs"
|
||||
|
||||
def set_attn_config(self, cfg, model_kwargs, model_config):
|
||||
# if cfg.rala_attention:
|
||||
# model_kwargs["attn_implementation"] = "rala"
|
||||
...
|
||||
def pre_model_load(self, cfg):
|
||||
if cfg.rala_attention:
|
||||
register_rala_model()
|
||||
|
||||
@@ -9,4 +9,5 @@ class LlamaRalaConfig(LlamaConfig):
|
||||
Configuration for LlamaRala model
|
||||
"""
|
||||
|
||||
model_type = "llama-rala"
|
||||
softmax_every: int = 6 # every N-th layer applies softmax
|
||||
|
||||
@@ -19,11 +19,18 @@ from typing import List, Optional, Tuple, Union, Unpack
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import Cache, GenerationMixin, LlamaModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
AutoModelForCausalLM,
|
||||
Cache,
|
||||
GenerationMixin,
|
||||
LlamaModel,
|
||||
)
|
||||
from transformers.modeling_outputs import CausalLMOutputWithPast
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LLAMA_ATTENTION_CLASSES,
|
||||
KwargsForCausalLM,
|
||||
LlamaAttention,
|
||||
LlamaDynamicNTKScalingRotaryEmbedding,
|
||||
LlamaLinearScalingRotaryEmbedding,
|
||||
LlamaMLP,
|
||||
@@ -329,7 +336,10 @@ class LlamaRalaDecoderLayer(nn.Module):
|
||||
if LlamaRalaConfig.is_layer_idx_softmax(
|
||||
config.num_hidden_layers, layer_idx, config.softmax_every
|
||||
):
|
||||
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
|
||||
self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
|
||||
config=config, layer_idx=layer_idx
|
||||
)
|
||||
# self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
|
||||
else:
|
||||
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
|
||||
|
||||
@@ -594,3 +604,20 @@ class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
|
||||
hidden_states=outputs.hidden_states,
|
||||
attentions=outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
def register_rala_model() -> None:
|
||||
"""
|
||||
Register differential attention components with the transformers library.
|
||||
This function registers the differential attention configurations and model classes
|
||||
with the Auto* classes from `transformers`, making them available through the
|
||||
standard model loading pipeline.
|
||||
"""
|
||||
# Register configs
|
||||
AutoConfig.register("llama-rala", LlamaRalaConfig)
|
||||
|
||||
# Register models
|
||||
AutoModel.register(LlamaRalaConfig, LlamaRalaModel)
|
||||
AutoModelForCausalLM.register(LlamaRalaConfig, LlamaRalaForCausalLM)
|
||||
|
||||
LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention
|
||||
|
||||
@@ -7,8 +7,10 @@ from torch import nn
|
||||
from transformers import PreTrainedModel
|
||||
from transformers.models.llama.modeling_llama import LlamaAttention
|
||||
|
||||
from axolotl.integrations.rala import LlamaRALAAttention
|
||||
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer
|
||||
from axolotl.integrations.rala.auto.llama.modeling_rala import (
|
||||
LlamaRALAAttention,
|
||||
LlamaRalaDecoderLayer,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user