updates and cleanup
This commit is contained in:
@@ -1,6 +0,0 @@
|
||||
metric,training,validation
|
||||
loss,15.633337020874023,15.604033470153809
|
||||
model_preparation_time,0.0058,0.0058
|
||||
runtime,77.8124,8.4643
|
||||
samples_per_second,23.133,23.629
|
||||
steps_per_second,23.133,23.629
|
||||
|
@@ -87,6 +87,8 @@ def convert_diff_transformer(cfg, cli_args, config_path):
|
||||
zero_init=cli_args.zero_init,
|
||||
sublayer_norm=cli_args.sublayer_norm,
|
||||
split_heads=cli_args.split_heads,
|
||||
init_scale=cli_args.init_scale,
|
||||
reinit_lambda_init=cli_args.reinit_lambda_init,
|
||||
)
|
||||
model = LlamaDifferentialForCausalLM.from_llama(model, config)
|
||||
model.to(cfg.device, dtype=cfg.torch_dtype)
|
||||
|
||||
@@ -54,8 +54,10 @@ class ConvertDiffTransformerCliArgs:
|
||||
|
||||
debug: bool = field(default=False)
|
||||
zero_init: bool = field(default=False)
|
||||
sublayer_norm: bool = field(default=True)
|
||||
sublayer_norm: bool = field(default=False)
|
||||
split_heads: bool = field(default=False)
|
||||
init_scale: float = field(default=1e-6)
|
||||
reinit_lambda_init: bool = field(default=True)
|
||||
|
||||
|
||||
def load_model_and_tokenizer(
|
||||
|
||||
@@ -2,12 +2,11 @@
|
||||
|
||||
### Usage
|
||||
|
||||
**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command.
|
||||
|
||||
```yaml
|
||||
plugins:
|
||||
- axolotl.integrations.diff_transformer.DifferentialTransformerPlugin
|
||||
|
||||
diff_attention: true
|
||||
diff_attn_zero_init: false
|
||||
diff_attn_sublayer_norm: true
|
||||
diff_attn_split_heads: false
|
||||
```
|
||||
|
||||
@@ -12,6 +12,3 @@ class DifferentialTransformerArgs(BaseModel):
|
||||
"""Input args for differential transformer."""
|
||||
|
||||
diff_attention: Optional[bool] = None
|
||||
diff_attn_zero_init: Optional[bool] = None
|
||||
diff_attn_sublayer_norm: Optional[bool] = None
|
||||
diff_attn_split_heads: Optional[bool] = None
|
||||
|
||||
@@ -191,6 +191,26 @@ class LlamaDifferentialModel(LlamaModel):
|
||||
new_layer.self_attn.lambda_q2.zero_()
|
||||
new_layer.self_attn.lambda_k2.zero_()
|
||||
new_layer.self_attn.lambda_init.zero_()
|
||||
else:
|
||||
logger.debug(
|
||||
f"Layer {layer_idx}: Initializing with scale {config.init_scale}"
|
||||
)
|
||||
# Initialize with small random values
|
||||
with torch.no_grad():
|
||||
new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_(
|
||||
0, config.init_scale
|
||||
)
|
||||
new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_(
|
||||
0, config.init_scale
|
||||
)
|
||||
new_layer.self_attn.lambda_q1.normal_(0, config.init_scale)
|
||||
new_layer.self_attn.lambda_k1.normal_(0, config.init_scale)
|
||||
new_layer.self_attn.lambda_q2.normal_(0, config.init_scale)
|
||||
new_layer.self_attn.lambda_k2.normal_(0, config.init_scale)
|
||||
if config.reinit_lambda_init:
|
||||
new_layer.self_attn.lambda_init.normal_(
|
||||
0, config.init_scale
|
||||
).abs_()
|
||||
|
||||
logger.info("Conversion complete")
|
||||
return new_model
|
||||
|
||||
@@ -710,25 +710,6 @@ class ModelLoader:
|
||||
"""
|
||||
sample packing uses custom FA2 patch
|
||||
"""
|
||||
# if self.cfg.flash_attention:
|
||||
# if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
# pass
|
||||
|
||||
# self.model_kwargs["attn_implementation"] = "flash_attention_2"
|
||||
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
# "flash_attention_2"
|
||||
# )
|
||||
# elif self.cfg.sdp_attention:
|
||||
# self.model_kwargs["attn_implementation"] = "sdpa"
|
||||
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
# "sdpa"
|
||||
# )
|
||||
# elif self.cfg.eager_attention:
|
||||
# self.model_kwargs["attn_implementation"] = "eager"
|
||||
# self.model_config._attn_implementation = ( # pylint: disable=protected-access
|
||||
# "eager"
|
||||
# )
|
||||
|
||||
if self.cfg.flash_attention:
|
||||
if not self.cfg.sample_packing and self.cfg.s2_attention:
|
||||
pass
|
||||
|
||||
@@ -69,7 +69,7 @@ def test_conversion_cli_debug(tmp_path: Path, base_config):
|
||||
yaml.dump(base_config, file)
|
||||
|
||||
cfg = load_cfg(str(config_path))
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True)
|
||||
cli_args = ConvertDiffTransformerCliArgs(debug=True, init_scale=0.1)
|
||||
_, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path))
|
||||
|
||||
assert not debug_info["generations_match"]
|
||||
|
||||
Reference in New Issue
Block a user