adding guard statements
This commit is contained in:
@@ -50,6 +50,13 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
||||
|
||||
|
||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
assert not (
|
||||
cli_args.split_heads and cli_args.zero_init
|
||||
), "Both `split_heads` and `zero_init` cannot be `True`"
|
||||
assert not (
|
||||
cli_args.zero_init and cli_args.mirror_weights
|
||||
), "Both `zero_init` and `mirror_weights` cannot be `True`"
|
||||
|
||||
debug_info = {}
|
||||
|
||||
# Load model and tokenizer
|
||||
@@ -72,21 +79,16 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
model, tokenizer
|
||||
)
|
||||
|
||||
# Convert attention
|
||||
LOG.info("Converting to differential attention...")
|
||||
if cli_args.split_heads and cli_args.zero_init:
|
||||
LOG.warning(
|
||||
Fore.YELLOW
|
||||
+ "Warning: Using split_heads with zero_init is not recommended; "
|
||||
+ "split_heads will preclude the effects of zero_init"
|
||||
+ Fore.RESET
|
||||
)
|
||||
try:
|
||||
# Convert attention
|
||||
LOG.info("Converting to differential attention...")
|
||||
|
||||
config = LlamaDifferentialConfig(
|
||||
**model.config.__dict__,
|
||||
zero_init=cli_args.zero_init,
|
||||
sublayer_norm=cli_args.sublayer_norm,
|
||||
split_heads=cli_args.split_heads,
|
||||
mirror_weights=cli_args.mirror_weights,
|
||||
)
|
||||
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -56,6 +56,7 @@ class ConvertDiffTransformerCliArgs:
|
||||
zero_init: bool = field(default=False)
|
||||
sublayer_norm: bool = field(default=True)
|
||||
split_heads: bool = field(default=False)
|
||||
mirror_weights: bool = field(default=False)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
||||
@@ -36,6 +36,7 @@ class LlamaDifferentialConfig(LlamaConfig):
|
||||
split_heads: bool = False,
|
||||
sublayer_norm: bool = True,
|
||||
zero_init: bool = False,
|
||||
mirror_weights: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -45,12 +46,15 @@ class LlamaDifferentialConfig(LlamaConfig):
|
||||
split_heads: Whether to use split heads mode for attention computation.
|
||||
sublayer_norm: Whether to apply normalization to sublayers.
|
||||
zero_init: Whether to initialize new weights to zero.
|
||||
mirror_weights: Whether to copy the positive attention component weights to
|
||||
the negative attention component.
|
||||
**kwargs: Additional arguments passed to LlamaConfig.
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
self.split_heads = split_heads
|
||||
self.sublayer_norm = sublayer_norm
|
||||
self.zero_init = zero_init
|
||||
self.mirror_weights = mirror_weights
|
||||
self.architectures = ["LlamaDifferentialModel"]
|
||||
self._attn_implementations = {
|
||||
"eager": "differential_eager",
|
||||
@@ -250,7 +254,7 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
new_layer.self_attn.lambda_q2.zero_()
|
||||
new_layer.self_attn.lambda_k2.zero_()
|
||||
new_layer.self_attn.lambda_init.zero_()
|
||||
else:
|
||||
elif config.mirror_weights:
|
||||
# Mirror weights for second component
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
|
||||
old_layer.self_attn.q_proj.weight.data
|
||||
|
||||
Reference in New Issue
Block a user