plugin implementation

This commit is contained in:
Dan Saunders
2024-12-18 01:26:41 +00:00
parent d22e1136bc
commit ea07a7086e
13 changed files with 118 additions and 30 deletions

3
.gitignore vendored
View File

@@ -186,3 +186,6 @@ out/
# vim # vim
*.swp *.swp
# symlinked to axolotl-artifacts in docker containers
outputs

View File

@@ -1,6 +0,0 @@
metric,training,validation
loss,1.8773103952407837,1.915901780128479
model_preparation_time,0.0051,0.0051
runtime,89.7635,8.9565
samples_per_second,20.053,22.33
steps_per_second,20.053,22.33
1 metric training validation
2 loss 1.8773103952407837 1.915901780128479
3 model_preparation_time 0.0051 0.0051
4 runtime 89.7635 8.9565
5 samples_per_second 20.053 22.33
6 steps_per_second 20.053 22.33

View File

@@ -1 +0,0 @@
/workspace/data/axolotl-artifacts

View File

@@ -3,7 +3,7 @@ CLI to run training on a model
""" """
import logging import logging
from pathlib import Path from pathlib import Path
from typing import Union from typing import Dict, Union
import fire import fire
from dotenv import load_dotenv from dotenv import load_dotenv
@@ -23,7 +23,7 @@ from axolotl.evaluate import evaluate
LOG = logging.getLogger("axolotl.cli.evaluate") LOG = logging.getLogger("axolotl.cli.evaluate")
def do_evaluate(cfg, cli_args) -> None: def do_evaluate(cfg, cli_args) -> Dict[str, float]:
# pylint: disable=duplicate-code # pylint: disable=duplicate-code
print_axolotl_text_art() print_axolotl_text_art()
check_accelerate_default_config() check_accelerate_default_config()
@@ -34,7 +34,7 @@ def do_evaluate(cfg, cli_args) -> None:
else: else:
dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args) dataset_meta = load_datasets(cfg=cfg, cli_args=cli_args)
evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta) return evaluate(cfg=cfg, cli_args=cli_args, dataset_meta=dataset_meta)
def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None: def do_cli(config: Union[Path, str] = Path("examples/"), **kwargs) -> None:

View File

@@ -15,7 +15,7 @@ from transformers import HfArgumentParser
from axolotl.cli import load_cfg, print_axolotl_text_art from axolotl.cli import load_cfg, print_axolotl_text_art
from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer from axolotl.common.cli import ConvertDiffTransformerCliArgs, load_model_and_tokenizer
from axolotl.integrations.differential_transformer.convert import ( from axolotl.integrations.differential_transformer.convert import (
convert_to_diff_attention, convert_to_differential_attention,
) )
LOG = logging.getLogger(__name__) LOG = logging.getLogger(__name__)
@@ -79,7 +79,7 @@ def convert_differential_transformer(cfg, cli_args, config_path):
# Convert attention # Convert attention
LOG.info("Converting to differential attention...") LOG.info("Converting to differential attention...")
try: try:
model = convert_to_diff_attention( model = convert_to_differential_attention(
model=model, model=model,
zero_init=cli_args.zero_init, zero_init=cli_args.zero_init,
sublayer_norm=cli_args.sublayer_norm, sublayer_norm=cli_args.sublayer_norm,
@@ -111,7 +111,10 @@ def convert_differential_transformer(cfg, cli_args, config_path):
data = yaml.safe_load(file) or {} data = yaml.safe_load(file) or {}
data["base_model"] = cfg.output_dir data["base_model"] = cfg.output_dir
data["diff_attention"] = True data["differential_attention"] = True
data["plugins"] = [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin"
]
with open(output_config_path, "w", encoding="utf-8") as file: with open(output_config_path, "w", encoding="utf-8") as file:
yaml.dump(data, file) yaml.dump(data, file)

View File

@@ -43,10 +43,12 @@ def merge_input_args():
input_args: List[str] = plugin_manager.get_input_args() input_args: List[str] = plugin_manager.get_input_args()
plugin_classes = [] plugin_classes = []
dynamic_input = "" dynamic_input = ""
for plugin_args in input_args: for plugin_args in input_args:
plugin_module, plugin_cls = plugin_args.rsplit(".", 1) plugin_module, plugin_cls = plugin_args.rsplit(".", 1)
dynamic_input += f"from {plugin_module} import {plugin_cls}\n" dynamic_input += f"from {plugin_module} import {plugin_cls}\n"
plugin_classes.append(plugin_cls) plugin_classes.append(plugin_cls)
if dynamic_input: if dynamic_input:
dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlConfigWCapabilities(AxolotlConfigWCapabilitiesBase, {', '.join(plugin_classes)}):\n pass\n"
dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n" dynamic_input += f"class AxolotlInputConfig(AxolotlInputConfigBase, {', '.join(plugin_classes)}):\n pass\n"
@@ -62,4 +64,5 @@ def merge_input_args():
"AxolotlConfigWCapabilities" "AxolotlConfigWCapabilities"
] ]
return AxolotlConfigWCapabilities, AxolotlInputConfig return AxolotlConfigWCapabilities, AxolotlInputConfig
return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase return AxolotlConfigWCapabilitiesBase, AxolotlInputConfigBase

