diff --git a/src/axolotl/cli/integrations/convert_differential_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py similarity index 87% rename from src/axolotl/cli/integrations/convert_differential_transformer.py rename to src/axolotl/cli/integrations/convert_diff_transformer.py index b50dd43dd..d91278fed 100644 --- a/src/axolotl/cli/integrations/convert_differential_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -14,7 +14,8 @@ from transformers import HfArgumentParser from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer -from axolotl.integrations.differential_transformer.convert import convert_to_diff_attn +from axolotl.integrations.diff_transformer.convert import convert_to_diff_attn +from axolotl.utils.yaml import dump_yaml_preserved_order LOG = logging.getLogger(__name__) @@ -51,7 +52,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"): raise -def convert_differential_transformer(cfg, cli_args, config_path): +def convert_diff_transformer(cfg, cli_args, config_path): debug_info = {} # Load model and tokenizer @@ -114,16 +115,23 @@ def convert_differential_transformer(cfg, cli_args, config_path): LOG.info("Saving updated config to %s", output_config_path) with open(config_path, "r", encoding="utf-8") as file: - data = yaml.safe_load(file) or {} + modified_cfg = yaml.safe_load(file) or {} - data["base_model"] = cfg.output_dir - data["differential_attention"] = True - data["plugins"] = [ - "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin" - ] + modified_cfg["base_model"] = cfg.output_dir + modified_cfg["diff_attention"] = True + plugin_class = ( + "axolotl.integrations.diff_transformer.DifferentialTransformerPlugin" + ) + if "plugins" in modified_cfg: + modified_cfg["plugins"].append(plugin_class) + else: + modified_cfg["plugins"] = [plugin_class] - with open(output_config_path, "w", encoding="utf-8") as file: - yaml.dump(data, file) + dump_yaml_preserved_order( + data=modified_cfg, + reference_yaml_path=config_path, + output_path=output_config_path, + ) else: LOG.info("Not saving converted model to disk") LOG.info("Pass --output-dir path/to/save to save model") @@ -191,7 +199,7 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): parser = HfArgumentParser(ConvertDiffTransformerCliArgs) cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - convert_differential_transformer(cfg, cli_args, config) + convert_diff_transformer(cfg, cli_args, config) if __name__ == "__main__": diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index 41354dcb0..00d075286 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -252,11 +252,11 @@ def merge_lora( @click.argument("config", type=click.Path(exists=True, path_type=str)) @add_options_from_dataclass(ConvertDiffTransformerCliArgs) @add_options_from_config(AxolotlInputConfig) -def convert_differential_transformer(config: str, **kwargs): +def convert_diff_transformer(config: str, **kwargs): """Convert model attention layers to differential attention layers.""" kwargs = {k: v for k, v in kwargs.items() if v is not None} - from axolotl.cli.integrations.convert_differential_transformer import do_cli + from axolotl.cli.integrations.convert_diff_transformer import do_cli do_cli(config=config, **kwargs) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index c51c4e2ab..ea3b91c0c 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -57,7 +57,7 @@ class EvaluateCliArgs: @dataclass class ConvertDiffTransformerCliArgs: """ - dataclass with arguments for convert-differential-transformer CLI + dataclass with arguments for convert-diff-transformer CLI """ debug: bool = field(default=False) diff --git a/src/axolotl/integrations/diff_transformer/README.md b/src/axolotl/integrations/diff_transformer/README.md new file mode 100644 index 000000000..14473f753 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/README.md @@ -0,0 +1,10 @@ +# Differential Transformer + +### Usage + +```yaml +plugins: + - axolotl.integrations.diff_transformer.DifferentialTransformerPlugin + +diff_attention: true +``` diff --git a/src/axolotl/integrations/differential_transformer/__init__.py b/src/axolotl/integrations/diff_transformer/__init__.py similarity index 81% rename from src/axolotl/integrations/differential_transformer/__init__.py rename to src/axolotl/integrations/diff_transformer/__init__.py index 63741793c..70459e026 100644 --- a/src/axolotl/integrations/differential_transformer/__init__.py +++ b/src/axolotl/integrations/diff_transformer/__init__.py @@ -13,11 +13,11 @@ class DifferentialTransformerPlugin(BasePlugin): """ def get_input_args(self): - return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs" + return "axolotl.integrations.diff_transformer.args.DifferentialTransformerArgs" def pre_model_load(self, cfg): """Apply differential attention patch before model loading if enabled.""" - if cfg.differential_attention: + if cfg.diff_attention: from axolotl.monkeypatch.attention.differential import ( patch_llama_attention_classes, ) diff --git a/src/axolotl/integrations/differential_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py similarity index 84% rename from src/axolotl/integrations/differential_transformer/args.py rename to src/axolotl/integrations/diff_transformer/args.py index bd6e01520..47c1fe110 100644 --- a/src/axolotl/integrations/differential_transformer/args.py +++ b/src/axolotl/integrations/diff_transformer/args.py @@ -11,4 +11,4 @@ LOG = logging.getLogger(__name__) class DifferentialTransformerArgs(BaseModel): """Input args for differential transformer.""" - differential_attention: Optional[bool] = None + diff_attention: Optional[bool] = None diff --git a/src/axolotl/integrations/differential_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py similarity index 99% rename from src/axolotl/integrations/differential_transformer/convert.py rename to src/axolotl/integrations/diff_transformer/convert.py index 4beaea7ae..5c10f2137 100644 --- a/src/axolotl/integrations/differential_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -11,7 +11,7 @@ from transformers.models.llama.modeling_llama import ( LlamaSdpaAttention, ) -from .differential_attention import ( +from .diff_attn import ( LlamaDifferentialAttention, LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, diff --git a/src/axolotl/integrations/diff_transformer/diff_attn.py b/src/axolotl/integrations/diff_transformer/diff_attn.py new file mode 100644 index 000000000..edf532c41 --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/diff_attn.py @@ -0,0 +1,375 @@ +"""Re-implemention of differential attention.""" +# pylint: disable=invalid-name + +import logging +import math +from typing import Any, Optional, Tuple + +import torch +import torch.nn.functional as F +from flash_attn.flash_attn_interface import flash_attn_func +from torch import nn +from transformers.cache_utils import Cache +from transformers.models.llama.modeling_llama import ( + LlamaRMSNorm, + LlamaRotaryEmbedding, + apply_rotary_pos_emb, +) + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" + batch_size, n_kv_heads, slen, head_dim = x.shape + if n_rep == 1: + return x + return ( + x[:, :, None, :, :] + .expand(batch_size, n_kv_heads, n_rep, slen, head_dim) + .reshape(batch_size, n_kv_heads * n_rep, slen, head_dim) + ) + + +def lambda_init_fn(depth): + return 0.8 - 0.6 * math.exp(-0.3 * depth) + + +class DifferentialAttentionBase(nn.Module): + """Base class for differential attention implementations.""" + + def __init__(self, config: Any, layer_idx: int): + super().__init__() + self._init_config(config, layer_idx) + self._init_projections() + self._init_differential_params() + self._init_normalization(config) + + def _init_config(self, config: Any, layer_idx: int): + """Initialize configuration parameters.""" + self.attention_dropout = config.attention_dropout + self.hidden_size = config.hidden_size + self.base_num_heads = config.num_attention_heads + self.base_num_kv_heads = config.num_key_value_heads + self.layer_idx = layer_idx + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.split_heads = config.split_heads + + if config.split_heads: + # Split heads mode - single projections + self.head_dim = config.hidden_size // config.num_attention_heads // 2 + # NOTE: This rounds down `base_num_heads / 2` as opposed to the original + # implementation, which asserts `self.base_num_heads` is even. + self.heads_per_component = self.base_num_heads // 2 + self.value_head_dim = 2 * self.head_dim + else: + # Double projection mode + self.head_dim = config.hidden_size // config.num_attention_heads + self.heads_per_component = self.base_num_heads + self.value_head_dim = self.head_dim + + def _init_projections(self): + """Initialize Q, K, V projections.""" + if self.split_heads: + # Split heads mode - single projections + q_out_dim = self.hidden_size + k_out_dim = self.hidden_size // self.base_num_heads * self.base_num_kv_heads + else: + # Double projection mode + q_out_dim = self.hidden_size * 2 + k_out_dim = ( + self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2 + ) + + self.q_proj = nn.Linear(self.hidden_size, q_out_dim, bias=False) + self.k_proj = nn.Linear(self.hidden_size, k_out_dim, bias=False) + self.v_proj = nn.Linear( + self.hidden_size, + self.hidden_size // self.base_num_heads * self.base_num_kv_heads, + bias=False, + ) + self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False) + + def _init_differential_params(self): + """Initialize differential attention parameters.""" + self.lambda_init = nn.Parameter( + torch.full((), lambda_init_fn(self.layer_idx)), + requires_grad=False, + ) + self.lambda_q1 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_k1 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_q2 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.lambda_k2 = nn.Parameter( + torch.zeros(self.head_dim).normal_(mean=0, std=0.1) + ) + self.rotary_emb = LlamaRotaryEmbedding( + self.max_position_embeddings, self.head_dim, self.rope_theta + ) + + def _init_normalization(self, config): + """Initialize normalization layers.""" + sublayer_norm = getattr(config, "sublayer_norm", True) + self.subln = ( + LlamaRMSNorm(self.value_head_dim, eps=1e-5) + if sublayer_norm + else nn.Identity() + ) + + def _prepare_attention_inputs(self, hidden_states: torch.Tensor): + """Prepare inputs for attention computation.""" + bsz, q_len, _ = hidden_states.size() + + # Project and split + qp = self.q_proj(hidden_states) + kp = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + q1, q2 = qp.chunk(2, dim=-1) + k1, k2 = kp.chunk(2, dim=-1) + + # Reshape + q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) + v = v.view(bsz, q_len, -1, self.value_head_dim).transpose(1, 2) + + return q1, q2, k1, k2, v + + def _apply_rotary_embeddings( + self, q1, q2, k1, k2, position_ids, position_embeddings + ): + """Apply rotary embeddings to queries and keys.""" + if position_embeddings is None: + if position_ids is None: + position_ids = torch.arange(q1.size(-2), device=q1.device) + cos, sin = self.rotary_emb(q1, position_ids) + else: + cos, sin = position_embeddings + + if self.split_heads: + cos, _ = cos.chunk(2, dim=2) + sin, _ = sin.chunk(2, dim=2) + + q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) + q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) + + return q1, q2, k1, k2, cos, sin + + def _handle_cache(self, k1, k2, v, past_key_value, cache_kwargs): + """Handle caching for autoregressive generation.""" + if past_key_value is not None: + k = torch.stack([k1, k2], dim=1) + k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) + k1, k2 = k.unbind(dim=1) + + # Repeat KV heads + k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) + k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) + v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) + + return k1, k2, v + + def _compute_lambda(self, q1): + """Compute lambda values for differential attention.""" + lambda_1 = torch.exp( + torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() + ).type_as(q1) + lambda_2 = torch.exp( + torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() + ).type_as(q1) + return lambda_1 - lambda_2 + self.lambda_init + + def _process_attention_output(self, attn, bsz, q_len): + """Process and project attention output.""" + attn = self.subln(attn) + attn = attn * (1 - self.lambda_init) + attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + return self.o_proj(attn) + + +class LlamaDifferentialAttention(DifferentialAttentionBase): + """Standard implementation of differential attention.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, # pylint: disable=unused-argument + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # Standard attention computation + attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) + attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) + + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : k1.shape[-2]] + attn1 = attn1 + causal_mask + attn2 = attn2 + causal_mask + + attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1) + attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2) + + dropout_p = self.attention_dropout if self.training else 0.0 + attn1 = F.dropout(attn1, p=dropout_p, training=self.training) + attn2 = F.dropout(attn2, p=dropout_p, training=self.training) + + lambda_full = self._compute_lambda(q1) + attn = torch.matmul(attn1, v) - lambda_full * torch.matmul(attn2, v) + + attn = self._process_attention_output(attn, bsz, q_len) + + if output_attentions: + return attn, attn1 - lambda_full * attn2, past_key_value + return attn, None, past_key_value + + +class LlamaDifferentialSdpaAttention(DifferentialAttentionBase): + """SDPA-based implementation of differential attention.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + if output_attentions: + return LlamaDifferentialAttention.forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # SDPA-specific attention computation + causal_mask = ( + None if attention_mask is None else attention_mask[:, :, :, : k1.shape[-2]] + ) + is_causal = attention_mask is None and q_len > 1 + dropout_p = self.attention_dropout if self.training else 0.0 + + if q1.device.type == "cuda" and causal_mask is not None: + q1, q2 = q1.contiguous(), q2.contiguous() + k1, k2 = k1.contiguous(), k2.contiguous() + v = v.contiguous() + + attn1 = F.scaled_dot_product_attention( + q1, k1, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal + ) + attn2 = F.scaled_dot_product_attention( + q2, k2, v, attn_mask=causal_mask, dropout_p=dropout_p, is_causal=is_causal + ) + + lambda_full = self._compute_lambda(q1) + attn = attn1 - lambda_full * attn2 + + attn = self._process_attention_output(attn, bsz, q_len) + return attn, None, past_key_value + + +class LlamaDifferentialFlashAttention2(DifferentialAttentionBase): + """Flash Attention 2-based implementation of differential attention.""" + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + cache_position: Optional[torch.LongTensor] = None, + position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, + **kwargs, # pylint: disable=unused-argument + ): + if output_attentions: + return LlamaDifferentialAttention.forward( + self, + hidden_states, + attention_mask, + position_ids, + past_key_value, + output_attentions, + use_cache, + cache_position, + position_embeddings, + ) + + bsz, q_len, _ = hidden_states.size() + q1, q2, k1, k2, v = self._prepare_attention_inputs(hidden_states) + q1, q2, k1, k2, cos, sin = self._apply_rotary_embeddings( + q1, q2, k1, k2, position_ids, position_embeddings + ) + + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + k1, k2, v = self._handle_cache(k1, k2, v, past_key_value, cache_kwargs) + + # Flash Attention specific processing + q1, q2 = q1.transpose(1, 2), q2.transpose(1, 2) + k1, k2 = k1.transpose(1, 2), k2.transpose(1, 2) + v = v.transpose(1, 2) + + dropout_p = self.attention_dropout if self.training else 0.0 + + if self.split_heads: + v1, v2 = v.chunk(2, dim=-1) + attn11 = flash_attn_func(q1, k1, v1, dropout_p=dropout_p, causal=True) + attn12 = flash_attn_func(q1, k1, v2, dropout_p=dropout_p, causal=True) + attn1 = torch.cat([attn11, attn12], dim=-1) + + attn21 = flash_attn_func(q2, k2, v1, dropout_p=dropout_p, causal=True) + attn22 = flash_attn_func(q2, k2, v2, dropout_p=dropout_p, causal=True) + attn2 = torch.cat([attn21, attn22], dim=-1) + else: + attn1 = flash_attn_func(q1, k1, v, dropout_p=dropout_p, causal=True) + attn2 = flash_attn_func(q2, k2, v, dropout_p=dropout_p, causal=True) + + attn1, attn2 = attn1.transpose(1, 2), attn2.transpose(1, 2) + + lambda_full = self._compute_lambda(q1) + attn = attn1 - lambda_full * attn2 + + attn = self._process_attention_output(attn, bsz, q_len) + return attn, None, past_key_value diff --git a/src/axolotl/integrations/differential_transformer/README.md b/src/axolotl/integrations/differential_transformer/README.md deleted file mode 100644 index f7bd74cbd..000000000 --- a/src/axolotl/integrations/differential_transformer/README.md +++ /dev/null @@ -1,10 +0,0 @@ -# Differential Transformer - -### Usage - -```yaml -plugins: - - axolotl.integrations.differential_transformer.DifferentialTransformerPlugin - -differential_attention: true -``` diff --git a/src/axolotl/integrations/differential_transformer/differential_attention.py b/src/axolotl/integrations/differential_transformer/differential_attention.py deleted file mode 100644 index af7473436..000000000 --- a/src/axolotl/integrations/differential_transformer/differential_attention.py +++ /dev/null @@ -1,641 +0,0 @@ -"""Re-implemention of differential attention.""" -# pylint: disable=invalid-name -import logging -import math -from typing import Any, Optional, Tuple - -import torch -import torch.nn.functional as F -import transformers -from flash_attn.flash_attn_interface import flash_attn_func -from torch import nn -from transformers.cache_utils import Cache -from transformers.models.llama.modeling_llama import ( - LlamaRMSNorm, - LlamaRotaryEmbedding, - apply_rotary_pos_emb, -) - -logging.basicConfig(level=logging.INFO) -logger = logging.getLogger(__name__) - - -def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: - """torch.repeat_interleave(x, dim=1, repeats=n_rep)""" - batch_size, n_kv_heads, slen, head_dim = x.shape - if n_rep == 1: - return x - return ( - x[:, :, None, :, :] - .expand(batch_size, n_kv_heads, n_rep, slen, head_dim) - .reshape(batch_size, n_kv_heads * n_rep, slen, head_dim) - ) - - -def lambda_init_fn(depth): - return 0.8 - 0.6 * math.exp(-0.3 * depth) - - -class LlamaDifferentialAttention(nn.Module): - """Differential Attention implementation as described in the Diff Transformer paper. - - This implements a modified attention mechanism that computes the difference between - two attention patterns, scaled by learned lambda parameters. The mechanism helps - reduce noise in the attention weights for irrelevant / less relevant tokens. - - Key components: - - Split head dimension for differential computation - - Learned lambda parameters that control attention scaling - - Sublayer normalization on the attention output - - See: - - https://arxiv.org/abs/2410.05258 - - https://github.com/microsoft/unilm/tree/master/Diff-Transformer - - Args: - config: Model configuration object containing hidden size, number of heads etc. - layer_idx: Index of this layer in the transformer stack - dtype: Data type for the layer parameters - """ - - def __init__( - self, - config: Any, - layer_idx: int, - ): - super().__init__() - - # Base model config - self.attention_dropout = config.attention_dropout - self.hidden_size = config.hidden_size - self.base_num_heads = config.num_attention_heads - self.base_num_kv_heads = config.num_key_value_heads - - if config.split_heads: - self.head_dim = config.hidden_size // config.num_attention_heads // 2 - else: - self.head_dim = config.hidden_size // config.num_attention_heads - - self.layer_idx = layer_idx - self.max_position_embeddings = config.max_position_embeddings - self.rope_theta = config.rope_theta - self.is_causal = True - self.split_heads = config.split_heads - - if config.split_heads: - # Split heads mode - # assert ( - # self.base_num_heads % 2 == 0 - # ), "Number of heads must be even for splitting" - self.heads_per_component = self.base_num_heads // 2 - - # Single projections - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias=False, - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads, - bias=False, - ) - else: - # Double projection mode - self.heads_per_component = self.base_num_heads - - # Double-sized projections - self.q_proj = nn.Linear( - self.hidden_size, - self.hidden_size * 2, - bias=False, - ) - self.k_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads * 2, - bias=False, - ) - - # Single V projection - self.v_proj = nn.Linear( - self.hidden_size, - self.hidden_size // self.base_num_heads * self.base_num_kv_heads, - bias=False, - ) - - # Output projection - self.o_proj = nn.Linear( - self.hidden_size, - self.hidden_size, - bias=False, - ) - - # Initialize differential attention parameters - self.lambda_init = nn.Parameter( - torch.full((), lambda_init_fn(self.layer_idx)), - requires_grad=False, - ) - self.lambda_q1 = nn.Parameter( - torch.zeros(self.head_dim).normal_(mean=0, std=0.1) - ) - self.lambda_k1 = nn.Parameter( - torch.zeros(self.head_dim).normal_(mean=0, std=0.1) - ) - self.lambda_q2 = nn.Parameter( - torch.zeros(self.head_dim).normal_(mean=0, std=0.1) - ) - self.lambda_k2 = nn.Parameter( - torch.zeros(self.head_dim).normal_(mean=0, std=0.1) - ) - - self.rotary_emb = LlamaRotaryEmbedding(config=config) - sublayer_norm = getattr(config, "sublayer_norm", True) - - if self.split_heads: - subln_dim = 2 * self.head_dim - else: - subln_dim = self.head_dim - - self.subln = ( - LlamaRMSNorm(hidden_size=subln_dim, eps=1e-5) - if sublayer_norm - else nn.Identity() - ) - - def forward( - self, - hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, # pylint: disable=unused-argument - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, # pylint: disable=unused-argument - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[tuple[torch.Tensor, torch.Tensor]], - ]: - bsz, q_len, _ = hidden_states.size() - - # Project to Q1,Q2 and K1,K2 - qp = self.q_proj(hidden_states) - kp = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Split into Q1,Q2 and K1,K2 - q1, q2 = qp.chunk(2, dim=-1) - k1, k2 = kp.chunk(2, dim=-1) - - # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape V - if self.split_heads: - v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) - else: - v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Apply rotary embeddings - if position_embeddings is None: - if position_ids is None: - position_ids = torch.arange(q_len, device=q1.device) - cos, sin = self.rotary_emb(q1, position_ids) - else: - cos, sin = position_embeddings - - if self.split_heads: - cos, _ = cos.chunk(2, dim=2) - sin, _ = sin.chunk(2, dim=2) - - q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) - q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - k = torch.stack([k1, k2], dim=1) - k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) - k1, k2 = k.unbind(dim=1) - - # Repeat KV heads to match Q heads - k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) - k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) - v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) - - # Calculate attention scores for both parts - attn1 = torch.matmul(q1, k1.transpose(-1, -2)) / math.sqrt(self.head_dim) - attn2 = torch.matmul(q2, k2.transpose(-1, -2)) / math.sqrt(self.head_dim) - - if attention_mask is not None: - causal_mask = attention_mask[:, :, :, : k1.shape[-2]] - attn1 = attn1 + causal_mask - attn2 = attn2 + causal_mask - - # Apply softmax - attn1 = F.softmax(attn1, dim=-1, dtype=torch.float32).type_as(attn1) - attn2 = F.softmax(attn2, dim=-1, dtype=torch.float32).type_as(attn2) - - # Apply dropout - attn1 = F.dropout(attn1, p=self.attention_dropout, training=self.training) - attn2 = F.dropout(attn2, p=self.attention_dropout, training=self.training) - - # Calculate lambda - lambda_1 = torch.exp( - torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() - ).type_as(q1) - lambda_2 = torch.exp( - torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() - ).type_as(q1) - lambda_full = lambda_1 - lambda_2 + self.lambda_init - - # Compute differential attention (following paper's formula) - attn_weights = attn1 - lambda_full * attn2 - - # Apply attention weights to values - attn = torch.matmul(attn_weights, v) - - # Apply sublayer norm and scaling - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - - # Reshape to output - attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) - attn = self.o_proj(attn) - - if output_attentions: - return attn, attn_weights, past_key_value - return attn, None, past_key_value - - -class LlamaDifferentialSdpaAttention(LlamaDifferentialAttention): - """Differential Attention implementation as described in the Diff Transformer paper. - This implements the same logic as `LlamaDifferentialAttention`, but uses - `scaled_dot_product_attention` instead of "manually" computing it under the hood. - - This implements a modified attention mechanism that computes the difference between - two attention patterns, scaled by learned lambda parameters. The mechanism helps - reduce noise in the attention weights for irrelevant / less relevant tokens. - - Key components: - - Split head dimension for differential computation - - Learned lambda parameters that control attention scaling - - Sublayer normalization on the attention output - - See: - - https://arxiv.org/abs/2410.05258 - - https://github.com/microsoft/unilm/tree/master/Diff-Transformer - - Args: - config: Model configuration object containing hidden size, number of heads etc. - layer_idx: Index of this layer in the transformer stack - dtype: Data type for the layer parameters - """ - - # pylint: disable=duplicate-code - def forward( - self, - hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, # pylint: disable=unused-argument - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[tuple[torch.Tensor, torch.Tensor]], - ]: - if output_attentions: - transformers.logger.warning_once( - "LlamaModel is using LlamaSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " - 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - # Project to Q1,Q2 and K1,K2 - qp = self.q_proj(hidden_states) - kp = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Split into Q1,Q2 and K1,K2 - q1, q2 = qp.chunk(2, dim=-1) - k1, k2 = kp.chunk(2, dim=-1) - - # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape V - if self.split_heads: - v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) - else: - v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Apply rotary embeddings - if position_embeddings is None: - if position_ids is None: - position_ids = torch.arange(q_len, device=q1.device) - cos, sin = self.rotary_emb(q1, position_ids) - else: - cos, sin = position_embeddings - - if self.split_heads: - cos, _ = cos.chunk(2, dim=2) - sin, _ = sin.chunk(2, dim=2) - - q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) - q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - k = torch.stack([k1, k2], dim=1) - k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) - k1, k2 = k.unbind(dim=1) - - # Repeat KV heads to match Q heads - k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) - k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) - v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) - - causal_mask = None - if attention_mask is not None: - causal_mask = attention_mask - causal_mask = causal_mask[:, :, :, : k1.shape[-2]] - - # SDPA with memory-efficient backend requires contiguous inputs on CUDA - if q1.device.type == "cuda" and causal_mask is not None: - q1, q2 = q1.contiguous(), q2.contiguous() - k1, k2 = k1.contiguous(), k2.contiguous() - v = v.contiguous() - - # Calculate attention using SDPA - is_causal = attention_mask is None and q_len > 1 - - dropout_p = self.attention_dropout if self.training else 0.0 - attn1 = F.scaled_dot_product_attention( - q1, - k1, - v, - attn_mask=causal_mask, - dropout_p=dropout_p, - is_causal=is_causal, - ) - attn2 = F.scaled_dot_product_attention( - q2, - k2, - v, - attn_mask=causal_mask, - dropout_p=dropout_p, - is_causal=is_causal, - ) - - # Calculate lambda - lambda_1 = torch.exp( - torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() - ).type_as(q1) - lambda_2 = torch.exp( - torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() - ).type_as(q1) - lambda_full = lambda_1 - lambda_2 + self.lambda_init - - # Combine the attention outputs - attn = attn1 - lambda_full * attn2 - - # Apply sublayer norm and scaling - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - - # Reshape to output - attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) - attn = self.o_proj(attn) - - if output_attentions: - return ( - attn, - None, - past_key_value, - ) # Note: can't return attn_weights with SDPA - return attn, None, past_key_value - - -class LlamaDifferentialFlashAttention2(LlamaDifferentialAttention): - """Differential Attention implementation using Flash Attention 2. - This implements the same logic as `LlamaDifferentialAttention`, but uses - Flash Attention 2 for more efficient computation. - - This implements a modified attention mechanism that computes the difference between - two attention patterns, scaled by learned lambda parameters. The mechanism helps - reduce noise in the attention weights for irrelevant / less relevant tokens. - - Key components: - - Split head dimension for differential computation - - Learned lambda parameters that control attention scaling - - Sublayer normalization on the attention output - - Flash Attention 2 for efficient attention computation - - See: - - https://arxiv.org/abs/2410.05258 - - https://github.com/microsoft/unilm/tree/master/Diff-Transformer - - Args: - config: Model configuration object containing hidden size, number of heads etc. - layer_idx: Index of this layer in the transformer stack - dtype: Data type for the layer parameters - """ - - # pylint: disable=duplicate-code - def forward( - self, - hidden_states: torch.Tensor, # [bsz, seq_len, hidden_size] - attention_mask: Optional[torch.Tensor] = None, - position_ids: Optional[torch.LongTensor] = None, - past_key_value: Optional[Cache] = None, - output_attentions: bool = False, - use_cache: bool = False, - cache_position: Optional[torch.LongTensor] = None, - position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, - **kwargs, - ) -> tuple[ - torch.Tensor, - Optional[torch.Tensor], - Optional[tuple[torch.Tensor, torch.Tensor]], - ]: - if output_attentions: - transformers.logger.warning_once( - "LlamaModel is using LlamaFlashAttention, but Flash Attention does not support `output_attentions=True`. " - "Falling back to the manual attention implementation." - ) - return super().forward( - hidden_states=hidden_states, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_value=past_key_value, - output_attentions=output_attentions, - use_cache=use_cache, - cache_position=cache_position, - position_embeddings=position_embeddings, - ) - - bsz, q_len, _ = hidden_states.size() - - # Project to Q1,Q2 and K1,K2 - qp = self.q_proj(hidden_states) - kp = self.k_proj(hidden_states) - v = self.v_proj(hidden_states) - - # Split into Q1,Q2 and K1,K2 - q1, q2 = qp.chunk(2, dim=-1) - k1, k2 = kp.chunk(2, dim=-1) - - # Reshape Q1,Q2 for attention - q1 = q1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - q2 = q2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape K1,K2 for attention - k1 = k1.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - k2 = k2.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Reshape V - if self.split_heads: - v = v.view(bsz, q_len, -1, 2 * self.head_dim).transpose(1, 2) - else: - v = v.view(bsz, q_len, -1, self.head_dim).transpose(1, 2) - - # Apply rotary embeddings - if position_embeddings is None: - if position_ids is None: - position_ids = torch.arange(q_len, device=q1.device) - cos, sin = self.rotary_emb(q1, position_ids) - else: - cos, sin = position_embeddings - - if self.split_heads: - cos, _ = cos.chunk(2, dim=2) - sin, _ = sin.chunk(2, dim=2) - - q1, k1 = apply_rotary_pos_emb(q1, k1, cos, sin) - q2, k2 = apply_rotary_pos_emb(q2, k2, cos, sin) - - if past_key_value is not None: - cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} - k = torch.stack([k1, k2], dim=1) - k, v = past_key_value.update(k, v, self.layer_idx, cache_kwargs) - k1, k2 = k.unbind(dim=1) - - # Repeat KV heads to match Q heads - k1 = repeat_kv(k1, self.base_num_heads // self.base_num_kv_heads) - k2 = repeat_kv(k2, self.base_num_heads // self.base_num_kv_heads) - v = repeat_kv(v, self.base_num_heads // self.base_num_kv_heads) - - q1 = q1.transpose(1, 2) - q2 = q2.transpose(1, 2) - k1 = k1.transpose(1, 2) - k2 = k2.transpose(1, 2) - v = v.transpose(1, 2) - - # Calculate attention using Flash Attention - dropout_p = self.attention_dropout if self.training else 0.0 - if self.split_heads: - v1, v2 = v.chunk(2, dim=-1) - attn11 = flash_attn_func( - q1, - k1, - v1, - dropout_p=dropout_p, - causal=True, - ) - attn12 = flash_attn_func( - q1, - k1, - v2, - dropout_p=dropout_p, - causal=True, - ) - attn1 = torch.cat([attn11, attn12], dim=-1) - - attn21 = flash_attn_func( - q2, - k2, - v1, - dropout_p=dropout_p, - causal=True, - ) - attn22 = flash_attn_func( - q2, - k2, - v2, - dropout_p=dropout_p, - causal=True, - ) - attn2 = torch.cat([attn21, attn22], dim=-1) - else: - attn1 = flash_attn_func( - q1, - k1, - v, - dropout_p=dropout_p, - causal=True, - ) - attn2 = flash_attn_func( - q2, - k2, - v, - dropout_p=dropout_p, - causal=True, - ) - - attn1 = attn1.transpose(1, 2) - attn2 = attn2.transpose(1, 2) - - # Calculate lambda - lambda_1 = torch.exp( - torch.sum(self.lambda_q1 * self.lambda_k1, dim=-1).float() - ).type_as(q1) - lambda_2 = torch.exp( - torch.sum(self.lambda_q2 * self.lambda_k2, dim=-1).float() - ).type_as(q1) - lambda_full = lambda_1 - lambda_2 + self.lambda_init - - # Combine the attention outputs - attn = attn1 - lambda_full * attn2 - - # Apply sublayer norm and scaling - attn = self.subln(attn) - attn = attn * (1 - self.lambda_init) - - # Reshape to output - attn = attn.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) - attn = self.o_proj(attn) - - if output_attentions: - return ( - attn, - None, - past_key_value, - ) # Note: can't return attn_weights with Flash Attention - return attn, None, past_key_value diff --git a/src/axolotl/monkeypatch/attention/differential.py b/src/axolotl/monkeypatch/attention/differential.py index a07b629b6..635573a4b 100644 --- a/src/axolotl/monkeypatch/attention/differential.py +++ b/src/axolotl/monkeypatch/attention/differential.py @@ -3,7 +3,7 @@ from transformers import PreTrainedModel from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES -from axolotl.integrations.differential_transformer.differential_attention import ( +from axolotl.integrations.diff_transformer.diff_attn import ( LlamaDifferentialAttention, LlamaDifferentialFlashAttention2, LlamaDifferentialSdpaAttention, diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 6eaa020da..37cbc0871 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -714,7 +714,7 @@ class ModelLoader: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - if self.cfg.differential_attention: + if self.cfg.differentiaion: self.model_kwargs[ "attn_implementation" ] = "differential_flash_attention_2" @@ -727,7 +727,7 @@ class ModelLoader: "flash_attention_2" ) elif self.cfg.sdp_attention: - if self.cfg.differential_attention: + if self.cfg.diff_attention: self.model_kwargs["attn_implementation"] = "differential_sdpa" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_sdpa" @@ -738,7 +738,7 @@ class ModelLoader: "sdpa" ) elif self.cfg.eager_attention: - if self.cfg.differential_attention: + if self.cfg.diff_attention: self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_eager" @@ -748,7 +748,7 @@ class ModelLoader: self.model_config._attn_implementation = ( # pylint: disable=protected-access "eager" ) - elif self.cfg.differential_attention: + elif self.cfg.diff_attention: self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access "differential_eager" diff --git a/src/axolotl/utils/yaml.py b/src/axolotl/utils/yaml.py new file mode 100644 index 000000000..107afafcf --- /dev/null +++ b/src/axolotl/utils/yaml.py @@ -0,0 +1,151 @@ +"""Utilities for YAML files.""" + +from collections import OrderedDict +from typing import Any, Dict, List, Set, Tuple, Union + +import yaml + + +class YAMLOrderTracker: + """Tracks the order of keys and section breaks in YAML files.""" + + def __init__(self, yaml_path: str): + self.yaml_path = yaml_path + self.structure, self.needs_break = self._parse_yaml_structure() + + def _get_indentation_level(self, line: str) -> int: + """Get the indentation level of a line.""" + return len(line) - len(line.lstrip()) + + def _parse_yaml_structure( + self, + ) -> Tuple[Dict[str, Union[List[str], Dict]], Set[str]]: + """Parse the YAML file to extract structure and identify section breaks.""" + with open(self.yaml_path, "r", encoding="utf-8") as file: + contents = file.readlines() + + structure: OrderedDict = OrderedDict() + needs_break = set() # Track which keys should have a break before them + current_path = [] + last_indentation = -1 + had_empty_line = False + + for line in contents: + # Track empty lines and comments + if not line.strip() or line.strip().startswith("#"): + had_empty_line = True + continue + + # Get indentation level and content + indentation = self._get_indentation_level(line) + content = line.strip() + + # Skip lines that don't define keys + if ":" not in content: + continue + + # Extract key + key = content.split(":")[0].strip() + + # If this is a top-level key and we had an empty line, mark it + if indentation == 0: + if had_empty_line: + needs_break.add(key) + had_empty_line = False + + # Handle indentation changes + if indentation > last_indentation: + current_path.append(key) + elif indentation < last_indentation: + levels_up = (last_indentation - indentation) // 2 + current_path = current_path[:-levels_up] + current_path[-1] = key + else: + if current_path: + current_path[-1] = key + + # Update structure + current_dict = structure + for path_key in current_path[:-1]: + if path_key not in current_dict: + current_dict[path_key] = OrderedDict() + current_dict = current_dict[path_key] + + if current_path: + if current_path[-1] not in current_dict: + current_dict[current_path[-1]] = OrderedDict() + + last_indentation = indentation + + return structure, needs_break + + +class OrderedDumper(yaml.SafeDumper): + """Custom YAML dumper that maintains dictionary order.""" + + +def ordered_dict_representer(dumper: OrderedDumper, data: Dict) -> Any: + """Custom representer for dictionaries that maintains order.""" + return dumper.represent_mapping("tag:yaml.org,2002:map", data.items()) + + +def reorder_dict(data: Dict, reference_structure: Dict) -> OrderedDict: + """Reorder a dictionary based on a reference structure.""" + ordered = OrderedDict() + + # First add keys that are in the reference order + for key in reference_structure: + if key in data: + if isinstance(reference_structure[key], dict) and isinstance( + data[key], dict + ): + ordered[key] = reorder_dict(data[key], reference_structure[key]) + else: + ordered[key] = data[key] + + # Then add any remaining keys that weren't in the reference + for key in data: + if key not in ordered: + ordered[key] = data[key] + + return ordered + + +def dump_yaml_preserved_order( + data: Dict, reference_yaml_path: str, output_path: str +) -> None: + """Dump YAML file while preserving nested order and normalized spacing.""" + # Get reference structure and spacing + tracker = YAMLOrderTracker(reference_yaml_path) + + # Reorder the data + ordered_data = reorder_dict(data, tracker.structure) + + # Register the custom representer + OrderedDumper.add_representer(dict, ordered_dict_representer) + OrderedDumper.add_representer(OrderedDict, ordered_dict_representer) + + # First dump to string + yaml_str = yaml.dump( + ordered_data, Dumper=OrderedDumper, sort_keys=False, default_flow_style=False + ) + + # Add spacing according to reference + lines = yaml_str.split("\n") + result_lines: List[str] = [] + current_line = 0 + + while current_line < len(lines): + line = lines[current_line] + if line.strip() and ":" in line and not line.startswith(" "): # Top-level key + key = line.split(":")[0].strip() + if key in tracker.needs_break: + # Add single empty line before this key + if result_lines and result_lines[-1] != "": + result_lines.append("") + result_lines.append(line) + current_line += 1 + + # Write the final result + with open(output_path, "w", encoding="utf-8") as file: + file.write("\n".join(result_lines)) diff --git a/tests/e2e/integrations/convert_differential_transformer/__init__.py b/tests/e2e/integrations/convert_diff_transformer/__init__.py similarity index 100% rename from tests/e2e/integrations/convert_differential_transformer/__init__.py rename to tests/e2e/integrations/convert_diff_transformer/__init__.py diff --git a/tests/e2e/integrations/convert_differential_transformer/conftest.py b/tests/e2e/integrations/convert_diff_transformer/conftest.py similarity index 85% rename from tests/e2e/integrations/convert_differential_transformer/conftest.py rename to tests/e2e/integrations/convert_diff_transformer/conftest.py index ed1eb3f36..d4ffeb759 100644 --- a/tests/e2e/integrations/convert_differential_transformer/conftest.py +++ b/tests/e2e/integrations/convert_diff_transformer/conftest.py @@ -9,9 +9,6 @@ def base_config(): """Basic config for testing.""" return { "base_model": "HuggingFaceTB/SmolLM2-135M", - "plugins": [ - "axolotl.integrations.differential_transformer.DifferentialTransformerPlugin", - ], "datasets": [ { "path": "axolotl-ai-co/alpaca_100_test", diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py similarity index 89% rename from tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py rename to tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py index 1cf569693..d5915f8a5 100644 --- a/tests/e2e/integrations/convert_differential_transformer/test_convert_and_evaluate.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_and_evaluate.py @@ -8,9 +8,7 @@ from pytest import approx from axolotl.cli import load_cfg from axolotl.cli.evaluate import do_evaluate -from axolotl.cli.integrations.convert_differential_transformer import ( - convert_differential_transformer, -) +from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs @@ -26,7 +24,7 @@ def test_conversion_and_eval_cli(tmp_path: Path, base_config): cli_args = ConvertDiffTransformerCliArgs( debug=True, zero_init=True, sublayer_norm=False ) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert debug_info["generations_match"] is True assert (output_dir / "model.safetensors").exists() diff --git a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py similarity index 79% rename from tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py rename to tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index 42ce3e612..e616a8ef1 100644 --- a/tests/e2e/integrations/convert_differential_transformer/test_convert_differential_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -10,23 +10,19 @@ import pytest import yaml from axolotl.cli import load_cfg -from axolotl.cli.integrations.convert_differential_transformer import ( - convert_differential_transformer, -) +from axolotl.cli.integrations.convert_diff_transformer import convert_diff_transformer from axolotl.cli.main import cli from axolotl.common.cli import ConvertDiffTransformerCliArgs def test_cli_validation(cli_runner): # Test missing config file - result = cli_runner.invoke(cli, ["convert-differential-transformer"]) + result = cli_runner.invoke(cli, ["convert-diff-transformer"]) assert result.exit_code != 0 assert "Error: Missing argument 'CONFIG'." in result.output # Test non-existent config file - result = cli_runner.invoke( - cli, ["convert-differential-transformer", "nonexistent.yml"] - ) + result = cli_runner.invoke(cli, ["convert-diff-transformer", "nonexistent.yml"]) assert result.exit_code != 0 assert "Error: Invalid value for 'CONFIG'" in result.output @@ -37,11 +33,9 @@ def test_basic_execution(cli_runner, tmp_path: Path, base_config): yaml.dump(base_config, file) with patch( - "axolotl.cli.integrations.convert_differential_transformer.do_cli" + "axolotl.cli.integrations.convert_diff_transformer.do_cli" ) as mock_do_cli: - result = cli_runner.invoke( - cli, ["convert-differential-transformer", str(config_path)] - ) + result = cli_runner.invoke(cli, ["convert-diff-transformer", str(config_path)]) assert result.exit_code == 0 mock_do_cli.assert_called_once() @@ -56,14 +50,9 @@ def test_conversion_cli_basic(tmp_path: Path, base_config): with open(config_path, "w", encoding="utf-8") as file: yaml.dump(base_config, file) - # Load config the same way do_cli does cfg = load_cfg(str(config_path)) - - # Create CLI args cli_args = ConvertDiffTransformerCliArgs() - - # Call convert_differential_transformer directly - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert not debug_info assert (output_dir / "model.safetensors").exists() @@ -79,14 +68,9 @@ def test_conversion_cli_debug(tmp_path: Path, base_config): with open(config_path, "w", encoding="utf-8") as file: yaml.dump(base_config, file) - # Load config the same way do_cli does cfg = load_cfg(str(config_path)) - - # Create CLI args cli_args = ConvertDiffTransformerCliArgs(debug=True) - - # Call convert_differential_transformer directly - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert not debug_info["generations_match"] assert not debug_info["match_expected"] @@ -107,7 +91,7 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config): cli_args = ConvertDiffTransformerCliArgs( debug=True, zero_init=True, sublayer_norm=False ) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert debug_info["generations_match"] is True assert (output_dir / "model.safetensors").exists() @@ -133,7 +117,7 @@ def test_conversion_cli_repoduce_attentions( cli_args = ConvertDiffTransformerCliArgs( debug=True, zero_init=True, sublayer_norm=False ) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert debug_info["generations_match"] is True assert (output_dir / "model.safetensors").exists() @@ -155,7 +139,7 @@ def test_conversion_cli_split_heads(tmp_path: Path, base_config, attention: str) cfg = load_cfg(str(config_path)) cli_args = ConvertDiffTransformerCliArgs(debug=True, split_heads=True) - _, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path)) + _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert debug_info["generations_match"] is False assert (output_dir / "model.safetensors").exists()