fixes and cleanup

This commit is contained in:
Dan Saunders
2024-12-28 01:10:56 +00:00
parent e5fa842ff8
commit 332ce0ae85
6 changed files with 83 additions and 56 deletions

View File

@@ -0,0 +1,6 @@
metric,training,validation
loss,15.633337020874023,15.604033470153809
model_preparation_time,0.0058,0.0058
runtime,77.8124,8.4643
samples_per_second,23.133,23.629
steps_per_second,23.133,23.629
1 metric training validation
2 loss 15.633337020874023 15.604033470153809
3 model_preparation_time 0.0058 0.0058
4 runtime 77.8124 8.4643
5 samples_per_second 23.133 23.629
6 steps_per_second 23.133 23.629

View File

@@ -18,7 +18,6 @@ from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tok
from axolotl.integrations.diff_transformer.modeling_diff_attn import (
LlamaDifferentialConfig,
LlamaDifferentialForCausalLM,
register_diff_attn,
)
from axolotl.utils.yaml import dump_yaml_preserved_order
@@ -51,7 +50,6 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
def convert_diff_transformer(cfg, cli_args, config_path):
register_diff_attn()
debug_info = {}
# Load model and tokenizer

View File

@@ -10,5 +10,10 @@ LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""Plugin for differential transformer integration with Axolotl."""
def __init__(self):
from .modeling_diff_attn import register_diff_attn
register_diff_attn()
def get_input_args(self):
return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs"

View File

@@ -46,27 +46,22 @@ class LlamaDifferentialAttentionBase(nn.Module):
def __init__(self, config: Any, layer_idx: int):
super().__init__()
self.config = config
self._init_config(config, layer_idx)
self._init_config(layer_idx)
self._init_projections()
self._init_differential_params()
self._init_normalization(config)
self._init_normalization()
def _init_config(self, config: Any, layer_idx: int):
def _init_config(self, layer_idx: int):
"""Initialize configuration parameters."""
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.head_dim = config.hidden_size // config.num_attention_heads
self.base_num_heads = config.num_attention_heads
self.base_num_kv_heads = config.num_key_value_heads
self.head_dim = self.config.hidden_size // self.config.num_attention_heads
self.base_num_heads = self.config.num_attention_heads
self.base_num_kv_heads = self.config.num_key_value_heads
self.num_key_value_groups = self.base_num_heads // self.base_num_kv_heads
self.layer_idx = layer_idx
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
self.is_causal = True
self.split_heads = config.split_heads
if config.split_heads:
if self.config.split_heads:
# Split heads mode - single projections
# NOTE: This rounds down `base_num_heads / 2` as opposed to the original
# implementation, which asserts `self.base_num_heads` is even
@@ -81,31 +76,29 @@ class LlamaDifferentialAttentionBase(nn.Module):
def _init_projections(self):
"""Initialize Q, K, V projections."""
if self.split_heads:
if self.config.split_heads:
# Split heads mode - single projections
q_out_dim = self.hidden_size
k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads
q_out_dim = self.config.hidden_size
k_out_dim = self.head_dim * self.base_num_kv_heads
else:
# Double projection mode
q_out_dim = self.hidden_size * 2
k_out_dim = (
self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2
)
q_out_dim = self.config.hidden_size * 2
k_out_dim = self.head_dim * self.base_num_kv_heads * 2
self.q_proj = nn.Linear(
self.hidden_size, q_out_dim, bias=self.config.attention_bias
self.config.hidden_size, q_out_dim, bias=self.config.attention_bias
)
self.k_proj = nn.Linear(
self.hidden_size, k_out_dim, bias=self.config.attention_bias
self.config.hidden_size, k_out_dim, bias=self.config.attention_bias
)
self.v_proj = nn.Linear(
self.hidden_size,
self.hidden_size // self.base_num_heads * self.base_num_kv_heads,
self.config.hidden_size,
self.head_dim * self.base_num_kv_heads,
bias=self.config.attention_bias,
)
self.o_proj = nn.Linear(
self.base_num_heads * self.head_dim,
self.hidden_size,
self.config.hidden_size,
bias=self.config.attention_bias,
)
@@ -129,11 +122,11 @@ class LlamaDifferentialAttentionBase(nn.Module):
)
self.rotary_emb = LlamaRotaryEmbedding(config=self.config)
def _init_normalization(self, config):
def _init_normalization(self):
"""Initialize normalization layers."""
sublayer_norm = getattr(config, "sublayer_norm", True)
sublayer_norm = getattr(self.config, "sublayer_norm", True)
if sublayer_norm:
self.subln = LlamaRMSNorm(self.value_head_dim, eps=config.rms_norm_eps)
self.subln = LlamaRMSNorm(self.value_head_dim, eps=self.config.rms_norm_eps)
else:
self.subln = nn.Identity()
@@ -148,7 +141,6 @@ class LlamaDifferentialAttentionBase(nn.Module):
q1, q2 = q.chunk(2, dim=-1)
k1, k2 = k.chunk(2, dim=-1)
# Reshape
q1 = q1.view(bsz, q_len, self.heads_per_component, self.head_dim).transpose(
1, 2
)
@@ -161,9 +153,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
k2 = k2.view(bsz, q_len, self.kv_heads_per_component, self.head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.base_num_kv_heads, self.value_head_dim).transpose(
1, 2
)
v = v.view(bsz, q_len, self.base_num_kv_heads, self.head_dim).transpose(1, 2)
return q1, q2, k1, k2, v
@@ -198,6 +188,8 @@ class LlamaDifferentialAttentionBase(nn.Module):
k1 = repeat_kv(k1, self.num_key_value_groups)
k2 = repeat_kv(k2, self.num_key_value_groups)
v = repeat_kv(v, self.num_key_value_groups)
if self.config.split_heads:
v = torch.cat(torch.chunk(v, 2, dim=1), dim=-1)
return k1, k2, v
@@ -215,7 +207,7 @@ class LlamaDifferentialAttentionBase(nn.Module):
"""Process and project attention output."""
attn = self.subln(attn)
attn = attn * (1 - self.lambda_init)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
attn = attn.transpose(1, 2).reshape(bsz, q_len, self.config.hidden_size)
return self.o_proj(attn)
@@ -255,7 +247,7 @@ class LlamaDifferentialAttention(LlamaDifferentialAttentionBase):
attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1)
attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2)
dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0
attn1 = F.dropout(attn1, p=dropout_p, training=self.training)
attn2 = F.dropout(attn2, p=dropout_p, training=self.training)
@@ -318,7 +310,7 @@ class LlamaDifferentialSdpaAttention(LlamaDifferentialAttentionBase):
None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]]
)
is_causal = attention_mask is None and q_len > 1
dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0
if q1.device.type == "cuda" and causal_mask is not None:
q1, q2 = q1.contiguous(), q2.contiguous()
@@ -396,9 +388,9 @@ class LlamaDifferentialFlashAttention2(LlamaDifferentialAttentionBase):
k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2)
v = v.transpose(1, 2)
dropout_p = self.attention_dropout if self.training else 0.0
dropout_p = self.config.attention_dropout if self.training else 0.0
if self.split_heads:
if self.config.split_heads:
v1, v2 = v.chunk(2, dim=-1)
attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True)
attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True)

View File

@@ -6,15 +6,10 @@ from typing import Optional, Union
import torch
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
LlamaForCausalLM,
LlamaModel,
LlamaPreTrainedModel,
)
from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaModel
from .diff_attn import (
LlamaDifferentialAttention,
LlamaDifferentialAttentionBase,
LlamaDifferentialFlashAttention2,
LlamaDifferentialSdpaAttention,
)
@@ -46,17 +41,6 @@ class LlamaDifferentialConfig(LlamaConfig):
}
class LlamaDifferentialPreTrainedModel(LlamaPreTrainedModel):
"""Base class for differential LLaMA models."""
config_class = LlamaDifferentialConfig
base_model_prefix = "llama_differential"
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (LlamaDifferentialAttentionBase, LlamaModel)):
module.gradient_checkpointing = value
class LlamaDifferentialModel(LlamaModel):
"""LlamaModel with differential attention."""
@@ -222,6 +206,37 @@ class LlamaDifferentialForCausalLM(LlamaForCausalLM):
super().__init__(config)
self.model = LlamaDifferentialModel(config)
# pylint: disable=protected-access
@classmethod
def _autoset_attn_implementation(
cls, config, **kwargs
): # pylint: disable=unused-argument
config._attn_implementation_autoset = True
attn_implementation = getattr(config, "_attn_implementation", None)
# Map standard types to differential types if mapping exists
if attn_implementation in config._attn_implementations:
config._attn_implementation = config._attn_implementations[
attn_implementation
]
return config
# If no mapping, validate it's a valid differential type
valid_impls = [
None,
"differential_eager",
"differential_sdpa",
"differential_flash_attention_2",
]
if attn_implementation not in valid_impls:
message = (
f"Specified `attn_implementation={attn_implementation}` is not supported. "
f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}"
)
raise ValueError(message)
return config
@classmethod
def from_llama(
cls, model: LlamaForCausalLM, config: Optional[LlamaDifferentialConfig] = None
@@ -257,3 +272,11 @@ def register_diff_attn():
# Register models
AutoModel.register(LlamaDifferentialConfig, LlamaDifferentialModel)
AutoModelForCausalLM.register(LlamaDifferentialConfig, LlamaDifferentialForCausalLM)
from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES
LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention
LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention
LLAMA_ATTENTION_CLASSES[
"differential_flash_attention_2"
] = LlamaDifferentialFlashAttention2

View File

@@ -130,6 +130,9 @@ def test_conversion_cli_repoduce_attentions(
)
def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str):
output_dir = tmp_path / "converted"
# Smallest model with an even number of attention heads
base_config["base_model"] = "HuggingFaceTB/SmolLM2-1.7B"
base_config["output_dir"] = str(output_dir)
base_config[attention] = True