updates and cleanup

This commit is contained in:
Dan Saunders
2025-01-06 17:04:05 +00:00
parent 2a7f139ad2
commit 70c4e6fbe6
8 changed files with 28 additions and 33 deletions

View File

@@ -1,6 +0,0 @@
metric,training,validation
loss,15.633337020874023,15.604033470153809
model_preparation_time,0.0058,0.0058
runtime,77.8124,8.4643
samples_per_second,23.133,23.629
steps_per_second,23.133,23.629
1 metric training validation
2 loss 15.633337020874023 15.604033470153809
3 model_preparation_time 0.0058 0.0058
4 runtime 77.8124 8.4643
5 samples_per_second 23.133 23.629
6 steps_per_second 23.133 23.629

View File

@@ -87,6 +87,8 @@ def convert_diff_transformer(cfg, cli_args, config_path):
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
split_heads=cli_args.split_heads,
init_scale=cli_args.init_scale,
reinit_lambda_init=cli_args.reinit_lambda_init,
)
model = LlamaDifferentialForCausalLM.from_llama(model, config)
model.to(cfg.device, dtype=cfg.torch_dtype)

View File

@@ -54,8 +54,10 @@ class ConvertDiffTransformerCliArgs:
debug: bool = field(default=False)
zero_init: bool = field(default=False)
sublayer_norm: bool = field(default=True)
sublayer_norm: bool = field(default=False)
split_heads: bool = field(default=False)
init_scale: float = field(default=1e-6)
reinit_lambda_init: bool = field(default=True)
def load_model_and_tokenizer(

View File

@@ -2,12 +2,11 @@
### Usage
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.
```yaml
plugins:
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
diff_attention: true
diff_attn_zero_init: false
diff_attn_sublayer_norm: true
diff_attn_split_heads: false
```

View File

@@ -12,6 +12,3 @@ class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
diff_attention: Optional[bool] = None
diff_attn_zero_init: Optional[bool] = None
diff_attn_sublayer_norm: Optional[bool] = None
diff_attn_split_heads: Optional[bool] = None

View File

@@ -191,6 +191,26 @@ class LlamaDifferentialModel(LlamaModel):
new_layer.self_attn.lambda_q2.zero_()
new_layer.self_attn.lambda_k2.zero_()
new_layer.self_attn.lambda_init.zero_()
else:
logger.debug(
f"Layer {layer_idx}: Initializing with scale {config.init_scale}"
)
# Initialize with small random values
with torch.no_grad():
new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_(
0, config.init_scale
)
new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_(
0, config.init_scale
)
new_layer.self_attn.lambda_q1.normal_(0, config.init_scale)
new_layer.self_attn.lambda_k1.normal_(0, config.init_scale)
new_layer.self_attn.lambda_q2.normal_(0, config.init_scale)
new_layer.self_attn.lambda_k2.normal_(0, config.init_scale)
if config.reinit_lambda_init:
new_layer.self_attn.lambda_init.normal_(
0, config.init_scale
).abs_()
logger.info("Conversion complete")
return new_model

View File

@@ -710,25 +710,6 @@ class ModelLoader:
"""
sample packing uses custom FA2 patch
"""
# if self.cfg.flash_attention:
# if not self.cfg.sample_packing and self.cfg.s2_attention:
# pass
# self.model_kwargs["attn_implementation"] = "flash_attention_2"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "flash_attention_2"
# )
# elif self.cfg.sdp_attention:
# self.model_kwargs["attn_implementation"] = "sdpa"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "sdpa"
# )
# elif self.cfg.eager_attention:
# self.model_kwargs["attn_implementation"] = "eager"
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
# "eager"
# )
if self.cfg.flash_attention:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass

View File

@@ -69,7 +69,7 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(debug=True)
cli_args = ConvertDiffTransformerCliArgs(debug=True, init_scale=0.1)
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
assert not debug_info["generations_match"]