View File

@@ -0,0 +1,10 @@
# Differential Transformer
### Usage
```yaml
plugins:
- axolotl.integrations.differential_transformer.DifferentialTransformerPlugin
differential_attention: true
```

View File

@@ -0,0 +1,25 @@
"""Definition of differential transformer plugin."""
import logging
from axolotl.integrations.base import BasePlugin
LOG = logging.getLogger(__name__)
class DifferentialTransformerPlugin(BasePlugin):
"""
Plugin for differential transformer integration with Axolotl.
"""
def get_input_args(self):
return "axolotl.integrations.differential_transformer.args.DifferentialTransformerArgs"
def pre_model_load(self, cfg):
"""Apply differential attention patch before model loading if enabled."""
if cfg.differential_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()

View File

@@ -0,0 +1,14 @@
"""Module for handling differential transfomer input arguments."""
import logging
from typing import Optional
from pydantic import BaseModel
LOG = logging.getLogger(__name__)
class DifferentialTransformerArgs(BaseModel):
"""Input args for differential transformer."""
differential_attention: Optional[bool] = None

View File

@@ -80,7 +80,7 @@ def copy_attention_weights(
) )
def convert_to_diff_attention( def convert_to_differential_attention(
model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True model: PreTrainedModel, zero_init: bool = False, sublayer_norm: bool = True
) -> PreTrainedModel: ) -> PreTrainedModel:
"""Convert a pre-trained model's attention layers to differential attention""" """Convert a pre-trained model's attention layers to differential attention"""

View File

