use autoconfig w rala

This commit is contained in:
Wing Lian
2025-01-15 23:14:47 -05:00
parent c196776996
commit d28fee7609
5 changed files with 40 additions and 12 deletions

View File

@@ -82,6 +82,7 @@ def convert_rala(cfg, cli_args, config_path):
zero_init=cli_args.zero_init, zero_init=cli_args.zero_init,
) )
model.to(cfg.device, dtype=cfg.torch_dtype) model.to(cfg.device, dtype=cfg.torch_dtype)
model.config.model_type = "llama-rala"
except Exception as exc: except Exception as exc:
LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc))
raise raise

View File

@@ -2,10 +2,8 @@
import logging import logging
# from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
from axolotl.integrations.base import BasePlugin from axolotl.integrations.base import BasePlugin
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRALAAttention from axolotl.integrations.rala.auto.llama.modeling_rala import register_rala_model
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -18,7 +16,6 @@ class RalaPlugin(BasePlugin):
def get_input_args(self): def get_input_args(self):
return "axolotl.integrations.rala.args.RalaArgs" return "axolotl.integrations.rala.args.RalaArgs"
def set_attn_config(self, cfg, model_kwargs, model_config): def pre_model_load(self, cfg):
# if cfg.rala_attention: if cfg.rala_attention:
# model_kwargs["attn_implementation"] = "rala" register_rala_model()
...

View File

@@ -9,4 +9,5 @@ class LlamaRalaConfig(LlamaConfig):
Configuration for LlamaRala model Configuration for LlamaRala model
""" """
model_type = "llama-rala"
softmax_every: int = 6 # every N-th layer applies softmax softmax_every: int = 6 # every N-th layer applies softmax

View File

@@ -19,11 +19,18 @@ from typing import List, Optional, Tuple, Union, Unpack
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch import nn from torch import nn
from transformers import Cache, GenerationMixin, LlamaModel from transformers import (
AutoConfig,
AutoModel,
AutoModelForCausalLM,
Cache,
GenerationMixin,
LlamaModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast from transformers.modeling_outputs import CausalLMOutputWithPast
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LLAMA_ATTENTION_CLASSES,
KwargsForCausalLM, KwargsForCausalLM,
LlamaAttention,
LlamaDynamicNTKScalingRotaryEmbedding, LlamaDynamicNTKScalingRotaryEmbedding,
LlamaLinearScalingRotaryEmbedding, LlamaLinearScalingRotaryEmbedding,
LlamaMLP, LlamaMLP,
@@ -329,7 +336,10 @@ class LlamaRalaDecoderLayer(nn.Module):
if LlamaRalaConfig.is_layer_idx_softmax( if LlamaRalaConfig.is_layer_idx_softmax(
config.num_hidden_layers, layer_idx, config.softmax_every config.num_hidden_layers, layer_idx, config.softmax_every
): ):
self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) self.self_attn = LLAMA_ATTENTION_CLASSES[config._attn_implementation](
config=config, layer_idx=layer_idx
)
# self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx)
else: else:
self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx) self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx)
@@ -594,3 +604,20 @@ class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin):
hidden_states=outputs.hidden_states, hidden_states=outputs.hidden_states,
attentions=outputs.attentions, attentions=outputs.attentions,
) )
def register_rala_model() -> None:
"""
Register differential attention components with the transformers library.
This function registers the differential attention configurations and model classes
with the Auto* classes from `transformers`, making them available through the
standard model loading pipeline.
"""
# Register configs
AutoConfig.register("llama-rala", LlamaRalaConfig)
# Register models
AutoModel.register(LlamaRalaConfig, LlamaRalaModel)
AutoModelForCausalLM.register(LlamaRalaConfig, LlamaRalaForCausalLM)
LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention

View File

@@ -7,8 +7,10 @@ from torch import nn
from transformers import PreTrainedModel from transformers import PreTrainedModel
from transformers.models.llama.modeling_llama import LlamaAttention from transformers.models.llama.modeling_llama import LlamaAttention
from axolotl.integrations.rala import LlamaRALAAttention from axolotl.integrations.rala.auto.llama.modeling_rala import (
from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer LlamaRALAAttention,
LlamaRalaDecoderLayer,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)