diff --git a/model-out/eval_summary.csv b/model-out/eval_summary.csv deleted file mode 100644 index 6a8f78af7..000000000 --- a/model-out/eval_summary.csv +++ /dev/null @@ -1,6 +0,0 @@ -metric,training,validation -loss,15.633337020874023,15.604033470153809 -model_preparation_time,0.0058,0.0058 -runtime,77.8124,8.4643 -samples_per_second,23.133,23.629 -steps_per_second,23.133,23.629 diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index ecde82251..03126f3bf 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -87,6 +87,8 @@ def convert_diff_transformer(cfg, cli_args, config_path): zero_init=cli_args.zero_init, sublayer_norm=cli_args.sublayer_norm, split_heads=cli_args.split_heads, + init_scale=cli_args.init_scale, + reinit_lambda_init=cli_args.reinit_lambda_init, ) model = LlamaDifferentialForCausalLM.from_llama(model, config) model.to(cfg.device, dtype=cfg.torch_dtype) diff --git a/src/axolotl/common/cli.py b/src/axolotl/common/cli.py index ebe098ca6..8f0d6bb77 100644 --- a/src/axolotl/common/cli.py +++ b/src/axolotl/common/cli.py @@ -54,8 +54,10 @@ class ConvertDiffTransformerCliArgs: debug: bool = field(default=False) zero_init: bool = field(default=False) - sublayer_norm: bool = field(default=True) + sublayer_norm: bool = field(default=False) split_heads: bool = field(default=False) + init_scale: float = field(default=1e-6) + reinit_lambda_init: bool = field(default=True) def load_model_and_tokenizer( diff --git a/src/axolotl/integrations/diff_transformer/README.md b/src/axolotl/integrations/diff_transformer/README.md index a683fdf1d..efba1fc39 100644 --- a/src/axolotl/integrations/diff_transformer/README.md +++ b/src/axolotl/integrations/diff_transformer/README.md @@ -2,12 +2,11 @@ ### Usage +**Note:** The following with be set in the model config output by the `axolotl convert-diff-transformer` command. + ```yaml plugins: - axolotl.integrations.diff_transformer.DifferentialTransformerPlugin diff_attention: true -diff_attn_zero_init: false -diff_attn_sublayer_norm: true -diff_attn_split_heads: false ``` diff --git a/src/axolotl/integrations/diff_transformer/args.py b/src/axolotl/integrations/diff_transformer/args.py index 332c0b4aa..47c1fe110 100644 --- a/src/axolotl/integrations/diff_transformer/args.py +++ b/src/axolotl/integrations/diff_transformer/args.py @@ -12,6 +12,3 @@ class DifferentialTransformerArgs(BaseModel): """Input args for differential transformer.""" diff_attention: Optional[bool] = None - diff_attn_zero_init: Optional[bool] = None - diff_attn_sublayer_norm: Optional[bool] = None - diff_attn_split_heads: Optional[bool] = None diff --git a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py index e41fd1fdb..e0f4eebc1 100644 --- a/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py +++ b/src/axolotl/integrations/diff_transformer/modeling_diff_attn.py @@ -191,6 +191,26 @@ class LlamaDifferentialModel(LlamaModel): new_layer.self_attn.lambda_q2.zero_() new_layer.self_attn.lambda_k2.zero_() new_layer.self_attn.lambda_init.zero_() + else: + logger.debug( + f"Layer {layer_idx}: Initializing with scale {config.init_scale}" + ) + # Initialize with small random values + with torch.no_grad(): + new_layer.self_attn.q_proj.weight.data[old_q_size:].normal_( + 0, config.init_scale + ) + new_layer.self_attn.k_proj.weight.data[old_k_size:].normal_( + 0, config.init_scale + ) + new_layer.self_attn.lambda_q1.normal_(0, config.init_scale) + new_layer.self_attn.lambda_k1.normal_(0, config.init_scale) + new_layer.self_attn.lambda_q2.normal_(0, config.init_scale) + new_layer.self_attn.lambda_k2.normal_(0, config.init_scale) + if config.reinit_lambda_init: + new_layer.self_attn.lambda_init.normal_( + 0, config.init_scale + ).abs_() logger.info("Conversion complete") return new_model diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 2c4d2513d..d16db7613 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -710,25 +710,6 @@ class ModelLoader: """ sample packing uses custom FA2 patch """ - # if self.cfg.flash_attention: - # if not self.cfg.sample_packing and self.cfg.s2_attention: - # pass - - # self.model_kwargs["attn_implementation"] = "flash_attention_2" - # self.model_config._attn_implementation = ( # pylint: disable=protected-access - # "flash_attention_2" - # ) - # elif self.cfg.sdp_attention: - # self.model_kwargs["attn_implementation"] = "sdpa" - # self.model_config._attn_implementation = ( # pylint: disable=protected-access - # "sdpa" - # ) - # elif self.cfg.eager_attention: - # self.model_kwargs["attn_implementation"] = "eager" - # self.model_config._attn_implementation = ( # pylint: disable=protected-access - # "eager" - # ) - if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass diff --git a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py index e1ad31fdd..92c8053c0 100644 --- a/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py +++ b/tests/e2e/integrations/convert_diff_transformer/test_convert_diff_transformer.py @@ -69,7 +69,7 @@ def test_conversion_cli_debug(tmp_path: Path, base_config): yaml.dump(base_config, file) cfg = load_cfg(str(config_path)) - cli_args = ConvertDiffTransformerCliArgs(debug=True) + cli_args = ConvertDiffTransformerCliArgs(debug=True, init_scale=0.1) _, debug_info = convert_diff_transformer(cfg, cli_args, str(config_path)) assert not debug_info["generations_match"]