adding guard statements
This commit is contained in:
@@ -50,6 +50,13 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"):
|
|||||||
|
|
||||||
|
|
||||||
def convert_diff_transformer(cfg, cli_args, config_path):
|
def convert_diff_transformer(cfg, cli_args, config_path):
|
||||||
|
assert not (
|
||||||
|
cli_args.split_heads and cli_args.zero_init
|
||||||
|
), "Both `split_heads` and `zero_init` cannot be `True`"
|
||||||
|
assert not (
|
||||||
|
cli_args.zero_init and cli_args.mirror_weights
|
||||||
|
), "Both `zero_init` and `mirror_weights` cannot be `True`"
|
||||||
|
|
||||||
debug_info = {}
|
debug_info = {}
|
||||||
|
|
||||||
# Load model and tokenizer
|
# Load model and tokenizer
|
||||||
@@ -72,21 +79,16 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
|||||||
model, tokenizer
|
model, tokenizer
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert attention
|
|
||||||
LOG.info("Converting to differential attention...")
|
|
||||||
if cli_args.split_heads and cli_args.zero_init:
|
|
||||||
LOG.warning(
|
|
||||||
Fore.YELLOW
|
|
||||||
+ "Warning: Using split_heads with zero_init is not recommended; "
|
|
||||||
+ "split_heads will preclude the effects of zero_init"
|
|
||||||
+ Fore.RESET
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
|
# Convert attention
|
||||||
|
LOG.info("Converting to differential attention...")
|
||||||
|
|
||||||
config = LlamaDifferentialConfig(
|
config = LlamaDifferentialConfig(
|
||||||
**model.config.__dict__,
|
**model.config.__dict__,
|
||||||
zero_init=cli_args.zero_init,
|
zero_init=cli_args.zero_init,
|
||||||
sublayer_norm=cli_args.sublayer_norm,
|
sublayer_norm=cli_args.sublayer_norm,
|
||||||
split_heads=cli_args.split_heads,
|
split_heads=cli_args.split_heads,
|
||||||
|
mirror_weights=cli_args.mirror_weights,
|
||||||
)
|
)
|
||||||
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
||||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class ConvertDiffTransformerCliArgs:
|
|||||||
zero_init: bool = field(default=False)
|
zero_init: bool = field(default=False)
|
||||||
sublayer_norm: bool = field(default=True)
|
sublayer_norm: bool = field(default=True)
|
||||||
split_heads: bool = field(default=False)
|
split_heads: bool = field(default=False)
|
||||||
|
mirror_weights: bool = field(default=False)
|
||||||
|
|
||||||
|
|
||||||
def load_model_and_tokenizer(
|
def load_model_and_tokenizer(
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ class LlamaDifferentialConfig(LlamaConfig):
|
|||||||
split_heads: bool = False,
|
split_heads: bool = False,
|
||||||
sublayer_norm: bool = True,
|
sublayer_norm: bool = True,
|
||||||
zero_init: bool = False,
|
zero_init: bool = False,
|
||||||
|
mirror_weights: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -45,12 +46,15 @@ class LlamaDifferentialConfig(LlamaConfig):
|
|||||||
split_heads: Whether to use split heads mode for attention computation.
|
split_heads: Whether to use split heads mode for attention computation.
|
||||||
sublayer_norm: Whether to apply normalization to sublayers.
|
sublayer_norm: Whether to apply normalization to sublayers.
|
||||||
zero_init: Whether to initialize new weights to zero.
|
zero_init: Whether to initialize new weights to zero.
|
||||||
|
mirror_weights: Whether to copy the positive attention component weights to
|
||||||
|
the negative attention component.
|
||||||
**kwargs: Additional arguments passed to LlamaConfig.
|
**kwargs: Additional arguments passed to LlamaConfig.
|
||||||
"""
|
"""
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.split_heads = split_heads
|
self.split_heads = split_heads
|
||||||
self.sublayer_norm = sublayer_norm
|
self.sublayer_norm = sublayer_norm
|
||||||
self.zero_init = zero_init
|
self.zero_init = zero_init
|
||||||
|
self.mirror_weights = mirror_weights
|
||||||
self.architectures = ["LlamaDifferentialModel"]
|
self.architectures = ["LlamaDifferentialModel"]
|
||||||
self._attn_implementations = {
|
self._attn_implementations = {
|
||||||
"eager": "differential_eager",
|
"eager": "differential_eager",
|
||||||
@@ -250,7 +254,7 @@ class LlamaDifferentialModel(LlamaModel):
|
|||||||
new_layer.self_attn.lambda_q2.zero_()
|
new_layer.self_attn.lambda_q2.zero_()
|
||||||
new_layer.self_attn.lambda_k2.zero_()
|
new_layer.self_attn.lambda_k2.zero_()
|
||||||
new_layer.self_attn.lambda_init.zero_()
|
new_layer.self_attn.lambda_init.zero_()
|
||||||
else:
|
elif config.mirror_weights:
|
||||||
# Mirror weights for second component
|
# Mirror weights for second component
|
||||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
|
new_layer.self_attn.q_proj.weight.data[old_q_size:].copy_(
|
||||||
old_layer.self_attn.q_proj.weight.data
|
old_layer.self_attn.q_proj.weight.data
|
||||||
|
|||||||
Reference in New Issue
Block a user