From 145664d82ccdc5b3c4fd31136225f8dfae52a935 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Wed, 15 Jan 2025 21:27:12 -0500 Subject: [PATCH] more fixups --- src/axolotl/cli/integrations/convert_rala.py | 197 ++++++ src/axolotl/cli/main.py | 13 + src/axolotl/core/trainer_builder.py | 2 +- src/axolotl/integrations/base.py | 24 + src/axolotl/integrations/rala/__init__.py | 24 + src/axolotl/integrations/rala/args.py | 14 + .../integrations/rala/auto/__init__.py | 0 .../integrations/rala/auto/llama/__init__.py | 0 .../rala/auto/llama/configuration_rala.py | 12 + .../rala/auto/llama/modeling_rala.py | 596 ++++++++++++++++++ src/axolotl/integrations/rala/convert.py | 104 +++ src/axolotl/utils/models.py | 6 +- 12 files changed, 989 insertions(+), 3 deletions(-) create mode 100644 src/axolotl/cli/integrations/convert_rala.py create mode 100644 src/axolotl/integrations/rala/__init__.py create mode 100644 src/axolotl/integrations/rala/args.py create mode 100644 src/axolotl/integrations/rala/auto/__init__.py create mode 100644 src/axolotl/integrations/rala/auto/llama/__init__.py create mode 100644 src/axolotl/integrations/rala/auto/llama/configuration_rala.py create mode 100644 src/axolotl/integrations/rala/auto/llama/modeling_rala.py create mode 100644 src/axolotl/integrations/rala/convert.py diff --git a/src/axolotl/cli/integrations/convert_rala.py b/src/axolotl/cli/integrations/convert_rala.py new file mode 100644 index 000000000..b2d7fa1d3 --- /dev/null +++ b/src/axolotl/cli/integrations/convert_rala.py @@ -0,0 +1,197 @@ +"""CLI to convert a transformers model's attns to rala attns.""" +import logging +import warnings +from pathlib import Path +from time import time +from typing import Union + +import fire +import torch +import yaml +from colorama import Fore +from dotenv import load_dotenv +from transformers import HfArgumentParser + +from axolotl.cli import load_cfg, print_axolotl_text_art +from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer +from axolotl.integrations.rala.convert import convert_to_rala +from axolotl.utils.yaml import dump_yaml_preserved_order + +LOG = logging.getLogger(__name__) + + +def test_inference(model, tokenizer, prompt="The quick brown fox"): + """Run test inference and return generation time""" + try: + inputs = tokenizer(prompt, return_tensors="pt") + inputs = { + k: v.to(device=model.device, dtype=torch.long) for k, v in inputs.items() + } + + start = time() + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=20, + num_beams=1, + do_sample=False, + pad_token_id=tokenizer.pad_token_id, + use_cache=False, + ) + elapsed = time() - start + + generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) + LOG.info("Prompt: %s", prompt) + LOG.info("Generated: %s", generated_text) + LOG.info("Generation time: %.2fs", elapsed) + + return elapsed, generated_text + + except Exception as exc: + LOG.error("Inference failed: %s", str(exc)) + raise + + +def convert_rala(cfg, cli_args, config_path): + debug_info = {} + + # Load model and tokenizer + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + model, tokenizer = load_model_and_tokenizer(cfg=cfg, cli_args=cli_args) + model.to(cfg.device, dtype=cfg.torch_dtype) + + # Log original model info + LOG.info( + "Original model config:\n\t- Hidden size: %d\n\t- Num attention heads: %d", + model.config.hidden_size, + model.config.num_attention_heads, + ) + + # Test original model + if cli_args.debug: + LOG.info("attention layers to RALA attention") + debug_info["orig_time"], debug_info["orig_text"] = test_inference( + model, tokenizer + ) + + # Convert attention + try: + model = convert_to_rala( + model=model, + zero_init=cli_args.zero_init, + ) + model.to(cfg.device, dtype=cfg.torch_dtype) + except Exception as exc: + LOG.error(Fore.RED + "Conversion failed: %s" + Fore.RESET, str(exc)) + raise + + # Test converted model + if cli_args.debug: + LOG.info("Testing converted model...") + debug_info["conv_time"], debug_info["conv_text"] = test_inference( + model, tokenizer + ) + + # Save if requested + if cfg.output_dir: + # Save model and tokenizer + LOG.info("Saving converted model to %s", cfg.output_dir) + model.save_pretrained(cfg.output_dir) + tokenizer.save_pretrained(cfg.output_dir) + + # Modify config to reflect new path / differential attention + output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" + LOG.info("Saving updated config to %s", output_config_path) + + with open(config_path, "r", encoding="utf-8") as file: + modified_cfg = yaml.safe_load(file) or {} + + modified_cfg["base_model"] = cfg.output_dir + modified_cfg["rala_attention"] = True + plugin_class = "axolotl.integrations.rala.RalaPlugin" + if "plugins" in modified_cfg: + modified_cfg["plugins"].append(plugin_class) + else: + modified_cfg["plugins"] = [plugin_class] + + dump_yaml_preserved_order( + data=modified_cfg, + reference_yaml_path=config_path, + output_path=output_config_path, + ) + else: + LOG.info("Not saving converted model to disk") + LOG.info("Pass --output-dir path/to/save to save model") + + if cli_args.debug: + LOG.info( + Fore.GREEN + + "Conversion successful!\n" + + f"Original generation time: {debug_info['orig_time']:.2f}s\n" + + f"Converted generation time: {debug_info['conv_time']:.2f}s" + + Fore.RESET + ) + + if debug_info["orig_text"] == debug_info["conv_text"]: + LOG.info( + Fore.GREEN + + "Generations match!\n" + + "Model generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + Fore.RESET + ) + debug_info["generations_match"] = True + else: + message = ( + "Generations do not match.\n" + + "Original generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['orig_text']}\n" + + "*" * 50 + + "\n" + + "Converted generation:\n" + + "*" * 50 + + "\n" + + f"{debug_info['conv_text']}\n" + + "*" * 50 + + "\n" + ) + debug_info["generations_match"] = False + + if cli_args.zero_init and not cli_args.sublayer_norm: + LOG.info(Fore.RED + message + Fore.RESET) + debug_info["match_expected"] = True + else: + LOG.info( + Fore.YELLOW + + message + + "However, this is expected since --zero-init" + + " and --no-sublayer-norm were not passed." + + Fore.RESET + ) + debug_info["match_expected"] = False + + return model, debug_info + + +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + if cfg.rala_attention: + cfg.rala_attention = False + parser = HfArgumentParser(ConvertDiffTransformerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + convert_rala(cfg, cli_args, config) + + +if __name__ == "__main__": + load_dotenv() + fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index d9d3a2135..bc2af94a1 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -261,6 +261,19 @@ def convert_diff_transformer(config: str, **kwargs): do_cli(config=config, **kwargs) +@cli.command() +@click.argument("config", type=click.Path(exists=True, path_type=str)) +@add_options_from_dataclass(ConvertDiffTransformerCliArgs) +@add_options_from_config(AxolotlInputConfig) +def convert_rala(config: str, **kwargs): + """Convert model attention layers to RALA attention layers.""" + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + from axolotl.cli.integrations.convert_rala import do_cli + + do_cli(config=config, **kwargs) + + @cli.command() @click.argument("directory", type=click.Choice(["examples", "deepspeed_configs"])) @click.option("--dest", help="Destination directory") diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 3c6a7026a..b52dc73a3 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -481,7 +481,7 @@ class AxolotlTrainer(SchedulerMixin, Trainer): if self.optimizer is None: # pylint: disable=access-member-before-definition decay_parameters = self.get_decay_parameter_names(opt_model) params = { - "to_weight_decay": {}, # LayerNorm and bias + "to_weight_decay": {}, # LayerNorm except bias "embeddings": {}, # lm_head, embed_tokens, "no_weight_decay": {}, } diff --git a/src/axolotl/integrations/base.py b/src/axolotl/integrations/base.py index a271c59d1..6c422c0b8 100644 --- a/src/axolotl/integrations/base.py +++ b/src/axolotl/integrations/base.py @@ -75,6 +75,19 @@ class BasePlugin: None """ + def set_attn_config( + self, cfg, model_kwargs, model_config + ): # pylint: disable=unused-argument + """ + Sets attention configuration for the model. + Parameters: + cfg (dict): The configuration for the plugin. + model_kwargs (dict): The model kwargs for the plugin. + model_config (object): The model configuration. + Returns: + None + """ + def post_model_load(self, cfg, model): # pylint: disable=unused-argument """ Performs actions after the model is loaded. @@ -304,6 +317,17 @@ class PluginManager: for plugin in self.plugins.values(): plugin.pre_model_load(cfg) + def set_attn_config(self, cfg, model_kwargs, model_config): + """ + modifies the attention configuration of the model kwargs for loading + Parameters: + cfg (dict): The configuration for the plugins. + model_kwargs (dict): The model's kwargs for construction the model + model_config (dict): The model's configuration. + """ + for plugin in self.plugins.values(): + plugin.set_attn_config(cfg, model_kwargs, model_config) + def post_model_load(self, cfg, model): """ Calls the post_model_load method of all registered plugins. diff --git a/src/axolotl/integrations/rala/__init__.py b/src/axolotl/integrations/rala/__init__.py new file mode 100644 index 000000000..fddc3e0bf --- /dev/null +++ b/src/axolotl/integrations/rala/__init__.py @@ -0,0 +1,24 @@ +"""Definition of RALA plugin.""" + +import logging + +# from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES + +from axolotl.integrations.base import BasePlugin +from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRALAAttention + +LOG = logging.getLogger(__name__) + + +class RalaPlugin(BasePlugin): + """ + Plugin for Rala integration with Axolotl. + """ + + def get_input_args(self): + return "axolotl.integrations.rala.args.RalaArgs" + + def set_attn_config(self, cfg, model_kwargs, model_config): + # if cfg.rala_attention: + # model_kwargs["attn_implementation"] = "rala" + ... diff --git a/src/axolotl/integrations/rala/args.py b/src/axolotl/integrations/rala/args.py new file mode 100644 index 000000000..384e68b33 --- /dev/null +++ b/src/axolotl/integrations/rala/args.py @@ -0,0 +1,14 @@ +"""Module for handling RALA input arguments.""" + +import logging +from typing import Optional + +from pydantic import BaseModel + +LOG = logging.getLogger(__name__) + + +class RalaArgs(BaseModel): + """Input args for RALA.""" + + rala_attention: Optional[bool] = None diff --git a/src/axolotl/integrations/rala/auto/__init__.py b/src/axolotl/integrations/rala/auto/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/rala/auto/llama/__init__.py b/src/axolotl/integrations/rala/auto/llama/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/axolotl/integrations/rala/auto/llama/configuration_rala.py b/src/axolotl/integrations/rala/auto/llama/configuration_rala.py new file mode 100644 index 000000000..04d485a15 --- /dev/null +++ b/src/axolotl/integrations/rala/auto/llama/configuration_rala.py @@ -0,0 +1,12 @@ +""" +Rala config class +""" +from transformers import LlamaConfig + + +class LlamaRalaConfig(LlamaConfig): + """ + Configuration for LlamaRala model + """ + + softmax_every: int = 6 # every N-th layer applies softmax diff --git a/src/axolotl/integrations/rala/auto/llama/modeling_rala.py b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py new file mode 100644 index 000000000..b951f47c4 --- /dev/null +++ b/src/axolotl/integrations/rala/auto/llama/modeling_rala.py @@ -0,0 +1,596 @@ +# Copyright 2024-2025 Axolotl AI. All rights reserved. +# +# This software may be used and distributed according to +# the terms of the Apache License 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. + +""" +Custom modeling code for RALA Llama +""" + +from typing import List, Optional, Tuple, Union, Unpack + +import torch +import torch.nn.functional as F +from torch import nn +from transformers import Cache, GenerationMixin, LlamaModel +from transformers.modeling_outputs import CausalLMOutputWithPast +from transformers.models.llama.modeling_llama import ( + KwargsForCausalLM, + LlamaAttention, + LlamaDynamicNTKScalingRotaryEmbedding, + LlamaLinearScalingRotaryEmbedding, + LlamaMLP, + LlamaPreTrainedModel, + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, + repeat_kv, +) + +from .configuration_rala import LlamaRalaConfig + + +def kappa(x: torch.Tensor) -> torch.Tensor: # pylint: disable=invalid-name + """ + The paper uses κ(x) = ELU(x) + 1. + x is assumed to be [batch, n_heads, seq_len, head_dim]. + """ + return F.elu(x) + 1 + + +class LlamaRALAAttention(nn.Module): + """ + LlamaAttention replaced with Rank-Augmented Linear Attention (RALA). + Adapted from the standard LlamaAttention for demonstration. + **Not** a fully drop-in replacement if you need caching/TP. + """ + + def __init__(self, config, layer_idx: Optional[int] = None): + super().__init__() + self.config = config + self.layer_idx = layer_idx + + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + + # Same Q, K, V, output projections + self.q_proj = nn.Linear( + self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias + ) + self.k_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.v_proj = nn.Linear( + self.hidden_size, + self.num_key_value_heads * self.head_dim, + bias=config.attention_bias, + ) + self.o_proj = nn.Linear( + self.hidden_size, self.hidden_size, bias=config.attention_bias + ) + + # We will preserve rope usage + self._init_rope() + + # A simple φ-projection for RALA: + # The paper uses φ(x) as a linear transform or identity. We'll do a linear: + self.phi = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def _init_rope(self): + # Standard Llama rope logic + if self.config.rope_scaling is None: + self.rotary_emb = LlamaRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + else: + scaling_type = self.config.rope_scaling["type"] + scaling_factor = self.config.rope_scaling["factor"] + if scaling_type == "linear": + self.rotary_emb = LlamaLinearScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + elif scaling_type == "dynamic": + self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + scaling_factor=scaling_factor, + base=self.rope_theta, + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + """ + RALA forward pass. + This version omits incremental decoding with `past_key_value` for simplicity + (linear attention caching is non-trivial). + """ + bsz, q_len, _ = hidden_states.size() + + # Standard Q, K, V + query_states = self.q_proj(hidden_states) # [b, seq, n_heads*dim] + key_states = self.k_proj(hidden_states) # [b, seq, n_kv_heads*dim] + value_states = self.v_proj(hidden_states) # [b, seq, n_kv_heads*dim] + + # Reshape to [b, n_heads, seq_len, head_dim] + query_states = query_states.view( + bsz, q_len, self.num_heads, self.head_dim + ).transpose(1, 2) + key_states = key_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + value_states = value_states.view( + bsz, q_len, self.num_key_value_heads, self.head_dim + ).transpose(1, 2) + + # Apply RoPE (rotary embeddings) just as in standard Llama + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb( + query_states, key_states, cos, sin + ) + + # 4. If we have a past_key_value (Cache object), let it update / append + if past_key_value is not None: + # This is the normal Llama pattern + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + # The .update() method returns updated (key_states, value_states) + # and typically updates internal buffers. It may also store `layer_idx` data. + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, cache_kwargs + ) + + # If you still want to handle the repeated KV for multi-group setups: + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # Now we apply RALA. + + # 1) Apply κ(.) to Q,K: shape [b, n_heads, seq_len, head_dim] + Q_kappa = kappa(query_states) # pylint: disable=invalid-name + K_kappa = kappa(key_states) # pylint: disable=invalid-name + + # 2) Compute global query Q_g = average of Q_kappa across seq_len => [b, n_heads, head_dim] + # The paper denotes Q_g = (1/N) Σ_i Q_kappa_i + seq_len_float = float(q_len) # for scaling + Q_g = Q_kappa.mean( # pylint: disable=invalid-name + dim=2 + ) # [b, n_heads, head_dim] + + # 3) Compute alpha_j for each token j in [0..seq_len-1] + # alpha_j = N * softmax( Q_g · K_kappa_j^T ), shape => [b, n_heads, seq_len] + # Dot product over head_dim + # K_kappa is [b, n_heads, seq_len, head_dim], Q_g is [b, n_heads, head_dim] + # We'll do an einsum or transpose to produce logits [b, n_heads, seq_len] + + # Dot product across the last dimension (d_head), resulting in shape [b, n_heads, seq_len] + # logits = torch.einsum("bnh, bnsh -> bns", Q_g, K_kappa) # [b, n_heads, seq_len] + logits = (Q_g.unsqueeze(2) * K_kappa).sum( + dim=-1 + ) # -> [b, n_heads, seq_len] # identical to above but torch.compile should work + + # 4) Incorporate causal or padding mask if provided. + # In standard Llama, attention_mask is broadcast as [b, 1, seq_len, seq_len] or similar. + # For RALA, we only do a single softmax over "j" dimension. We can add the mask to logits. + # Caution: This might not replicate strict causal linear attention. It's a best-effort approach. + if attention_mask is not None: + # Usually Llama's causal mask is [b, 1, q_len, kv_len] with 0 or -inf + # We want shape [b, n_heads, seq_len], so we can broadcast accordingly: + # e.g., attention_mask: [b, 1, q_len, seq_len] + # We pick the slice that corresponds to q_len vs. kv_len. + # Typically the last two dims are (q_len, kv_len). We want the kv_len dimension to be `seq_len`. + # We'll do something like: + if attention_mask.dim() == 4: + # attention_mask: [b, 1, q_len, kv_len] + # if q_len == kv_len, we can do attention_mask[:, :, :, :seq_len], then squeeze dims + mask_2d = attention_mask[:, 0, :, :q_len] # [b, q_len, seq_len] + # we only want [b, n_heads, seq_len], so we must broadcast over q_len if needed + # but in this snippet, we do a single alpha_j for each j *per head*, + # ignoring per-token Q_i. So there's a mismatch. + # A simpler approach is to apply the mask for the entire sequence if a token j is invalid for ANY i. + # That is approximate. We'll just pick the first row of q_len, or do min across i dimension... + # For demonstration, let's sum or min across i dimension to see if j is valid for ANY i. + # Or we do a "causal" approach: all tokens j>i get masked. But there's no direct i index here in alpha_j. + # We'll just do a rough approach, e.g. mask = min across the q_len dimension: + mask_1d = torch.min(mask_2d, dim=1)[ + 0 + ] # [b, seq_len], picking the worst mask across query positions + # broadcast for n_heads + mask_1d = mask_1d.unsqueeze(1).expand( + -1, self.num_heads, -1 + ) # [b, n_heads, seq_len] + logits = logits + mask_1d + else: + # Possibly it's [b, seq_len]. Then we just broadcast to [b,n_heads,seq_len]. + mask_1d = attention_mask # [b, seq_len] + mask_1d = mask_1d.unsqueeze(1).expand(-1, self.num_heads, -1) + logits = logits + mask_1d + + alpha = F.softmax(logits, dim=-1) # [b, n_heads, seq_len] + # multiply by seq_len per the formula + alpha = alpha * seq_len_float + + # 5) Construct the outer-sum: Σ_j alpha_j * (K_kappa_j^T V_j) + # The paper shows a d×d matrix formed per head. + # K_kappa: [b, n_heads, seq_len, head_dim], V: [b, n_heads, seq_len, head_dim] + # For each j, do outer product K_kappa_j (d×1) × V_j^T (1×d) => d×d + # Then multiply by alpha_j and sum over j. + # We'll do an einsum for that: [b,n_heads,seq_len,d] outer [b,n_heads,seq_len,d] => [b,n_heads,d,d] + # alpha: [b, n_heads, seq_len]. + value_states_ = value_states # [b, n_heads, seq_len, head_dim] + outer_sum = torch.einsum("bns,bnsd,bnsf->bndf", alpha, K_kappa, value_states_) + + # Explanation: + # - 'bnhs' is alpha (batch, n_heads, seq_len) + # - 'bnhsd' is K_kappa (b,n_heads,seq_len, d) + # - 'bnhsf' is V (b,n_heads,seq_len, d) + # We want [b,n_heads,d,f], which is the d×d matrix per head. + # Actually we need an outer product (K_kappa_j^T × V_j). That is [d, d]. + # The call above is not quite correct if we want K_kappa_j^T × V_j as [d,d]. + # Let's do a simpler approach: + # outer_sum = sum_j alpha_j * (K_kappa_j^T outer V_j). + # = "bnhs,bnhsd,bnhsf -> bnhdf" + # means: alpha has shape (b,n,h,s), K_kappa has shape (b,n,h,s,d), V has shape (b,n,h,s,d) + # We want to produce (b,n,h,d,d). + # So the correct einsum string is 'bnhs,bnhsd,bnhsf->bnhdf': + # alpha indexes b,n,h,s + # K_kappa indexes b,n,h,s,d => K_kappa_j + # V indexes b,n,h,s,f => V_j + # The resulting shape is (b,n,h,d,f). Great. + + # 6) For each token i, Y_i = φ(X_i) ∘ [ κ(Q_i) × outer_sum ] + # Here κ(Q_i) is shape [b,n,h,d], outer_sum is shape [b,n,h,d,d]. + # We'll do a batch matmul: result_attn = Q_kappa_i × outer_sum => [b,n,h,d] + # Then multiply elementwise by φ(X_i). + # But φ(X_i) is a single [b,seq_len,d_model], so we reshape to [b,seq_len,n,h_dim]. + # We'll do per-token i in a loop or broadcast. Let's do it in a single operation with einsum: + + # first, compute φ(X): + # X is the original hidden_states: [b, seq_len, d_model] + X_phi = self.phi( # pylint: disable=invalid-name + hidden_states + ) # [b, seq_len, d_model] + X_phi = X_phi.view( # pylint: disable=invalid-name + bsz, q_len, self.num_heads, self.head_dim + ) # [b, s, n, d] + X_phi = X_phi.transpose(1, 2) # [b, n, s, d] # pylint: disable=invalid-name + + # Now for each i in [0..q_len-1], we do a matrix multiply: + # result_attn_i = Q_kappa_i [b,n,s,d] × outer_sum [b,n,d,d] => we want [b,n,s,d]. + # We'll do: + result_attn = torch.einsum("bnsd,bndf->bnsf", Q_kappa, outer_sum) # [b,n,s,d] + + # Then elementwise multiply by φ(X_i): + context_layer = X_phi * result_attn # [b,n,s,d] + + # Finally, reorder to [b, s, n, d] -> [b, s, n*d] + context_layer = context_layer.transpose(1, 2).contiguous() # [b, s, n, d] + context_layer = context_layer.view(bsz, q_len, self.hidden_size) + + # One last linear projection: + attn_output = self.o_proj(context_layer) + + if output_attentions: + # alpha => [b, n_heads, (past_len + q_len)] + attn_weights = alpha + else: + attn_weights = None + + # Return 3-tuple: (attn_output, attn_weights, past_key_value) + return attn_output, attn_weights, past_key_value + + +class LlamaRalaDecoderLayer(nn.Module): + """ + LlamaDecoderLayer with RALA support + """ + + def __init__(self, config: LlamaRalaConfig, layer_idx: int): + super().__init__() + self.hidden_size = config.hidden_size + + if LlamaRalaConfig.is_layer_idx_softmax( + config.num_hidden_layers, layer_idx, config.softmax_every + ): + self.self_attn = LlamaAttention(config=config, layer_idx=layer_idx) + else: + self.self_attn = LlamaRALAAttention(config=config, layer_idx=layer_idx) + + self.mlp = LlamaMLP(config) + self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = LlamaRMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) + + @classmethod + def is_layer_idx_softmax( + cls, num_hidden_layers: int, layer_idx: int, softmax_every: int + ) -> bool: + inner_layers = num_hidden_layers - 2 + if 1 + softmax_every * (inner_layers // softmax_every) == inner_layers: + softmax_start_idx = 1 + elif 1 + softmax_every * (inner_layers // softmax_every) > inner_layers: + layer_group_size = 1 + softmax_every * ((inner_layers // softmax_every) - 1) + softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2 + elif 1 + softmax_every * (inner_layers // softmax_every) < inner_layers: + layer_group_size = 1 + softmax_every * (inner_layers // softmax_every) + softmax_start_idx = 1 + (inner_layers - layer_group_size) // 2 + + softmax_layers = set(range(softmax_start_idx, num_hidden_layers, softmax_every)) + softmax_layers.add(0) + softmax_layers.add(num_hidden_layers - 1) + + return layer_idx in softmax_layers + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[ + Tuple[torch.Tensor, torch.Tensor] + ] = None, # will become mandatory in v4.46 + **kwargs, + ) -> Tuple[ + torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] + ]: + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): + attention mask of size `(batch_size, sequence_length)` if flash attention is used or `(batch_size, 1, + query_sequence_length, key_sequence_length)` if default attention is used. + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence + position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): + Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, + with `head_dim` being the embedding dimension of each attention head. + kwargs (`dict`, *optional*): + Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code + into the model + """ + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + cache_position=cache_position, + position_embeddings=position_embeddings, + **kwargs, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) # type: ignore + + if use_cache: + outputs += (present_key_value,) # type: ignore + + return outputs # type: ignore + + +class LlamaRalaModel(LlamaModel): + """ + LlamaModel with RALA support + """ + + config_class = LlamaRalaConfig + + def __init__(self, config: LlamaRalaConfig): + LlamaPreTrainedModel.__init__(self, config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding( + config.vocab_size, config.hidden_size, self.padding_idx + ) + + self.layers = nn.ModuleList( + [ + LlamaRalaDecoderLayer(config, layer_idx) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.rotary_emb = LlamaRotaryEmbedding(config=config) + + self.gradient_checkpointing = False + + # Initialize weights and apply final processing + self.post_init() + + +class LlamaRalaForCausalLM(LlamaPreTrainedModel, GenerationMixin): + """ + LlamaForCausalLM with RALA support + """ + + config_class = LlamaRalaConfig + _no_split_modules = ["LlamaRalaDecoderLayer"] + + _tied_weights_keys = ["lm_head.weight"] + _tp_plan = {"lm_head": "colwise_rep"} + + def __init__(self, config): + super().__init__(config) + self.model = LlamaRalaModel(config) + self.vocab_size = config.vocab_size + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, + num_logits_to_keep: int = 0, + **kwargs: Unpack[KwargsForCausalLM], # type: ignore + ) -> Union[Tuple, CausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + num_logits_to_keep (`int`, *optional*): + Calculate logits for the last `num_logits_to_keep` tokens. If `0`, calculate logits for all + `input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that + token can save memory, which becomes pretty significant for long sequences or large vocabulary size. + Returns: + Example: + ```python + >>> from transformers import AutoTokenizer, LlamaForCausalLM + >>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf") + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + output_attentions = ( + output_attentions + if output_attentions is not None + else self.config.output_attentions + ) + output_hidden_states = ( + output_hidden_states + if output_hidden_states is not None + else self.config.output_hidden_states + ) + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + cache_position=cache_position, + **kwargs, + ) + + hidden_states = outputs[0] + # Only compute necessary logits, and do not upcast them to float if we are not computing the loss + logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, + labels=labels, + vocab_size=self.config.vocab_size, + **kwargs, + ) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) diff --git a/src/axolotl/integrations/rala/convert.py b/src/axolotl/integrations/rala/convert.py new file mode 100644 index 000000000..5523ba9d3 --- /dev/null +++ b/src/axolotl/integrations/rala/convert.py @@ -0,0 +1,104 @@ +""" +conversion for llama models to use RALA attention +""" +import logging + +from torch import nn +from transformers import PreTrainedModel +from transformers.models.llama.modeling_llama import LlamaAttention + +from axolotl.integrations.rala import LlamaRALAAttention +from axolotl.integrations.rala.auto.llama.modeling_rala import LlamaRalaDecoderLayer + +logger = logging.getLogger(__name__) + +ATTENTION_MAPPING = { + LlamaAttention: LlamaRALAAttention, +} + + +def copy_attention_weights( + old_attn, + new_attn, + zero_init: bool = False, +) -> None: + """ + Copy weights from old attention layer to new RALA layer. + Copies q, k, v, o + """ + new_attn.q_proj.weight.data.copy_(old_attn.q_proj.weight.data) + new_attn.k_proj.weight.data.copy_(old_attn.k_proj.weight.data) + new_attn.v_proj.weight.data.copy_(old_attn.v_proj.weight.data) + new_attn.o_proj.weight.data.copy_(old_attn.o_proj.weight.data) + + # Zero out lambda parameters for exact equivalence + if zero_init: + nn.init.zeros_(new_attn.phi.weight) + else: + nn.init.normal_(new_attn.phi.weight) + if new_attn.phi.bias: + nn.init.normal_(new_attn.phi.bias) + + logger.debug( + "Copied positive attention weights from %s to %s", + type(old_attn).__name__, + type(new_attn).__name__, + ) + + +def convert_to_rala( + model: PreTrainedModel, zero_init: bool = False, softmax_every_n: int = 6 +) -> PreTrainedModel: + """Convert a pre-trained model's attention layers to differential attention""" + layer_idx = 0 + + def convert_module(module, softmax_every, num_hidden_layers): + nonlocal layer_idx + + # Iterate through module children, convert any attn layers to diff attn + for name, child in module.named_children(): + if isinstance(child, tuple(ATTENTION_MAPPING.keys())): + decoder_layer_idx = child.layer_idx + if LlamaRalaDecoderLayer.is_layer_idx_softmax( + num_hidden_layers, decoder_layer_idx, softmax_every + ): + continue + # Choose appropriate differential attention class + # pylint: disable=duplicate-code + attention_class = ATTENTION_MAPPING[type(child)] + + layer_type = type(child).__name__ + logger.info( + f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}" + ) + + # Create new diff attn layer + new_attention = attention_class( + config=module.config if hasattr(module, "config") else model.config, + layer_idx=layer_idx, + ) + + # Copy weights from old attention to new attention + new_attention.to(child.q_proj.weight.device) + copy_attention_weights(child, new_attention, zero_init=zero_init) + + # Replace the layer + setattr(module, name, new_attention) + layer_idx += 1 + elif len(list(child.children())) > 0: + convert_module(child, softmax_every, num_hidden_layers) + + model.config.softmax_every = softmax_every_n + convert_module(model, softmax_every_n, model.config.num_hidden_layers) + logger.info(f"Converted {layer_idx} attention layers to RALA attention") + + model.config.architectures = [ + "LlamaRalaForCausalLM", + ] + model.config.model_type = "llama_rala" + model.config.auto_map = { + "AutoConfig": "llama.configuration_rala.LlamaRalaConfig", + "AutoModel": "llama.modeling_rala.LlamaRalaModel", + "AutoModelForCausalLM": "llama.modeling_rala.LlamaRalaForCausalLM", + } + return model diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index d16db7613..a3a2bdff2 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -48,6 +48,7 @@ from transformers.integrations.deepspeed import ( ) from axolotl.common.architectures import MOE_ARCH_BLOCK +from axolotl.integrations.base import PluginManager from axolotl.models.mamba import fix_mamba_attn_for_loss from axolotl.monkeypatch.multipack import ( SUPPORTED_MULTIPACK_MODEL_TYPES, @@ -375,8 +376,6 @@ class ModelLoader: def apply_patches(self) -> None: # load any patches from plugins - from axolotl.integrations.base import PluginManager - plugin_manager = PluginManager.get_instance() plugin_manager.pre_model_load(self.cfg) @@ -757,6 +756,9 @@ class ModelLoader: if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True + plugin_manager = PluginManager.get_instance() + plugin_manager.set_attn_config(self.cfg, self.model_kwargs, self.model_config) + def build_model(self, qlora_fsdp) -> bool: def _configure_zero3_memory_efficient_loading(): """