updates and cleanup

This commit is contained in:
Dan Saunders
2025-01-06 17:04:05 +00:00
parent 2a7f139ad2
commit 70c4e6fbe6
8 changed files with 28 additions and 33 deletions

View File

@@ -87,6 +87,8 @@ def convert_diff_transformer(cfg, cli_args, config_path):
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
init_scale=cli_args.init_scale,
reinit_lambda_init=cli_args.reinit_lambda_init,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -54,8 +54,10 @@ class ConvertDiffTransformerCliArgs:
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
sublayer_norm: bool = field(default=False)
split_heads: bool = field(default=False)
init_scale: float = field(default=1e-6)
reinit_lambda_init: bool = field(default=True)
def load_model_and_tokenizer(

View File

@@ -2,12 +2,11 @@
### Usage
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.
```yaml
plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
diff_attention: true
diff_attn_zero_init: false
diff_attn_sublayer_norm: true
diff_attn_split_heads: false
```

View File

@@ -12,6 +12,3 @@ class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
diff_attention: Optional[bool] = None
diff_attn_zero_init: Optional[bool] = None
diff_attn_sublayer_norm: Optional[bool] = None
diff_attn_split_heads: Optional[bool] = None

View File

@@ -191,6 +191,26 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.lambda_q2.zero_()
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
else:
logger.debug(
f"Layer {layer_idx}: Initializing with scale {config.init_scale}"
)
# Initialize with small random values
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_(
0, config.init_scale
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_(
0, config.init_scale
)
new_layer.self_attn.lambda_q1.normal_(0, config.init_scale)
new_layer.self_attn.lambda_k1.normal_(0, config.init_scale)
new_layer.self_attn.lambda_q2.normal_(0, config.init_scale)
new_layer.self_attn.lambda_k2.normal_(0, config.init_scale)
if config.reinit_lambda_init:
new_layer.self_attn.lambda_init.normal_(
0, config.init_scale
).abs_()
logger.info("Conversion complete")
return new_model

View File

@@ -710,25 +710,6 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
# if self.cfg.flash_attention:
# if not self.cfg.sample_packing and self.cfg.s2_attention:
# pass
# self.model_kwargs["attn_implementation"] = "flash_attention_2"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "flash_attention_2"
# )
# elif self.cfg.sdp_attention:
# self.model_kwargs["attn_implementation"] = "sdpa"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "sdpa"
# )
# elif self.cfg.eager_attention:
# self.model_kwargs["attn_implementation"] = "eager"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "eager"
# )
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass