plugin implementation

This commit is contained in:
Dan Saunders
2024-12-18 01:26:41 +00:00
parent d22e1136bc
commit ea07a7086e
13 changed files with 118 additions and 30 deletions

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
"""
import logging
from pathlib import Path
from typing import Union
from typing import Dict, Union
import fire
from dotenv import load_dotenv
@@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate
LOG = logging.getLogger("axolotl.cli.evaluate")
def do_evaluate(cfg, cli_args) -> None:
def do_evaluate(cfg, cli_args) -> Dict[str, float]:
# pylint: disable=duplicate-code
print_axolotl_text_art()
check_accelerate_default_config()
@@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -15,7 +15,7 @@ from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.differential_transformer.convert import (
convert_to_diff_attention,
convert_to_differential_attention,
)
LOG = logging.getLogger(__name__)
@@ -79,7 +79,7 @@ def convert_differential_transformer(cfg, cli_args, config_path):
# Convert attention
LOG.info("Converting to differential attention...")
try:
model = convert_to_diff_attention(
model = convert_to_differential_attention(
model=model,
zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm,
@@ -111,7 +111,10 @@ def convert_differential_transformer(cfg, cli_args, config_path):
data = yaml.safe_load(file) or {}
data["base_model"] = cfg.output_dir
data["diff_attention"] = True
data["differential_attention"] = True
data["plugins"] = [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin"
]
with open(output_config_path, "w", encoding="utf-8") as file:
yaml.dump(data, file)

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = []
dynamic_input = ""
for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls)
if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities"
]
return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -0,0 +1,10 @@
# Differential Transformer
### Usage
```yaml
plugins:
- axolotl.integrations.differential_transformer.DifferentialTransformerPlugin
differential_attention: true
```

View File

@@ -0,0 +1,25 @@
"""Definition of differential transformer plugin."""
import logging
from axolotl.integrations.base import BasePlugin
LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""
Plugin for differential transformer integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.differential_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()

View File

@@ -0,0 +1,14 @@
"""Module for handling differential transfomer input arguments."""
import logging
from typing import Optional
from pydantic import BaseModel
LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
differential_attention: Optional[bool] = None

View File

@@ -80,7 +80,7 @@ def copy_attention_weights(
)
def convert_to_diff_attention(
def convert_to_differential_attention(
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention"""

View File

@@ -727,8 +727,6 @@ class AxolotlInputConfig(
eager_attention: Optional[bool] = None
diff_attention: Optional[bool] = None
unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None

View File

@@ -444,13 +444,6 @@ class ModelLoader:
patch_mistral_cross_entropy()
if self.cfg.diff_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()
def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -721,7 +714,7 @@ class ModelLoader:
if not self.cfg.sample_packing and self.cfg.s2_attention:
pass
if self.cfg.diff_attention:
if self.cfg.differential_attention:
self.model_kwargs[
"attn_implementation"
] = "differential_flash_attention_2"
@@ -734,7 +727,7 @@ class ModelLoader:
"flash_attention_2"
)
elif self.cfg.sdp_attention:
if self.cfg.diff_attention:
if self.cfg.differential_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa"
@@ -745,7 +738,7 @@ class ModelLoader:
"sdpa"
)
elif self.cfg.eager_attention:
if self.cfg.diff_attention:
if self.cfg.differential_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"
@@ -755,7 +748,7 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager"
)
elif self.cfg.diff_attention:
elif self.cfg.differential_attention:
self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager"