adding split_heads argument for retaining original (Q, K) dimensionanlity

This commit is contained in:
Dan Saunders
2024-12-18 05:56:29 +00:00
parent 505321ac95
commit 66176b3e07
5 changed files with 171 additions and 48 deletions

View File

@@ -14,9 +14,7 @@ from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.differential_transformer.convert import (
convert_to_differential_attention,
)
from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn
LOG = logging.getLogger(__name__)
@@ -78,11 +76,19 @@ def convert_differential_transformer(cfg, cli_args, config_path):
# Convert attention
LOG.info("Converting to differential attention...")
if cli_args.split_heads and cli_args.zero_init:
LOG.warning(
Fore.YELLOW
+ "Warning: Using split_heads with zero_init is not recommended; "
+ "split_heads will preclude the effects of zero_init"
+ Fore.RESET
)
try:
model = convert_to_differential_attention(
model = convert_to_diff_attn(
model=model,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
)
model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc:

View File

@@ -63,6 +63,7 @@ class ConvertDiffTransformerCliArgs:
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)
def load_model_and_tokenizer(

View File

@@ -80,14 +80,18 @@ def copy_attention_weights(
)
def convert_to_differential_attention(
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
def convert_to_diff_attn(
model: PreTrainedModel,
zero_init: bool = False,
sublayer_norm: bool = True,
split_heads: bool = True,
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0
# Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm
model.config.split_heads = split_heads
def convert_module(module):
nonlocal layer_idx
@@ -111,7 +115,8 @@ def convert_to_differential_attention(
# Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device)
copy_attention_weights(child, new_attention, zero_init=zero_init)
if not split_heads:
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer
setattr(module, name, new_attention)

View File

@@ -70,26 +70,51 @@ class LlamaDifferentialAttention(nn.Module):
self.hidden_size = config.hidden_size
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
if config.split_heads:
self.head_dim = config.hidden_size // config.num_attention_heads // 2
else:
self.head_dim = config.hidden_size // config.num_attention_heads
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.split_heads = config.split_heads
# For Q1 and Q2
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size * 2,
bias=False,
)
if config.split_heads:
# Split heads mode
assert (
self.base_num_heads % 2 == 0
), "Number of heads must be even for splitting"
self.heads_per_component = self.base_num_heads // 2
# For K1 and K2
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
bias=False,
)
# Single projections
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size,
bias=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
)
else:
# Double projection mode
self.heads_per_component = self.base_num_heads
# Double-sized projections
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size * 2,
bias=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
bias=False,
)
# Single V projection
self.v_proj = nn.Linear(
@@ -125,8 +150,14 @@ class LlamaDifferentialAttention(nn.Module):
self.rotary_emb = LlamaRotaryEmbedding(config=config)
sublayer_norm = getattr(config, "sublayer_norm", True)
if self.split_heads:
subln_dim = 2 * self.head_dim
else:
subln_dim = self.head_dim
self.subln = (
LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5)
LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5)
if sublayer_norm
else nn.Identity()
)
@@ -167,7 +198,10 @@ class LlamaDifferentialAttention(nn.Module):
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
@@ -177,6 +211,10 @@ class LlamaDifferentialAttention(nn.Module):
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
@@ -192,8 +230,6 @@ class LlamaDifferentialAttention(nn.Module):
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
# Calculate attention scores for both parts
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
@@ -307,13 +343,18 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
@@ -323,6 +364,10 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
@@ -468,13 +513,18 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2)
q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings
if position_embeddings is None:
@@ -484,6 +534,10 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
else:
cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
@@ -506,20 +560,54 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
# Calculate attention using Flash Attention
dropout_p = self.attention_dropout if self.training else 0.0
attn1 = flash_attn_func(
q1,
k1,
v,
dropout_p=dropout_p,
causal=True,
)
attn2 = flash_attn_func(
q2,
k2,
v,
dropout_p=dropout_p,
causal=True,
)
if self.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(
q1,
k1,
v1,
dropout_p=dropout_p,
causal=True,
)
attn12 = flash_attn_func(
q1,
k1,
v2,
dropout_p=dropout_p,
causal=True,
)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(
q2,
k2,
v1,
dropout_p=dropout_p,
causal=True,
)
attn22 = flash_attn_func(
q2,
k2,
v2,
dropout_p=dropout_p,
causal=True,
)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(
q1,
k1,
v,
dropout_p=dropout_p,
causal=True,
)
attn2 = flash_attn_func(
q2,
k2,
v,
dropout_p=dropout_p,
causal=True,
)
attn1 = attn1.transpose(1, 2)
attn2 = attn2.transpose(1, 2)

View File

@@ -106,3 +106,26 @@ def test_conversion_cli_repoduce_attentions(
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()