@@ -727,8 +727,6 @@ class AxolotlInputConfig(
eager_attention: Optional[bool] = None eager_attention: Optional[bool] = None
diff_attention: Optional[bool] = None
unsloth_cross_entropy_loss: Optional[bool] = None unsloth_cross_entropy_loss: Optional[bool] = None
unsloth_lora_mlp: Optional[bool] = None unsloth_lora_mlp: Optional[bool] = None
unsloth_lora_qkv: Optional[bool] = None unsloth_lora_qkv: Optional[bool] = None

View File

@@ -444,13 +444,6 @@ class ModelLoader:
patch_mistral_cross_entropy() patch_mistral_cross_entropy()
if self.cfg.diff_attention:
from axolotl.monkeypatch.attention.differential import (
patch_llama_attention_classes,
)
patch_llama_attention_classes()
def patch_attention(self) -> None: def patch_attention(self) -> None:
if hasattr(self.model_config, "model_type"): if hasattr(self.model_config, "model_type"):
if self.model_config.model_type == "mllama" and self.cfg.flash_attention: if self.model_config.model_type == "mllama" and self.cfg.flash_attention:
@@ -721,7 +714,7 @@ class ModelLoader:
if not self.cfg.sample_packing and self.cfg.s2_attention: if not self.cfg.sample_packing and self.cfg.s2_attention:
pass pass
if self.cfg.diff_attention: if self.cfg.differential_attention:
self.model_kwargs[ self.model_kwargs[
"attn_implementation" "attn_implementation"
] = "differential_flash_attention_2" ] = "differential_flash_attention_2"
@@ -734,7 +727,7 @@ class ModelLoader:
"flash_attention_2" "flash_attention_2"
) )
elif self.cfg.sdp_attention: elif self.cfg.sdp_attention:
if self.cfg.diff_attention: if self.cfg.differential_attention:
self.model_kwargs["attn_implementation"] = "differential_sdpa" self.model_kwargs["attn_implementation"] = "differential_sdpa"
self.model_config._attn_implementation = ( # pylint: disable=protected-access self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_sdpa" "differential_sdpa"
@@ -745,7 +738,7 @@ class ModelLoader:
"sdpa" "sdpa"
) )
elif self.cfg.eager_attention: elif self.cfg.eager_attention:
if self.cfg.diff_attention: if self.cfg.differential_attention:
self.model_kwargs["attn_implementation"] = "differential_eager" self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager" "differential_eager"
@@ -755,7 +748,7 @@ class ModelLoader:
self.model_config._attn_implementation = ( # pylint: disable=protected-access self.model_config._attn_implementation = ( # pylint: disable=protected-access
"eager" "eager"
) )
elif self.cfg.diff_attention: elif self.cfg.differential_attention:
self.model_kwargs["attn_implementation"] = "differential_eager" self.model_kwargs["attn_implementation"] = "differential_eager"
self.model_config._attn_implementation = ( # pylint: disable=protected-access self.model_config._attn_implementation = ( # pylint: disable=protected-access
"differential_eager" "differential_eager"

View File

@@ -6,12 +6,14 @@ from typing import Optional
import pytest import pytest
import yaml import yaml
from pytest import approx
from axolotl.cli import load_cfg from axolotl.cli import load_cfg
from axolotl.cli.evaluate import do_evaluate
from axolotl.cli.integrations.convert_differential_transformer import ( from axolotl.cli.integrations.convert_differential_transformer import (
convert_differential_transformer, convert_differential_transformer,
) )
from axolotl.common.cli import ConvertDiffTransformerCliArgs from axolotl.common.cli import ConvertDiffTransformerCliArgs, EvaluateCliArgs
@pytest.fixture() @pytest.fixture()
@@ -19,9 +21,12 @@ def base_config():
"""Basic config for testing.""" """Basic config for testing."""
return { return {
"base_model": "HuggingFaceTB/SmolLM2-135M", "base_model": "HuggingFaceTB/SmolLM2-135M",
"plugins": [
"axolotl.integrations.differential_transformer.DifferentialTransformerPlugin",
],
"datasets": [ "datasets": [
{ {
"path": "mhenrichsen/alpaca_2k_test", "path": "axolotl-ai-co/alpaca_100_test",
"type": "alpaca", "type": "alpaca",
}, },
], ],
@@ -103,7 +108,9 @@ def test_conversion_cli_reproduce(tmp_path: Path, base_config):
assert (output_dir / "axolotl_config.yml").exists() assert (output_dir / "axolotl_config.yml").exists()
@pytest.mark.parametrize("attention", ["sdp_attention", "flash_attention"]) @pytest.mark.parametrize(
"attention", ["eager_attention", "sdp_attention", "flash_attention"]
)
def test_conversion_cli_repoduce_attentions( def test_conversion_cli_repoduce_attentions(
tmp_path: Path, base_config, attention: Optional[str] tmp_path: Path, base_config, attention: Optional[str]
): ):
@@ -125,3 +132,42 @@ def test_conversion_cli_repoduce_attentions(
assert (output_dir / "model.safetensors").exists() assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists() assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists() assert (output_dir / "axolotl_config.yml").exists()
def test_conversion_and_eval_cli(tmp_path: Path, base_config):
output_dir = tmp_path / "converted"
base_config["output_dir"] = str(output_dir)
config_path = tmp_path / "config.yml"
with open(config_path, "w", encoding="utf-8") as file:
yaml.dump(base_config, file)
cfg = load_cfg(str(config_path))
cli_args = ConvertDiffTransformerCliArgs(
debug=True, zero_init=True, sublayer_norm=False
)
_, debug_info = convert_differential_transformer(cfg, cli_args, str(config_path))
assert debug_info["generations_match"] is True
assert (output_dir / "model.safetensors").exists()
assert (output_dir / "config.json").exists()
assert (output_dir / "axolotl_config.yml").exists()
eval_cfg = load_cfg(str(output_dir))
eval_cli_args = EvaluateCliArgs()
all_metrics = do_evaluate(eval_cfg, eval_cli_args)
assert list(all_metrics.keys()) == [
"train_loss",
"train_model_preparation_time",
"train_runtime",
"train_samples_per_second",
"train_steps_per_second",
"eval_loss",
"eval_model_preparation_time",
"eval_runtime",
"eval_samples_per_second",
"eval_steps_per_second",
]
assert all_metrics["train_loss"] == approx(1.7307, rel=1e-4)
assert all_metrics["eval_loss"] == approx(1.8387, rel=1e-4)