diff --git a/src/axolotl/cli/integrations/convert_diff_transformer.py b/src/axolotl/cli/integrations/convert_diff_transformer.py index 8886c4946..116a60480 100644 --- a/src/axolotl/cli/integrations/convert_diff_transformer.py +++ b/src/axolotl/cli/integrations/convert_diff_transformer.py @@ -7,6 +7,7 @@ from typing import Union import fire import torch +import yaml from colorama import Fore from dotenv import load_dotenv from transformers import HfArgumentParser @@ -50,13 +51,7 @@ def test_inference(model, tokenizer, prompt="The quick brown fox"): raise -def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): - print_axolotl_text_art() - - cfg = load_cfg(config, **kwargs) - parser = HfArgumentParser(TrainerCliArgs) - cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) - +def convert_diff_transformer(cfg, cli_args, config_path): try: # Load model and tokenizer with warnings.catch_warnings(): @@ -90,8 +85,26 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): # Save if requested if cfg.output_dir: + # Save model and tokenizer LOG.info("Saving converted model to %s", cfg.output_dir) model.save_pretrained(cfg.output_dir) + tokenizer.save_pretrained(cfg.output_dir) + + # Modify config to reflect new path / differential attention + output_config_path = Path(cfg.output_dir) / "axolotl_config.yml" + LOG.info("Saving updated config to %s", output_config_path) + + with open(config_path, "r", encoding="utf-8") as file: + data = yaml.safe_load(file) or {} + + data["base_model"] = cfg.output_dir + data["diff_attention"] = True + + with open(output_config_path, "w", encoding="utf-8") as file: + yaml.dump(data, file) + else: + LOG.info("Not saving converted model to disk") + LOG.info("Pass --output-dir path/to/save to save model") LOG.info( Fore.GREEN @@ -122,10 +135,16 @@ def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): raise +def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs): + print_axolotl_text_art() + + cfg = load_cfg(config, **kwargs) + parser = HfArgumentParser(TrainerCliArgs) + cli_args, _ = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + convert_diff_transformer(cfg, cli_args, config) + + if __name__ == "__main__": load_dotenv() - logging.basicConfig( - level=logging.INFO, - format="%(asctime)s - %(levelname)s - %(message)s", - ) fire.Fire(do_cli) diff --git a/src/axolotl/cli/main.py b/src/axolotl/cli/main.py index aae9d55d0..c29874590 100644 --- a/src/axolotl/cli/main.py +++ b/src/axolotl/cli/main.py @@ -242,11 +242,6 @@ def merge_lora( @cli.command() @click.argument("config", type=click.Path(exists=True, path_type=str)) -@click.option( - "--output-dir", - type=click.Path(path_type=str), - help="Directory to save converted model", -) @add_options_from_dataclass(TrainerCliArgs) @add_options_from_config(AxolotlInputConfig) def convert_diff_transformer(config: str, **kwargs): diff --git a/src/axolotl/core/trainer_builder.py b/src/axolotl/core/trainer_builder.py index 54ee19536..0743d4c92 100755 --- a/src/axolotl/core/trainer_builder.py +++ b/src/axolotl/core/trainer_builder.py @@ -293,7 +293,7 @@ class AxolotlTrainingArguments(AxolotlTrainingMixins, TrainingArguments): """ Training arguments for Causal trainer - This code is duplicated due to HF TrainingArguments not setting output_dir with a defaujlt value + This code is duplicated due to HF TrainingArguments not setting output_dir with a default value so it can't be used as a mixin. """ diff --git a/src/axolotl/integrations/diff_transformer/convert.py b/src/axolotl/integrations/diff_transformer/convert.py index 36d97037b..584c19d5f 100644 --- a/src/axolotl/integrations/diff_transformer/convert.py +++ b/src/axolotl/integrations/diff_transformer/convert.py @@ -88,7 +88,6 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: for name, child in module.named_children(): if isinstance(child, attention_patterns): layer_type = type(child).__name__ - logger.info(f"Converting attention layer {layer_idx}: {layer_type}") # Choose appropriate differential attention class if isinstance(child, LlamaSdpaAttention): @@ -96,6 +95,10 @@ def convert_to_diff_attention(model: PreTrainedModel) -> PreTrainedModel: else: attention_class = LlamaDifferentialAttention + logger.info( + f"Converting attention layer {layer_idx}: {layer_type} to {attention_class.__name__}" + ) + # Create new diff attn layer new_attention = attention_class( config=module.config if hasattr(module, "config") else model.config, diff --git a/src/axolotl/integrations/diff_transformer/patches.py b/src/axolotl/integrations/diff_transformer/patches.py new file mode 100644 index 000000000..7ff35633c --- /dev/null +++ b/src/axolotl/integrations/diff_transformer/patches.py @@ -0,0 +1,46 @@ +"""Patches related to differential transformers implementation.""" +from transformers import PreTrainedModel +from transformers.models.llama.modeling_llama import LLAMA_ATTENTION_CLASSES + +from .multihead_diffattn import ( + LlamaDifferentialAttention, + LlamaDifferentialSdpaAttention, +) + + +def patch_transformers(): + """Patch transformers to support differential attention""" + + # Add our attention class to the registry + LLAMA_ATTENTION_CLASSES["differential_eager"] = LlamaDifferentialAttention + LLAMA_ATTENTION_CLASSES["differential_sdpa"] = LlamaDifferentialSdpaAttention + + # Store original method for use in our patch + # original_autoset = PreTrainedModel._autoset_attn_implementation + + @classmethod + def new_autoset(_, config, **kwargs): # pylint: disable=unused-argument + config._attn_implementation_autoset = True # pylint: disable=protected-access + attn_implementation = getattr(config, "_attn_implementation", None) + + valid_impls = [ + None, + "eager", + "sdpa", + "flash_attention_2", + "differential_eager", + "differential_sdpa", + ] + if attn_implementation not in valid_impls: + message = ( + f"Specified `attn_implementation={attn_implementation}` is not supported. " + f"The only possible arguments are: {', '.join(repr(x) for x in valid_impls if x)}" + ) + raise ValueError(message + ".") + + return config + + # Apply patch + PreTrainedModel._autoset_attn_implementation = ( # pylint: disable=protected-access + new_autoset + ) diff --git a/src/axolotl/train.py b/src/axolotl/train.py index c5cf7ad6d..7f51175bf 100644 --- a/src/axolotl/train.py +++ b/src/axolotl/train.py @@ -87,7 +87,7 @@ def train( ) resume_from_checkpoint = cfg.resume_from_checkpoint - # Load the model and tokenizer + # Load the model msg = "loading model" if cfg.adapter: msg += " and peft_config..." diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 5ddf04811..cab61c148 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -724,6 +724,8 @@ class AxolotlInputConfig( eager_attention: Optional[bool] = None + diff_attention: Optional[bool] = None + unsloth_cross_entropy_loss: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None diff --git a/src/axolotl/utils/models.py b/src/axolotl/utils/models.py index 523fd76fe..7d8599c3c 100644 --- a/src/axolotl/utils/models.py +++ b/src/axolotl/utils/models.py @@ -710,24 +710,60 @@ class ModelLoader: """ sample packing uses custom FA2 patch """ + print( + self.cfg.flash_attention, + self.cfg.sdp_attention, + self.cfg.eager_attention, + self.cfg.diff_attention, + ) + if self.cfg.flash_attention: if not self.cfg.sample_packing and self.cfg.s2_attention: pass - self.model_kwargs["attn_implementation"] = "flash_attention_2" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "flash_attention_2" - ) + + if self.cfg.diff_attention: + self.model_kwargs[ + "attn_implementation" + ] = "differential_flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_flash_attention_2" + ) + else: + self.model_kwargs["attn_implementation"] = "flash_attention_2" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "flash_attention_2" + ) elif self.cfg.sdp_attention: - self.model_kwargs["attn_implementation"] = "sdpa" - self.model_config._attn_implementation = ( # pylint: disable=protected-access - "sdpa" - ) + if self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_sdpa" + ) + else: + self.model_kwargs["attn_implementation"] = "sdpa" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "sdpa" + ) elif self.cfg.eager_attention: - self.model_kwargs["attn_implementation"] = "eager" + if self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "differential_eager" + ) + else: + self.model_kwargs["attn_implementation"] = "eager" + self.model_config._attn_implementation = ( # pylint: disable=protected-access + "eager" + ) + elif self.cfg.diff_attention: + self.model_kwargs["attn_implementation"] = "differential_eager" self.model_config._attn_implementation = ( # pylint: disable=protected-access - "eager" + "differential_eager" ) + if "attn_implementation" in self.model_kwargs: + print(self.model_kwargs["attn_implementation"]) + if self.cfg.low_cpu_mem_usage: self.model_kwargs["low_cpu_mem_usage"] = True @@ -816,6 +852,8 @@ class ModelLoader: if self.cfg.is_multimodal: self.model_config.text_config = self.text_model_config + + # self.model._attn_implementation_autoset = False self.model = self.AutoModelLoader.from_pretrained( self.base_model, config=self.model_config, @@ -1030,6 +1068,10 @@ class ModelLoader: integrate_rope_embeddings() def load_model(self) -> Tuple[PreTrainedModel, Optional[PeftConfig]]: + from axolotl.integrations.diff_transformer.patches import patch_transformers + + patch_transformers() + self.apply_patches() self.set_auto_model_loader() self.set_device_map_config()