adding guard statements

This commit is contained in:
Dan Saunders
2025-01-10 16:39:21 +00:00
parent 4f804f6d88
commit 7aca08ff60
3 changed files with 17 additions and 10 deletions

View File

@@ -50,6 +50,13 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
def convert_diff_transformer(cfg, cli_args, config_path):
assert not (
cli_args.split_heads and cli_args.zero_init
), "Both `split_heads` and `zero_init` cannot be `True`"
assert not (
cli_args.zero_init and cli_args.mirror_weights
), "Both `zero_init` and `mirror_weights` cannot be `True`"
debug_info = {}
# Load model and tokenizer
@@ -72,21 +79,16 @@ def convert_diff_transformer(cfg, cli_args, config_path):
model, tokenizer
)
# Convert attention
LOG.info("Converting to differential attention...")
if cli_args.split_heads and cli_args.zero_init:
LOG.warning(
Fore.YELLOW
+ "Warning: Using split_heads with zero_init is not recommended; "
+ "split_heads will preclude the effects of zero_init"
+ Fore.RESET
)
try:
# Convert attention
LOG.info("Converting to differential attention...")
config = LlamaDifferentialConfig(
**model.config.__dict__,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
mirror_weights=cli_args.mirror_weights,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -56,6 +56,7 @@ class ConvertDiffTransformerCliArgs:
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
split_heads: bool = field(default=False)
mirror_weights: bool = field(default=False)
def load_model_and_tokenizer(

View File

@@ -36,6 +36,7 @@ class LlamaDifferentialConfig(LlamaConfig):
split_heads: bool = False,
sublayer_norm: bool = True,
zero_init: bool = False,
mirror_weights: bool = False,
**kwargs,
):
"""
@@ -45,12 +46,15 @@ class LlamaDifferentialConfig(LlamaConfig):
split_heads: Whether to use split heads mode for attention computation.
sublayer_norm: Whether to apply normalization to sublayers.
zero_init: Whether to initialize new weights to zero.
mirror_weights: Whether to copy the positive attention component weights to
the negative attention component.
**kwargs: Additional arguments passed to LlamaConfig.
"""
super().__init__(**kwargs)
self.split_heads = split_heads
self.sublayer_norm = sublayer_norm
self.zero_init = zero_init
self.mirror_weights = mirror_weights
self.architectures = ["LlamaDifferentialModel"]
self._attn_implementations = {
"eager": "differential_eager",
@@ -250,7 +254,7 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.lambda_q2.zero_()
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
else:
elif config.mirror_weights:
# Mirror weights for second component
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
old_layer.self_attn.q_proj.weight.data