fixes and cleanup
This commit is contained in:
6
model-out/eval_summary.csv
Normal file
6
model-out/eval_summary.csv
Normal file
@@ -0,0 +1,6 @@
|
||||
metric,training,validation
|
||||
loss,15.633337020874023,15.604033470153809
|
||||
model_preparation_time,0.0058,0.0058
|
||||
runtime,77.8124,8.4643
|
||||
samples_per_second,23.133,23.629
|
||||
steps_per_second,23.133,23.629
|
||||
|
@@ -18,7 +18,6 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tok
|
||||
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
|
||||
LlamaDifferentialConfig,
|
||||
LlamaDifferentialForCausalLM,
|
||||
register_diff_attn,
|
||||
)
|
||||
from axolotl.utils.yaml import dump_yaml_preserved_order
|
||||
|
||||
@@ -51,7 +50,6 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||
|
||||
|
||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
register_diff_attn()
|
||||
debug_info = {}
|
||||
|
||||
# Load model and tokenizer
|
||||
|
||||
@@ -10,5 +10,10 @@ LOG = logging.getLogger(__name__)
|
||||
class DifferentialTransformerPlugin(BasePlugin):
|
||||
"""Plugin for differential transformer integration with Axolotl."""
|
||||
|
||||
def __init__(self):
|
||||
from .modeling_diff_attn import register_diff_attn
|
||||
|
||||
register_diff_attn()
|
||||
|
||||
def get_input_args(self):
|
||||
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"
|
||||
|
||||
@@ -46,27 +46,22 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
|
||||
def __init__(self, config: Any, layer_idx: int):
|
||||
super().__init__()
|
||||
|
||||
self.config = config
|
||||
self._init_config(config, layer_idx)
|
||||
self._init_config(layer_idx)
|
||||
self._init_projections()
|
||||
self._init_differential_params()
|
||||
self._init_normalization(config)
|
||||
self._init_normalization()
|
||||
|
||||
def _init_config(self, config: Any, layer_idx: int):
|
||||
def _init_config(self, layer_idx: int):
|
||||
"""Initialize configuration parameters."""
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = config.hidden_size
|
||||
self.head_dim = config.hidden_size // config.num_attention_heads
|
||||
self.base_num_heads = config.num_attention_heads
|
||||
self.base_num_kv_heads = config.num_key_value_heads
|
||||
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
|
||||
self.base_num_heads = self.config.num_attention_heads
|
||||
self.base_num_kv_heads = self.config.num_key_value_heads
|
||||
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
|
||||
self.layer_idx = layer_idx
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
self.split_heads = config.split_heads
|
||||
|
||||
if config.split_heads:
|
||||
if self.config.split_heads:
|
||||
# Split heads mode - single projections
|
||||
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
|
||||
# implementation, which asserts `self.base_num_heads` is even
|
||||
@@ -81,31 +76,29 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
|
||||
def _init_projections(self):
|
||||
"""Initialize Q, K, V projections."""
|
||||
if self.split_heads:
|
||||
if self.config.split_heads:
|
||||
# Split heads mode - single projections
|
||||
q_out_dim = self.hidden_size
|
||||
k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads
|
||||
q_out_dim = self.config.hidden_size
|
||||
k_out_dim = self.head_dim * self.base_num_kv_heads
|
||||
else:
|
||||
# Double projection mode
|
||||
q_out_dim = self.hidden_size * 2
|
||||
k_out_dim = (
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
|
||||
)
|
||||
q_out_dim = self.config.hidden_size * 2
|
||||
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
|
||||
|
||||
self.q_proj = nn.Linear(
|
||||
self.hidden_size, q_out_dim, bias=self.config.attention_bias
|
||||
self.config.hidden_size, q_out_dim, bias=self.config.attention_bias
|
||||
)
|
||||
self.k_proj = nn.Linear(
|
||||
self.hidden_size, k_out_dim, bias=self.config.attention_bias
|
||||
self.config.hidden_size, k_out_dim, bias=self.config.attention_bias
|
||||
)
|
||||
self.v_proj = nn.Linear(
|
||||
self.hidden_size,
|
||||
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
|
||||
self.config.hidden_size,
|
||||
self.head_dim * self.base_num_kv_heads,
|
||||
bias=self.config.attention_bias,
|
||||
)
|
||||
self.o_proj = nn.Linear(
|
||||
self.base_num_heads * self.head_dim,
|
||||
self.hidden_size,
|
||||
self.config.hidden_size,
|
||||
bias=self.config.attention_bias,
|
||||
)
|
||||
|
||||
@@ -129,11 +122,11 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
)
|
||||
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
|
||||
|
||||
def _init_normalization(self, config):
|
||||
def _init_normalization(self):
|
||||
"""Initialize normalization layers."""
|
||||
sublayer_norm = getattr(config, "sublayer_norm", True)
|
||||
sublayer_norm = getattr(self.config, "sublayer_norm", True)
|
||||
if sublayer_norm:
|
||||
self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps)
|
||||
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
|
||||
else:
|
||||
self.subln = nn.Identity()
|
||||
|
||||
@@ -148,7 +141,6 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
q1, q2 = q.chunk(2, dim=-1)
|
||||
k1, k2 = k.chunk(2, dim=-1)
|
||||
|
||||
# Reshape
|
||||
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
@@ -161,9 +153,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose(
|
||||
1, 2
|
||||
)
|
||||
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
|
||||
|
||||
return q1, q2, k1, k2, v
|
||||
|
||||
@@ -198,6 +188,8 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
k1 = repeat_kv(k1, self.num_key_value_groups)
|
||||
k2 = repeat_kv(k2, self.num_key_value_groups)
|
||||
v = repeat_kv(v, self.num_key_value_groups)
|
||||
if self.config.split_heads:
|
||||
v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1)
|
||||
|
||||
return k1, k2, v
|
||||
|
||||
@@ -215,7 +207,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
|
||||
"""Process and project attention output."""
|
||||
attn = self.subln(attn)
|
||||
attn = attn * (1 - self.lambda_init)
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
|
||||
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
|
||||
return self.o_proj(attn)
|
||||
|
||||
|
||||
@@ -255,7 +247,7 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
|
||||
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
|
||||
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
|
||||
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
dropout_p = self.config.attention_dropout if self.training else 0.0
|
||||
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
|
||||
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
|
||||
|
||||
@@ -318,7 +310,7 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
|
||||
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
|
||||
)
|
||||
is_causal = attention_mask is None and q_len > 1
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
dropout_p = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
if q1.device.type == "cuda" and causal_mask is not None:
|
||||
q1, q2 = q1.contiguous(), q2.contiguous()
|
||||
@@ -396,9 +388,9 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
|
||||
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
|
||||
v = v.transpose(1, 2)
|
||||
|
||||
dropout_p = self.attention_dropout if self.training else 0.0
|
||||
dropout_p = self.config.attention_dropout if self.training else 0.0
|
||||
|
||||
if self.split_heads:
|
||||
if self.config.split_heads:
|
||||
v1, v2 = v.chunk(2, dim=-1)
|
||||
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
|
||||
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)
|
||||
|
||||
@@ -6,15 +6,10 @@ from typing import Optional, Union
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
||||
from transformers.models.llama.configuration_llama import LlamaConfig
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaForCausalLM,
|
||||
LlamaModel,
|
||||
LlamaPreTrainedModel,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
|
||||
|
||||
from .diff_attn import (
|
||||
LlamaDifferentialAttention,
|
||||
LlamaDifferentialAttentionBase,
|
||||
LlamaDifferentialFlashAttention2,
|
||||
LlamaDifferentialSdpaAttention,
|
||||
)
|
||||
@@ -46,17 +41,6 @@ class LlamaDifferentialConfig(LlamaConfig):
|
||||
}
|
||||
|
||||
|
||||
class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
|
||||
"""Base class for differential LLaMA models."""
|
||||
|
||||
config_class = LlamaDifferentialConfig
|
||||
base_model_prefix = "llama_differential"
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (LlamaDifferentialAttentionBase, LlamaModel)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
|
||||
class LlamaDifferentialModel(LlamaModel):
|
||||
"""LlamaModel with differential attention."""
|
||||
|
||||
@@ -222,6 +206,37 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
|
||||
super().__init__(config)
|
||||
self.model = LlamaDifferentialModel(config)
|
||||
|
||||
# pylint: disable=protected-access
|
||||
@classmethod
|
||||
def _autoset_attn_implementation(
|
||||
cls, config, **kwargs
|
||||
): # pylint: disable=unused-argument
|
||||
config._attn_implementation_autoset = True
|
||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||
|
||||
# Map standard types to differential types if mapping exists
|
||||
if attn_implementation in config._attn_implementations:
|
||||
config._attn_implementation = config._attn_implementations[
|
||||
attn_implementation
|
||||
]
|
||||
return config
|
||||
|
||||
# If no mapping, validate it's a valid differential type
|
||||
valid_impls = [
|
||||
None,
|
||||
"differential_eager",
|
||||
"differential_sdpa",
|
||||
"differential_flash_attention_2",
|
||||
]
|
||||
if attn_implementation not in valid_impls:
|
||||
message = (
|
||||
f"Specified `attn_implementation={attn_implementation}` is not supported. "
|
||||
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
|
||||
)
|
||||
raise ValueError(message)
|
||||
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_llama(
|
||||
cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None
|
||||
@@ -257,3 +272,11 @@ def register_diff_attn():
|
||||
# Register models
|
||||
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
|
||||
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
|
||||
|
||||
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
|
||||
|
||||
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
|
||||
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
|
||||
LLAMA_ATTENTION_CLASSES[
|
||||
"differential_flash_attention_2"
|
||||
] = LlamaDifferentialFlashAttention2
|
||||
|
||||
@@ -130,6 +130,9 @@ def test_conversion_cli_repoduce_attentions(
|
||||
)
|
||||
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
|
||||
output_dir = tmp_path / "converted"
|
||||
|
||||
# Smallest model with an even number of attention heads
|
||||
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
|
||||
base_config["output_dir"] = str(output_dir)
|
||||
base_config[attention] = True
|
||||
|
||||
|
||||
Reference in New Issue
Block a user