From 7aca08ff60443f0ae68eb45202ba694bcada6d5e Mon Sep 17 00:00:00 2001 From: Dan Saunders Date: Fri, 10 Jan 2025 16:39:21 +0000 Subject: [PATCH] adding guard statements --- .../integrations/convert_diff_transformer.py | 20 ++++++++++--------- src/axolotl/common/cli.py | 1 + .../diff_transformer/modeling_diff_attn.py | 6 +++++- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index ecde82251..3b0f16ca9 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -50,6 +50,13 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"): def convert_diff_transformer(cfg, cli_args, config_path): + assert not ( + cli_args.split_heads and cli_args.zero_init + ), "Both `split_heads` and `zero_init` cannot be `True`" + assert not ( + cli_args.zero_init and cli_args.mirror_weights + ), "Both `zero_init` and `mirror_weights` cannot be `True`" + debug_info = {} # Load model and tokenizer @@ -72,21 +79,16 @@ def convert_diff_transformer(cfg, cli_args, config_path): model, tokenizer ) - # Convert attention - LOG.info("Converting to differential attention...") - if cli_args.split_heads and cli_args.zero_init: - LOG.warning( - Fore.YELLOW - + "Warning: Using split_heads with zero_init is not recommended; " - + "split_heads will preclude the effects of zero_init" - + Fore.RESET - ) try: + # Convert attention + LOG.info("Converting to differential attention...") + config = LlamaDifferentialConfig( **model.config.__dict__, zero_init=cli_args.zero_init, sublayer_norm=cli_args.sublayer_norm, split_heads=cli_args.split_heads, + mirror_weights=cli_args.mirror_weights, ) model = LlamaDifferentialForCausalLM.from_llama(model, config) model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index ebe098ca6..8b31b52b5 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -56,6 +56,7 @@ class ConvertDiffTransformerCliArgs: zero_init: bool = field(default=False) sublayer_norm: bool = field(default=True) split_heads: bool = field(default=False) + mirror_weights: bool = field(default=False) def load_model_and_tokenizer( diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index c8a663cb3..90e5d838b 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -36,6 +36,7 @@ class LlamaDifferentialConfig(LlamaConfig): split_heads: bool = False, sublayer_norm: bool = True, zero_init: bool = False, + mirror_weights: bool = False, **kwargs, ): """ @@ -45,12 +46,15 @@ class LlamaDifferentialConfig(LlamaConfig): split_heads: Whether to use split heads mode for attention computation. sublayer_norm: Whether to apply normalization to sublayers. zero_init: Whether to initialize new weights to zero. + mirror_weights: Whether to copy the positive attention component weights to + the negative attention component. **kwargs: Additional arguments passed to LlamaConfig. """ super().__init__(**kwargs) self.split_heads = split_heads self.sublayer_norm = sublayer_norm self.zero_init = zero_init + self.mirror_weights = mirror_weights self.architectures = ["LlamaDifferentialModel"] self._attn_implementations = { "eager": "differential_eager", @@ -250,7 +254,7 @@ class LlamaDifferentialModel(LlamaModel): new_layer.self_attn.lambda_q2.zero_() new_layer.self_attn.lambda_k2.zero_() new_layer.self_attn.lambda_init.zero_() - else: + elif config.mirror_weights: # Mirror weights for second component new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_( old_layer.self_attn.q_proj.weight.data