adding split_heads argument for retaining original (Q, K) dimensionanlity

This commit is contained in:
Dan Saunders
2024-12-18 05:56:29 +00:00
parent 505321ac95
commit 66176b3e07
5 changed files with 171 additions and 48 deletions

View File

@@ -14,9 +14,7 @@ from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.differential_transformer.convert import ( from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn
convert_to_differential_attention,
)
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -78,11 +76,19 @@ def convert_differential_transformer(cfg, cli_args, config_path):
# Convert attention # Convert attention
LOG.info("Converting to differential attention...") LOG.info("Converting to differential attention...")
if cli_args.split_heads and cli_args.zero_init:
LOG.warning(
Fore.YELLOW
+ "Warning: Using split_heads with zero_init is not recommended; "
+ "split_heads will preclude the effects of zero_init"
+ Fore.RESET
)
try: try:
model = convert_to_differential_attention( model = convert_to_diff_attn(
model=model, model=model,
zero_init=cli_args.zero_init, zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm, sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
) )
model.to(cfg.device, dtype=cfg.torch_dtype) model.to(cfg.device, dtype=cfg.torch_dtype)
except Exception as exc: except Exception as exc:

View File

@@ -63,6 +63,7 @@ class ConvertDiffTransformerCliArgs:
debug: bool = field(default=False) debug: bool = field(default=False)
zero_init: bool = field(default=False) zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True) sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)
def load_model_and_tokenizer( def load_model_and_tokenizer(

View File

@@ -80,14 +80,18 @@ def copy_attention_weights(
) )
def convert_to_differential_attention( def convert_to_diff_attn(
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True model: PreTrainedModel,
zero_init: bool = False,
sublayer_norm: bool = True,
split_heads: bool = True,
) -> PreTrainedModel: ) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention""" """Convert a pre-trained model's attention layers to differential attention"""
layer_idx = 0 layer_idx = 0
# Set sublayer norm as config on the model. # Set sublayer norm as config on the model.
model.config.sublayer_norm = sublayer_norm model.config.sublayer_norm = sublayer_norm
model.config.split_heads = split_heads
def convert_module(module): def convert_module(module):
nonlocal layer_idx nonlocal layer_idx
@@ -111,7 +115,8 @@ def convert_to_differential_attention(
# Copy weights from old attention to new attention # Copy weights from old attention to new attention
new_attention.to(child.q_proj.weight.device) new_attention.to(child.q_proj.weight.device)
copy_attention_weights(child, new_attention, zero_init=zero_init) if not split_heads:
copy_attention_weights(child, new_attention, zero_init=zero_init)
# Replace the layer # Replace the layer
setattr(module, name, new_attention) setattr(module, name, new_attention)

View File

@@ -70,26 +70,51 @@ class LlamaDifferentialAttention(nn.Module):
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.base_num_heads = config.num_attention_heads self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads self.base_num_kv_heads = config.num_key_value_heads
self.head_dim = config.hidden_size // config.num_attention_heads
if config.split_heads:
self.head_dim = config.hidden_size // config.num_attention_heads // 2
else:
self.head_dim = config.hidden_size // config.num_attention_heads
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.is_causal = True self.is_causal = True
self.split_heads = config.split_heads
# For Q1 and Q2 if config.split_heads:
self.q_proj = nn.Linear( # Split heads mode
self.hidden_size, assert (
self.hidden_size * 2, self.base_num_heads % 2 == 0
bias=False, ), "Number of heads must be even for splitting"
) self.heads_per_component = self.base_num_heads // 2
# For K1 and K2 # Single projections
self.k_proj = nn.Linear( self.q_proj = nn.Linear(
self.hidden_size, self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, self.hidden_size,
bias=False, bias=False,
) )
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
bias=False,
)
else:
# Double projection mode
self.heads_per_component = self.base_num_heads
# Double-sized projections
self.q_proj = nn.Linear(
self.hidden_size,
self.hidden_size * 2,
bias=False,
)
self.k_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2,
bias=False,
)
# Single V projection # Single V projection
self.v_proj = nn.Linear( self.v_proj = nn.Linear(
@@ -125,8 +150,14 @@ class LlamaDifferentialAttention(nn.Module):
self.rotary_emb = LlamaRotaryEmbedding(config=config) self.rotary_emb = LlamaRotaryEmbedding(config=config)
sublayer_norm = getattr(config, "sublayer_norm", True) sublayer_norm = getattr(config, "sublayer_norm", True)
if self.split_heads:
subln_dim = 2 * self.head_dim
else:
subln_dim = self.head_dim
self.subln = ( self.subln = (
LlamaRMSNorm(hidden_size=self.head_dim, eps=1e-5) LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5)
if sublayer_norm if sublayer_norm
else nn.Identity() else nn.Identity()
) )
@@ -167,7 +198,10 @@ class LlamaDifferentialAttention(nn.Module):
k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V # Reshape V
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings # Apply rotary embeddings
if position_embeddings is None: if position_embeddings is None:
@@ -177,6 +211,10 @@ class LlamaDifferentialAttention(nn.Module):
else: else:
cos, sin = position_embeddings cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
@@ -192,8 +230,6 @@ class LlamaDifferentialAttention(nn.Module):
v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads)
# Calculate attention scores for both parts # Calculate attention scores for both parts
# NOTE(Dan): the Differential Transformers paper scales by a constant scaling factor
# instead of sqrt(head_dim). This could be set on the class as `self.scaling`.
attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim)
attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim)
@@ -307,13 +343,18 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
k1, k2 = kp.chunk(2, dim=-1) k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention # Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention # Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V # Reshape V
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings # Apply rotary embeddings
if position_embeddings is None: if position_embeddings is None:
@@ -323,6 +364,10 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention):
else: else:
cos, sin = position_embeddings cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
@@ -468,13 +513,18 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
k1, k2 = kp.chunk(2, dim=-1) k1, k2 = kp.chunk(2, dim=-1)
# Reshape Q1,Q2 for attention # Reshape Q1,Q2 for attention
q1 = q1.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
q2 = q2.view(bsz, q_len, self.base_num_heads, self.head_dim).transpose(1, 2) q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape K1,K2 for attention # Reshape K1,K2 for attention
k1 = k1.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
k2 = k2.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Reshape V # Reshape V
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2) if self.split_heads:
v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2)
else:
v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
# Apply rotary embeddings # Apply rotary embeddings
if position_embeddings is None: if position_embeddings is None:
@@ -484,6 +534,10 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
else: else:
cos, sin = position_embeddings cos, sin = position_embeddings
if self.split_heads:
cos, _ = cos.chunk(2, dim=2)
sin, _ = sin.chunk(2, dim=2)
q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin)
q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin)
@@ -506,20 +560,54 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention):
# Calculate attention using Flash Attention # Calculate attention using Flash Attention
dropout_p = self.attention_dropout if self.training else 0.0 dropout_p = self.attention_dropout if self.training else 0.0
attn1 = flash_attn_func( if self.split_heads:
q1, v1, v2 = v.chunk(2, dim=-1)
k1, attn11 = flash_attn_func(
v, q1,
dropout_p=dropout_p, k1,
causal=True, v1,
) dropout_p=dropout_p,
attn2 = flash_attn_func( causal=True,
q2, )
k2, attn12 = flash_attn_func(
v, q1,
dropout_p=dropout_p, k1,
causal=True, v2,
) dropout_p=dropout_p,
causal=True,
)
attn1 = torch.cat([attn11, attn12], dim=-1)
attn21 = flash_attn_func(
q2,
k2,
v1,
dropout_p=dropout_p,
causal=True,
)
attn22 = flash_attn_func(
q2,
k2,
v2,
dropout_p=dropout_p,
causal=True,
)
attn2 = torch.cat([attn21, attn22], dim=-1)
else:
attn1 = flash_attn_func(
q1,
k1,
v,
dropout_p=dropout_p,
causal=True,
)
attn2 = flash_attn_func(
q2,
k2,
v,
dropout_p=dropout_p,
causal=True,
)
attn1 = attn1.transpose(1, 2) attn1 = attn1.transpose(1, 2)
attn2 = attn2.transpose(1, 2) attn2 = attn2.transpose(1, 2)

View File

@@ -106,3 +106,26 @@ def test_conversion_cli_repoduce_attentions(
assert (output_dir / "model.safetensors").exists() assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists() assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists() assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is False
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()