From 1c5b78621ccf869b0750063ccb75f3a23780f7b6 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Sat, 21 Dec 2024 00:27:59 -0500 Subject: [PATCH] fix forward sig more fixes --- src/axolotl/cli/integrations/convert_rala.py | 197 +++++++++++++ src/axolotl/cli/main.py | 13 + src/axolotl/core/trainer_builder.py | 3 +- src/axolotl/integrations/base.py | 27 ++ src/axolotl/integrations/rala/__init__.py | 34 +++ src/axolotl/integrations/rala/args.py | 14 + src/axolotl/integrations/rala/convert.py | 88 ++++++ src/axolotl/integrations/rala/rala_attn.py | 278 ++++++++++++++++++ .../monkeypatch/attention/differential.py | 1 + src/axolotl/utils/models.py | 6 +- 10 files changed, 658 insertions(+), 3 deletions(-) create mode 100644 src/axolotl/cli/integrations/convert_rala.py create mode 100644 src/axolotl/integrations/rala/__init__.py create mode 100644 src/axolotl/integrations/rala/args.py create mode 100644 src/axolotl/integrations/rala/convert.py create mode 100644 src/axolotl/integrations/rala/rala_attn.py diff --git a/src/axolotl/cli/integrations/convert_rala.py b/src/axolotl/cli/integrations/convert_rala.py new file mode 100644 index 000000000..149cf70cb --- /dev/null +++ b/src/axolotl/cli/integrations/convert_rala.py @@ -0,0 +1,197 @@ +"""CLI to convert a transformers model's attns to rala attns.""" +import logging +import warnings +from pathlib import Path +from time import time +from typing import Union + +import fire +import torch +import yaml +from colorama import Fore +from dotenv import load_dotenv +from transformers import HfArgumentParser + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer +from axolotl.integrations.rala.convert import convert_to_rala +from axolotl.utils.yaml import dump_yaml_preserved_order + +LOG = logging.getLogger(__name__) + + +def test_inference(model, tokenizer, prompt="The quick brown fox"): + """Run test inference and return generation time""" + try: + inputs = tokenizer(prompt, return_tensors="pt") + inputs = { + k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() + } + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) + + return elapsed, generated_text + + except Exception as exc: + LOG.error("Inference failed: %s", str(exc)) + raise + + +def convert_rala(cfg, cli_args, config_path): + debug_info = {} + + # Load model and tokenizer + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model.to(cfg.device, dtype=cfg.torch_dtype) + + # Log original model info + LOG.info( + "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", + model.config.hidden_size, + model.config.num_attention_heads, + ) + + # Test original model + if cli_args.debug: + LOG.info("attention layers to RALA attention") + debug_info["orig_time"], debug_info["orig_text"] = test_inference( + model, tokenizer + ) + + # Convert attention + try: + model = convert_to_rala( + model=model, + zero_init=cli_args.zero_init, + ) + model.to(cfg.device, dtype=cfg.torch_dtype) + except Exception as exc: + LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) + raise + + # Test converted model + if cli_args.debug: + LOG.info("Testing converted model...") + debug_info["conv_time"], debug_info["conv_text"] = test_inference( + model, tokenizer + ) + + # Save if requested + if cfg.output_dir: + # Save model and tokenizer + LOG.info("Saving converted model to %s", cfg.output_dir) + model.save_pretrained(cfg.output_dir) + tokenizer.save_pretrained(cfg.output_dir) + + # Modify config to reflect new path / differential attention + output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" + LOG.info("Saving updated config to %s", output_config_path) + + with open(config_path, "r", encoding="utf-8") as file: + modified_cfg = yaml.safe_load(file) or {} + + modified_cfg["base_model"] = cfg.output_dir + modified_cfg["rala_attention"] = True + plugin_class = ( + "axolotl.integrations.rala.RalaPlugin" + ) + if "plugins" in modified_cfg: + modified_cfg["plugins"].append(plugin_class) + else: + modified_cfg["plugins"] = [plugin_class] + + dump_yaml_preserved_order( + data=modified_cfg, + reference_yaml_path=config_path, + output_path=output_config_path, + ) + else: + LOG.info("Not saving converted model to disk") + LOG.info("Pass --output-dir path/to/save to save model") + + if cli_args.debug: + LOG.info( + Fore.GREEN + + "Conversion successful!\n" + + f"Original generation time: {debug_info['orig_time']:.2f}s\n" + + f"Converted generation time: {debug_info['conv_time']:.2f}s" + + Fore.RESET + ) + + if debug_info["orig_text"] == debug_info["conv_text"]: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + "Model generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + debug_info["generations_match"] = True + else: + message = ( + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['conv_text']}\n" + + "*" * 50 + + "\n" + ) + debug_info["generations_match"] = False + + if cli_args.zero_init and not cli_args.sublayer_norm: + LOG.info(Fore.RED + message + Fore.RESET) + debug_info["match_expected"] = True + else: + LOG.info( + Fore.YELLOW + + message + + "However, this is expected since --zero-init" + + " and --no-sublayer-norm were not passed." + + Fore.RESET + ) + debug_info["match_expected"] = False + + return model, debug_info + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(ConvertDiffTransformerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + convert_rala(cfg, cli_args, config) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 00d075286..7dfe2094f 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -261,6 +261,19 @@ def convert_diff_transformer(config: str, **kwargs): do_cli(config=config, **kwargs) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@add_options_from_dataclass(ConvertDiffTransformerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def convert_rala(config: str, **kwargs): + """Convert model attention layers to RALA attention layers.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + from axolotl.cli.integrations.convert_rala import do_cli + + do_cli(config=config, **kwargs) + + @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 0743d4c92..39d3b631f 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -466,6 +466,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): self.args.loraplus_lr_ratio is None and self.args.embedding_lr_scale is None and self.args.embedding_lr is None + and self.args.lr_groups is None and self.args.alternate_optimizer not in [ "optimi_adamw", @@ -481,7 +482,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): if self.optimizer is None: # pylint: disable=access-member-before-definition decay_parameters = self.get_decay_parameter_names(opt_model) params = { - "to_weight_decay": {}, # LayerNorm and bias + "to_weight_decay": {}, # LayerNorm except bias "embeddings": {}, # lm_head, embed_tokens, "no_weight_decay": {}, } diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index a271c59d1..6068fdc34 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -75,6 +75,21 @@ class BasePlugin: None """ + def set_attn_config( + self, cfg, model_kwargs, model_config + ): # pylint: disable=unused-argument + """ + Sets attention configuration for the model. + + Parameters: + cfg (dict): The configuration for the plugin. + model_kwargs (dict): The model kwargs for the plugin. + model_config (object): The model configuration. + + Returns: + None + """ + def post_model_load(self, cfg, model): # pylint: disable=unused-argument """ Performs actions after the model is loaded. @@ -304,6 +319,18 @@ class PluginManager: for plugin in self.plugins.values(): plugin.pre_model_load(cfg) + def set_attn_config(self, cfg, model_kwargs, model_config): + """ + modifies the attention configuration of the model kwargs for loading + + Parameters: + cfg (dict): The configuration for the plugins. + model_kwargs (dict): The model's kwargs for construction the model + model_config (dict): The model's configuration. + """ + for plugin in self.plugins.values(): + plugin.set_attn_config(cfg, model_kwargs, model_config) + def post_model_load(self, cfg, model): """ Calls the post_model_load method of all registered plugins. diff --git a/src/axolotl/integrations/rala/__init__.py b/src/axolotl/integrations/rala/__init__.py new file mode 100644 index 000000000..527f9d910 --- /dev/null +++ b/src/axolotl/integrations/rala/__init__.py @@ -0,0 +1,34 @@ +"""Definition of RALA plugin.""" + +import logging + +from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES + +from axolotl.integrations.base import BasePlugin +from axolotl.integrations.rala.rala_attn import LlamaRALAAttention + +LOG = logging.getLogger(__name__) + + +class RalaPlugin(BasePlugin): + """ + Plugin for Rala integration with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.rala.args.RalaArgs" + + def pre_model_load(self, cfg): + """Apply differential attention patch before model loading if enabled.""" + if cfg.rala_attention: + LLAMA_ATTENTION_CLASSES["rala"] = LlamaRALAAttention + + from axolotl.monkeypatch.attention.differential import ( + patch_llama_attention_classes, + ) + + patch_llama_attention_classes() + + def set_attn_config(self, cfg, model_kwargs, model_config): + if cfg.rala_attention: + model_kwargs["attn_implementation"] = "rala" diff --git a/src/axolotl/integrations/rala/args.py b/src/axolotl/integrations/rala/args.py new file mode 100644 index 000000000..384e68b33 --- /dev/null +++ b/src/axolotl/integrations/rala/args.py @@ -0,0 +1,14 @@ +"""Module for handling RALA input arguments.""" + +import logging +from typing import Optional + +from pydantic import BaseModel + +LOG = logging.getLogger(__name__) + + +class RalaArgs(BaseModel): + """Input args for RALA.""" + + rala_attention: Optional[bool] = None diff --git a/src/axolotl/integrations/rala/convert.py b/src/axolotl/integrations/rala/convert.py new file mode 100644 index 000000000..a9a0bb956 --- /dev/null +++ b/src/axolotl/integrations/rala/convert.py @@ -0,0 +1,88 @@ +""" +conversion for llama models to use RALA attention +""" +import logging + +from torch import nn +from transformers import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaAttention + +from axolotl.integrations.rala import LlamaRALAAttention + +logger = logging.getLogger(__name__) + +ATTENTION_MAPPING = { + LlamaAttention: LlamaRALAAttention, +} + + +def copy_attention_weights( + old_attn, + new_attn, + zero_init: bool = False, +) -> None: + """ + Copy weights from old attention layer to new RALA layer. + Copies q, k, v, o + """ + new_attn.q_proj.weight.data.copy_(old_attn.q_proj.weight.data) + new_attn.k_proj.weight.data.copy_(old_attn.k_proj.weight.data) + new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data) + new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data) + + # Zero out lambda parameters for exact equivalence + if zero_init: + nn.init.zeros_(new_attn.phi.weight) + else: + nn.init.normal_(new_attn.phi) + nn.init.zeros_(new_attn.phi.bias) + + logger.debug( + "Copied positive attention weights from %s to %s", + type(old_attn).__name__, + type(new_attn).__name__, + ) + + +def convert_to_rala( + model: PreTrainedModel, + zero_init: bool = False, +) -> PreTrainedModel: + """Convert a pre-trained model's attention layers to differential attention""" + layer_idx = 0 + + def convert_module(module): + nonlocal layer_idx + + # Iterate through module children, convert any attn layers to diff attn + for name, child in module.named_children(): + if isinstance(child, tuple(ATTENTION_MAPPING.keys())): + # Choose appropriate differential attention class + # pylint: disable=duplicate-code + attention_class = ATTENTION_MAPPING[type(child)] + + layer_type = type(child).__name__ + logger.info( + f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}" + ) + + # Create new diff attn layer + new_attention = attention_class( + config=module.config if hasattr(module, "config") else model.config, + layer_idx=layer_idx, + ) + + # Copy weights from old attention to new attention + new_attention.to(child.q_proj.weight.device) + copy_attention_weights(child, new_attention, zero_init=zero_init) + + # Replace the layer + setattr(module, name, new_attention) + layer_idx += 1 + elif len(list(child.children())) > 0: + convert_module(child) + + convert_module(model) + logger.info(f"Converted {layer_idx} attention layers to RALA attention") + + return model diff --git a/src/axolotl/integrations/rala/rala_attn.py b/src/axolotl/integrations/rala/rala_attn.py new file mode 100644 index 000000000..d73806202 --- /dev/null +++ b/src/axolotl/integrations/rala/rala_attn.py @@ -0,0 +1,278 @@ +from typing import Optional, Tuple + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Cache +from transformers.models.llama.modeling_llama import ( + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + repeat_kv, +) + + +def kappa(x: torch.Tensor) -> torch.Tensor: + """ + The paper uses κ(x) = ELU(x) + 1. + x is assumed to be [batch, n_heads, seq_len, head_dim]. + """ + return F.elu(x) + 1 + + +class LlamaRALAAttention(nn.Module): + """ + LlamaAttention replaced with Rank-Augmented Linear Attention (RALA). + Adapted from the standard LlamaAttention for demonstration. + **Not** a fully drop-in replacement if you need caching/TP. + """ + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + # Same Q, K, V, output projections + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=config.attention_bias + ) + + # We will preserve rope usage + self._init_rope() + + # A simple φ-projection for RALA: + # The paper uses φ(x) as a linear transform or identity. We'll do a linear: + self.phi = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + + def _init_rope(self): + # Standard Llama rope logic + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + """ + RALA forward pass. + This version omits incremental decoding with `past_key_value` for simplicity + (linear attention caching is non-trivial). + """ + bsz, q_len, _ = hidden_states.size() + + # Standard Q, K, V + query_states = self.q_proj(hidden_states) # [b, seq, n_heads*dim] + key_states = self.k_proj(hidden_states) # [b, seq, n_kv_heads*dim] + value_states = self.v_proj(hidden_states) # [b, seq, n_kv_heads*dim] + + # Reshape to [b, n_heads, seq_len, head_dim] + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + # Apply RoPE (rotary embeddings) just as in standard Llama + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # If you still want to handle the repeated KV for multi-group setups: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Now we apply RALA. + + # 1) Apply κ(.) to Q,K: shape [b, n_heads, seq_len, head_dim] + Q_kappa = kappa(query_states) + K_kappa = kappa(key_states) + + # 2) Compute global query Q_g = average of Q_kappa across seq_len => [b, n_heads, head_dim] + # The paper denotes Q_g = (1/N) Σ_i Q_kappa_i + seq_len_float = float(q_len) # for scaling + Q_g = Q_kappa.mean(dim=2) # [b, n_heads, head_dim] + + # 3) Compute alpha_j for each token j in [0..seq_len-1] + # alpha_j = N * softmax( Q_g · K_kappa_j^T ), shape => [b, n_heads, seq_len] + # Dot product over head_dim + # K_kappa is [b, n_heads, seq_len, head_dim], Q_g is [b, n_heads, head_dim] + # We'll do an einsum or transpose to produce logits [b, n_heads, seq_len] + + # Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len] + # logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len] + logits = (Q_g.unsqueeze(2) * K_kappa).sum(dim=-1) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work + + # 4) Incorporate causal or padding mask if provided. + # In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar. + # For RALA, we only do a single softmax over "j" dimension. We can add the mask to logits. + # Caution: This might not replicate strict causal linear attention. It's a best-effort approach. + if attention_mask is not None: + # Usually Llama's causal mask is [b, 1, q_len, kv_len] with 0 or -inf + # We want shape [b, n_heads, seq_len], so we can broadcast accordingly: + # e.g., attention_mask: [b, 1, q_len, seq_len] + # We pick the slice that corresponds to q_len vs. kv_len. + # Typically the last two dims are (q_len, kv_len). We want the kv_len dimension to be `seq_len`. + # We'll do something like: + if attention_mask.dim() == 4: + # attention_mask: [b, 1, q_len, kv_len] + # if q_len == kv_len, we can do attention_mask[:, :, :, :seq_len], then squeeze dims + mask_2d = attention_mask[:, 0, :, :q_len] # [b, q_len, seq_len] + # we only want [b, n_heads, seq_len], so we must broadcast over q_len if needed + # but in this snippet, we do a single alpha_j for each j *per head*, + # ignoring per-token Q_i. So there's a mismatch. + # A simpler approach is to apply the mask for the entire sequence if a token j is invalid for ANY i. + # That is approximate. We'll just pick the first row of q_len, or do min across i dimension... + # For demonstration, let's sum or min across i dimension to see if j is valid for ANY i. + # Or we do a "causal" approach: all tokens j>i get masked. But there's no direct i index here in alpha_j. + # We'll just do a rough approach, e.g. mask = min across the q_len dimension: + mask_1d = torch.min(mask_2d, dim=1)[ + 0 + ] # [b, seq_len], picking the worst mask across query positions + # broadcast for n_heads + mask_1d = mask_1d.unsqueeze(1).expand( + -1, self.num_heads, -1 + ) # [b, n_heads, seq_len] + logits = logits + mask_1d + else: + # Possibly it's [b, seq_len]. Then we just broadcast to [b,n_heads,seq_len]. + mask_1d = attention_mask # [b, seq_len] + mask_1d = mask_1d.unsqueeze(1).expand(-1, self.num_heads, -1) + logits = logits + mask_1d + + alpha = F.softmax(logits, dim=-1) # [b, n_heads, seq_len] + # multiply by seq_len per the formula + alpha = alpha * seq_len_float + + # 5) Construct the outer-sum: Σ_j alpha_j * (K_kappa_j^T V_j) + # The paper shows a d×d matrix formed per head. + # K_kappa: [b, n_heads, seq_len, head_dim], V: [b, n_heads, seq_len, head_dim] + # For each j, do outer product K_kappa_j (d×1) × V_j^T (1×d) => d×d + # Then multiply by alpha_j and sum over j. + # We'll do an einsum for that: [b,n_heads,seq_len,d] outer [b,n_heads,seq_len,d] => [b,n_heads,d,d] + # alpha: [b, n_heads, seq_len]. + value_states_ = value_states # [b, n_heads, seq_len, head_dim] + outer_sum = torch.einsum("bns,bnsd,bnsf->bndf", alpha, K_kappa, value_states_) + + # Explanation: + # - 'bnhs' is alpha (batch, n_heads, seq_len) + # - 'bnhsd' is K_kappa (b,n_heads,seq_len, d) + # - 'bnhsf' is V (b,n_heads,seq_len, d) + # We want [b,n_heads,d,f], which is the d×d matrix per head. + # Actually we need an outer product (K_kappa_j^T × V_j). That is [d, d]. + # The call above is not quite correct if we want K_kappa_j^T × V_j as [d,d]. + # Let's do a simpler approach: + # outer_sum = sum_j alpha_j * (K_kappa_j^T outer V_j). + # = "bnhs,bnhsd,bnhsf -> bnhdf" + # means: alpha has shape (b,n,h,s), K_kappa has shape (b,n,h,s,d), V has shape (b,n,h,s,d) + # We want to produce (b,n,h,d,d). + # So the correct einsum string is 'bnhs,bnhsd,bnhsf->bnhdf': + # alpha indexes b,n,h,s + # K_kappa indexes b,n,h,s,d => K_kappa_j + # V indexes b,n,h,s,f => V_j + # The resulting shape is (b,n,h,d,f). Great. + + # 6) For each token i, Y_i = φ(X_i) ∘ [ κ(Q_i) × outer_sum ] + # Here κ(Q_i) is shape [b,n,h,d], outer_sum is shape [b,n,h,d,d]. + # We'll do a batch matmul: result_attn = Q_kappa_i × outer_sum => [b,n,h,d] + # Then multiply elementwise by φ(X_i). + # But φ(X_i) is a single [b,seq_len,d_model], so we reshape to [b,seq_len,n,h_dim]. + # We'll do per-token i in a loop or broadcast. Let's do it in a single operation with einsum: + + # first, compute φ(X): + # X is the original hidden_states: [b, seq_len, d_model] + X_phi = self.phi(hidden_states) # [b, seq_len, d_model] + X_phi = X_phi.view(bsz, q_len, self.num_heads, self.head_dim) # [b, s, n, d] + X_phi = X_phi.transpose(1, 2) # [b, n, s, d] + + # Now for each i in [0..q_len-1], we do a matrix multiply: + # result_attn_i = Q_kappa_i [b,n,s,d] × outer_sum [b,n,d,d] => we want [b,n,s,d]. + # We'll do: + result_attn = torch.einsum("bnsd,bndf->bnsf", Q_kappa, outer_sum) # [b,n,s,d] + + # Then elementwise multiply by φ(X_i): + context_layer = X_phi * result_attn # [b,n,s,d] + + # Finally, reorder to [b, s, n, d] -> [b, s, n*d] + context_layer = context_layer.transpose(1, 2).contiguous() # [b, s, n, d] + context_layer = context_layer.view(bsz, q_len, self.hidden_size) + + # One last linear projection: + attn_output = self.o_proj(context_layer) + + # Not returning a standard attn_weights. + # If you want to return alpha as "attention," we can do so: + if output_attentions: + # alpha: [b, n_heads, seq_len], but note it's only the "global" weighting of each key, + # not a (q_len x kv_len) map like standard attention. + attn_weights = alpha + else: + attn_weights = None + + # We omit cache / past_key_value returns to keep it simpler. + return attn_output, attn_weights, None diff --git a/src/axolotl/monkeypatch/attention/differential.py b/src/axolotl/monkeypatch/attention/differential.py index 635573a4b..c5e843e24 100644 --- a/src/axolotl/monkeypatch/attention/differential.py +++ b/src/axolotl/monkeypatch/attention/differential.py @@ -32,6 +32,7 @@ def patch_llama_attention_classes(): "differential_eager", "differential_sdpa", "differential_flash_attention_2", + "rala", ] if attn_implementation not in valid_impls: message = ( diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 37cbc0871..427f93473 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -48,6 +48,7 @@ from transformers.integrations.deepspeed import ( ) from axolotl.common.architectures import MOE_ARCH_BLOCK +from axolotl.integrations.base import PluginManager from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -375,8 +376,6 @@ class ModelLoader: def apply_patches(self) -> None: # load any patches from plugins - from axolotl.integrations.base import PluginManager - plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) @@ -757,6 +756,9 @@ class ModelLoader: if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True + plugin_manager = PluginManager.get_instance() + plugin_manager.set_attn_config(self.cfg, self.model_kwargs, self.model_config) + def build_model(self, qlora_fsdp) -> bool: def _configure_zero3_memory_efficient_loading(): """